GarbageSegregation/inference.py
2023-01-10 01:52:32 +08:00

48 lines
1.3 KiB
Python

labels = ['can', 'paper_cup', 'paper_box', 'paper_milkbox', 'plastic']
print("[*] Importing packages...")
import tensorflow as tf
from tensorflow import keras
import os
import cv2
import pandas as pd
import numpy as np
print("[*] Loading model...")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.get_logger().setLevel('ERROR')
model = keras.models.load_model('model_without_preprocess_finetuned.h5')
def predict(img, weight):
prob = model.predict([np.expand_dims(img, 0), np.expand_dims(weight, 0)])
print(prob)
result = prob.argmax(-1)[0]
return labels[result]
print("[*] Warming up model...")
img = cv2.imread(os.path.join('test_data', 'can_dew_0_preprocessed.jpg'))
print(predict(img, 52.69))
print("[*] Done!")
if __name__ == '__main__':
def pred(f, dirpath):
img = cv2.imread(os.path.join(dirpath, f))
weight = df.loc[f]['weight']
prob = model.predict([np.expand_dims(img, 0), np.expand_dims(weight, 0)])
result = prob.argmax(-1)[0]
return labels[result]
df = pd.read_csv('test_data/weights_test.csv')
df = df.set_index('name')
for dirpath, dirnames, filenames in os.walk('test_data'):
for f in filenames:
if f.endswith('.jpg'):
print(f'{f}: {pred(f, dirpath)}')