-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathapp.py
executable file
·97 lines (85 loc) · 3.12 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
import json
import os
from typing import List, Dict
import numpy as np
from flask import Flask, jsonify, render_template, request
class MainApp:
images_path: List[str]
images: List[str]
labels: Dict[str, List[int]]
def __init__(self, dataset_name: str = "test_dataset", port: int = 5000) -> None:
self.dataset_name = dataset_name
self.dataset_folder = f"static/img/{self.dataset_name}/"
self.app = Flask(__name__)
self.app.config["DEBUG"] = True
self.load_images()
self.port = port
# Register routes
self.app.add_url_rule("/", "home", self.home, methods=["GET"])
self.app.add_url_rule(
"/save_labels", "save_labels", self.save_labels, methods=["POST"]
)
self.app.add_url_rule("/online", "online", self.online, methods=["GET"])
def run(self) -> None:
self.app.run(port=self.port)
def load_images(self) -> None:
self.images = [f for f in os.listdir(self.dataset_folder) if f.endswith(".jpg")]
self.images.sort(key=lambda name: int(name.split(".jpg")[0]))
self.images_path = [self.dataset_folder + im for im in self.images]
self.labels = {}
try:
with open(f"data/{self.dataset_name}_labels.json") as f:
self.labels = json.load(f)
except OSError:
pass
def home(self) -> str:
idx = int(request.args.get("idx", 0))
idx = np.clip(idx, 0, len(self.images) - 1)
image = self.images_path[idx]
percent = round(100 * ((idx + 1) / len(self.images)), 0)
label: List[int] = self.labels.get(self.images[idx], [])
# if label is None:
# # get previous image label
# label = labels.get(self.images[idx - 1], [])
return render_template(
"home.html",
idx=idx,
total=len(self.images),
image=image,
percent=percent,
label=label,
)
def save_labels(self):
label = json.loads(request.data.decode("utf8"))
key = self.images[int(label["idx"])]
# Delete incomplete labels
if None in label["values"] and key in self.labels:
del self.labels[key]
else:
# Save only labels with 3 labels (3 points)
if len(label["values"]) == 3:
self.labels[key] = label["values"]
with open(f"data/{self.dataset_name}_labels.json", "w") as f:
json.dump(self.labels, f)
return jsonify(status="ok", label=label)
def online(self) -> str:
image = self.images_path[0]
percent = 0
label: List[int] = []
return render_template(
"online.html",
idx=0,
total=len(self.images),
image=image,
percent=percent,
label=label,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a line detector")
parser.add_argument(
"-n", "--name", help="Dataset name", type=str, default="test_dataset"
)
args = parser.parse_args()
app = MainApp(args.name)
app.run()