Skip to content

Commit

Permalink
Update RL examples: Drive and VehicleFollowing (#2061)
Browse files Browse the repository at this point in the history
  • Loading branch information
Adaickalavan authored Jun 12, 2023
1 parent 0dea7af commit 1ad324b
Show file tree
Hide file tree
Showing 28 changed files with 593 additions and 768 deletions.
22 changes: 19 additions & 3 deletions docs/benchmarks/driving_smarts_2023_1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,24 @@ predictive models, etc, may be used to develop the policy.
Several scenarios are provided for training. Their names and tasks are as follows.
The desired task execution is illustrated in a gif by a trained baseline agent.

.. todo::

Provide sample training scenarios and corresponding ``.gif`` images showing a baseline model traversing the map.
**Driving SMARTS 2023.1 scenarios**

+ :scenarios:`cruise_2lane_agents_1 <sumo/straight/cruise_2lane_agents_1>`
+ :scenarios:`cutin_2lane_agents_1 <sumo/straight/cutin_2lane_agents_1>`
+ :scenarios:`merge_exit_sumo_t_agents_1 <sumo/straight/merge_exit_sumo_t_agents_1>`
+ :scenarios:`overtake_2lane_agents_1 <sumo/straight/overtake_2lane_agents_1>`
+ :scenarios:`00a445fb-7293-4be6-adbc-e30c949b6cf7_agents_1 <argoverse/straight/00a445fb-7293-4be6-adbc-e30c949b6cf7_agents_1>`
+ :scenarios:`0a53dd99-2946-4b4d-ab66-c4d6fef97be2_agents_1 <argoverse/straight/0a53dd99-2946-4b4d-ab66-c4d6fef97be2_agents_1>`
+ :scenarios:`0a576bf1-66ae-495a-9c87-236f3fc2aa01_agents_1 <argoverse/straight/0a576bf1-66ae-495a-9c87-236f3fc2aa01_agents_1>`

**Driving SMARTS 2023.2 scenarios**

+ :scenarios:`1_to_3lane_left_turn_sumo_c_agents_1 <sumo/intersections/1_to_3lane_left_turn_sumo_c_agents_1>`
+ :scenarios:`1_to_3lane_left_turn_middle_lane_c_agents_1 <sumo/intersections/1_to_3lane_left_turn_middle_lane_c_agents_1>`
+ :scenarios:`00b15e74-04a8-4bd4-9a78-eb24f0c0a980_agents_1 <argoverse/turn/00b15e74-04a8-4bd4-9a78-eb24f0c0a980_agents_1>`
+ :scenarios:`0a60b442-56b0-46c3-be45-cf166a182b67_agents_1 <argoverse/turn/0a60b442-56b0-46c3-be45-cf166a182b67_agents_1>`
+ :scenarios:`0a764a82-b44e-481e-97e7-05e1f1f925f6_agents_1 <argoverse/turn/0a764a82-b44e-481e-97e7-05e1f1f925f6_agents_1>`
+ :scenarios:`0bf054e3-7698-4b86-9c98-626df2dee9f4_agents_1 <argoverse/turn/0bf054e3-7698-4b86-9c98-626df2dee9f4_agents_1>`

Observation space
-----------------
Expand Down Expand Up @@ -264,6 +279,7 @@ Evaluate
$ python3.8 -m venv ./.venv
$ source ./.venv/bin/activate
$ pip install --upgrade pip
$ pip install wheel==0.38.4
$ pip install -e .[camera_obs,argoverse,envision,sumo]
$ scl zoo install examples/rl/drive/inference
# For Driving SMARTS 2023.1
Expand Down
15 changes: 15 additions & 0 deletions docs/benchmarks/driving_smarts_2023_3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ The desired task execution is illustrated in a gif by a trained baseline agent.

.. image:: /_static/driving_smarts_2023/platoon_straight_2lane_agents_1.gif

+ :scenarios:`straight_2lane_sumo_agents_1 <sumo/vehicle_following/straight_2lane_sumo_agents_1>`
+ :scenarios:`straight_2lane_sumo_t_agents_1 <sumo/vehicle_following/straight_2lane_sumo_t_agents_1>`
+ :scenarios:`straight_3lanes_sumo_agents_1 <sumo/vehicle_following/straight_3lanes_sumo_agents_1>`
+ :scenarios:`straight_3lanes_sumo_t_agents_1 <sumo/vehicle_following/straight_3lanes_sumo_t_agents_1>`
+ :scenarios:`straight_3lanes_sumo_t_agents_2 <sumo/vehicle_following/straight_3lanes_sumo_t_agents_2>`
+ :scenarios:`merge_exit_sumo_agents_1 <sumo/vehicle_following/merge_exit_sumo_agents_1>`
+ :scenarios:`merge_exit_sumo_t_agents_1 <sumo/vehicle_following/merge_exit_sumo_t_agents_1>`
+ :scenarios:`merge_exit_sumo_t_agents_2 <sumo/vehicle_following/merge_exit_sumo_t_agents_2>`
+ :scenarios:`ff239c9d-e4ff-4acc-bad5-bd55648c212e_0_agents_1 <argoverse/vehicle_following/ff239c9d-e4ff-4acc-bad5-bd55648c212e_0_agents_1>`
+ :scenarios:`ff239c9d-e4ff-4acc-bad5-bd55648c212e_agents_1 <argoverse/vehicle_following/ff239c9d-e4ff-4acc-bad5-bd55648c212e_agents_1>`
+ :scenarios:`ff6dc43b-dd27-4fe4-94b6-5c1b3940daed_agents_1 <argoverse/vehicle_following/ff6dc43b-dd27-4fe4-94b6-5c1b3940daed_agents_1>`
+ :scenarios:`ff9619b5-b0c0-4942-b5d8-df6a5814f8a2_agents_1 <argoverse/vehicle_following/ff9619b5-b0c0-4942-b5d8-df6a5814f8a2_agents_1>`
+ :scenarios:`ffd10ec2-715b-48af-a89d-b11f79927f63_agents_1 <argoverse/vehicle_following/ffd10ec2-715b-48af-a89d-b11f79927f63_agents_1>`

Observation space
-----------------

Expand Down Expand Up @@ -254,6 +268,7 @@ Evaluate
$ python3.8 -m venv ./.venv
$ source ./.venv/bin/activate
$ pip install --upgrade pip
$ pip install wheel==0.38.4
$ pip install -e .[camera_obs,argoverse,envision,sumo]
$ scl zoo install examples/rl/platoon/inference
$ scl benchmark run driving_smarts_2023_3 examples.rl.platoon.inference:contrib-agent-v0 --auto-install
Expand Down
7 changes: 4 additions & 3 deletions examples/rl/drive/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@ def entry_point(**kwargs):
road_waypoints=False,
signals=False,
top_down_rgb=RGB(
width=128,
height=128,
resolution=80 / 128, # m/pixels
width=200,
height=200,
resolution=80 / 200, # m/pixels
),
)

