-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
89 lines (63 loc) · 2.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import tensorflow as tf
import cv2 as cv
import os
import numpy as np
from fastapi import FastAPI, HTTPException, File, UploadFile
from pydantic import BaseModel
from typing import Any
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications import xception
import uvicorn
from pathlib import Path
import random
import string
app = FastAPI()
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
GARBAGE_MAPPER = ['ewest', 'glass', 'metal', 'nnon_recyclable', 'paper', 'plastic']
class GarbageResponseSchema(BaseModel):
status_code: int = 200
results: dict[str, float]
class DefaultSchema(BaseModel):
message: str
status_code: int
def predict_result(file_path:str):
config_path = os.path.join(BASE_DIR, 'config/config.json')
model_path = os.path.join(BASE_DIR, 'ml-models/garbage_model_weights.h5')
with open(config_path, 'r') as f:
model = tf.keras.models.model_from_json(f.read())
model.load_weights(model_path)
input_image = image.load_img(file_path, target_size= (320,320))
input_image = image.img_to_array(input_image)
input_image = np.expand_dims(input_image, axis=0)
output = model.predict(input_image)
# delete file after response
if os.path.exists(file_path):
os.remove(file_path)
return output[0].tolist()
@app.get('/', response_model=DefaultSchema)
def index() -> DefaultSchema:
return DefaultSchema(message='Garbage Dectector', status_code=200)
@app.post('/garbage-upload', response_model=GarbageResponseSchema)
async def garbage_response(
image: UploadFile = File(...)
) -> GarbageResponseSchema:
if not image:
raise HTTPException(status_code=404, detail="Please provide image")
content_type = image.content_type.split('/')
if content_type[0] != 'image':
raise HTTPException(status_code=400, detail="The uploaded files is not a image")
uploads_path = os.path.join(BASE_DIR, 'uploads')
image_name = ''.join(random.choice(string.ascii_lowercase) for i in range(7)) + image.filename
file_path = os.path.join(uploads_path, image_name)
with open(file_path, "wb") as buffer:
buffer.write(await image.read())
result = predict_result(file_path=file_path)
output = {GARBAGE_MAPPER[key]: value for key,value in enumerate(result)}
sorted_list = sorted(output.items(), key=lambda x:x[1], reverse=True)
result_list = dict(sorted_list)
return GarbageResponseSchema(
status_code=200,
results=result_list,
)
if __name__ == "__main__":
uvicorn.run(f"{Path(__file__).stem}:app", host='127.0.0.1', port=9000, reload=True)