[TFLite] 5(3). 텐서플로 라이트 모델을 이용한 안드로이드 앱 개발 (추론 및 결과 해석)

5 minute read


추론 및 결과 해석

지난 두 번의 포스팅에서 TFLite 모델 로드 과정과 입력 이미지 전처리 과정을 알아보았습니다.

이번에는 마지막으로 모델의 추론 및 결과 해석 과정에 대해 살펴보고, 결과 화면까지 확인해보겠습니다.


1. 추론


입력 이미지 변환을 완료했다면 모델에 데이터를 입력하여 추론하고 그 결과를 해석합니다.

추론에는 Interpreter의 run( ) 메서드를 이용합니다.

run( ) 메서드의 파라미터

  • 파라미터 1: 추론에 사용할 입력 데이터
  • 파라미터 2: 추론 결과를 담을 버퍼

입력 데이터는 drawView에서 Bitmap 형태로 받아온 이미지에 앞선 포스팅에서 본 resizeBitmap( ) 메서드와 convertBitmapToGrayByteBuffer( ) 메서드를 적용하여 ByteBuffer 타입으로 변한된 이미지 데이터를 사용합니다.

출력 데이터를 받기 위해 버퍼를 만들어 전달해야 하는데, 버퍼를 생성하려면 먼저 출력 텐서의 형태를 알아야 합니다. 우리가 만들고 있는 손글씨 분류 앱의 경우 입력 이미지를 0~9의 10개 클래스로 분류하기 때문에 출력 데이터에는 10개의 출력 클래스가 있습니다.


1. 모델 출력 클래스 수를 담을 멤버 변수 선언

    /*전역변수 선언*/
    var context: Context = context

    var interpreter: Interpreter

    var modelInputWidth: Int = 0
    var modelInputHeight: Int = 0
    var modelInputChannel: Int = 0

    // 5. 추론 및 결과 해석 - 1. 추론
    var modelOutputClasses: Int = 0


2. 모델 출력 클래스 수 계산

입력 이미지 형상을 계산할 때 정의했던 initModelShape( ) 메서드에 출력 텐서의 형태를 이용해 출력 클래스 수를 가져오는 코드를 추가합니다.

    private fun initModelShape(){

        val inputTensor: Tensor = interpreter.getInputTensor(0)
        val inputShape = inputTensor.shape()
        modelInputChannel = inputShape[0]
        modelInputWidth = inputShape[1]
        modelInputHeight = inputShape[2]

        // 5. 추론 및 결과 해석 - 1. 추론(cont.)
        // 출력 텐서의 형태를 이용하여 출력 클래스 수 가져오기
        val outputTensor = interpreter.getOutputTensor(0)
        val outputShape = outputTensor.shape()
        modelOutputClasses = outputShape[1]
    }


3. 손글씨 분류 모델의 추론

입력 이미지와 출력 이미지를 담을 버퍼를 Interpreter 인스턴스의 run( ) 메서드에 전달하여 모델의 추론을 시작합니다.

    // 5. 추론 및 결과 해석 - 1. 추론(cont.)
    // 출력 클래스 수를 이용하여 출력 값을 담을 배열을 생성하고 interpreter의 run() 메서드에 전달하여 추론을 수행
    public fun classify(image: Bitmap){
        // 전처리된 입력 이미지
        val buffer = convertBitmapToGrayByteBuffer(resizeBitmap(image))
        // 추론 결과를 담을 이차원 배열
        val result = Array(1) { FloatArray(modelOutputClasses) { 0.0f } }
        // 추론 수행
        interpreter.run(buffer, result)

    }

아직 반환 값이 없는데, 이는 바로 이어서 추론 결과 해석 파트에서 추가합니다.



2. 추론 결과 해석


추론 결과를 얻었으면 그 결과를 해석해야 합니다. 추론 결과는 분류 가능한 클래스의 개수만큼(여기서는 0~9의 10개) 전달됩니다. 각 배열 값에는 해당 클래스에 속할 확률이 들어있습니다.

이 중 하나의 예측 클래스를 반환하기 위해 확률이 가장 높은 클래스를 찾아내는 로직을 구현합니다.


4. 추론 결과 해석

확률이 가장 높은 클래스를 찾아내는 로직을 구현합니다.

    // 5. 추론 및 결과 해석 - 2. 추론 결과 해석
    // 추론 결과값을 확인하여 확률이 가장 높은 클래스를 반환
    private fun argmax(array: FloatArray): Pair<Int, Float>{
        var argmax: Int = 0
        var max: Float = array[0]

        for(i in 1 until array.size){
            val f = array[i]
            if(f > max){
                argmax = i
                max = f
            }
        }

        return Pair(argmax, max)
    }

확률 값이 들어있는 FloatArray를 argmax( ) 메서드의 파라미터로 전달하여 호출하면 가장 높은 확률의 클래스의 인덱스와 확률 값을 반환합니다.

우리는 0~9의 숫자를 분류하기 때문에 인덱스 값이 곧 숫자 값입니다.


5. classify 메서드 수정

앞에서 구현했던 classify 메서드에서 추론 결과를 해석하여 가장 확률이 높은 클래스의 인덱스와 확률 값을 반환하도록 코드를 수정합니다.

    // 5. 추론 및 결과 해석 - 2. 추론 결과 해석(cont.)
    // 추론 결과에서 확률이 가장 높은 클래스와 그 확률을 반환
    public fun classify(image: Bitmap): Pair<Int, Float>{

        val buffer = convertBitmapToGrayByteBuffer(resizeBitmap(image))
        val result = Array(1) { FloatArray(modelOutputClasses) { 0.0f } }
        interpreter.run(buffer, result)

        // 5. 추론 및 결과 해석 - 2. 추론 결과 해석(cont.)
        // 확률이 가장 높은 클래스와 확률을 반환
        return argmax(result[0])

    }


