Skip to content

Commit

Permalink
GitHub Actions: Format
Browse files Browse the repository at this point in the history
  • Loading branch information
Adaickalavan committed Apr 21, 2023
1 parent 72f9384 commit 6a5db2e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def params(self) -> Params:
gap_between_vehicles=GapBetweenVehicles(
active=False,
interest="Leader-007",
), # TODO: Activate after implementing gap_between_vehicles cost function.
), # TODO: Activate after implementing gap_between_vehicles cost function.
steps=Steps(
active=False,
),
Expand Down Expand Up @@ -143,9 +143,7 @@ def score(self, records_sum: Dict[str, Dict[str, Record]]) -> Score:


def _humanness(costs: Costs) -> float:
humanness = np.array(
[costs.comfort, costs.lane_center_offset]
)
humanness = np.array([costs.comfort, costs.lane_center_offset])
humanness = np.mean(humanness, dtype=float)
return 1 - humanness

Expand Down
11 changes: 5 additions & 6 deletions smarts/env/gymnasium/wrappers/metric/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@
from smarts.core.utils.math import running_mean
from smarts.core.vehicle_index import VehicleIndex
from smarts.env.gymnasium.wrappers.metric.params import Params
from smarts.env.gymnasium.wrappers.metric.utils import SlidingWindow
from smarts.env.gymnasium.wrappers.metric.types import Costs

from smarts.env.gymnasium.wrappers.metric.utils import SlidingWindow

Done = NewType("Done", bool)

Expand All @@ -52,8 +51,8 @@ def func(road_map: RoadMap, done: Done, obs: Observation) -> Costs:

def _comfort() -> Callable[[RoadMap, Done, Observation], Costs]:
jerk_linear_max = np.linalg.norm(np.array([0.9, 0.9, 0])) # Units: m/s^3
acc_linear_max = np.linalg.norm(np.array([2.0,1.47,0])) # Units: m/s^2
T_p = 30 # Penalty time steps = penalty time / delta time step = 3s / 0.1s = 30
acc_linear_max = np.linalg.norm(np.array([2.0, 1.47, 0])) # Units: m/s^2
T_p = 30 # Penalty time steps = penalty time / delta time step = 3s / 0.1s = 30
T_u = 0
step = 0
dyn_window = SlidingWindow(size=T_p)
Expand All @@ -65,12 +64,12 @@ def func(road_map: RoadMap, done: Done, obs: Observation) -> Costs:

jerk_linear = np.linalg.norm(obs.ego_vehicle_state.linear_jerk)
acc_linear = np.linalg.norm(obs.ego_vehicle_state.linear_acceleration)
dyn = max(jerk_linear/jerk_linear_max, acc_linear/acc_linear_max)
dyn = max(jerk_linear / jerk_linear_max, acc_linear / acc_linear_max)

dyn_window.move(dyn)
u_t = 1 if dyn_window.max() > 1 else 0
T_u += u_t

if not done:
return Costs(comfort=-1)
else:
Expand Down
27 changes: 13 additions & 14 deletions smarts/env/gymnasium/wrappers/metric/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ def multiply(value: Union[int, float], multiplier: Union[int, float]) -> float:
class SlidingWindow:
"""A sliding window which moves to the right by accepting new elements. The
maximum value within the sliding window can be queried at anytime by calling
the max() method.
the max() method.
"""
def __init__(self, size:int):

def __init__(self, size: int):
"""
Args:
size (int): Size of the sliding window.
Expand All @@ -112,7 +113,7 @@ def __init__(self, size:int):
self._size = size
self._time = -1

def move(self, x:Union[int,float]):
def move(self, x: Union[int, float]):
"""Moves the sliding window one step to the right by appending the new
element x and discarding the oldest element on the left.
Expand All @@ -121,29 +122,27 @@ def move(self, x:Union[int,float]):
"""
self._time += 1

# When values deque is full, remove head element of max_candidates deque
# When values deque is full, remove head element of max_candidates deque
# if it matches head element of values deque.
if len(self._values) == self._size:
if self._values[0][0] == self._max_candidates[0][0]:
self._max_candidates.popleft()
# Append x to values deque.
self._values.append((self._time,x))
self._values.append((self._time, x))

# Remove elements from max_candidates deque's tail which are less than x.
# Remove elements from max_candidates deque's tail which are less than x.
while self._max_candidates and self._max_candidates[-1][1] < x:
self._max_candidates.pop()
# Append x to max_candidates deque.
self._max_candidates.append((self._time,x))
self._max_candidates.append((self._time, x))

def max(self):
""" Returns the maximum element within the sliding window.
"""
"""Returns the maximum element within the sliding window."""
return self._max_candidates[0][1]

def display(self):
"""Print the contents of the sliding window.
"""
"""Print the contents of the sliding window."""
print("[", end="")
for i in self._values:
print(i, end=' ')
print(i, end=" ")
print("]")

0 comments on commit 6a5db2e

Please sign in to comment.