케라스나 파이토치에서 만든 AI모델을 CoreML(ios), Tensorflow mobile or lite(android)를 통해 적용할 수 있다.
Tensorflow Mobile
- 구글이 제공하는 머신러닝 프레임워크, .pb파일 이용
- Feedforward, convolution, recurrent neural network 지원(지금은 full RNN, LSTM도 지원)
- C++ 기반의 API, Java Wrapper 제공
Tensorflow Lite
- 구글이 제공하는 머신러닝 프레임워크, .tflite파일 이용
- 제공하는 모델이 좀 더 한정적이라고 합니다.
- C++ 기반의 API, Java Wrapper 제공
- Tf Mobile보다 가볍게 최적화 되어있음
어쨌든 keras 모델을 ios나 안드로이드에 삽입하여 사용 가능 하다는 것.(.h5자체를 넣으면 더 좋겠지만..)
(안드로이드 + tf lite 공개 샘플 사용해보기)
첫 번째 단계로, .h5파일이나 .ckpt파일을 .pb파일로 변환해야 한다.
그 전에 .ckpt, .pb, .h5가 뭔지 어떻게 변환하는지 살펴보자.
○ 모델 포맷
1. .ckpt : 모델의 변수(가중치)인 체크포인트 파일(텐서플로우)
2. .pb : 모델의 변수(가중치) + 모델 구조(전체 그래프)로 이루어진 바이너리 파일(텐서플로우)
3. .pbtxt : pb파일을 읽을 수 있는 텍스트 파일, 모델 구조 파악 가능(텐서플로우)
4. .h5 : Keras에서는 모델 및 가중치 모두 가지고 있는 HDF형식 파일
1. ckpt
모델에 대한 메타 정보를 가지고 있어 재학습이 가능하며, 모델 구조를 제외한 모델의 가중치만 담은 파일이다. 하지만 predict를 할 때 필요한 정보 외에도 담겨있기 때문에 파일의 크기가 무겁다.
.ckpt.data : 모델(graph)을 제외한 모든 변수를 포함한 파일 (.ckpt와 동일), 모델 복원시 meta파일과 data파일을 이용
.ckpt.index : meta파일과 data파일을 매핑하기 위해 내부적으로 필요한 인덱스 파일
.ckpt.meta : 모델(graph)만 있는 파일, 변수를 제외한 그래프의 구조를 담은 파일.
|
with tf.Session(graph=g) as sess :
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
ckpt_path = saver.save(sess, "saved/train1")
|
cs |
ckpt 파일을 생성하고 저장하면
train1.ckpt.data-00000-of-00001, train1.ckpt.index, train1.ckpt.meta 이런 식으로 지정한 경로에 파일이 생긴다.
2. .pb
모델(graph)와 가중치(weight) 모두 저장된 파일. freeze_graph.py를 통해 만들 수 있다. 그래프를 freezing한다는 뜻은 결국 모델을 .pb파일로 저장 시킨다는 것. 재학습 불가능
주로 .ckpt파일이나 .h5파일을 .pb파일로 변환하여 모바일이나 c++기반 프로그램에 사용한다.
.ckpt파일을 .pb파일로 변환하는 freeze_graph.py 모듈은 텐서플로우 공식홈페이지에서 다운가능하다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
def freeze_graph(input_graph,
input_saver,
input_binary,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_whitelist="",
variable_names_blacklist="",
input_meta_graph=None,
input_saved_model_dir=None,
saved_model_tags=tag_constants.SERVING,
checkpoint_version=saver_pb2.SaverDef.V2):
|
cs |
중요한 부분은 output_node_names이다. 마지막 노드의 이름을 알아야 한다. ('dense_2' 이런식으로?)
1
2
3
4
5
6
7
8
9
10
11
12
|
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
def main():
freeze_graph.freeze_graph('/graph.pbtxt', "", False,
'/checkpoint.ckpt', 'output_node_name',
"save/restore_all", "save/Const",
'frozen.pb', True, "")
if __name__ == '__main__':
main()
|
cs |
ckpt + pbtxt 파일을 이용하여 .pb파일 만드는 예제
3. pbtxt
.pb는 바이너리 파일이고 .pbtxt는 텍스트 파일. 따라서 .pb파일이 더 가벼우나 .pbtxt는 사람이 읽을 수 있다는 장점이 있다.
4. .h5
Hierarchical Data Format(HDF)형식으로 저장되는 데이터. 케라스에서 모델 및 가중치를 모두 가지고 있으며, keras.models.load_model()을 통해 불러와 사용이 가능
다음 시간에는 케라스에서 모델을 저장하고 Conversion하는 부분에 대해 포스팅하겠습니다.
[References]
Getting Started with Keras and Adroid Studio(Youtube)
E2E tf.Keras to TFLite to Android
pb파일 TensorBoard에 띄우기
Hierarchical Data Format(wikipedia)
User Keras model on Android(Youtube)