3. 자원 해제


이제 모델이 추론하기 위한 모든 코드 작성을 마쳤습니다.

여기서 주의할 것은, Interpreter는 리소스를 가지고 있기 때문에 DrawActiivty 액티비티(DrawView를 포함하는 액티비티) 종료 시 자원 해제 메서드를 호출하여 Interpreter의 자원을 해제할 수 있도록 하는 메서드를 정의해야 합니다.

6. Interpreter 자원 해제

    // 6. 자원 해제
    // interpreter 자원 정리
    public fun finish(){
        if(interpreter != null)
            interpreter.close()
    }


DrawActivity 코드 수정하기


Classify 클래스에서의 코드 작성은 끝마쳤습니다.

이제 마지막으로 DrawActivity.kt 소스파일에서 [Classify] 버튼을 누르면 모델이 추론한 결과값을 텍스트 뷰에 출력하고, 액티비티 종료 시 모델의 자원 해제 메서드를 호출하도록 하는 코드를 추가합니다.

class DrawActivity : AppCompatActivity() {

    val binding by lazy {ActivityDrawBinding.inflate(layoutInflater)}

    lateinit var classifier: Classifier

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(binding.root)

        var drawView = binding.drawView
        drawView.setStrokeWidth(100.0f)
        drawView.setBackgroundColor(Color.BLACK)
        drawView.setColor(Color.WHITE)

        var resultView = binding.resultView

        // 버튼 클릭 리스너 설정
        binding.classifyBtn.setOnClickListener {
            var image = drawView.getBitmap()

            // 추론 메서드를 호출하고 결과를 전달받아 resultView에 출력
            val res = classifier.classify(image)
            val outStr = String.format(Locale.ENGLISH, "%d: %.0f%%",
                res.first, res.second*100.0f)
            resultView.text = outStr
        }
        binding.clearBtn.setOnClickListener {
            drawView.clearCanvas()
        }

        try {
            classifier = Classifier(this)
        }catch(ioe: IOException){
            Log.d("DigitClassifier", "failed to init Classifier", ioe)
        }

    }
  
    // 액티비티 종료 시 호출되는 onDestroy 메서드 오버라이드
    override fun onDestroy() {
        // Classifier의 finish() 메서드를 호출하여 액티비티 종료 시 자원 해제제
        classifier.finish()

        super.onDestroy()
    }
}



출력 결과


이제 모든 코드 작성을 마쳤습니다!

에뮬레이터를 실행하여 결과를 보도록 하겠습니다.

image-20210815234310330


[부록] 개선된 모델 적용


지금 사용하는 모델은 아주 단순한 다층 퍼셉트론 구조의 신경망이기 때문에, 추론 결과가 좋지 않을 수 있습니다. 이를테면, 아래와 같이 말이죠.

image-20210815234513165

더 좋은 모델을 개발하여 새로운 모델을 사용하고 싶다면, 간단히 프로젝트의 assets 폴더에 tflite 파일을 추가하고, Classify 클래스의 MODEL_NAME에 지정되어 있는 파일명을 바꾸기만 하면 됩니다.

단순한 이층 퍼셉트론 신경망 모델을 합성곱 신경망 기반의 모델로 바꿔봅시다.


모델 개발

import tensorflow as tf

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train, x_test = x_train / 255.0, x_test / 255.0

x_train_4d = x_train.reshape(-1, 28, 28, 1)
x_test_4d = x_test.reshape(-1, 28, 28, 1)

cnn_model = tf.keras.models.Sequential([tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28, 28, 1)), 
                                        tf.keras.layers.MaxPooling2D((2,2)), 
                                        tf.keras.layers.Conv2D(64, (3,3), activation='relu'), 
                                        tf.keras.layers.MaxPooling2D((2,2)), 
                                        tf.keras.layers.Conv2D(64, (3,3), activation='relu'), 
                                        tf.keras.layers.Flatten(), 
                                        tf.keras.layers.Dense(64, activation='relu'), 
                                        tf.keras.layers.Dense(10, activation='softmax')])

cnn_model.compile(optimizer='adam', 
                  loss = 'sparse_categorical_crossentropy', 
                  metrics=['accuracy'])
cnn_model.fit(x_train_4d, y_train, epochs=5)

converter = tf.lite.TFLiteConverter.from_keras_model(cnn_model)
tflite_model = converter.convert()

with open('./keras_model_cnn.tflite', 'wb') as f:
    f.write(tflite_model)


assets 폴더에 tflite 파일 추가

image-20210815234842274


MODEL_NAME 변경

// const val MODEL_NAME = "keras_model.tflite" // tflite 파일명
const val MODEL_NAME = "keras_model_cnn.tflite"



정리


이상으로 3회의 포스팅에 걸쳐 텐서플로 라이트 모델을 이용하여 손글씨 분류 앱을 개발해보았습니다.

모델 추론 및 추론 해석 과정은 다음과 같습니다.

  • 모델 출력 클래스 수 계산
  • Interpreter.run( ) 메서드에 전처리된 입력 이미지와 출력 결과를 담을 버퍼를 전달
  • 추론 결과 해석 메서드 구현
  • 해석된 추론 결과 반환

마지막으로 액티비티 종료 시 Interpreter에 할당된 자원을 해제해주는 메서드를 구현하는 것도 잊지 말아야 합니다.


다음 포스팅에서는 지금까지 구현한 기능을 보다 편리하게 이용할 수 있도록 다양한 기능을 제공하는 텐서플로 라이트 서포트 라이브러리에 대해 알아보겠습니다.

Categories:

Updated:

Leave a comment