This commit is contained in:
Jetson 2023-01-10 01:52:32 +08:00
commit 77ea4bb6a0
17 changed files with 557 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
**__pycache__

195
extract.py Normal file
View File

@ -0,0 +1,195 @@
import numpy as np
import cv2
BELT_Y_THRESHOLD = 390
OUTPUT_SIZE = 320
def garbage_extract_no_preprocess(img, showCompare = False):
"""
Extracts the garbage from image
img: A 4:3 image
Returns the object's image with size 320x320, black borders, scale to fit
if showCompare is True, returns the image with annotations without clipping
"""
# Flip if needed
if img.shape[0] < img.shape[1]:
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
# Resize
img = cv2.resize(img, (390, 520), interpolation=cv2.INTER_AREA)
output = img.copy()
# Chroma Keying (blue)
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
maskKey = cv2.inRange(hsv, (90,100,130),(110,255,255))
maskKey = cv2.bitwise_not(maskKey)
output = cv2.bitwise_and(output, output, mask = maskKey)
# cv2.imshow('Output', hsv)
# cv2.waitKey(0)
# Find upper borders
gray = cv2.cvtColor(output, cv2.COLOR_BGR2GRAY)
_, bw = cv2.threshold(gray, 50, 255, cv2.THRESH_BINARY)
contours, hierarchy = cv2.findContours(bw[:BELT_Y_THRESHOLD - 10, :], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Fliter small dots
fil = list(filter(lambda x: cv2.contourArea(x) > 300, contours))
cv2.drawContours(output, fil, -1, 255, 3)
# cv2.imshow('Output', output)
# cv2.waitKey(0)
x,y,w,h = cv2.boundingRect(np.vstack(fil)) if len(fil) != 0 else (0, 0, 390, 520)
# Delete left & right
output[:, :x, :] = 0
output[:, x+w: , :] = 0
# Object border
rect = [x, y, x+w, 500]
# grabCut
mask = np.zeros(img.shape[:2],np.uint8)
bgdModel = np.zeros((1,65),np.float64)
fgdModel = np.zeros((1,65),np.float64)
cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
mask2 = np.where((mask==2)|(mask==0),0,1).astype('uint8')
output = output*mask2[:,:,np.newaxis]
if showCompare:
output = cv2.rectangle(output, rect[:2], rect[2:], (0,255,0), 3)
return output
# Clip
# output = output[rect[1]:rect[3], rect[0]:rect[2], :]
print(rect[1], rect[3], rect[0], rect[2])
h = rect[3] - rect[1]
w = rect[2] - rect[0]
if h > w:
rect[2] = min(rect[2] + (h-w)//2, output.shape[0]-1)
rect[0] = max(rect[0] - (h-w)//2, 0)
else:
rect[3] = min(rect[3] + (w-h)//2, output.shape[1]-1)
rect[1] = max(rect[1] - (w-h)//2, 0)
print(rect[1], rect[3], rect[0], rect[2])
output = img[rect[1]:rect[3], rect[0]:rect[2], :]
# Resize
h, w, c = output.shape
scale = OUTPUT_SIZE/w if w > h else OUTPUT_SIZE/h
output = cv2.resize(output, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA)
delta_w = OUTPUT_SIZE - output.shape[1]
delta_h = OUTPUT_SIZE - output.shape[0]
top, bottom = delta_h//2, delta_h-(delta_h//2)
left, right = delta_w//2, delta_w-(delta_w//2)
color = [0, 0, 0]
output = cv2.copyMakeBorder(output, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=color)
return output
def garbage_extract(img, showCompare = False):
"""
Extracts the garbage from image
img: A 4:3 image
Returns the object's image with size 320x320, black borders, scale to fit
if showCompare is True, returns the image with annotations without clipping
"""
# Flip if needed
if img.shape[0] < img.shape[1]:
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
# Resize
img = cv2.resize(img, (390, 520), interpolation=cv2.INTER_AREA)
output = img.copy()
# Chroma Keying (blue)
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
maskKey = cv2.inRange(hsv, (90,100,130),(110,255,255))
maskKey = cv2.bitwise_not(maskKey)
output = cv2.bitwise_and(output, output, mask = maskKey)
# cv2.imshow('Output', hsv)
# cv2.waitKey(0)
# Find upper borders
gray = cv2.cvtColor(output, cv2.COLOR_BGR2GRAY)
_, bw = cv2.threshold(gray, 50, 255, cv2.THRESH_BINARY)
contours, hierarchy = cv2.findContours(bw[:BELT_Y_THRESHOLD - 10, :], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Fliter small dots
fil = list(filter(lambda x: cv2.contourArea(x) > 300, contours))
# cv2.drawContours(output, fil, -1, 255, 3)
# cv2.imshow('Output', output)
# cv2.waitKey(0)
x,y,w,h = cv2.boundingRect(np.vstack(fil))
# Delete left & right
output[:, :x, :] = 0
output[:, x+w: , :] = 0
# Object border
rect = (x, y, x+w, 500)
# grabCut
mask = np.zeros(img.shape[:2],np.uint8)
bgdModel = np.zeros((1,65),np.float64)
fgdModel = np.zeros((1,65),np.float64)
cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
mask2 = np.where((mask==2)|(mask==0),0,1).astype('uint8')
output = output*mask2[:,:,np.newaxis]
if showCompare:
output = cv2.rectangle(output, rect[:2], rect[2:], (0,255,0), 3)
return output
# Clip
output = output[rect[1]:rect[3], rect[0]:rect[2], :]
# Resize
h, w, c = output.shape
scale = OUTPUT_SIZE/w if w > h else OUTPUT_SIZE/h
output = cv2.resize(output, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA)
delta_w = OUTPUT_SIZE - output.shape[1]
delta_h = OUTPUT_SIZE - output.shape[0]
top, bottom = delta_h//2, delta_h-(delta_h//2)
left, right = delta_w//2, delta_w-(delta_w//2)
color = [0, 0, 0]
output = cv2.copyMakeBorder(output, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=color)
return output
if __name__ == "__main__":
img = cv2.imread('Camera/paper/paper_leafyellow_14.jpg')
# compare = garbage_extract(img, True)
output = garbage_extract(img)
# Flip if needed
if img.shape[0] < img.shape[1]:
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
# Resize
img = cv2.resize(img, (390, 520), interpolation=cv2.INTER_AREA)
# cv2.imshow('Compare', np.hstack([compare, img]))
cv2.imshow('Output', output)
cv2.waitKey(0)
cv2.destroyAllWindows()

100
frontend/index.html Normal file
View File

@ -0,0 +1,100 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<!-- Minified version -->
<link rel="stylesheet" href="https://cdn.simplecss.org/simple.min.css">
<title>Automated Garbage Segregation</title>
<style>
body {
grid-template-columns: 1fr min(75rem, 90%) 1fr;
}
body, html {
padding: 0;
margin: 0;
height: 100%;
width: 100%;
}
#main {
width: 100%;
height: 100%;
display: grid;
grid-template-columns: 1fr 2fr;
}
#main > div {
display: flex;
justify-content: center;
align-items: center;
height: 100%;
}
</style>
</head>
<body>
<div id="main" v-scope @mounted="mounted()" @unmounted="unmounted()">
<div id="image">
<img v-if="state =='put'" :src="placeholder" alt="placeholder">
<img v-else-if="state =='camera'" :src="loader" alt="loader">
<img v-else :src="imageUrl" alt="item">
</div>
<div id="status">
<h1>{{ status }}</h1>
</div>
</div>
<script src="https://unpkg.com/petite-vue@0.2.2/dist/petite-vue.iife.js"></script>
<script>
let app = PetiteVue.createApp({
intervalId: -1,
pollingInterval: 500,
placeholder: "/static/scroll-down.gif",
loader: "/static/loader.gif",
imageUrl: "/photo",
type: "未知",
state: "put",
get status() {
let data = {
"put": "請放置垃圾",
"camera": "拍照中,請勿移動",
"identify": "辨識中",
"identified": this.type,
}
return data[this.state];
},
mounted() {
console.log("mounted");
if (this.intervalId < 0) this.intervalId = setInterval(() => { this.polling(); }, this.pollingInterval);
},
unmounted() {
console.log("unmounted");
if (this.intervalId >= 0) {
clearInterval(this.intervalId);
}
},
update(data) {
let state = data.state;
if (state == "identified") this.type = data.type;
this.state = state;
if (state == "identify") this.imageUrl = `/photo?${Date.now()}`
},
async polling() {
let res, data;
try {
res = await fetch("/poll");
data = await res.json();
} catch (error) {
console.error(error);
return;
}
this.update(data);
}
});
app.mount("#main");
</script>
</body>
</html>

72
frontend/server.py Normal file
View File

@ -0,0 +1,72 @@
import os
from flask import Blueprint, Flask, render_template, jsonify, send_file, request
image_path = "/tmp/photo.jpg"
state = "put"
trash_type = ""
api = Blueprint('api', __name__)
# 放下垃圾
@api.route("/putdown")
def putdown():
global state
state = "camera"
return "ok"
# 拍好照
@api.route("/pic")
def pic():
global state
state = "identify"
return "ok"
# 辨識好
@api.route("/result")
def result():
global state, trash_type
trash = request.args.get("type")
if not trash:
return "sad", 400
trash_type = trash
state = "identified"
return "ok"
# 準備好下一個
@api.route("/ready")
def ready():
global state
state = "put"
return "ok"
app = Flask(__name__)
app.register_blueprint(api, url_prefix="/api")
@app.route("/")
def index():
return send_file("index.html")
@app.route("/poll")
def poll():
data = {
"state": state
}
if state == "identified":
data.update({"type": trash_type})
return jsonify(data)
@app.route("/photo")
def photo():
return send_file(image_path)
if __name__ == "__main__":
app.run(host='0.0.0.0', debug=True)

BIN
frontend/static/loader.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.3 KiB

BIN
frontend/static/photo.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 195 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.6 KiB

47
inference.py Normal file
View File

@ -0,0 +1,47 @@
labels = ['can', 'paper_cup', 'paper_box', 'paper_milkbox', 'plastic']
print("[*] Importing packages...")
import tensorflow as tf
from tensorflow import keras
import os
import cv2
import pandas as pd
import numpy as np
print("[*] Loading model...")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.get_logger().setLevel('ERROR')
model = keras.models.load_model('model_without_preprocess_finetuned.h5')
def predict(img, weight):
prob = model.predict([np.expand_dims(img, 0), np.expand_dims(weight, 0)])
print(prob)
result = prob.argmax(-1)[0]
return labels[result]
print("[*] Warming up model...")
img = cv2.imread(os.path.join('test_data', 'can_dew_0_preprocessed.jpg'))
print(predict(img, 52.69))
print("[*] Done!")
if __name__ == '__main__':
def pred(f, dirpath):
img = cv2.imread(os.path.join(dirpath, f))
weight = df.loc[f]['weight']
prob = model.predict([np.expand_dims(img, 0), np.expand_dims(weight, 0)])
result = prob.argmax(-1)[0]
return labels[result]
df = pd.read_csv('test_data/weights_test.csv')
df = df.set_index('name')
for dirpath, dirnames, filenames in os.walk('test_data'):
for f in filenames:
if f.endswith('.jpg'):
print(f'{f}: {pred(f, dirpath)}')

135
main.py Normal file
View File

@ -0,0 +1,135 @@
import cv2
import numpy as np
import requests
import urllib.request
import urllib.parse
import serial
import os
from inference import predict, labels
from extract import garbage_extract_no_preprocess
CAM_IP = "192.168.154.28"
API_BASE = "http://localhost:5000/api"
PHOTO_PATH = "/tmp/photo.jpg"
label_name = ["紙類", "塑膠類", "鋁罐"]
def findCOM():
for i in range(0, 10):
if os.path.exists(f'/dev/ttyACM{i}'):
return f'/dev/ttyACM{i}'
def API(endpoint):
try:
requests.get(API_BASE + endpoint)
except requests.exceptions.RequestException as e:
print(f"[!] Request failed: {e}")
def loop():
ser = serial.Serial(findCOM(), 9600)
try:
while True:
print("[+] Waiting for arduino...")
response = ser.read(1)
print(response)
if (response == b's'):
print("[+] Object placed!")
API('/putdown')
print("[+] Waiting for weight...")
weight = float(ser.readline().decode())
print(f"[+] Got weight: {weight}!")
print("[+] Taking photo...")
takePhoto()
API('/pic')
print("[+] Took photo!")
print("[+] Predicting...")
label_idx = getPrediction(weight)
print(f"[+] Got prediction: {label_name[label_idx]}!")
print("[+] Sending prediction...")
ser.write(bytes([label_idx]))
API(f'/result?{urllib.parse.quote_plus(label_name[label_idx])}')
print("[+] Sent prediction!")
print("[+] Waiting for finish...")
while ser.read(1) != b'd':
pass
print("[+] Finished!")
API('/ready')
except:
ser.close()
def takePhoto():
try:
URL = f"http://{CAM_IP}:8080/photo.jpg"
imgResp = requests.get(URL).content
imgNp = np.array(bytearray(imgResp), dtype=np.uint8)
img = cv2.imdecode(imgNp, -1)
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
img = garbage_extract_no_preprocess(img)
cv2.imwrite(PHOTO_PATH, img)
except Exception as e:
print(f"[!] Photo / Preprocess error: {e}")
def getPrediction(weight):
try:
img = cv2.imread(PHOTO_PATH)
label = predict(img, weight)
except Exception as e:
print(f"[!] Predict error: {e}")
return 3
if label in ['paper_cup', 'paper_box', 'paper_milkbox']:
return 0
elif label in ['plastic']:
return 1
elif label in ['can']:
return 2
def test():
import time
print("[+] Waiting for arduino...")
time.sleep(1)
print("[+] Object placed!")
API('/putdown')
print("[+] Waiting for weight...")
time.sleep(1.5)
weight = 30.0
print(f"[+] Got weight: {weight}!")
print("[+] Taking photo...")
takePhoto()
API(f'/pic')
print("[+] Took photo!")
print("[+] Predicting...")
label_idx = getPrediction(weight)
print(f"[+] Got prediction: {label_name[label_idx]}!")
print("[+] Sending prediction...")
API(f'/result?type={urllib.parse.quote_plus(label_name[label_idx])}')
print("[+] Sent prediction!")
print("[+] Waiting for finish...")
time.sleep(5)
print("[+] Finished!")
API('/ready')
time.sleep(5)
while True:
test()
# loop()

BIN
photo.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 195 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

View File

@ -0,0 +1,7 @@
name,weight
can_dew_0_preprocessed.jpg,52.69
paper_coffee_3_preprocessed.jpg,62.57
paper_eyedrop_3_preprocessed.jpg,38.56
plas_smoothie_7_preprocessed.jpg,60.51
plas_cheese_13_preprocessed.jpg,102.96
paper_leafyellow_11_preprocessed.jpg,108.2
1 name weight
2 can_dew_0_preprocessed.jpg 52.69
3 paper_coffee_3_preprocessed.jpg 62.57
4 paper_eyedrop_3_preprocessed.jpg 38.56
5 plas_smoothie_7_preprocessed.jpg 60.51
6 plas_cheese_13_preprocessed.jpg 102.96
7 paper_leafyellow_11_preprocessed.jpg 108.2