2023-01-18 16:01:11 +08:00

71 lines
2.1 KiB
Python

import multiprocessing
keras_path = '../model_without_preprocess_finetuned.h5'
onnx_path = 'model.onnx'
# convert to onnx
def keras2onnx():
import os
import onnx
import onnxmltools
import tensorflow as tf
from tensorflow import keras
print('[*] Converting Keras Model to onnx')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.get_logger().setLevel('ERROR')
keras_model = keras.models.load_model(keras_path)
onnx_model = onnxmltools.convert_keras(keras_model)
# one data at a time
onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 1
onnx_model.graph.input[1].type.tensor_type.shape.dim[0].dim_value = 1
onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_value = 1
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, onnx_path)
print(f'[*] onnx file saved as {onnx_path}')
def onnx2rt():
# convert to tensorrt
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
trt_file = '../model.trt'
print('[*] Converting onnx to tensorrt')
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
print('ERROR: Failed to parse the ONNX file.')
for error in range(parser.num_errors):
print (parser.get_error(error))
config = builder.create_builder_config()
profile = builder.create_optimization_profile()
config.add_optimization_profile(profile)
# config.flags = 1 << (int)(trt.BuilderFlag.DEBUG)
with open(trt_file, "wb") as f:
f.write(builder.build_serialized_network(network, config))
print(f'[*] tensorrt file saved as {trt_file}')
if __name__ == "__main__":
onnx = multiprocessing.Process(target=keras2onnx)
rt = multiprocessing.Process(target=onnx2rt)
onnx.start()
onnx.join()
rt.start()
rt.join()