GarbageSegregation/inference_rt.py
2023-01-18 16:01:11 +08:00

46 lines
1.3 KiB
Python

labels = ['can', 'paper_cup', 'paper_box', 'paper_milkbox', 'plastic']
print("[*] Importing packages...")
import common
import tensorrt as trt
import os
import cv2
import pandas as pd
import numpy as np
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
print("[*] Loading model...")
# load trt engine
trt_path = 'model.trt'
with open(trt_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
inputs, outputs, bindings, stream = common.allocate_buffers(engine)
if __name__ == '__main__':
def pred(f, dirpath):
img = cv2.imread(os.path.join(dirpath, f))
weight = df.loc[f]['weight']
inputs[0].host = np.expand_dims(img, 0).astype('float32')
inputs[1].host = np.expand_dims(weight, 0).astype('float32')
# inference
with engine.create_execution_context() as context:
trt_outputs = common.do_inference_v2(context, bindings=bindings, inputs=inputs, outputs=outputs,stream=stream)
result = trt_outputs[0].argmax(-1)
return labels[result]
df = pd.read_csv('test_data/weights_test.csv')
df = df.set_index('name')
for dirpath, dirnames, filenames in os.walk('test_data'):
for f in filenames:
if f.endswith('.jpg'):
print(f'{f}: {pred(f, dirpath)}')