-
Notifications
You must be signed in to change notification settings - Fork 6
/
nsfw_detect.py
47 lines (42 loc) · 1.83 KB
/
nsfw_detect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from aidb.config.config_types import AIDBListType
from aidb.inference.examples.google_inference_service import GoogleVisionAnnotate
DB_URL = 'sqlite+aiosqlite://'
DB_NAME = 'aidb_test_nsfw.sqlite'
nsfw_detect_service = GoogleVisionAnnotate(
name="nsfw_detect",
token=None, # automatically get token from gcloud
columns_to_input_keys=[
('requests', AIDBListType(), 'image', 'source', 'imageUri')],
response_keys_to_columns=[
('responses', AIDBListType(), 'safeSearchAnnotation', 'adult'),
('responses', AIDBListType(), 'safeSearchAnnotation', 'spoof'),
('responses', AIDBListType(), 'safeSearchAnnotation', 'medical'),
('responses', AIDBListType(), 'safeSearchAnnotation', 'violence'),
('responses', AIDBListType(), 'safeSearchAnnotation', 'racy')],
input_columns_types=[str],
output_columns_types=[str, str, str, str, str],
preferred_batch_size=128,
project_id="your-project-id",
default_args={('requests', AIDBListType(), 'features', 'type'): 'SAFE_SEARCH_DETECTION',
'parent': 'projects/your-project-id'})
inference_engines = [
{
"service": nsfw_detect_service,
"input_col": ("images_source.image_url", "images_source.image_id"),
"output_col": ("images.adult", "images.spoof", "images.medical", "images.violence", "images.racy", "images.image_id"),
"copy": {"images_source.image_id": "images.image_id"}
}
]
blobs_csv_file = "tests/data/image_path_data.csv"
blob_table_name = "images_source"
blobs_keys_columns = ["image_id"]
"""
dictionary of table names to list of columns
"""
tables = {"images": [
{"name": "image_id", "is_primary_key": True, "refers_to": ("images_source", "image_id"), "dtype": int},
{"name": "adult", "dtype": str},
{"name": "spoof", "dtype": str},
{"name": "medical", "dtype": str},
{"name": "violence", "dtype": str},
{"name": "racy", "dtype": str}]}