새소식

SW 개발

[Android + Keras] 케라스 모델을 안드로이드에서 사용하려면(Tensorflow Mobile,Lite)

 

 

케라스나 파이토치에서 만든 AI모델을 CoreML(ios), Tensorflow mobile or lite(android)를 통해 적용할 수 있다.

 

Tensorflow Mobile

  • 구글이 제공하는 머신러닝 프레임워크, .pb파일 이용
  • Feedforward, convolution, recurrent neural network 지원(지금은 full RNN, LSTM도 지원)
  • C++ 기반의 API, Java Wrapper 제공

 

 

Tensorflow Lite

  • 구글이 제공하는 머신러닝 프레임워크, .tflite파일 이용
  • 제공하는 모델이 좀 더 한정적이라고 합니다.
  • C++ 기반의 API, Java Wrapper 제공
  • Tf Mobile보다 가볍게 최적화 되어있음

 

 

어쨌든 keras 모델을 ios나 안드로이드에 삽입하여 사용 가능 하다는 것.(.h5자체를 넣으면 더 좋겠지만..)

(안드로이드 + tf lite 공개 샘플 사용해보기)

 

 

첫 번째 단계로, .h5파일이나 .ckpt파일을 .pb파일로 변환해야 한다.

그 전에 .ckpt, .pb, .h5가 뭔지 어떻게 변환하는지 살펴보자.

 

 


 

○ 모델 포맷

 

1. .ckpt : 모델의 변수(가중치)인 체크포인트 파일(텐서플로우)

2. .pb : 모델의 변수(가중치) + 모델 구조(전체 그래프)로 이루어진 바이너리 파일(텐서플로우)

3. .pbtxt : pb파일을 읽을 수 있는 텍스트 파일, 모델 구조 파악 가능(텐서플로우)

4. .h5 : Keras에서는 모델 및 가중치 모두 가지고 있는 HDF형식 파일

 

 

1. ckpt

모델에 대한 메타 정보를 가지고 있어 재학습이 가능하며, 모델 구조를 제외한 모델의 가중치만 담은 파일이다. 하지만 predict를 할 때 필요한 정보 외에도 담겨있기 때문에 파일의 크기가 무겁다.

 

.ckpt.data : 모델(graph)을 제외한 모든 변수를 포함한 파일 (.ckpt와 동일), 모델 복원시 meta파일과 data파일을 이용

.ckpt.index : meta파일과 data파일을 매핑하기 위해 내부적으로 필요한 인덱스 파일

.ckpt.meta : 모델(graph)만 있는 파일, 변수를 제외한 그래프의 구조를 담은 파일.

 

1
2
3
4
5
6
7
with tf.Session(graph=g) as sess :
 
    saver = tf.train.Saver()
 
    sess.run(tf.global_variables_initializer())
 
    ckpt_path = saver.save(sess, "saved/train1")
cs

 

ckpt 파일을 생성하고 저장하면

train1.ckpt.data-00000-of-00001, train1.ckpt.index, train1.ckpt.meta 이런 식으로 지정한 경로에 파일이 생긴다.

 


 

2. .pb

모델(graph)와 가중치(weight) 모두 저장된 파일. freeze_graph.py를 통해 만들 수 있다. 그래프를 freezing한다는 뜻은 결국 모델을 .pb파일로 저장 시킨다는 것. 재학습 불가능

 

주로 .ckpt파일이나 .h5파일을 .pb파일로 변환하여 모바일이나 c++기반 프로그램에 사용한다.

.ckpt파일을 .pb파일로 변환하는 freeze_graph.py 모듈은 텐서플로우 공식홈페이지에서 다운가능하다.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def freeze_graph(input_graph, 
                input_saver, 
                input_binary, 
                input_checkpoint, 
                output_node_names, 
                restore_op_name, 
                filename_tensor_name, 
                output_graph, 
                clear_devices, 
                initializer_nodes, 
                variable_names_whitelist=""
                variable_names_blacklist=""
                input_meta_graph=None, 
                input_saved_model_dir=None, 
                saved_model_tags=tag_constants.SERVING, 
                checkpoint_version=saver_pb2.SaverDef.V2):
 
cs

 

중요한 부분은 output_node_names이다. 마지막 노드의 이름을 알아야 한다. ('dense_2' 이런식으로?)

 

 

1
2
3
4
5
6
7
8
9
10
11
12
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
 
def main():
    freeze_graph.freeze_graph('/graph.pbtxt'"", False,
                                '/checkpoint.ckpt''output_node_name',
                                "save/restore_all""save/Const",
                                'frozen.pb', True, "")
 
if __name__ == '__main__'
    main()
 
cs

 

ckpt + pbtxt 파일을 이용하여 .pb파일 만드는 예제

 


 

3. pbtxt

.pb는 바이너리 파일이고 .pbtxt는 텍스트 파일. 따라서 .pb파일이 더 가벼우나 .pbtxt는 사람이 읽을 수 있다는 장점이 있다. 


 

4. .h5

Hierarchical Data Format(HDF)형식으로 저장되는 데이터. 케라스에서 모델 및 가중치를 모두 가지고 있으며, keras.models.load_model()을 통해 불러와 사용이 가능

 

 

 

다음 시간에는 케라스에서 모델을 저장하고 Conversion하는 부분에 대해 포스팅하겠습니다.

 

 

 

[References]

 

Getting Started with Keras and Adroid Studio(Youtube)

E2E tf.Keras to TFLite to Android

pb파일 TensorBoard에 띄우기

Hierarchical Data Format(wikipedia)

User Keras model on Android(Youtube)

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.