agent_params = {
"top_down_rgb": interface.top_down_rgb,
"action_space_type": interface.action,
"num_stack": 3, # Number of frames to stack as input to policy network.
"crop": (50, 50, 0, 70), # Crop image from left, right, top, and bottom. Units: pixels.
}

return AgentSpec(
Expand Down
85 changes: 64 additions & 21 deletions examples/rl/drive/inference/contrib_policy/filter_obs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Any, Dict, Sequence
from typing import Any, Dict, Sequence, Tuple

import gym
import numpy as np
Expand All @@ -11,11 +11,17 @@
class FilterObs:
"""Filter only the selected observation parameters."""

def __init__(self, top_down_rgb: RGB):
def __init__(
self, top_down_rgb: RGB, crop: Tuple[int, int, int, int] = (0, 0, 0, 0)
):
self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(3, top_down_rgb.height, top_down_rgb.width),
shape=(
3,
top_down_rgb.height - crop[2] - crop[3],
top_down_rgb.width - crop[0] - crop[1],
),
dtype=np.uint8,
)

Expand All @@ -26,7 +32,9 @@ def __init__(self, top_down_rgb: RGB):
self._lane_divider_color = np.array(SceneColors.LaneDivider.value[0:3]) * 255
self._edge_divider_color = np.array(SceneColors.EdgeDivider.value[0:3]) * 255
self._ego_color = np.array(SceneColors.Agent.value[0:3]) * 255
self._goal_color = np.array(Colors.Purple.value[0:3]) * 255

