From 6a5db2e8ca0eb2fc2f792863a8c467c625bebb2a Mon Sep 17 00:00:00 2001 From: adai Date: Fri, 21 Apr 2023 23:12:46 +0000 Subject: [PATCH] GitHub Actions: Format --- .../v2023/metric_formula_platoon.py | 6 ++--- smarts/env/gymnasium/wrappers/metric/costs.py | 11 ++++---- smarts/env/gymnasium/wrappers/metric/utils.py | 27 +++++++++---------- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/smarts/benchmark/driving_smarts/v2023/metric_formula_platoon.py b/smarts/benchmark/driving_smarts/v2023/metric_formula_platoon.py index 74fe098854..e3e496eb33 100644 --- a/smarts/benchmark/driving_smarts/v2023/metric_formula_platoon.py +++ b/smarts/benchmark/driving_smarts/v2023/metric_formula_platoon.py @@ -66,7 +66,7 @@ def params(self) -> Params: gap_between_vehicles=GapBetweenVehicles( active=False, interest="Leader-007", - ), # TODO: Activate after implementing gap_between_vehicles cost function. + ), # TODO: Activate after implementing gap_between_vehicles cost function. steps=Steps( active=False, ), @@ -143,9 +143,7 @@ def score(self, records_sum: Dict[str, Dict[str, Record]]) -> Score: def _humanness(costs: Costs) -> float: - humanness = np.array( - [costs.comfort, costs.lane_center_offset] - ) + humanness = np.array([costs.comfort, costs.lane_center_offset]) humanness = np.mean(humanness, dtype=float) return 1 - humanness diff --git a/smarts/env/gymnasium/wrappers/metric/costs.py b/smarts/env/gymnasium/wrappers/metric/costs.py index 8997dbea8b..75a94bcea9 100644 --- a/smarts/env/gymnasium/wrappers/metric/costs.py +++ b/smarts/env/gymnasium/wrappers/metric/costs.py @@ -30,9 +30,8 @@ from smarts.core.utils.math import running_mean from smarts.core.vehicle_index import VehicleIndex from smarts.env.gymnasium.wrappers.metric.params import Params -from smarts.env.gymnasium.wrappers.metric.utils import SlidingWindow from smarts.env.gymnasium.wrappers.metric.types import Costs - +from smarts.env.gymnasium.wrappers.metric.utils import SlidingWindow Done = NewType("Done", bool) @@ -52,8 +51,8 @@ def func(road_map: RoadMap, done: Done, obs: Observation) -> Costs: def _comfort() -> Callable[[RoadMap, Done, Observation], Costs]: jerk_linear_max = np.linalg.norm(np.array([0.9, 0.9, 0])) # Units: m/s^3 - acc_linear_max = np.linalg.norm(np.array([2.0,1.47,0])) # Units: m/s^2 - T_p = 30 # Penalty time steps = penalty time / delta time step = 3s / 0.1s = 30 + acc_linear_max = np.linalg.norm(np.array([2.0, 1.47, 0])) # Units: m/s^2 + T_p = 30 # Penalty time steps = penalty time / delta time step = 3s / 0.1s = 30 T_u = 0 step = 0 dyn_window = SlidingWindow(size=T_p) @@ -65,12 +64,12 @@ def func(road_map: RoadMap, done: Done, obs: Observation) -> Costs: jerk_linear = np.linalg.norm(obs.ego_vehicle_state.linear_jerk) acc_linear = np.linalg.norm(obs.ego_vehicle_state.linear_acceleration) - dyn = max(jerk_linear/jerk_linear_max, acc_linear/acc_linear_max) + dyn = max(jerk_linear / jerk_linear_max, acc_linear / acc_linear_max) dyn_window.move(dyn) u_t = 1 if dyn_window.max() > 1 else 0 T_u += u_t - + if not done: return Costs(comfort=-1) else: diff --git a/smarts/env/gymnasium/wrappers/metric/utils.py b/smarts/env/gymnasium/wrappers/metric/utils.py index 2aecca2d1d..93cdc1affe 100644 --- a/smarts/env/gymnasium/wrappers/metric/utils.py +++ b/smarts/env/gymnasium/wrappers/metric/utils.py @@ -100,9 +100,10 @@ def multiply(value: Union[int, float], multiplier: Union[int, float]) -> float: class SlidingWindow: """A sliding window which moves to the right by accepting new elements. The maximum value within the sliding window can be queried at anytime by calling - the max() method. + the max() method. """ - def __init__(self, size:int): + + def __init__(self, size: int): """ Args: size (int): Size of the sliding window. @@ -112,7 +113,7 @@ def __init__(self, size:int): self._size = size self._time = -1 - def move(self, x:Union[int,float]): + def move(self, x: Union[int, float]): """Moves the sliding window one step to the right by appending the new element x and discarding the oldest element on the left. @@ -121,29 +122,27 @@ def move(self, x:Union[int,float]): """ self._time += 1 - # When values deque is full, remove head element of max_candidates deque + # When values deque is full, remove head element of max_candidates deque # if it matches head element of values deque. if len(self._values) == self._size: if self._values[0][0] == self._max_candidates[0][0]: self._max_candidates.popleft() # Append x to values deque. - self._values.append((self._time,x)) + self._values.append((self._time, x)) - # Remove elements from max_candidates deque's tail which are less than x. + # Remove elements from max_candidates deque's tail which are less than x. while self._max_candidates and self._max_candidates[-1][1] < x: self._max_candidates.pop() # Append x to max_candidates deque. - self._max_candidates.append((self._time,x)) - + self._max_candidates.append((self._time, x)) + def max(self): - """ Returns the maximum element within the sliding window. - """ + """Returns the maximum element within the sliding window.""" return self._max_candidates[0][1] - + def display(self): - """Print the contents of the sliding window. - """ + """Print the contents of the sliding window.""" print("[", end="") for i in self._values: - print(i, end=' ') + print(i, end=" ") print("]")