71 lines
2.1 KiB
Python
71 lines
2.1 KiB
Python
import multiprocessing
|
|
|
|
keras_path = '../model_without_preprocess_finetuned.h5'
|
|
onnx_path = 'model.onnx'
|
|
|
|
# convert to onnx
|
|
def keras2onnx():
|
|
import os
|
|
|
|
import onnx
|
|
import onnxmltools
|
|
import tensorflow as tf
|
|
from tensorflow import keras
|
|
|
|
print('[*] Converting Keras Model to onnx')
|
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
tf.get_logger().setLevel('ERROR')
|
|
|
|
keras_model = keras.models.load_model(keras_path)
|
|
onnx_model = onnxmltools.convert_keras(keras_model)
|
|
|
|
# one data at a time
|
|
onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 1
|
|
onnx_model.graph.input[1].type.tensor_type.shape.dim[0].dim_value = 1
|
|
onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_value = 1
|
|
|
|
onnx.checker.check_model(onnx_model)
|
|
onnx.save(onnx_model, onnx_path)
|
|
|
|
print(f'[*] onnx file saved as {onnx_path}')
|
|
|
|
def onnx2rt():
|
|
# convert to tensorrt
|
|
import tensorrt as trt
|
|
|
|
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
|
|
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
|
|
|
trt_file = '../model.trt'
|
|
|
|
print('[*] Converting onnx to tensorrt')
|
|
|
|
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
|
|
with open(onnx_path, 'rb') as f:
|
|
if not parser.parse(f.read()):
|
|
print('ERROR: Failed to parse the ONNX file.')
|
|
for error in range(parser.num_errors):
|
|
print (parser.get_error(error))
|
|
|
|
config = builder.create_builder_config()
|
|
profile = builder.create_optimization_profile()
|
|
|
|
config.add_optimization_profile(profile)
|
|
# config.flags = 1 << (int)(trt.BuilderFlag.DEBUG)
|
|
|
|
with open(trt_file, "wb") as f:
|
|
f.write(builder.build_serialized_network(network, config))
|
|
|
|
print(f'[*] tensorrt file saved as {trt_file}')
|
|
|
|
if __name__ == "__main__":
|
|
onnx = multiprocessing.Process(target=keras2onnx)
|
|
rt = multiprocessing.Process(target=onnx2rt)
|
|
|
|
onnx.start()
|
|
onnx.join()
|
|
|
|
rt.start()
|
|
rt.join()
|