self._blur_radius = 2
self._res = top_down_rgb.resolution
h = top_down_rgb.height
w = top_down_rgb.width
Expand All @@ -43,6 +51,14 @@ def __init__(self, top_down_rgb: RGB):
self._rgb_mask = np.zeros(shape=(h, w, 3), dtype=np.uint8)
self._rgb_mask[shape[0][0] : shape[0][1], shape[1][0] : shape[1][1], :] = 1

self._crop = crop
assert (
self._crop[0] < np.floor(w / 2)
and self._crop[1] < np.floor(w / 2)
and self._crop[2] < np.floor(h / 2)
and self._crop[3] < np.floor(h / 2)
), f"Expected crop to be less than half the image size, got crop={self._crop} for an image of size[h,w]=[{h},{w}]."

def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
"""Adapts the environment's observation."""
# fmt: off
Expand All @@ -62,8 +78,8 @@ def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
# Superimpose waypoints onto rgb image
wps = obs["waypoint_paths"]["position"][0:11, 3:, 0:3]
for path in wps[:]:
wps_valid = wps_to_pixels(
wps=path,
wps_valid = points_to_pixels(
points=path,
ego_pos=ego_pos,
ego_heading=ego_heading,
w=w,
Expand All @@ -75,6 +91,28 @@ def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
if all(rgb_ego[img_y, img_x, :] == self._no_color):
rgb_ego[img_y, img_x, :] = self._wps_color

# Superimpose goal position onto rgb image
if not all((goal:=obs["ego_vehicle_state"]["mission"]["goal_position"]) == np.zeros((3,))):
goal_pixel = points_to_pixels(
points=np.expand_dims(goal,axis=0),
ego_pos=ego_pos,
ego_heading=ego_heading,
w=w,
h=h,
res=self._res,
)
if len(goal_pixel) != 0:
img_x, img_y = goal_pixel[0][0], goal_pixel[0][1]
if all(rgb_ego[img_y, img_x, :] == self._no_color) or all(rgb_ego[img_y, img_x, :] == self._wps_color):
rgb_ego[
max(img_y-self._blur_radius,0):min(img_y+self._blur_radius,h),
max(img_x-self._blur_radius,0):min(img_x+self._blur_radius,w),
:,
] = self._goal_color

# Crop image
rgb_ego = rgb_ego[self._crop[2]:h-self._crop[3],self._crop[0]:w-self._crop[1],:]

# Channel first rgb
rgb_ego = rgb_ego.transpose(2, 0, 1)

Expand Down Expand Up @@ -123,14 +161,19 @@ def replace_color(
# fmt: on


def wps_to_pixels(
wps: np.ndarray, ego_pos: np.ndarray, ego_heading: float, w: int, h: int, res: float
def points_to_pixels(
points: np.ndarray,
ego_pos: np.ndarray,
ego_heading: float,
w: int,
h: int,
res: float,
) -> np.ndarray:
"""Converts waypoints into pixel coordinates in order to superimpose the
waypoints onto the RGB image.
"""Converts points into pixel coordinates in order to superimpose the
points onto the RGB image.
Args:
wps (np.ndarray): Waypoints for a single route. Shape (n,3).
points (np.ndarray): Array of points. Shape (n,3).
ego_pos (np.ndarray): Ego position. Shape = (3,).
ego_heading (float): Ego heading in radians.
w (int): Width of RGB image
Expand All @@ -139,19 +182,19 @@ def wps_to_pixels(
ground_size/image_size.
Returns:
np.ndarray: Array of waypoint coordinates on the RGB image. Shape = (m,3).
np.ndarray: Array of point coordinates on the RGB image. Shape = (m,3).
"""
# fmt: off
mask = [False if all(point == np.zeros(3,)) else True for point in wps]
wps_nonzero = wps[mask]
wps_delta = wps_nonzero - ego_pos
wps_rotated = rotate_axes(wps_delta, theta=ego_heading)
wps_pixels = wps_rotated / np.array([res, res, res])
wps_overlay = np.array([w / 2, h / 2, 0]) + wps_pixels * np.array([1, -1, 1])
wps_rfloat = np.rint(wps_overlay)
wps_valid = wps_rfloat[(wps_rfloat[:,0] >= 0) & (wps_rfloat[:,0] < w) & (wps_rfloat[:,1] >= 0) & (wps_rfloat[:,1] < h)]
wps_rint = wps_valid.astype(int)
return wps_rint
mask = [False if all(point == np.zeros(3,)) else True for point in points]
points_nonzero = points[mask]
points_delta = points_nonzero - ego_pos
points_rotated = rotate_axes(points_delta, theta=ego_heading)
points_pixels = points_rotated / np.array([res, res, res])
points_overlay = np.array([w / 2, h / 2, 0]) + points_pixels * np.array([1, -1, 1])
points_rfloat = np.rint(points_overlay)
points_valid = points_rfloat[(points_rfloat[:,0] >= 0) & (points_rfloat[:,0] < w) & (points_rfloat[:,1] >= 0) & (points_rfloat[:,1] < h)]
points_rint = points_valid.astype(int)
return points_rint
# fmt: on


Expand Down
15 changes: 9 additions & 6 deletions examples/rl/drive/inference/contrib_policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,20 @@ class Policy(Agent):
"""Policy class to be submitted by the user. This class will be loaded
and tested during evaluation."""

def __init__(self, num_stack, top_down_rgb, action_space_type):
def __init__(self, num_stack, top_down_rgb, crop, action_space_type):
"""All policy initialization matters, including loading of model, is
performed here. To be implemented by the user.
"""

import stable_baselines3 as sb3lib
from contrib_policy import network
from contrib_policy.filter_obs import FilterObs
from contrib_policy.format_action import FormatAction
from contrib_policy.frame_stack import FrameStack
from contrib_policy.make_dict import MakeDict

model_path = Path(__file__).resolve().parents[0] / "saved_model"
self.model = sb3lib.PPO.load(model_path)
self._model = self._get_model()

self._filter_obs = FilterObs(top_down_rgb=top_down_rgb)
self._filter_obs = FilterObs(top_down_rgb=top_down_rgb, crop=crop)
self._frame_stack = FrameStack(
input_space=self._filter_obs.observation_space,
num_stack=num_stack,
Expand All @@ -43,7 +41,7 @@ def __init__(self, num_stack, top_down_rgb, action_space_type):
def act(self, obs):
"""Mandatory act function to be implemented by user."""
processed_obs = self._process(obs)
action, _ = self.model.predict(observation=processed_obs, deterministic=True)
action, _ = self._model.predict(observation=processed_obs, deterministic=True)
formatted_action = self._format_action.format(
action=int(action), prev_heading=obs["ego_vehicle_state"]["heading"]
)
Expand All @@ -57,3 +55,8 @@ def _process(self, obs):
obs = self._frame_stack.stack(obs)
obs = self._make_dict.make(obs)
return obs

def _get_model(self):
import stable_baselines3 as sb3lib

return sb3lib.PPO.load(path=Path(__file__).resolve().parents[0] / "saved_model")
19 changes: 13 additions & 6 deletions examples/rl/drive/train/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,27 @@ smarts:
agent_locator: inference:contrib-agent-v0
env_id: smarts.env:driving-smarts-v2023
scenarios:
- scenarios/sumo/intersections/1_to_1lane_left_turn_c_agents_1
- scenarios/sumo/straight/cruise_2lane_agents_1
- scenarios/sumo/straight/cutin_2lane_agents_1
- scenarios/sumo/straight/merge_exit_sumo_t_agents_1
- scenarios/sumo/straight/overtake_2lane_agents_1

- scenarios/sumo/intersections/1_to_3lane_left_turn_sumo_c_agents_1
- scenarios/sumo/intersections/1_to_3lane_left_turn_middle_lane_c_agents_1

# PPO algorithm
alg:
n_steps: 2048
batch_size: 512
n_steps: 1024
batch_size: 64
n_epochs: 4
target_kl: 0.1
# ent_coef: 0.01 # For exploration. Range = 0 to 0.01

# Training over all scenarios
epochs: 500 # Number of training loops.

# Training per scenario
train_steps: 10_000
checkpoint_freq: 10_000 # Save a model every checkpoint_freq calls to env.step().
eval_freq: 10_000 # Evaluate the trained model every eval_freq steps and save the best model.
train_steps: 4_096
checkpoint_freq: 4_096 # Save a model every checkpoint_freq calls to env.step().
eval_freq: 4_096 # Evaluate the trained model every eval_freq steps and save the best model.
eval_eps: 5 # Number of evaluation epsiodes.
20 changes: 11 additions & 9 deletions examples/rl/drive/train/env.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
import sys
from pathlib import Path

# To import submission folder
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
# To import train folder
sys.path.insert(0, str(Path(__file__).resolve().parents[0]))

import gymnasium as gym

from smarts.zoo.agent_spec import AgentSpec

def make_env(env_id, scenario, agent_interface, config, seed):

def make_env(env_id, scenario, agent_spec: AgentSpec, config, seed):
from preprocess import Preprocess
from reward import Reward
from stable_baselines3.common.monitor import Monitor
from train.reward import Reward

from smarts.env.gymnasium.wrappers.api_reversion import Api021Reversion
from smarts.env.gymnasium.wrappers.single_agent import SingleAgent

env = gym.make(
env_id,
scenario=scenario,
agent_interface=agent_interface,
agent_interface=agent_spec.interface,
seed=seed,
sumo_headless=not config.sumo_gui, # If False, enables sumo-gui display.
headless=not config.head, # If False, enables Envision display.
)
env = Reward(env)
env = SingleAgent(env)
env = Api021Reversion(env)
env = Preprocess(env, agent_interface)
env = Reward(env=env, crop=agent_spec.agent_params["crop"])
env = SingleAgent(env=env)
env = Api021Reversion(env=env)
env = Preprocess(env=env, agent_spec=agent_spec)
env = Monitor(env)

return env
15 changes: 10 additions & 5 deletions examples/rl/drive/train/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
from contrib_policy.frame_stack import FrameStack
from contrib_policy.make_dict import MakeDict

from smarts.core.agent_interface import AgentInterface
from smarts.zoo.agent_spec import AgentSpec


class Preprocess(gym.Wrapper):
def __init__(self, env: gym.Env, agent_interface: AgentInterface):
def __init__(self, env: gym.Env, agent_spec: AgentSpec):
super().__init__(env)

self._filter_obs = FilterObs(top_down_rgb=agent_interface.top_down_rgb)
self._filter_obs = FilterObs(
top_down_rgb=agent_spec.interface.top_down_rgb,
crop=agent_spec.agent_params["crop"],
)
self._frame_stack = FrameStack(
input_space=self._filter_obs.observation_space,
num_stack=3,
num_stack=agent_spec.agent_params["num_stack"],
stack_axis=0,
)
self._frame_stack.reset()
Expand All @@ -23,7 +26,9 @@ def __init__(self, env: gym.Env, agent_interface: AgentInterface):
self.observation_space = self._make_dict.observation_space

self._prev_heading: float
self._format_action = FormatAction(agent_interface.action)
self._format_action = FormatAction(
action_space_type=agent_spec.interface.action
)
self.action_space = self._format_action.action_space

def _process(self, obs):
Expand Down
Loading

0 comments on commit 1ad324b

Please sign in to comment.