[TFLite] 5(1). 텐서플로 라이트 모델을 이용한 안드로이드 앱 개발 (UI 구성과 TFLite 모델 로드)

4 minute read


UI 구성과 TFLite 모델 로드

텐서플로 라이트를 이용한 앱 개발 프로세스는 크게 다음의 3단계로 구성됩니다.

  • TFLite 모델 로드
  • 입력 이미지 전처리
  • 추론 결과 해석

이번 포스팅에서는 먼저 앱 개발 프로세스를 확인하기 위한 간단한 UI 구성을 살펴보고, TFLite 모델을 로드하는 방법에 대해 알아보겠습니다.

UI 구성


먼저 TFLite 모델을 활용할 기본적인 UI를 구성합니다.

여기서는 MNIST 데이터로 학습된 모델을 이용하여 손글씨 분류 앱을 만들 것입니다.

UI 구성

1. activity_main.xml

메인 액티비티입니다. DRAWVIEW 버튼을 클릭하면 손글씨를 분류할 액티비티로 전환됩니다.

image-20210815205257559

2. activity_draw.xml

손글씨를 분류할 액티비티입니다. 위 쪽에 DrawView를 배치하여 사용자가 손글씨를 입력하면 그 이미지를 이용해 추론을 합니다.

2개의 버튼과 1개의 텍스트 뷰를 배치하는데, 버튼은 각각 추론과 초기화를 수행하고 텍스트 뷰는 모델의 추론 결과를 출력합니다.

image-20210815205437837

UI에 대한 소개는 이정도만 하고 넘어갑니다. 안드로이드에서 DrawView를 사용하는 방법이 궁금하신 분은 [Android] 안드로이드에서 DrawView 위젯 사용하기 에서 확인해주세요.



TFLite 모델 로드


모델을 사용할 UI 구성이 완료되었으면 이제 본격적으로 TFLite 모델을 이용한 앱 개발을 시작합니다.

먼저 tflite 파일을 로드해야 합니다.

코드를 작성하기 전에 텐서플로 라이트 라이브러리 의존성을 추가하고, tflite 파일을 프로젝트에 추가해야 합니다.

TFLite 모델을 불러올 때는 먼저 ByteBuffer 클래스로 불러오고, 여기에 옵션을 추가하여 Interpreter 객체를 생성합니다. 최종적으로 모델을 사용할 수 있는 형태는 Interpreter 객체입니다.


1. 텐서플로 라이트 라이브러리 의존성 추가

모듈의 build.gradle 파일에 다음 의존성을 추가합니다.

...
dependencies {

    implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
    implementation 'androidx.core:core-ktx:1.6.0'
    implementation 'androidx.appcompat:appcompat:1.3.1'
    implementation 'com.google.android.material:material:1.4.0'
    implementation 'androidx.constraintlayout:constraintlayout:2.1.0'

    implementation 'com.github.divyanshub024:AndroidDraw:v0.1'
    // 텐서플로 라이트 라이브러리 의존성 추가
    // Interpreter, Tensor, DataType 등의 텐서플로 라이트 클래스 사용 가능
    implementation 'org.tensorflow:tensorflow-lite:2.4.0'

    testImplementation 'junit:junit:4.+'
    androidTestImplementation 'androidx.test.ext:junit:1.1.3'
    androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'
}


2. tflite 파일 추가

[app] 우클릭 - [New] - [Folder] - [Assets Folder] 에서 assets 폴더를 만들고 그 안에 tflite 파일을 추가합니다.

image-20210815210643968


3. Classifier 클래스 생성

모델과 관련된 작업을 담당할 Classifier 클래스를 생성합니다. Classifier는 모델 파일을 로드하고 이미지를 입력하면 추론하여 결과 값을 해석하는 일련의 동작을 모두 수행할 클래스입니다.

모델을 불러올 때 assets 폴더를 참조하는데, 이때 앱 컨텍스트가 필요합니다. Classifier의 생성자 파라미터로 Context를 추가하고, 이를 클래스 모든 곳에서 사용할 수 있도록 전역변수로 선언합니다.

// 1. 모델과 관련된 작업을 할 클래스 생성
// 모델 파일을 로드하고 이미지를 입력하면 추론하여 결과 값을 해석
class Classifier(context: Context) {
    /*전역변수 선언*/
    var context: Context = context
}


4. tflite 파일 로드

다음으로 assets 폴더에서 tflite 파일을 읽어오는 메서드를 구현합니다. 이 메서드는 tflite 파일명을 입력받아 ByteBuffer 클래스로 모델을 반환합니다.

    // 2. assets 폴더에서 tflite 파일을 읽어오는 함수 정의
    // tflite 파일명을 입력받아 ByteBuffer 클래스로 모델을 반환
    private fun loadModelFile(modelName: String): ByteBuffer? {

        // AssetManager는 assets 폴더에 저장된 리소스에 접근하기 위한 기능을 제공
        val am = this.context.assets // AssetManager
        // AssetManager.openRd(파일명): AssetFileDescriptor를 반환
        val afd: AssetFileDescriptor? = am.openFd(modelName)// modelName 에 해당하는 파일이 없을 경우 null
        if (afd == null) {
            throw IOException() // 자신을 호출한 쪽에서 예외처리 요구
            return null
        }
        // AssetFileDescriptor.fileDescriptor: 파일의 FileDescriptor 반환 -> 해당 파일의 읽기/쓰기 가능
        // FileInputStream의 생성자에 FileDescriptor를 해당 파일의 입력 스트림 반환
        val fis = FileInputStream(afd.fileDescriptor) // FileInputStream
        // fis.read()로 읽을 수도 있지만 성능을 위해 스트림의 FileChannel 이용
        val fc = fis.channel // FileChannel

        // 파일디스크립터 오프셋과 길이
        val startOffset = afd.startOffset // long
        val declaredLength = afd.declaredLength // long

        // FileChannel.map() 메서드로 ByteBuffer 클래스를 상속한 MappedByteBuffer 인스턴스 생성
        // 파라미터: 참조모드, 오프셋, 길이
        // 최종적으로 tflite 파일을 ByteBuffer 형으로 읽어오는데 성공!
        return fc.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)

    }

