Skip to content

Commit 9fa8e28

Browse files
authored
Use common observation utility functions. (#2141)
1 parent e38fbca commit 9fa8e28

File tree

3 files changed

+23
-117
lines changed

3 files changed

+23
-117
lines changed

examples/e10_drive/inference/contrib_policy/filter_obs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Any, Dict, Sequence, Tuple
2+
from typing import Any, Dict, Tuple
33

44
import gymnasium as gym
55
import numpy as np
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import math
2-
from typing import Any, Dict, Sequence, Tuple
2+
from typing import Any, Dict, Tuple
33

44
import gymnasium as gym
55
import numpy as np
66

77
from smarts.core.agent_interface import RGB
88
from smarts.core.colors import Colors, SceneColors
9+
from smarts.core.utils.observations import points_to_pixels, replace_rgb_image_color
910

1011

1112
class FilterObs:
@@ -72,19 +73,19 @@ def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
7273
# Get rgb image, remove road, and replace other egos (if any) as background vehicles
7374
rgb = obs["top_down_rgb"]
7475
h, w, _ = rgb.shape
75-
rgb_noroad = replace_color(rgb=rgb, old_color=[self._road_color, self._lane_divider_color, self._edge_divider_color], new_color=self._no_color)
76-
rgb_ego = replace_color(rgb=rgb_noroad, old_color=[self._ego_color], new_color=self._traffic_color, mask=self._rgb_mask)
76+
rgb_noroad = replace_rgb_image_color(rgb=rgb, old_color=[self._road_color, self._lane_divider_color, self._edge_divider_color], new_color=self._no_color)
77+
rgb_ego = replace_rgb_image_color(rgb=rgb_noroad, old_color=[self._ego_color], new_color=self._traffic_color, mask=self._rgb_mask)
7778

7879
# Superimpose waypoints onto rgb image
7980
wps = obs["waypoint_paths"]["position"][0:11, 3:, 0:3]
8081
for path in wps[:]:
8182
wps_valid = points_to_pixels(
8283
points=path,
83-
ego_pos=ego_pos,
84-
ego_heading=ego_heading,
85-
w=w,
86-
h=h,
87-
res=self._res,
84+
center_position=ego_pos,
85+
heading=ego_heading,
86+
width=w,
87+
height=h,
88+
resolution=self._res,
8889
)
8990
for point in wps_valid:
9091
img_x, img_y = point[0], point[1]
@@ -95,11 +96,11 @@ def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
9596
if not all((goal:=obs["ego_vehicle_state"]["mission"]["goal_position"]) == np.zeros((3,))):
9697
goal_pixel = points_to_pixels(
9798
points=np.expand_dims(goal,axis=0),
98-
ego_pos=ego_pos,
99-
ego_heading=ego_heading,
100-
w=w,
101-
h=h,
102-
res=self._res,
99+
center_position=ego_pos,
100+
heading=ego_heading,
101+
width=w,
102+
height=h,
103+
resolution=self._res,
103104
)
104105
if len(goal_pixel) != 0:
105106
img_x, img_y = goal_pixel[0][0], goal_pixel[0][1]
@@ -121,100 +122,3 @@ def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
121122
return filtered_obs
122123
# fmt: on
123124

124-
125-
def replace_color(
126-
rgb: np.ndarray,
127-
old_color: Sequence[np.ndarray],
128-
new_color: np.ndarray,
129-
mask: np.ndarray = np.ma.nomask,
130-
) -> np.ndarray:
131-
"""Convert pixels of value `old_color` to `new_color` within the masked
132-
region in the received RGB image.
133-
134-
Args:
135-
rgb (np.ndarray): RGB image. Shape = (m,n,3).
136-
old_color (Sequence[np.ndarray]): List of old colors to be removed from the RGB image. Shape = (3,).
137-
new_color (np.ndarray): New color to be added to the RGB image. Shape = (3,).
138-
mask (np.ndarray, optional): Valid regions for color replacement. Shape = (m,n,3).
139-
Defaults to np.ma.nomask .
140-
141-
Returns:
142-
np.ndarray: RGB image with `old_color` pixels changed to `new_color`
143-
within the masked region. Shape = (m,n,3).
144-
"""
145-
# fmt: off
146-
assert all(color.shape == (3,) for color in old_color), (
147-
f"Expected old_color to be of shape (3,), but got {[color.shape for color in old_color]}.")
148-
assert new_color.shape == (3,), (
149-
f"Expected new_color to be of shape (3,), but got {new_color.shape}.")
150-
151-
nc = new_color.reshape((1, 1, 3))
152-
nc_array = np.full_like(rgb, nc)
153-
rgb_masked = np.ma.MaskedArray(data=rgb, mask=mask)
154-
155-
rgb_condition = rgb_masked
156-
result = rgb
157-
for color in old_color:
158-
result = np.ma.where((rgb_condition == color.reshape((1, 1, 3))).all(axis=-1)[..., None], nc_array, result)
159-
160-
return result
161-
# fmt: on
162-
163-
164-
def points_to_pixels(
165-
points: np.ndarray,
166-
ego_pos: np.ndarray,
167-
ego_heading: float,
168-
w: int,
169-
h: int,
170-
res: float,
171-
) -> np.ndarray:
172-
"""Converts points into pixel coordinates in order to superimpose the
173-
points onto the RGB image.
174-
175-
Args:
176-
points (np.ndarray): Array of points. Shape (n,3).
177-
ego_pos (np.ndarray): Ego position. Shape = (3,).
178-
ego_heading (float): Ego heading in radians.
179-
w (int): Width of RGB image
180-
h (int): Height of RGB image.
181-
res (float): Resolution of RGB image in meters/pixels. Computed as
182-
ground_size/image_size.
183-
184-
Returns:
185-
np.ndarray: Array of point coordinates on the RGB image. Shape = (m,3).
186-
"""
187-
# fmt: off
188-
mask = [False if all(point == np.zeros(3,)) else True for point in points]
189-
points_nonzero = points[mask]
190-
points_delta = points_nonzero - ego_pos
191-
points_rotated = rotate_axes(points_delta, theta=ego_heading)
192-
points_pixels = points_rotated / np.array([res, res, res])
193-
points_overlay = np.array([w / 2, h / 2, 0]) + points_pixels * np.array([1, -1, 1])
194-
points_rfloat = np.rint(points_overlay)
195-
points_valid = points_rfloat[(points_rfloat[:,0] >= 0) & (points_rfloat[:,0] < w) & (points_rfloat[:,1] >= 0) & (points_rfloat[:,1] < h)]
196-
points_rint = points_valid.astype(int)
197-
return points_rint
198-
# fmt: on
199-
200-
201-
def rotate_axes(points: np.ndarray, theta: float) -> np.ndarray:
202-
"""A counterclockwise rotation of the x-y axes by an angle theta θ about
203-
the z-axis.
204-
205-
Args:
206-
points (np.ndarray): x,y,z coordinates in original axes. Shape = (n,3).
207-
theta (np.float): Axes rotation angle in radians.
208-
209-
Returns:
210-
np.ndarray: x,y,z coordinates in rotated axes. Shape = (n,3).
211-
"""
212-
# fmt: off
213-
theta = (theta + np.pi) % (2 * np.pi) - np.pi
214-
ct, st = np.cos(theta), np.sin(theta)
215-
R = np.array([[ ct, st, 0],
216-
[-st, ct, 0],
217-
[ 0, 0, 1]])
218-
rotated_points = (R.dot(points.T)).T
219-
return rotated_points
220-
# fmt: on

smarts/core/utils/observations.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,14 @@ def points_to_pixels(
7676
7777
Args:
7878
points (np.ndarray): Array of points. Shape (n,3).
79-
ego_pos (np.ndarray): Ego position. Shape = (3,).
80-
ego_heading (float): Ego heading in radians.
81-
w (int): Width of RGB image
82-
h (int): Height of RGB image.
83-
res (float): Resolution of RGB image in meters/pixels. Computed as
84-
ground_size/image_size.
79+
center_position (np.ndarray): Center position of image. Generally, this
80+
is equivalent to ego position. Shape = (3,).
81+
heading (float): Heading of image in radians. Generally, this is
82+
equivalent to ego heading.
83+
width (int): Width of RGB image
84+
height (int): Height of RGB image.
85+
resolution (float): Resolution of RGB image in meters/pixels. Computed
86+
as ground_size/image_size.
8587
8688
Returns:
8789
np.ndarray: Array of point coordinates on the RGB image. Shape = (m,3).

0 commit comments

Comments
 (0)