Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade SB3 to v2.0.0 in Drive and VehicleFollowing RL examples #2075

Merged
merged 2 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Copy and pasting the git commit messages is __NOT__ enough.
- Renamed `rllib/rllib.py` to `rllib/pg_pbt_example.py`.
- Loosened constraint of `gymnasium` from `==0.27.0` to `>=0.26.3`.
- `LaneFollowingController` now uses a different pole placement method to compute lateral/heading gains. Numerical behaviour is unchanged. Performance is slightly faster.
- Upgraded Stable Baselines3 from v1.7.0 to v2.0.0, and switched to Gymnasium backend, in Drive and VehicleFollowing RL examples.
### Deprecated
### Fixed
- Missing neighborhood vehicle state `'lane_id'` is now added to the `hiway-v1` formatted observations.
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
[![SMARTS CI Format](https://github.com/huawei-noah/SMARTS/actions/workflows/ci-format.yml/badge.svg?branch=master)](https://github.com/huawei-noah/SMARTS/actions/workflows/ci-format.yml?query=branch%3Amaster)
[![Documentation Status](https://readthedocs.org/projects/smarts/badge/?version=latest)](https://smarts.readthedocs.io/en/latest/?badge=latest)
![Code style](https://img.shields.io/badge/code%20style-black-000000.svg)
[![Pyversion](https://img.shields.io/pypi/pyversions/smarts.svg)](https://badge.fury.io/py/smarts)
[![PyPI version](https://badge.fury.io/py/smarts.svg)](https://badge.fury.io/py/smarts)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)

SMARTS (Scalable Multi-Agent Reinforcement Learning Training School) is a simulation platform for multi-agent reinforcement learning (RL) and research on autonomous driving. Its focus is on realistic and diverse interactions. It is part of the [XingTian](https://github.com/huawei-noah/xingtian/) suite of RL platforms from Huawei Noah's Ark Lab.
Expand Down Expand Up @@ -44,8 +46,8 @@ Several agent control policies and agent [action types](smarts/core/controllers/
### RL Model
1. [Drive](examples/rl/drive). See [Driving SMARTS 2023.1 & 2023.2](https://smarts.readthedocs.io/en/latest/benchmarks/driving_smarts_2023_1.html) for more info.
1. [VehicleFollowing](examples/rl/platoon). See [Driving SMARTS 2023.3](https://smarts.readthedocs.io/en/latest/benchmarks/driving_smarts_2023_3.html) for more info.
1. [PG](examples/rl/rllib/pg_example.py). See [RLlib](https://smarts.readthedocs.io/en/latest/docs/ecosystem/rllib.html) for more info.
1. [PG Population Based Training](examples/rl/rllib/pg_pbt_example.py). See [RLlib](https://smarts.readthedocs.io/en/latest/docs/ecosystem/rllib.html) for more info.
1. [PG](examples/rl/rllib/pg_example.py). See [RLlib](https://smarts.readthedocs.io/en/latest/ecosystem/rllib.html) for more info.
1. [PG Population Based Training](examples/rl/rllib/pg_pbt_example.py). See [RLlib](https://smarts.readthedocs.io/en/latest/ecosystem/rllib.html) for more info.

### RL Environment
1. [ULTRA](https://github.com/smarts-project/smarts-project.rl/blob/master/ultra) provides a gym-based environment built upon SMARTS to tackle intersection navigation, specifically the unprotected left turn.
Expand Down
2 changes: 1 addition & 1 deletion examples/rl/drive/inference/contrib_policy/filter_obs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from typing import Any, Dict, Sequence, Tuple

import gym
import gymnasium as gym
import numpy as np

from smarts.core.agent_interface import RGB
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, Tuple

import gym
import gymnasium as gym
import numpy as np

from smarts.core.controllers import ActionSpaceType
Expand Down
2 changes: 1 addition & 1 deletion examples/rl/drive/inference/contrib_policy/frame_stack.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
from collections import deque

import gym
import gymnasium as gym
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion examples/rl/drive/inference/contrib_policy/make_dict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict

import gym
import gymnasium as gym
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion examples/rl/drive/inference/contrib_policy/network.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gym
import gymnasium as gym
import torch as th
import torch.nn as nn
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
Expand Down
2 changes: 1 addition & 1 deletion examples/rl/drive/inference/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ zip_safe = True
python_requires = == 3.8.*
install_requires =
setuptools==65.4.0
stable-baselines3==1.7.0
stable-baselines3==2.0.0
tensorboard==2.12.0
torch==1.13.1
torchinfo==1.7.2
1 change: 0 additions & 1 deletion examples/rl/drive/train/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def make_env(env_id, scenario, agent_spec: AgentSpec, config, seed):
)
env = Reward(env=env, crop=agent_spec.agent_params["crop"])
env = SingleAgent(env=env)
env = Api021Reversion(env=env)
env = Preprocess(env=env, agent_spec=agent_spec)
env = Monitor(env)

Expand Down
10 changes: 5 additions & 5 deletions examples/rl/drive/train/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@ def step(self, action):
formatted_action = self._format_action.format(
action=action, prev_heading=self._prev_heading
)
obs, reward, done, info = self.env.step(formatted_action)
obs, reward, terminated, truncated, info = self.env.step(formatted_action)
self._prev_heading = obs["ego_vehicle_state"]["heading"]
obs = self._process(obs)
return obs, reward, done, info
return obs, reward, terminated, truncated, info

def reset(self):
def reset(self, *, seed=None, options=None):
"""Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data."""

self._frame_stack.reset()
obs = self.env.reset()
obs, info = self.env.reset(seed=seed, options=options)
self._prev_heading = obs["ego_vehicle_state"]["heading"]
obs = self._process(obs)
return obs
return obs, info
57 changes: 28 additions & 29 deletions examples/rl/drive/train/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
absl-py==1.4.0
argcomplete==3.0.8
argcomplete==3.1.1
attrs==23.1.0
Automat==22.10.0
av==10.0.0
Expand All @@ -11,36 +11,36 @@ click==8.1.3
cloudpickle==1.6.0
colorlog==6.7.0
constantly==15.1.0
contourpy==1.0.7
contourpy==1.1.0
cycler==0.11.0
distlib==0.3.6
eclipse-sumo==1.17.0
filelock==3.12.0
fonttools==4.39.4
eclipse-sumo==1.18.0
Farama-Notifications==0.0.4
filelock==3.12.2
fonttools==4.40.0
future==0.18.3
google-auth==2.19.1
google-auth==2.21.0
google-auth-oauthlib==0.4.6
grpcio==1.55.0
gym==0.21.0
gymnasium==0.27.0
gymnasium-notices==0.0.1
grpcio==1.56.0
gym==0.19.0
gymnasium==0.28.1
hyperlink==21.0.0
idna==3.4
ijson==3.2.0.post0
importlib-metadata==4.13.0
ijson==3.2.2
importlib-metadata==6.7.0
importlib-resources==5.12.0
incremental==22.10.0
jax-jumpy==1.0.0
joblib==1.2.0
joblib==1.3.1
kiwisolver==1.4.4
llvmlite==0.40.0
llvmlite==0.40.1
Markdown==3.4.3
markdown-it-py==2.2.0
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.7.1
mdurl==0.1.2
nox==2023.4.22
numba==0.57.0
numba==0.57.1
numpy==1.23.5
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
Expand All @@ -52,47 +52,46 @@ packaging==23.1
Panda3D==1.10.9
panda3d-gltf==0.13
panda3d-simplepbr==0.10
pandas==2.0.2
pandas==2.0.3
Pillow==9.5.0
platformdirs==3.5.1
protobuf==3.20.3
platformdirs==3.8.0
protobuf==4.23.3
psutil==5.9.5
pyarrow==12.0.0
pyarrow==12.0.1
pyasn1==0.5.0
pyasn1-modules==0.3.0
pybullet==3.2.5
Pygments==2.15.1
pyparsing==3.0.9
pyparsing==3.1.0
pyproj==3.5.0
python-dateutil==2.8.2
pytz==2023.3
PyYAML==6.0
requests==2.31.0
requests-oauthlib==1.3.1
rich==13.4.1
rich==13.4.2
rsa==4.9
Rtree==1.0.1
scipy==1.10.1
shapely==2.0.1
Shimmy==0.2.1
six==1.16.0
stable-baselines3==1.7.0
stable-baselines3==2.0.0
tableprint==0.9.1
tensorboard==2.12.0
tensorboard-data-server==0.7.0
tensorboard-data-server==0.7.1
tensorboard-plugin-wit==1.8.1
torch==1.13.1
torchinfo==1.7.2
tornado==6.3.2
trimesh==3.9.29
Twisted==22.10.0
typing_extensions==4.6.3
typing_extensions==4.7.0
tzdata==2023.3
urllib3==1.26.16
virtualenv==20.23.0
virtualenv==20.23.1
wcwidth==0.2.6
websocket-client==1.5.2
Werkzeug==2.3.5
websocket-client==1.6.1
Werkzeug==2.3.6
yattag==1.15.1
zipp==3.15.0
zope.interface==6.0
2 changes: 1 addition & 1 deletion examples/rl/drive/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from itertools import cycle, islice
from typing import Any, Dict

import gym
import gymnasium as gym

# Load inference module to register agent
import inference
Expand Down
2 changes: 1 addition & 1 deletion examples/rl/platoon/inference/contrib_policy/filter_obs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from typing import Any, Dict, Sequence, Tuple

import gym
import gymnasium as gym
import numpy as np

from smarts.core.agent_interface import RGB
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, Tuple

import gym
import gymnasium as gym
import numpy as np

from smarts.core.controllers import ActionSpaceType
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
from collections import deque

import gym
import gymnasium as gym
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion examples/rl/platoon/inference/contrib_policy/make_dict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict

import gym
import gymnasium as gym
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion examples/rl/platoon/inference/contrib_policy/network.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gym
import gymnasium as gym
import torch as th
import torch.nn as nn
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
Expand Down
2 changes: 1 addition & 1 deletion examples/rl/platoon/inference/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ zip_safe = True
python_requires = == 3.8.*
install_requires =
setuptools==65.4.0
stable-baselines3==1.7.0
stable-baselines3==2.0.0
tensorboard==2.12.0
torch==1.13.1
torchinfo==1.7.2
1 change: 0 additions & 1 deletion examples/rl/platoon/train/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def make_env(env_id, scenario, agent_spec: AgentSpec, config, seed):
)
env = Reward(env=env, crop=agent_spec.agent_params["crop"])
env = SingleAgent(env=env)
env = Api021Reversion(env=env)
env = Preprocess(env=env, agent_spec=agent_spec)
env = Monitor(env)

Expand Down
12 changes: 6 additions & 6 deletions examples/rl/platoon/train/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gym
import gymnasium as gym
from contrib_policy.filter_obs import FilterObs
from contrib_policy.format_action import FormatAction
from contrib_policy.frame_stack import FrameStack
Expand Down Expand Up @@ -40,14 +40,14 @@ def _process(self, obs):
def step(self, action):
"""Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data."""
formatted_action = self._format_action.format(action)
obs, reward, done, info = self.env.step(formatted_action)
obs, reward, terminated, truncated, info = self.env.step(formatted_action)
obs = self._process(obs)
return obs, reward, done, info
return obs, reward, terminated, truncated, info

def reset(self):
def reset(self, *, seed=None, options=None):
"""Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data."""

self._frame_stack.reset()
obs = self.env.reset()
obs, info = self.env.reset(seed=seed, options=options)
obs = self._process(obs)
return obs
return obs, info
Loading