[TFLite] 6(1). 프레임 워크를 활용한 이미지 분류 앱 개발 (텐서플로 라이트 서포트 라이브러리)

9 minute read


텐서플로 라이트 서포트 라이브러리

텐서플로 라이트 서포트 라이브러리는 안드로이드에서 텐서플로 라이트를 편리하게 이용할 수 있도록 다양한 기능을 제공하는 라이브러리입니다.

텐서플로 라이트 서포트 라이브러리를 이용하면 지난 세 번의 [텐서플로 라이트를 이용한 안드로이드 앱 개발] 포스팅에서의 과정을 보다 편리하게 구현할 수 있습니다.


텐서플로 라이트 서포트 라이브러리의 구성


활발하게 개발이 진행 중인 텐서플로 라이트 서포트 라이브러리는 5개의 패키지로 이루어져 있습니다.


텐서플로 라이트 서포트 라이브러리의 패키지

  • common: 파일 읽기, 공통 인터페이스 등 공통 기능
  • image: 이미지 변환, 전처리, 저장 등
  • label: 레이블 관리, 매핑 등
  • model: 모델 객체화, GPU 위임 등
  • tensorbuffer: ByteBuffer 관리

이제 텐서플로 라이트 서포트 라이브러리를 이용하여 앱을 개발하는 과정을 실제 코드와 함께 살펴보겠습니다.

기기에서 이미지를 불러와 모델로 추론하는 앱을 개발해보겠습니다.

텐서플로 라이트 서포트 라이브러리의 최신 버전은 아래 주소에서 확인할 수 있습니다.

https://mvnrepository.com/artifact/org.tensorflow/tensorflow-lite-support



1. 프로젝트 생성 및 의존성 추가


텐서플로 라이트 서포트 라이브러리는 텐서플로 라이트 라이브러리에 포함되어 있지 않기 때문에 안드로이드 프로젝트에 별도로 의존성을 추가해야 합니다.

dependencies {

    implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
    implementation 'androidx.core:core-ktx:1.6.0'
    implementation 'androidx.appcompat:appcompat:1.3.1'
    implementation 'com.google.android.material:material:1.4.0'
    implementation 'androidx.constraintlayout:constraintlayout:2.1.0'

    // 1. 텐서플로 라이트와 텐서플로 라이트 서포트 라이브러리 의존성 추가
    implementation 'org.tensorflow:tensorflow-lite:2.4.0'
    implementation 'org.tensorflow:tensorflow-lite-support:0.1.0'

    testImplementation 'junit:junit:4.+'
    androidTestImplementation 'androidx.test.ext:junit:1.1.3'
    androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'
}



2. 모델 로드


이번에 만들 클래스 명은 ClassifierWithSupport로 하겠습니다.

/*텐서플로 라이트 서포트 라이브러리를 이용한 Classifier 클래스 구현*/
class ClassifierWithSuppport(context: Context) {
  
}


1. tflite 파일 추가

모델이 저장된 tflite 파일을 프로젝트의 assets 폴더에 추가합니다.

이번 프로젝트에서는 ImageNet 데이터로 학습된 케라스 애플리케이션의 MobileNetV2 모델을 사용하겠습니다.

import tensorflow as tf

mobilenet_imagenet_model = tf.keras.applications.MobileNetV2(weights="imagenet")

converter = tf.lite.TFLiteConverter.from_keras_model(mobilenet_imagenet_model)
tflite_model = converter.convert()

with open('./mobilenet_imagenet_model.tflite', 'wb') as f:
    f.write(tflite_model)

image-20210817114350329


2. tflite 파일 로드

우선 모델 파일명을 상수로 선언합니다. 클래스 내에서 상수를 선언할 때는 companion object 블록을 사용합니다.

다음으로 init 블록 내에서 모델을 ByteBuffer 형으로 불러온 다음 Interpreter를 생성합니다. 텐서플로 라이트 서포트 라이브러리를 사용하면 FileUtil.loadMappedFile 메서드를 이용하여 간단히 모델 파일을 불러와 ByteBuffer를 생성할 수 있습니다.

