Skip to content

Commit daa6c6d

Browse files
dekahaedraYu LiunLevin13
authored
vis + gate algo refactor (#34)
* Gate task * With optical flow added * Updated optical flow usage * Modified how dense optical flow is used * updated vis * merge to new format * update to new perception directory * deleted outdated files * Updated optical flow & separated classes * Removed unnecessary code * updated with vis * added init * added init * Refactor Vis + Gate Algos * vis updates merged * tidied up formatting and imports * random code style fixes * fixed some style issues * vis now iterates through data frames without error Co-authored-by: Yu Liu <[email protected]> Co-authored-by: nLevin13 <[email protected]>
1 parent 0f3a6d5 commit daa6c6d

38 files changed

+701
-760
lines changed

.gitattributes

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.zip filter=lfs diff=lfs merge=lfs -text

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,5 @@ __pycache__/
1818
# IDE files
1919
.idea
2020
.vs_code/
21+
22+
data/

perception/__init__.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1-
import perception.tasks.TestTasks.TestAlgo as TestAlgo
2-
import perception.tasks.gate.GateCenter as GateSeg
1+
import perception.vis.TestAlgo as TestAlgo
2+
import perception.tasks.gate.GateCenterAlgo as GateSeg
3+
import perception.tasks.gate.GateSegmentationAlgoA as GateSegA
4+
import perception.tasks.gate.GateSegmentationAlgoB as GateSegB
5+
import perception.tasks.gate.GateSegmentationAlgoC as GateSegC
36
# import perception.tasks as tasks
47

58
ALGOS = {
69
'test': TestAlgo.TestAlgo,
7-
'gateseg': GateSeg.GateCenter
10+
'gateseg': GateSeg.GateCenterAlgo,
11+
'gatesegA': GateSegA.GateSegmentationAlgoA,
12+
'gatesegB': GateSegB.GateSegmentationAlgoB,
13+
'gatesegC': GateSegC.GateSegmentationAlgoC
814
}

perception/tasks/TaskPerceiver.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import Any, Dict, Tuple
1+
from typing import Any, Dict
22
import numpy as np
33

4-
class TaskPerceiver:
54

5+
class TaskPerceiver:
66
def __init__(self, **kwargs):
77
"""Initializes the TaskPerceiver.
88
Args:
@@ -11,22 +11,18 @@ def __init__(self, **kwargs):
1111
for the slider which controls this variable, and default_val is
1212
the initial value of the slider.
1313
"""
14-
self.time = 0
1514
self.kwargs = kwargs
1615

1716
def analyze(self, frame: np.ndarray, debug: bool, slider_vals: Dict[str, int]) -> Any:
1817
"""Runs the algorithm and returns the result.
1918
Args:
20-
frame: The frame to analyze
19+
frame: The frame to analyze
2120
debug: Whether or not to display intermediate images for debugging
2221
slider_vals: A list of names of the variables which the user should be
2322
able to control from the Visualizer, mapped to current slider
2423
value for that variable
25-
Returns:
26-
the result of the algorithm
27-
debug frames must each be same size as original input frame. Might change this in the future.
24+
Returns:
25+
the result of the algorithm
26+
debug frames must each be same size as original input frame. Might change this in the future.
2827
"""
2928
raise NotImplementedError("Need to implement with child class.")
30-
31-
def var_info(self) -> Dict[str, Tuple[Tuple[int, int], int]]:
32-
return self.kwargs

perception/tasks/cross/CrossPerceiver.py

-8
This file was deleted.
File renamed without changes.

perception/tasks/cross/cross_detection.py

+12-17
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
#############################################################################
99

1010
sys.path.insert(0, '../background_removal')
11-
from peak_removal_adaptive_thresholding import filter_out_highest_peak_multidim
12-
from combined_filter import combined_filter
11+
from perception.tasks.segmentation.peak_removal_adaptive_thresholding import filter_out_highest_peak_multidim
12+
from perception.tasks.segmentation.combinedFilter import init_combined_filter
1313

14-
ret, frame = True, cv2.imread('../data/cross/cross.png') # https://i.imgur.com/rjv1Vcy.png
14+
ret, frame = True, cv2.imread('../data/cross/cross.png') # https://i.imgur.com/rjv1Vcy.png
1515

1616
# "hsv" = Apply hsv thresholding before trying to find the path marker
1717
# "multidim" = Apply filter_out_highest_peak_multidim
@@ -29,7 +29,7 @@ def find_cross(frame, draw_figs=True):
2929
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
3030

3131
ret, thresh = cv2.threshold(gray, 127, 255,0)
32-
__, contours,hierarchy = cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
32+
__, contours,hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
3333
contours.sort(key=lambda c: cv2.contourArea(c), reverse=True)
3434

3535
possible_crosses = []
@@ -44,17 +44,16 @@ def find_cross(frame, draw_figs=True):
4444
if defects is not None and len(defects) == 4:
4545
possible_crosses.append(defects)
4646

47-
4847
if draw_figs:
4948
img = frame.copy()
5049
for defects in possible_crosses:
5150
for i in range(defects.shape[0]):
52-
s,e,f,d = defects[i,0]
51+
s, e, f, d = defects[i, 0]
5352
# start = tuple(cnt[s][0])
5453
# end = tuple(cnt[e][0])
5554
far = tuple(cnt[f][0])
5655
# cv2.line(img,start,end,[0,255,0],2)
57-
cv2.circle(img,far,5,[0,0,255],-1)
56+
cv2.circle(img, far, 5, [0, 0, 255], -1)
5857
cv2.imshow('cross at contour number ' + str(i),img)
5958
cv2.imshow('original', frame)
6059

@@ -64,15 +63,15 @@ def find_cross(frame, draw_figs=True):
6463
###########################################
6564
# Main Body
6665
###########################################
67-
66+
# TODO: port to vis
6867
if __name__ == "__main__":
68+
combined_filter = init_combined_filter()
69+
6970
ret_tries = 0
70-
while(1 and ret_tries < 50):
71+
while 1 and ret_tries < 50:
7172
# ret,frame = cap.read()
72-
73-
if ret == True:
73+
if ret:
7474
# frame = cv2.resize(frame, (0,0), fx=0.5, fy=0.5)
75-
7675
if thresholding == "multidim":
7776
votes1, threshed = filter_out_highest_peak_multidim(frame)
7877
threshed = cv2.morphologyEx(threshed, cv2.MORPH_OPEN, np.ones((5,5),np.uint8))
@@ -86,13 +85,9 @@ def find_cross(frame, draw_figs=True):
8685

8786
ret_tries = 0
8887
k = cv2.waitKey(60) & 0xff
89-
if k == 27: # esc
90-
if testing:
91-
print("hsv thresholds:")
92-
print(thresholds_used)
88+
if k == 27: # esc
9389
break
9490
else:
9591
ret_tries += 1
9692

9793
cv2.destroyAllWindows()
98-
cap.release()
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
1-
from perception.tasks.gate.GateSegmentation import GateSegmentationAlgo
1+
from perception.tasks.gate.GateSegmentationAlgoA import GateSegmentationAlgoA
22
from perception.tasks.TaskPerceiver import TaskPerceiver
33
from collections import namedtuple
4-
import sys
54

65
import numpy as np
76
import math
87
import cv2 as cv
9-
import time
108
import statistics
119

1210

13-
class GateCenter(TaskPerceiver):
11+
class GateCenterAlgo(TaskPerceiver):
1412
center_x_locs, center_y_locs = [], []
1513
output_class = namedtuple("GateOutput", ["centerx", "centery"])
1614
output_type = {'centerx': np.int16, 'centery': np.int16}
@@ -20,39 +18,43 @@ def __init__(self):
2018
self.gate_center = self.output_class(250, 250)
2119
self.use_optical_flow = False
2220
self.optical_flow_c = 0.1
23-
self.gate = GateSegmentationAlgo()
21+
self.gate = GateSegmentationAlgoA()
2422
self.prvs = None
2523

24+
# TODO: do input and return typing
2625
def analyze(self, frame, debug, slider_vals):
27-
self.optical_flow_c = slider_vals['optical_flow_c'] / 100
28-
rect1, rect2, debug_filter = self.gate.analyze(frame, True)
26+
self.optical_flow_c = slider_vals['optical_flow_c']/100
27+
rect, debug_filters = self.gate.analyze(frame, True)
28+
debug_filter = debug_filters[-1]
29+
debug_filters = debug_filters[:-1]
30+
2931
if self.prvs is None:
3032
# frame = cv.resize(frame, None, fx=0.3, fy=0.3)
3133
self.prvs = cv.cvtColor(frame, cv.COLOR_BGR2GRAY)
32-
else:
33-
if rect1 and rect2:
34-
self.gate_center = self.get_center(rect1, rect2, frame)
34+
else:
35+
if rect[0] and rect[1]:
36+
self.gate_center = self.get_center(rect[0], rect[1], frame)
3537
if self.use_optical_flow:
36-
cv.circle(debug_filter, self.gate_center, 5, (3, 186, 252), -1)
38+
cv.circle(debug_filter, self.gate_center, 5, (3,186,252), -1)
3739
else:
38-
cv.circle(debug_filter, self.gate_center, 5, (0, 0, 255), -1)
39-
40+
cv.circle(debug_filter, self.gate_center, 5, (0,0,255), -1)
41+
4042
if debug:
41-
return (self.output_class(self.gate_center[0], self.gate_center[1]), [frame, debug_filter])
42-
return self.output_class(self.gate_center[0], self.gate_center[1])
43+
return (self.gate_center[0], self.gate_center[1]), list(debug_filters) + [debug_filter]
44+
return (self.gate_center[0], self.gate_center[1])
4345

4446
def center_without_optical_flow(self, center_x, center_y):
4547
# get starting center location, averaging over the first 2510 frames
4648
if len(self.center_x_locs) == 0:
4749
self.center_x_locs.append(center_x)
4850
self.center_y_locs.append(center_y)
49-
51+
5052
elif len(self.center_x_locs) < 25:
5153
self.center_x_locs.append(center_x)
5254
self.center_y_locs.append(center_y)
5355
center_x = int(statistics.mean(self.center_x_locs))
5456
center_y = int(statistics.mean(self.center_y_locs))
55-
57+
5658
# use new center location only when it is close to the previous valid location
5759
else:
5860
self.center_x_locs.append(center_x)
@@ -61,11 +63,11 @@ def center_without_optical_flow(self, center_x, center_y):
6163
self.center_y_locs.pop(0)
6264
x_temp_avg = int(statistics.mean(self.center_x_locs))
6365
y_temp_avg = int(statistics.mean(self.center_y_locs))
64-
if math.sqrt((center_x - x_temp_avg) ** 2 + (center_y - y_temp_avg) ** 2) > 10:
66+
if math.sqrt((center_x - x_temp_avg)**2 + (center_y - y_temp_avg)**2) > 10:
6567
center_x, center_y = int(x_temp_avg), int(y_temp_avg)
66-
68+
6769
return (center_x, center_y)
68-
70+
6971
def dense_optical_flow(self, frame):
7072
next_frame = cv.cvtColor(frame, cv.COLOR_BGR2GRAY)
7173
flow = cv.calcOpticalFlowFarneback(self.prvs, next_frame, None, 0.5, 3, 15, 3, 5, 1.2, 0)
@@ -82,53 +84,16 @@ def get_center(self, rect1, rect2, frame):
8284
x2, y2, w2, h2 = rect2
8385
center_x, center_y = (x1 + x2) // 2, ((y1 + h1 // 2) + (y2 + h2 // 2)) // 2
8486
self.prvs, mag, ang = self.dense_optical_flow(frame)
85-
# print(np.mean(mag))
86-
if len(self.center_x_locs) < 25 or (np.mean(mag) < 40 and ((not self.use_optical_flow) or \
87-
(self.use_optical_flow and (
88-
center_x - self.gate_center[0]) ** 2 + (
89-
center_y - self.gate_center[
90-
1]) ** 2 < 50))):
87+
88+
if len(self.center_x_locs) < 25 or (np.mean(mag) < 40 and ((not self.use_optical_flow ) or \
89+
(self.use_optical_flow and (center_x - self.gate_center[0])**2 + (center_y - self.gate_center[1])**2 < 50))):
9190
self.use_optical_flow = False
9291
return self.center_without_optical_flow(center_x, center_y)
9392
self.use_optical_flow = True
9493
return (int(self.gate_center[0] + self.optical_flow_c * np.mean(mag * np.cos(ang))), \
9594
(int(self.gate_center[1] + self.optical_flow_c * np.mean(mag * np.sin(ang)))))
9695

9796

98-
# this part is temporary and will be covered by other files in the future
9997
if __name__ == '__main__':
100-
cap = cv.VideoCapture(sys.argv[1])
101-
ret_tries = 0
102-
start_time = time.time()
103-
frame_count = 0
104-
paused = False
105-
speed = 1
106-
ret, frame1 = cap.read()
107-
frame1 = cv.resize(frame1, None, fx=0.3, fy=0.3)
108-
prvs = cv.cvtColor(frame1, cv.COLOR_BGR2GRAY)
109-
hsv = np.zeros_like(frame1)
110-
hsv[..., 1] = 255
111-
gate_center = GateCenter()
112-
while ret_tries < 50:
113-
for _ in range(speed):
114-
ret, frame = cap.read()
115-
if frame_count == 1000:
116-
break
117-
if ret:
118-
frame = cv.resize(frame, None, fx=0.3, fy=0.3)
119-
center, filtered_frame = gate_center.analyze(frame, True)
120-
cv.imshow('original', frame)
121-
cv.imshow('filtered_frame', filtered_frame)
122-
ret_tries = 0
123-
key = cv.waitKey(30)
124-
if key == ord('q') or key == 27:
125-
break
126-
if key == ord('p'):
127-
paused = not paused
128-
if key == ord('i') and speed > 1:
129-
speed -= 1
130-
if key == ord('o'):
131-
speed += 1
132-
else:
133-
ret_tries += 1
134-
frame_count += 1
98+
from perception.vis.vis import run
99+
run(['..\..\..\data\GOPR1142.MP4'], GateCenterAlgo(), False)

0 commit comments

Comments
 (0)