46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
labels = ['can', 'paper_cup', 'paper_box', 'paper_milkbox', 'plastic']
|
|
|
|
print("[*] Importing packages...")
|
|
import common
|
|
import tensorrt as trt
|
|
import os
|
|
import cv2
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
|
|
|
print("[*] Loading model...")
|
|
|
|
# load trt engine
|
|
trt_path = 'model.trt'
|
|
with open(trt_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
|
|
engine = runtime.deserialize_cuda_engine(f.read())
|
|
inputs, outputs, bindings, stream = common.allocate_buffers(engine)
|
|
|
|
if __name__ == '__main__':
|
|
def pred(f, dirpath):
|
|
img = cv2.imread(os.path.join(dirpath, f))
|
|
weight = df.loc[f]['weight']
|
|
|
|
inputs[0].host = np.expand_dims(img, 0).astype('float32')
|
|
inputs[1].host = np.expand_dims(weight, 0).astype('float32')
|
|
|
|
# inference
|
|
with engine.create_execution_context() as context:
|
|
trt_outputs = common.do_inference_v2(context, bindings=bindings, inputs=inputs, outputs=outputs,stream=stream)
|
|
|
|
result = trt_outputs[0].argmax(-1)
|
|
return labels[result]
|
|
|
|
|
|
df = pd.read_csv('test_data/weights_test.csv')
|
|
df = df.set_index('name')
|
|
|
|
for dirpath, dirnames, filenames in os.walk('test_data'):
|
|
for f in filenames:
|
|
if f.endswith('.jpg'):
|
|
print(f'{f}: {pred(f, dirpath)}')
|
|
|
|
|