그리고 생성한 interpreter 인스턴스는 클래스의 프로퍼티로 선언해줍니다.

    /* 상수 선언 */
    companion object{
        // 2-1. 모델 로드: tflite 모델을 assets 디렉터리에 추가
        // 2-2. 모델 로드: 모델 파일명을 상수로 선언
        private const val MODEL_NAME = "mobilenet_imagenet_model.tflite"
    }
    /* 프로퍼티 선언 */
    var context: Context = context
    // 2-4. 모델 로드: interpreter 프로퍼티 선언
    var interpreter: Interpreter


    init{
        // 2-3. 모델 로드: tflite 파일 로드
        val model: ByteBuffer? = FileUtil.loadMappedFile(context, MODEL_NAME)
        model?.order(ByteOrder.nativeOrder())?:throw IOException()
        interpreter = Interpreter(model)
    }



3. 입력 이미지 전처리


1. 입력 이미지 관련 프로퍼티 선언

입력 이미지 관련 프로퍼티에는 다음과 같은 것들이 있습니다.

    /* 프로퍼티 선언 */
    var context: Context = context
    var interpreter: Interpreter
    // 3-1. 입력 이미지 전처리: 모델의 입력 이미지를 저장할 프로퍼티 선언
    lateinit var inputImage: TensorImage
    // 3-2. 입력 이미지 전처리: 모델의 입력 형상 프로퍼티 선언
    var modelInputChannel: Int = 0
    var modelInputWidth: Int = 0
    var modelInputHeight: Int = 0


2. 입력 이미지 데이터를 얻어오는 메서드 구현

입력 이미지 전처리 메서드를 구현합니다. 텐서플로 라이트 서포트 라이브러리는 모델에 입력할 이미지를 담을 수 있는 TensorImage 클래스를 제공합니다. 이를 이용하면 Bitmap 포맷의 이미지를 모델에 바로 입력되는 ByteBuffer 포맷으로 변환할 수 있습니다.

또한 이미지 크기, 데이터 타입 등을 쉽게 얻을 수 있으며, 이미 구현된 다양한 이미지 처리 알고리즘을 ImageProcessor, ImageOperator 클래스로 간단히 적용할 수 있습니다.

이 입력 이미지 데이터를 얻어오는 메서드는 클래스의 init 블록 내에서 호출됩니다.

    init{
        val model: ByteBuffer? = FileUtil.loadMappedFile(context, MODEL_NAME)
        model?.order(ByteOrder.nativeOrder())?:throw IOException()
        interpreter = Interpreter(model)
        // 3-4. 입력 이미지 전처리: 메서드 호출
        initModelShape()
    }

    // 3-3. 입력 이미지 전처리: 메서드 정의
    // 모델의 입력 형상과 데이터 타입을 프로퍼티에 저장
    private fun initModelShape(){
        val inputTensor = interpreter.getInputTensor(0)
        val shape = inputTensor.shape()
        modelInputChannel = shape[0]
        modelInputWidth = shape[1]
        modelInputHeight = shape[2]
        // 모델의 입력값을 저장할 TensorImage 생성
        inputImage = TensorImage(inputTensor.dataType())
    }


3. 입력 이미지를 전처리하는 메서드 구현

다음으로 앞에서 얻은 입력 이미지의 데이터들을 이용해 입력 이미지를 전처리하는 메서드를 구현합니다.

