import multiprocessing keras_path = '../model_without_preprocess_finetuned.h5' onnx_path = 'model.onnx' # convert to onnx def keras2onnx(): import os import onnx import onnxmltools import tensorflow as tf from tensorflow import keras print('[*] Converting Keras Model to onnx') os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' tf.get_logger().setLevel('ERROR') keras_model = keras.models.load_model(keras_path) onnx_model = onnxmltools.convert_keras(keras_model) # one data at a time onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 1 onnx_model.graph.input[1].type.tensor_type.shape.dim[0].dim_value = 1 onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_value = 1 onnx.checker.check_model(onnx_model) onnx.save(onnx_model, onnx_path) print(f'[*] onnx file saved as {onnx_path}') def onnx2rt(): # convert to tensorrt import tensorrt as trt TRT_LOGGER = trt.Logger(trt.Logger.ERROR) EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) trt_file = '../model.trt' print('[*] Converting onnx to tensorrt') with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser: with open(onnx_path, 'rb') as f: if not parser.parse(f.read()): print('ERROR: Failed to parse the ONNX file.') for error in range(parser.num_errors): print (parser.get_error(error)) config = builder.create_builder_config() profile = builder.create_optimization_profile() config.add_optimization_profile(profile) # config.flags = 1 << (int)(trt.BuilderFlag.DEBUG) with open(trt_file, "wb") as f: f.write(builder.build_serialized_network(network, config)) print(f'[*] tensorrt file saved as {trt_file}') if __name__ == "__main__": onnx = multiprocessing.Process(target=keras2onnx) rt = multiprocessing.Process(target=onnx2rt) onnx.start() onnx.join() rt.start() rt.join()