-
Notifications
You must be signed in to change notification settings - Fork 21
/
api.py
69 lines (57 loc) · 2.14 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import asyncio
import aiohttp
import async_timeout
import numpy as np
import uvloop
from aiohttp import web
from aiohttp.web import FileField
from aiohttp.web import HTTPBadRequest
from aiohttp.web import HTTPNotFound
from aiohttp.web import HTTPUnsupportedMediaType
from classify_nsfw import caffe_preprocess_and_compute, load_model
nsfw_net, caffe_transformer = load_model()
def classify(image: bytes) -> np.float64:
scores = caffe_preprocess_and_compute(image,
caffe_transformer=caffe_transformer,
caffe_net=nsfw_net,
output_layers=["prob"])
return scores[1]
async def fetch(session, url):
with async_timeout.timeout(10):
async with session.get(url) as response:
if response.status == 404:
raise HTTPNotFound()
return await response.read()
class API(web.View):
async def post(self):
request = self.request
data = await request.post()
try:
if data.get('url'):
image = await fetch(session, data["url"])
elif data.get('file'):
image = data.get('file')
if type(image) == FileField:
image = image.file.read()
else:
raise OSError("File is not a valid multipart file upload.")
else:
raise KeyError()
nsfw_prob = classify(image)
text = nsfw_prob.astype(str)
return web.Response(text=text)
except KeyError:
error_text = "Missing `url` or `file` POST parameter"
return HTTPBadRequest(text=error_text)
except OSError as e:
if "cannot identify" in str(e):
raise HTTPUnsupportedMediaType(text="Invalid image")
else:
raise e
except ValueError:
raise HTTPBadRequest(text="Malformed image provided")
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
session = aiohttp.ClientSession()
app = web.Application()
app.router.add_route("*", "/", API)
web.run_app(app)