feat: TensorRT model & inference
This commit is contained in:
70
TensorRT_Model/convert_model.py
Normal file
70
TensorRT_Model/convert_model.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import multiprocessing
|
||||
|
||||
keras_path = '../model_without_preprocess_finetuned.h5'
|
||||
onnx_path = 'model.onnx'
|
||||
|
||||
# convert to onnx
|
||||
def keras2onnx():
|
||||
import os
|
||||
|
||||
import onnx
|
||||
import onnxmltools
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
|
||||
print('[*] Converting Keras Model to onnx')
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
tf.get_logger().setLevel('ERROR')
|
||||
|
||||
keras_model = keras.models.load_model(keras_path)
|
||||
onnx_model = onnxmltools.convert_keras(keras_model)
|
||||
|
||||
# one data at a time
|
||||
onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 1
|
||||
onnx_model.graph.input[1].type.tensor_type.shape.dim[0].dim_value = 1
|
||||
onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_value = 1
|
||||
|
||||
onnx.checker.check_model(onnx_model)
|
||||
onnx.save(onnx_model, onnx_path)
|
||||
|
||||
print(f'[*] onnx file saved as {onnx_path}')
|
||||
|
||||
def onnx2rt():
|
||||
# convert to tensorrt
|
||||
import tensorrt as trt
|
||||
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
|
||||
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
|
||||
trt_file = '../model.trt'
|
||||
|
||||
print('[*] Converting onnx to tensorrt')
|
||||
|
||||
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
|
||||
with open(onnx_path, 'rb') as f:
|
||||
if not parser.parse(f.read()):
|
||||
print('ERROR: Failed to parse the ONNX file.')
|
||||
for error in range(parser.num_errors):
|
||||
print (parser.get_error(error))
|
||||
|
||||
config = builder.create_builder_config()
|
||||
profile = builder.create_optimization_profile()
|
||||
|
||||
config.add_optimization_profile(profile)
|
||||
# config.flags = 1 << (int)(trt.BuilderFlag.DEBUG)
|
||||
|
||||
with open(trt_file, "wb") as f:
|
||||
f.write(builder.build_serialized_network(network, config))
|
||||
|
||||
print(f'[*] tensorrt file saved as {trt_file}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
onnx = multiprocessing.Process(target=keras2onnx)
|
||||
rt = multiprocessing.Process(target=onnx2rt)
|
||||
|
||||
onnx.start()
|
||||
onnx.join()
|
||||
|
||||
rt.start()
|
||||
rt.join()
|
||||
Reference in New Issue
Block a user