메서드를 호출할 때 modelName을 파라미터로 전달해야 하는데, 이 파라미터는 Classifier 클래스 밖에 상수로 선언합니다.

const val MODEL_NAME = "keras_model.tflite" // tflite 파일명

다른 모델을 사용한다면 이 상수값만 바꿔주면 됩니다.


그리고 마지막으로 init 초기화 블록 안에서 loadModelFile( ) 메서드를 호출합니다. init 블록 안에서 모델의 초기화와 관련된 코드들을 수행할 것입니다.

모델을 ByteBuffer? 타입으로 불러온 뒤, 모델의 byte order를 설정해야 합니다. 여기서는 시스템의 byte order와 동일하게 설정하도록 ByteOrder 오브젝트 클래스의 nativeOrder( ) 메서드를 사용합니다.

class Classifier(context: Context) {
    /*전역변수 선언*/
    var context: Context = context
  
		init{
        // 2(cont). 모델 초기화
        val model: ByteBuffer? = loadModelFile(MODEL_NAME) // ByteBuffer 인스턴스
        // 시스템의 byteOrder와 동일하게 동작
        // DrawActivity에서 Classifier 인스턴스를 생성할 때 예외처리
        model?.order(ByteOrder.nativeOrder())?:throw IOException()
    }
    ...
}

MODEL_NAME에 해당하는 tflite 파일이 없다면 model 변수에는 null 값이 저장됩니다. 따라서 ‘ ?. ‘(Safety Call)과 ‘ ?: ‘(Elvis Operator) 를 사용하여 null 값에 대한 처리를 해줍니다. 만약 model이 null이라면, 생성자를 호출한 쪽에서 예외 처리를 하도록 합니다.


5. Interpreter 생성

앞에서 TFLite 모델은 ByteBuffer 타입으로 저장한 후에 최종적으로 Interpreter 형으로 사용해야 한다고 했습니다.

init 초기화 블록에 Interpreter를 생성하는 코드를 추가합니다.

    init{
        val model: ByteBuffer? = loadModelFile(MODEL_NAME) 
        model?.order(ByteOrder.nativeOrder())?:throw IOException()

        // 3. Interpreter 생성
        // Interpreter는 모델에 데이터를 입력하고 추론 결과를 전달받을 수 있는 클래스
        interpreter = Interpreter(model)
    }



DrawActivity 코드 수정


Classifier 클래스에서 모델을 불러오는 데까지 성공했습니다. 입력 데이터를 만들고 추론을 수행하기 전에, DrawActivity의 코드를 먼저 수정합니다.

바인딩을 연결하고, 위젯들을 연결 및 설정합니다. 또한 Classifier 인스턴스 생성 시 예외처리 코드도 추가합니다.

class DrawActivity : AppCompatActivity() {

    // 바인딩 연결
    val binding by lazy {ActivityDrawBinding.inflate(layoutInflater)}
    // classifier 선언
    lateinit var classifier: Classifier

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(binding.root)

        // drawView 초기화
        var drawView = binding.drawView
        drawView.setStrokeWidth(100.0f)
        drawView.setBackgroundColor(Color.BLACK)
        drawView.setColor(Color.WHITE)

        // 결과를 보여줄 텍스트 뷰
        var resultView = binding.resultView

        // 버튼 클릭 리스너 설정
        binding.classifyBtn.setOnClickListener {
            var image = drawView.getBitmap()
        }
        binding.clearBtn.setOnClickListener {
            drawView.clearCanvas()
        }

        // Classifier 인스턴스를 생성하고 예외 처리까지
        // 이제 DrawActivity가 생성될 때 Classifier도 생성되고 초기화 됨
        try {
            classifier = Classifier(this)
        }catch(ioe: IOException){
            Log.d("DigitClassifier", "failed to init Classifier", ioe)
        }

    }



정리


이번 포스팅에서는 기본 UI 구성과 tflite 모델을 불러오는 방법에 대해 알아보았습니다.

모델을 로드하는 과정은 먼저 모델 로드 메서드를 이용해 모델을 ByteBuffer? 타입으로 불러온 뒤, 그 모델을 Interpreter 클래스의 생성자 파라미터로 전달하여 최종적으로 우리가 조작할 수 있는 형태인 Interpreter 타입으로 생성합니다.

다음 포스팅에서는 입력 이미지 전처리 과정에 대해 살펴보겠습니다.

Categories:

Updated:

Leave a comment