이 입력 이미지 전처리 메서드는 모델 추론 시(classify 메서드, 뒤에서 구현) 호출됩니다.

    // 3-5. 입력 이미지 전처리: TensorImage에 bitmap 이미지 입력 및 이미지 전처리 로직 정의
    // Bitmap 이미지를 입력 받아 전처리하고 이를 TensorImage 형태로 반환
    private fun loadImage(bitmap: Bitmap?): TensorImage{
        // TensorImage에 이미지 데이터 저장
        inputImage.load(bitmap)
        // 전처리 ImageProcessor 정의
        val imageProcessor =
            ImageProcessor.Builder()                            // Builder 생성
                .add(ResizeOp(modelInputWidth,modelInputHeight, // 이미지 크기 변환
                     ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
                .add(NormalizeOp(0.0f, 255.0f))                 // 이미지 정규화
                .build()                                        // ImageProcessor 생성
        // 이미지를 전처리하여 TensorImage 형태로 반환
        return imageProcessor.process(inputImage)
    }



4. 추론


모델 로드, 입력 이미지 전처리까지 완료했으면 모델이 추론을 할 차례입니다.


1. initModelShape( ) 메서드에 모델의 출력 값을 저장할 TensorBuffer를 생성하는 코드 추가

모델이 추론을 완료하면 그 추론 결괏값을 TensorBuffer에 담아야 합니다.

initModelShape 메서드에 해당 코드를 추가하고, 출력 값을 저장할 TensorBuffer는 프로퍼티로 선언합니다.

    /* 프로퍼티 선언 */
    var context: Context = context

    var interpreter: Interpreter

    lateinit var inputImage: TensorImage
    var modelInputChannel: Int = 0
    var modelInputWidth: Int = 0
    var modelInputHeight: Int = 0
    // 4-1. 추론: 모델의 추론된 출력 값을 저장할 프로퍼티 선언
    lateinit var outputBuffer: TensorBuffer

    ...

    private fun initModelShape(){
        val inputTensor = interpreter.getInputTensor(0)
        val shape = inputTensor.shape()
        modelInputChannel = shape[0]
        modelInputWidth = shape[1]
        modelInputHeight = shape[2]

        inputImage = TensorImage(inputTensor.dataType())

        // 4-2. 추론: 모델의 출력값을 저장할 TensorBuffer 생성
        val outputTensor = interpreter.getOutputTensor(0)
        outputBuffer = TensorBuffer.createFixedSize(outputTensor.shape(), outputTensor.dataType())

    }


2. 추론 메서드 정의

추론을 수행할 classify 메서드를 정의합니다. Interpreter.run() 메서드를 사용합니다.

    // 4-3. 추론: 추론 메서드 정의
    fun classify(image: Bitmap?): Pair<String, Float>{
        inputImage = loadImage(image)
        interpreter.run(inputImage.buffer, outputBuffer.buffer.rewind())

    }

반환문은 뒤의 ‘추론 결과 해석’ 부분에서 추가합니다.



5. 추론 결과 해석


추론까지 수행했으면 이제 TensorBuffer에 담겨있는 결괏값을 해석해야 합니다.

앞선 포스팅에서의 MNIST 데이터셋은 운이 좋게도 0~9의 인덱스 값이 각 숫자에 대응했지만, 대부분의 경우에는 인덱스 값을 각 클래스(레이블)에 매핑하는 과정이 필요합니다.

그러려면 먼저 레이블을 모두 포함하는 텍스트 파일을 프로젝트의 assets 폴더에 추가해야 합니다.


1. 레이블 파일 추가

ImageNet의 레이블 목록은 아래 주소에서 볼 수 있습니다.

https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt

위 레이블 목록을 텍스트 파일로 만들어, 프로젝트의 assets 폴더에 추가합니다.

image-20210817121812604


2. 레이블 파일 로드

레이블 파일을 추가했으니 이제 로드합니다.

과정은 tflite 파일을 로드할 때와 마찬가지로, 라벨 파일명을 선언하고 이 라벨들을 저장할 List< String > 타입의 리스트를 프로퍼티로 선언해준 뒤 init 블록 내에서 파일을 로드하면 됩니다.

    /* 상수 선언 */
    companion object{
        private const val MODEL_NAME = "mobilenet_imagenet_model.tflite"
        // 5-1. 추론 결과 해석: 분류 클래스 라벨을 포함하는 txt 파일을 assets 디렉터리에 추가
        // 5-2. 추론 결과 해석: 라벨 파일명을 상수로 선언
        private const val LABEL_FILE = "imagenet_labels.txt"
    }
    /* 프로퍼티 선언 */
    var context: Context = context

    var interpreter: Interpreter

    lateinit var inputImage: TensorImage
    var modelInputChannel: Int = 0
    var modelInputWidth: Int = 0
    var modelInputHeight: Int = 0

    lateinit var outputBuffer: TensorBuffer
    // 5-3. 추론 결과 해석: 라벨 목록을 저장하는 프로퍼티 선언
    private lateinit var labels: List<String>


    init{
        val model: ByteBuffer? = FileUtil.loadMappedFile(context, MODEL_NAME)
        model?.order(ByteOrder.nativeOrder())?:throw IOException()
        interpreter = Interpreter(model)

        initModelShape()
        // 5-4. 추론 결과 해석: 라벨 파일 로드
        labels = FileUtil.loadLabels(context, LABEL_FILE)
    }


3. 추론 결과 매핑

이제 앞에서 구현한 classify 메서드 안에서 모델의 추론 결괏값을 레이블에 매핑하여 반환하는 코드를 추가합니다.

이 때 상위 몇 개의 클래스를 반환할 지는 개발자가 정하면 됩니다. 여기서는 상위 1개의 클래스를 반환하도록 argmax 메서드를 구현하겠습니다.

    fun classify(image: Bitmap?): Pair<String, Float>{
        inputImage = loadImage(image)
        interpreter.run(inputImage.buffer, outputBuffer.buffer.rewind())

        // 5-5. 추론 결과 해석: 모델 출력값을 라벨에 매핑하여 반환
        val output = TensorLabel(labels, outputBuffer).getMapWithFloatValue() // Map<String, Float>

        return argmax(output)

    }

    // 5-6. 추론 결과 해석: Map에서 확률이 가장 높은 클래스명과 확률 쌍을 찾아서 반환하는 메서드 정의
    private fun argmax(map: Map<String, Float>): Pair<String, Float>{
        var maxKey = ""
        var maxVal = -1.0f

        for(entry in map.entries){
            var f = entry.value
            if(f > maxVal){
                maxKey = entry.key
                maxVal = f
            }
        }

        return Pair(maxKey, maxVal)
    }



6. 자원 해제


모델의 사용이 끝났으면 Interpreter 인스턴스에 할당된 자원을 해제해주어야 합니다.

    // 6. 자원 해제: 자원 해제 메서드 정의
    fun finish(){
        if(interpreter != null)
            interpreter.close()
    }



Model 클래스


지금까지 텐서플로 라이트 서포트 라이브러리를 이용하여 코드를 간략하게 수정하였습니다. 이제 마지막으로 Interpreter를 비롯해 딥러닝 모델 전체를 객체화한 Model 클래스를 사용한 코드를 이전 코드와 비교하여 보겠습니다.

Model 클래스는 텐서플로 라이트 서포트 라이브러리의 model 패키지에 구현된 클래스입니다. tflite 파일 로드, Interpreter를 이용한 추론 등 딥러닝 모델이 직접 수행하는 동작을 한 데 모아 객체화한 Model 클래스를 이용하면 앞서 구현한 모델 활용 코드를 더욱 간소화할 수 있습니다.

아래 코드는 전체 코드입니다.

/*텐서플로 라이트 서포트 라이브러리의 model.Model 클래스를 이용한 Classifier 클래스 구현*/
class ClassifierWithModel(context: Context) {
    /* 상수 선언 */
    companion object{
        // 2-1. 모델 로드: tflite 모델을 assets 디렉터리에 추가
        // 2-2. 모델 로드: 모델 파일명을 상수로 선언
        private const val MODEL_NAME = "mobilenet_imagenet_model.tflite"
        // 5-1. 추론 결과 해석: 분류 클래스 라벨을 포함하는 txt 파일을 assets 디렉터리에 추가
        // 5-2. 추론 결과 해석: 라벨 파일명을 상수로 선언
        private const val LABEL_FILE = "imagenet_labels.txt"
    }
    /* 프로퍼티 선언 */
    var context: Context = context

    // ===========================================================================================
    // 2-4. 모델 로드: interpreter 프로퍼티 선언
    // lateinit var interpreter: Interpreter
    // Model 클래스 사용 시 Interpreter를 직접 생성할 필요가 없음
    var model: Model
    // ============================================================================================

    // 3-1. 입력 이미지 전처리: 모델의 입력 이미지를 저장할 프로퍼티 선언
    lateinit var inputImage: TensorImage
    // 3-2. 입력 이미지 전처리: 모델의 입력 형상 프로퍼티 선언
    var modelInputChannel: Int = 0
    var modelInputWidth: Int = 0
    var modelInputHeight: Int = 0
    // 4-1. 추론: 모델의 추론된 출력 값을 저장할 프로퍼티 선언
    lateinit var outputBuffer: TensorBuffer
    // 5-3. 추론 결과 해석: 라벨 목록을 저장하는 프로퍼티 선언
    private lateinit var labels: List<String>


    init{
        // ========================================================================================
        // 2-3. 모델 로드: tflite 파일 로드
        // val model: ByteBuffer? = FileUtil.loadMappedFile(context, MODEL_NAME)
        // model?.order(ByteOrder.nativeOrder())?:throw IOException()
        // interpreter = Interpreter(model)
        // Model 클래스가 tflite 파일 로드부터 추론까지 모두 수행
        model = Model.createModel(context, MODEL_NAME)
        // ========================================================================================

        // 3-4. 입력 이미지 전처리: 메서드 호출
        initModelShape()
        // 5-4. 추론 결과 해석: 라벨 파일 로드
        labels = FileUtil.loadLabels(context, LABEL_FILE)
    }

    // 3-3. 입력 이미지 전처리: 메서드 정의
    // 모델의 입력 형상과 데이터 타입을 프로퍼티에 저장
    private fun initModelShape(){
        // ========================================================================================
        // val inputTensor = interpreter.getInputTensor(0)
        val inputTensor = model.getInputTensor(0)
        // ========================================================================================
        val shape = inputTensor.shape()
        modelInputChannel = shape[0]
        modelInputWidth = shape[1]
        modelInputHeight = shape[2]
        // 모델의 입력값을 저장할 TensorImage 생성
        inputImage = TensorImage(inputTensor.dataType())

        // 4-2. 추론: 모델의 출력값을 저장할 TensorBuffer 생성
        // ========================================================================================
        // val outputTensor = interpreter.getOutputTensor(0)
        val outputTensor = model.getOutputTensor(0)
        // ========================================================================================
        outputBuffer = TensorBuffer.createFixedSize(outputTensor.shape(), outputTensor.dataType())

    }

    // 3-4. 입력 이미지 전처리: TensorImage에 bitmap 이미지 입력 및 이미지 전처리 로직 정의
    // Bitmap 이미지를 입력 받아 전처리하고 이를 TensorImage 형태로 반환
    private fun loadImage(bitmap: Bitmap?): TensorImage{
        // TensorImage에 이미지 데이터 저장
        inputImage.load(bitmap)
        // 전처리 ImageProcessor 정의
        val imageProcessor =
            ImageProcessor.Builder()                            // Builder 생성
                .add(ResizeOp(modelInputWidth,modelInputHeight, // 이미지 크기 변환
                    ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
                .add(NormalizeOp(0.0f, 255.0f))    // 이미지 정규화
                .build()                                       // ImageProcessor 생성
        // 이미지를 전처리하여 TensorImage 형태로 반환
        return imageProcessor.process(inputImage)
    }

    // 4-3. 추론: 추론 메서드 정의
    fun classify(image: Bitmap?): Pair<String, Float>{
        inputImage = loadImage(image)
        // ========================================================================================
        // interpreter.run(inputImage.buffer, outputBuffer.buffer.rewind())
        // Model 클래스의 파라미터는 각각 Object의 배열과 Object의 Map을 요구
        val inputs = arrayOf<Object>(inputImage.buffer as Object)
        val outputs = mutableMapOf<Int, Object>()
        outputs.put(0, outputBuffer.buffer.rewind() as Object)
        model.run(inputs, outputs as @NonNull Map<Int, Any>)
        // ========================================================================================

        // 5-5. 추론 결과 해석: 모델 출력값을 라벨에 매핑하여 반환
        val output = TensorLabel(labels, outputBuffer).getMapWithFloatValue() // Map<String, Float>

        return argmax(output)

    }

    // 5-6. 추론 결과 해석: Map에서 확률이 가장 높은 클래스명과 확률 쌍을 찾아서 반환하는 메서드 정의
    private fun argmax(map: Map<String, Float>): Pair<String, Float>{
        var maxKey = ""
        var maxVal = -1.0f

        for(entry in map.entries){
            var f = entry.value
            if(f > maxVal){
                maxKey = entry.key
                maxVal = f
            }
        }

        return Pair(maxKey, maxVal)
    }

    // 6. 자원 해제: 자원 해제 메서드 정의
    fun finish(){
        // ========================================================================================
        // if(interpreter != null)
        //     interpreter.close()
        if(model != null)
            model.close()
        // ========================================================================================
    }

}



정리


  • 텐서플로 라이트를 이용한 앱 개발 워크 플로는 tflite 파일 로드 ➡ Model(Interpreter) 생성 ➡ 입력 이미지 전처리 ➡ 추론 ➡ 추론 결과 해석 ➡ 자원 해제 의 순으로 진행됩니다.
  • 텐서플로 라이트 서포트 라이브러리를 이용하면 간략한 코드 작성이 가능해집니다.
  • 텐서플로 라이트 서포트 라이브러리의 model.Model 클래스를 이용하면 tflite 파일을 로드하고 Interpreter를 생성하는 과정을 더 간략하게 줄일 수 있습니다.

Categories:

Updated:

Leave a comment