diff --git a/examples/rl/drive/inference/contrib_policy/policy.py b/examples/rl/drive/inference/contrib_policy/policy.py index 8287c40d3a..f63a72e70c 100644 --- a/examples/rl/drive/inference/contrib_policy/policy.py +++ b/examples/rl/drive/inference/contrib_policy/policy.py @@ -45,7 +45,7 @@ def act(self, obs): processed_obs = self._process(obs) action, _ = self.model.predict(observation=processed_obs, deterministic=True) formatted_action = self._format_action.format( - model_action=action, prev_heading=obs["ego_vehicle_state"]["heading"] + action=int(action), prev_heading=obs["ego_vehicle_state"]["heading"] ) return formatted_action diff --git a/examples/rl/platoon/inference/contrib_policy/filter_obs.py b/examples/rl/platoon/inference/contrib_policy/filter_obs.py index 4e521932d5..6bd1f3d259 100644 --- a/examples/rl/platoon/inference/contrib_policy/filter_obs.py +++ b/examples/rl/platoon/inference/contrib_policy/filter_obs.py @@ -21,7 +21,6 @@ def __init__(self, top_down_rgb: RGB): self._no_color = np.zeros((3,)) self._wps_color = np.array(Colors.GreenTransparent.value[0:3]) * 255 - self._leader_color = np.array(SceneColors.SocialAgent.value[0:3]) * 255 self._traffic_color = np.array(SceneColors.SocialVehicle.value[0:3]) * 255 self._road_color = np.array(SceneColors.Road.value[0:3]) * 255 self._lane_divider_color = np.array(SceneColors.LaneDivider.value[0:3]) * 255 diff --git a/examples/rl/platoon/inference/contrib_policy/policy.py b/examples/rl/platoon/inference/contrib_policy/policy.py index 215e23b905..10d81044d4 100644 --- a/examples/rl/platoon/inference/contrib_policy/policy.py +++ b/examples/rl/platoon/inference/contrib_policy/policy.py @@ -44,7 +44,7 @@ def act(self, obs): """Mandatory act function to be implemented by user.""" processed_obs = self._process(obs) action, _ = self.model.predict(observation=processed_obs, deterministic=True) - formatted_action = self._format_action.format(action) + formatted_action = self._format_action.format(action=int(action)) return formatted_action def _process(self, obs): diff --git a/examples/rl/platoon/train/preprocess.py b/examples/rl/platoon/train/preprocess.py index b45ac8d1a2..5d9079a9d9 100644 --- a/examples/rl/platoon/train/preprocess.py +++ b/examples/rl/platoon/train/preprocess.py @@ -24,7 +24,6 @@ def __init__(self, env: gym.Env, agent_interface: AgentInterface): self._format_action = FormatAction(agent_interface.action) self.action_space = self._format_action.action_space - print("Policy initialised.") def _process(self, obs): obs = self._filter_obs.filter(obs)