SW 개발

[Android + Keras] .h5파일을 .pb와 .tflite 파일로 변환하기(Tensorflow Lite)

minkyung 2020. 5. 6. 16:13

 

텐서플로우 라이트, 텐서플로2.0도 Support

 

TensorFlow 2.0을 사용하면 ML 응용 프로그램을 훨씬 쉽게 개발할 수 있습니다. Keras를 TensorFlow에 긴밀하게 통합하고 기본적으로 열악한 실행 및 Pythonic 함수 실행을 통해 TensorFlow 2.0은 Python 개발자에게 친숙한 응용 프로그램 개발 경험을 제공합니다. ML의 경계를 넓히는 연구원을 위해 우리는 TensorFlow의 저수준 API에 많은 투자를했습니다. 이제 내부적으로 사용되는 모든 op를 내보내고 변수 및 검사 점과 같은 중요한 개념에 대한 상속 가능한 인터페이스를 제공합니다. 이를 통해 TensorFlow를 다시 빌드하지 않고도 TensorFlow의 내부를 구축 할 수 있습니다.

 

 

 

텐서플로 라이트 공식 홈페이지에 들어가면 머신러닝 모델을 모바일 기기 및 IoT기기에 배포하는 가이드를 제공하고 있습니다. 또, 이미지 분류, 객체 감지, 스마트 답장 등 머신러닝 모바일앱 예제를 깃허브를 통해 제공하고 있습니다.




 

 

텐서플로우 모델을 안드로이드에 탑재해보자

 

공식 문서

 

1. 훈련을 마친 .h5 확장자 모델을 .pb파일로 변환

 

1
2
3
4
5
from tensorflow import keras
model = keras.models.load_model(trained_model.h5, compile=False)
 
export_path = 'saved_model.pb가 저장될 디렉토리'
model.save(export_path, save_format="tf")
cs

 

 

export_path는 비어있어야 하며(중요), 실행 후 다음과 같은 파일과 디렉토리가 생성된다.

 

 

 

 

2. .pb파일을 다시 .tflite 파일로 변환

 

1
2
3
4
5
6
7
8
import tensorflow as tf
 
saved_model_dir = 'saved_model.pb가 들어있는 경로'
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                       tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open('저장할 경로/converted_model.tflite''wb').write(tflite_model)
cs

 

실행 후 지정한 경로 밑에 converted_model.tflite 파일이 생성됩니다.

 

 

3. build.gradle(Module: app)에 ndk 및 dependencies 추가

 

 

 

 

4. assets 디렉토리 밑에 .tflite파일 넣기

 

assets 디렉토리 없을 경우, File -> New -> Folder -> Assets Folder

 

 

 

 

 

 

5. .tflite 로드 및 초기화

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/* ModelClient.java */
 
Interpreter tflite = getTfliteInterpreter("converted_model.tflite");
 
private Interpreter getTfliteInterpreter(String modelPath) {
    try {
        return new Interpreter(loadModelFile(MainActivity.this, modelPath));
    }
    catch (Exception e) {
        e.printStackTrace();
    }
    return null;
}
 
 
/** Load TF Lite model from assets. */
private MappedByteBuffer loadModelFile(Activity activity, String modelPath) throws IOException {
    AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(modelPath);
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
 
cs

 

 

중요한 것은 tflite를 지원하지 않는 레이어들이 있다는 것이다.

원래는 CNN+LSTM 레이어의 조합으로 모델을 생성했다가 CNN+Flatten 레이어로 변경하였다.

 

 

 

 

References

 

텐서플로우 라이트 TextClassification 공식 예제