1
1
import math
2
- from typing import Any , Dict , Sequence , Tuple
2
+ from typing import Any , Dict , Tuple
3
3
4
4
import gymnasium as gym
5
5
import numpy as np
6
6
7
7
from smarts .core .agent_interface import RGB
8
8
from smarts .core .colors import Colors , SceneColors
9
+ from smarts .core .utils .observations import points_to_pixels , replace_rgb_image_color
9
10
10
11
11
12
class FilterObs :
@@ -72,19 +73,19 @@ def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
72
73
# Get rgb image, remove road, and replace other egos (if any) as background vehicles
73
74
rgb = obs ["top_down_rgb" ]
74
75
h , w , _ = rgb .shape
75
- rgb_noroad = replace_color (rgb = rgb , old_color = [self ._road_color , self ._lane_divider_color , self ._edge_divider_color ], new_color = self ._no_color )
76
- rgb_ego = replace_color (rgb = rgb_noroad , old_color = [self ._ego_color ], new_color = self ._traffic_color , mask = self ._rgb_mask )
76
+ rgb_noroad = replace_rgb_image_color (rgb = rgb , old_color = [self ._road_color , self ._lane_divider_color , self ._edge_divider_color ], new_color = self ._no_color )
77
+ rgb_ego = replace_rgb_image_color (rgb = rgb_noroad , old_color = [self ._ego_color ], new_color = self ._traffic_color , mask = self ._rgb_mask )
77
78
78
79
# Superimpose waypoints onto rgb image
79
80
wps = obs ["waypoint_paths" ]["position" ][0 :11 , 3 :, 0 :3 ]
80
81
for path in wps [:]:
81
82
wps_valid = points_to_pixels (
82
83
points = path ,
83
- ego_pos = ego_pos ,
84
- ego_heading = ego_heading ,
85
- w = w ,
86
- h = h ,
87
- res = self ._res ,
84
+ center_position = ego_pos ,
85
+ heading = ego_heading ,
86
+ width = w ,
87
+ height = h ,
88
+ resolution = self ._res ,
88
89
)
89
90
for point in wps_valid :
90
91
img_x , img_y = point [0 ], point [1 ]
@@ -95,11 +96,11 @@ def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
95
96
if not all ((goal := obs ["ego_vehicle_state" ]["mission" ]["goal_position" ]) == np .zeros ((3 ,))):
96
97
goal_pixel = points_to_pixels (
97
98
points = np .expand_dims (goal ,axis = 0 ),
98
- ego_pos = ego_pos ,
99
- ego_heading = ego_heading ,
100
- w = w ,
101
- h = h ,
102
- res = self ._res ,
99
+ center_position = ego_pos ,
100
+ heading = ego_heading ,
101
+ width = w ,
102
+ height = h ,
103
+ resolution = self ._res ,
103
104
)
104
105
if len (goal_pixel ) != 0 :
105
106
img_x , img_y = goal_pixel [0 ][0 ], goal_pixel [0 ][1 ]
@@ -121,100 +122,3 @@ def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
121
122
return filtered_obs
122
123
# fmt: on
123
124
124
-
125
- def replace_color (
126
- rgb : np .ndarray ,
127
- old_color : Sequence [np .ndarray ],
128
- new_color : np .ndarray ,
129
- mask : np .ndarray = np .ma .nomask ,
130
- ) -> np .ndarray :
131
- """Convert pixels of value `old_color` to `new_color` within the masked
132
- region in the received RGB image.
133
-
134
- Args:
135
- rgb (np.ndarray): RGB image. Shape = (m,n,3).
136
- old_color (Sequence[np.ndarray]): List of old colors to be removed from the RGB image. Shape = (3,).
137
- new_color (np.ndarray): New color to be added to the RGB image. Shape = (3,).
138
- mask (np.ndarray, optional): Valid regions for color replacement. Shape = (m,n,3).
139
- Defaults to np.ma.nomask .
140
-
141
- Returns:
142
- np.ndarray: RGB image with `old_color` pixels changed to `new_color`
143
- within the masked region. Shape = (m,n,3).
144
- """
145
- # fmt: off
146
- assert all (color .shape == (3 ,) for color in old_color ), (
147
- f"Expected old_color to be of shape (3,), but got { [color .shape for color in old_color ]} ." )
148
- assert new_color .shape == (3 ,), (
149
- f"Expected new_color to be of shape (3,), but got { new_color .shape } ." )
150
-
151
- nc = new_color .reshape ((1 , 1 , 3 ))
152
- nc_array = np .full_like (rgb , nc )
153
- rgb_masked = np .ma .MaskedArray (data = rgb , mask = mask )
154
-
155
- rgb_condition = rgb_masked
156
- result = rgb
157
- for color in old_color :
158
- result = np .ma .where ((rgb_condition == color .reshape ((1 , 1 , 3 ))).all (axis = - 1 )[..., None ], nc_array , result )
159
-
160
- return result
161
- # fmt: on
162
-
163
-
164
- def points_to_pixels (
165
- points : np .ndarray ,
166
- ego_pos : np .ndarray ,
167
- ego_heading : float ,
168
- w : int ,
169
- h : int ,
170
- res : float ,
171
- ) -> np .ndarray :
172
- """Converts points into pixel coordinates in order to superimpose the
173
- points onto the RGB image.
174
-
175
- Args:
176
- points (np.ndarray): Array of points. Shape (n,3).
177
- ego_pos (np.ndarray): Ego position. Shape = (3,).
178
- ego_heading (float): Ego heading in radians.
179
- w (int): Width of RGB image
180
- h (int): Height of RGB image.
181
- res (float): Resolution of RGB image in meters/pixels. Computed as
182
- ground_size/image_size.
183
-
184
- Returns:
185
- np.ndarray: Array of point coordinates on the RGB image. Shape = (m,3).
186
- """
187
- # fmt: off
188
- mask = [False if all (point == np .zeros (3 ,)) else True for point in points ]
189
- points_nonzero = points [mask ]
190
- points_delta = points_nonzero - ego_pos
191
- points_rotated = rotate_axes (points_delta , theta = ego_heading )
192
- points_pixels = points_rotated / np .array ([res , res , res ])
193
- points_overlay = np .array ([w / 2 , h / 2 , 0 ]) + points_pixels * np .array ([1 , - 1 , 1 ])
194
- points_rfloat = np .rint (points_overlay )
195
- points_valid = points_rfloat [(points_rfloat [:,0 ] >= 0 ) & (points_rfloat [:,0 ] < w ) & (points_rfloat [:,1 ] >= 0 ) & (points_rfloat [:,1 ] < h )]
196
- points_rint = points_valid .astype (int )
197
- return points_rint
198
- # fmt: on
199
-
200
-
201
- def rotate_axes (points : np .ndarray , theta : float ) -> np .ndarray :
202
- """A counterclockwise rotation of the x-y axes by an angle theta θ about
203
- the z-axis.
204
-
205
- Args:
206
- points (np.ndarray): x,y,z coordinates in original axes. Shape = (n,3).
207
- theta (np.float): Axes rotation angle in radians.
208
-
209
- Returns:
210
- np.ndarray: x,y,z coordinates in rotated axes. Shape = (n,3).
211
- """
212
- # fmt: off
213
- theta = (theta + np .pi ) % (2 * np .pi ) - np .pi
214
- ct , st = np .cos (theta ), np .sin (theta )
215
- R = np .array ([[ ct , st , 0 ],
216
- [- st , ct , 0 ],
217
- [ 0 , 0 , 1 ]])
218
- rotated_points = (R .dot (points .T )).T
219
- return rotated_points
220
- # fmt: on
0 commit comments