diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e79868dca..bdd1eca733 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ Copy and pasting the git commit messages is __NOT__ enough. - Renamed `rllib/rllib.py` to `rllib/pg_pbt_example.py`. - Loosened constraint of `gymnasium` from `==0.27.0` to `>=0.26.3`. - `LaneFollowingController` now uses a different pole placement method to compute lateral/heading gains. Numerical behaviour is unchanged. Performance is slightly faster. +- Upgraded Stable Baselines3 from v1.7.0 to v2.0.0, and switched to Gymnasium backend, in Drive and VehicleFollowing RL examples. ### Deprecated ### Fixed - Missing neighborhood vehicle state `'lane_id'` is now added to the `hiway-v1` formatted observations. diff --git a/README.md b/README.md index 6067a0eabb..de12e49dde 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,8 @@ [![SMARTS CI Format](https://github.com/huawei-noah/SMARTS/actions/workflows/ci-format.yml/badge.svg?branch=master)](https://github.com/huawei-noah/SMARTS/actions/workflows/ci-format.yml?query=branch%3Amaster) [![Documentation Status](https://readthedocs.org/projects/smarts/badge/?version=latest)](https://smarts.readthedocs.io/en/latest/?badge=latest) ![Code style](https://img.shields.io/badge/code%20style-black-000000.svg) +[![Pyversion](https://img.shields.io/pypi/pyversions/smarts.svg)](https://badge.fury.io/py/smarts) +[![PyPI version](https://badge.fury.io/py/smarts.svg)](https://badge.fury.io/py/smarts) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) SMARTS (Scalable Multi-Agent Reinforcement Learning Training School) is a simulation platform for multi-agent reinforcement learning (RL) and research on autonomous driving. Its focus is on realistic and diverse interactions. It is part of the [XingTian](https://github.com/huawei-noah/xingtian/) suite of RL platforms from Huawei Noah's Ark Lab. @@ -44,8 +46,8 @@ Several agent control policies and agent [action types](smarts/core/controllers/ ### RL Model 1. [Drive](examples/rl/drive). See [Driving SMARTS 2023.1 & 2023.2](https://smarts.readthedocs.io/en/latest/benchmarks/driving_smarts_2023_1.html) for more info. 1. [VehicleFollowing](examples/rl/platoon). See [Driving SMARTS 2023.3](https://smarts.readthedocs.io/en/latest/benchmarks/driving_smarts_2023_3.html) for more info. -1. [PG](examples/rl/rllib/pg_example.py). See [RLlib](https://smarts.readthedocs.io/en/latest/docs/ecosystem/rllib.html) for more info. -1. [PG Population Based Training](examples/rl/rllib/pg_pbt_example.py). See [RLlib](https://smarts.readthedocs.io/en/latest/docs/ecosystem/rllib.html) for more info. +1. [PG](examples/rl/rllib/pg_example.py). See [RLlib](https://smarts.readthedocs.io/en/latest/ecosystem/rllib.html) for more info. +1. [PG Population Based Training](examples/rl/rllib/pg_pbt_example.py). See [RLlib](https://smarts.readthedocs.io/en/latest/ecosystem/rllib.html) for more info. ### RL Environment 1. [ULTRA](https://github.com/smarts-project/smarts-project.rl/blob/master/ultra) provides a gym-based environment built upon SMARTS to tackle intersection navigation, specifically the unprotected left turn. diff --git a/examples/rl/drive/inference/contrib_policy/filter_obs.py b/examples/rl/drive/inference/contrib_policy/filter_obs.py index e366355c16..4be368171e 100644 --- a/examples/rl/drive/inference/contrib_policy/filter_obs.py +++ b/examples/rl/drive/inference/contrib_policy/filter_obs.py @@ -1,7 +1,7 @@ import math from typing import Any, Dict, Sequence, Tuple -import gym +import gymnasium as gym import numpy as np from smarts.core.agent_interface import RGB diff --git a/examples/rl/drive/inference/contrib_policy/format_action.py b/examples/rl/drive/inference/contrib_policy/format_action.py index 646f43f8f8..d1603d0862 100644 --- a/examples/rl/drive/inference/contrib_policy/format_action.py +++ b/examples/rl/drive/inference/contrib_policy/format_action.py @@ -1,6 +1,6 @@ from typing import Callable, Tuple -import gym +import gymnasium as gym import numpy as np from smarts.core.controllers import ActionSpaceType diff --git a/examples/rl/drive/inference/contrib_policy/frame_stack.py b/examples/rl/drive/inference/contrib_policy/frame_stack.py index 9da17cd193..8f7d8d7193 100644 --- a/examples/rl/drive/inference/contrib_policy/frame_stack.py +++ b/examples/rl/drive/inference/contrib_policy/frame_stack.py @@ -1,7 +1,7 @@ import copy from collections import deque -import gym +import gymnasium as gym import numpy as np diff --git a/examples/rl/drive/inference/contrib_policy/make_dict.py b/examples/rl/drive/inference/contrib_policy/make_dict.py index f995c93afc..599a79719f 100644 --- a/examples/rl/drive/inference/contrib_policy/make_dict.py +++ b/examples/rl/drive/inference/contrib_policy/make_dict.py @@ -1,6 +1,6 @@ from typing import Dict -import gym +import gymnasium as gym import numpy as np diff --git a/examples/rl/drive/inference/contrib_policy/network.py b/examples/rl/drive/inference/contrib_policy/network.py index 4ed514d1df..b1ae6bd98d 100644 --- a/examples/rl/drive/inference/contrib_policy/network.py +++ b/examples/rl/drive/inference/contrib_policy/network.py @@ -1,4 +1,4 @@ -import gym +import gymnasium as gym import torch as th import torch.nn as nn from stable_baselines3.common.preprocessing import get_flattened_obs_dim diff --git a/examples/rl/drive/inference/setup.cfg b/examples/rl/drive/inference/setup.cfg index a5962664aa..3a48464cd7 100644 --- a/examples/rl/drive/inference/setup.cfg +++ b/examples/rl/drive/inference/setup.cfg @@ -17,7 +17,7 @@ zip_safe = True python_requires = == 3.8.* install_requires = setuptools==65.4.0 - stable-baselines3==1.7.0 + stable-baselines3==2.0.0 tensorboard==2.12.0 torch==1.13.1 torchinfo==1.7.2 diff --git a/examples/rl/drive/train/env.py b/examples/rl/drive/train/env.py index a4dc0fea12..7f5469d259 100644 --- a/examples/rl/drive/train/env.py +++ b/examples/rl/drive/train/env.py @@ -27,7 +27,6 @@ def make_env(env_id, scenario, agent_spec: AgentSpec, config, seed): ) env = Reward(env=env, crop=agent_spec.agent_params["crop"]) env = SingleAgent(env=env) - env = Api021Reversion(env=env) env = Preprocess(env=env, agent_spec=agent_spec) env = Monitor(env) diff --git a/examples/rl/drive/train/preprocess.py b/examples/rl/drive/train/preprocess.py index ec2f890809..213e72b358 100644 --- a/examples/rl/drive/train/preprocess.py +++ b/examples/rl/drive/train/preprocess.py @@ -43,16 +43,16 @@ def step(self, action): formatted_action = self._format_action.format( action=action, prev_heading=self._prev_heading ) - obs, reward, done, info = self.env.step(formatted_action) + obs, reward, terminated, truncated, info = self.env.step(formatted_action) self._prev_heading = obs["ego_vehicle_state"]["heading"] obs = self._process(obs) - return obs, reward, done, info + return obs, reward, terminated, truncated, info - def reset(self): + def reset(self, *, seed=None, options=None): """Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data.""" self._frame_stack.reset() - obs = self.env.reset() + obs, info = self.env.reset(seed=seed, options=options) self._prev_heading = obs["ego_vehicle_state"]["heading"] obs = self._process(obs) - return obs + return obs, info diff --git a/examples/rl/drive/train/requirements.txt b/examples/rl/drive/train/requirements.txt index 90d5d558db..e9412269ee 100644 --- a/examples/rl/drive/train/requirements.txt +++ b/examples/rl/drive/train/requirements.txt @@ -1,5 +1,5 @@ absl-py==1.4.0 -argcomplete==3.0.8 +argcomplete==3.1.1 attrs==23.1.0 Automat==22.10.0 av==10.0.0 @@ -11,36 +11,36 @@ click==8.1.3 cloudpickle==1.6.0 colorlog==6.7.0 constantly==15.1.0 -contourpy==1.0.7 +contourpy==1.1.0 cycler==0.11.0 distlib==0.3.6 -eclipse-sumo==1.17.0 -filelock==3.12.0 -fonttools==4.39.4 +eclipse-sumo==1.18.0 +Farama-Notifications==0.0.4 +filelock==3.12.2 +fonttools==4.40.0 future==0.18.3 -google-auth==2.19.1 +google-auth==2.21.0 google-auth-oauthlib==0.4.6 -grpcio==1.55.0 -gym==0.21.0 -gymnasium==0.27.0 -gymnasium-notices==0.0.1 +grpcio==1.56.0 +gym==0.19.0 +gymnasium==0.28.1 hyperlink==21.0.0 idna==3.4 -ijson==3.2.0.post0 -importlib-metadata==4.13.0 +ijson==3.2.2 +importlib-metadata==6.7.0 importlib-resources==5.12.0 incremental==22.10.0 jax-jumpy==1.0.0 -joblib==1.2.0 +joblib==1.3.1 kiwisolver==1.4.4 -llvmlite==0.40.0 +llvmlite==0.40.1 Markdown==3.4.3 -markdown-it-py==2.2.0 +markdown-it-py==3.0.0 MarkupSafe==2.1.3 matplotlib==3.7.1 mdurl==0.1.2 nox==2023.4.22 -numba==0.57.0 +numba==0.57.1 numpy==1.23.5 nvidia-cublas-cu11==11.10.3.66 nvidia-cuda-nvrtc-cu11==11.7.99 @@ -52,47 +52,46 @@ packaging==23.1 Panda3D==1.10.9 panda3d-gltf==0.13 panda3d-simplepbr==0.10 -pandas==2.0.2 +pandas==2.0.3 Pillow==9.5.0 -platformdirs==3.5.1 -protobuf==3.20.3 +platformdirs==3.8.0 +protobuf==4.23.3 psutil==5.9.5 -pyarrow==12.0.0 +pyarrow==12.0.1 pyasn1==0.5.0 pyasn1-modules==0.3.0 pybullet==3.2.5 Pygments==2.15.1 -pyparsing==3.0.9 +pyparsing==3.1.0 pyproj==3.5.0 python-dateutil==2.8.2 pytz==2023.3 PyYAML==6.0 requests==2.31.0 requests-oauthlib==1.3.1 -rich==13.4.1 +rich==13.4.2 rsa==4.9 Rtree==1.0.1 scipy==1.10.1 shapely==2.0.1 -Shimmy==0.2.1 six==1.16.0 -stable-baselines3==1.7.0 +stable-baselines3==2.0.0 tableprint==0.9.1 tensorboard==2.12.0 -tensorboard-data-server==0.7.0 +tensorboard-data-server==0.7.1 tensorboard-plugin-wit==1.8.1 torch==1.13.1 torchinfo==1.7.2 tornado==6.3.2 trimesh==3.9.29 Twisted==22.10.0 -typing_extensions==4.6.3 +typing_extensions==4.7.0 tzdata==2023.3 urllib3==1.26.16 -virtualenv==20.23.0 +virtualenv==20.23.1 wcwidth==0.2.6 -websocket-client==1.5.2 -Werkzeug==2.3.5 +websocket-client==1.6.1 +Werkzeug==2.3.6 yattag==1.15.1 zipp==3.15.0 zope.interface==6.0 diff --git a/examples/rl/drive/train/run.py b/examples/rl/drive/train/run.py index c42b9cab7c..8f52e949b7 100644 --- a/examples/rl/drive/train/run.py +++ b/examples/rl/drive/train/run.py @@ -12,7 +12,7 @@ from itertools import cycle, islice from typing import Any, Dict -import gym +import gymnasium as gym # Load inference module to register agent import inference diff --git a/examples/rl/platoon/inference/contrib_policy/filter_obs.py b/examples/rl/platoon/inference/contrib_policy/filter_obs.py index e366355c16..4be368171e 100644 --- a/examples/rl/platoon/inference/contrib_policy/filter_obs.py +++ b/examples/rl/platoon/inference/contrib_policy/filter_obs.py @@ -1,7 +1,7 @@ import math from typing import Any, Dict, Sequence, Tuple -import gym +import gymnasium as gym import numpy as np from smarts.core.agent_interface import RGB diff --git a/examples/rl/platoon/inference/contrib_policy/format_action.py b/examples/rl/platoon/inference/contrib_policy/format_action.py index 532770b3a1..305cb4bc05 100644 --- a/examples/rl/platoon/inference/contrib_policy/format_action.py +++ b/examples/rl/platoon/inference/contrib_policy/format_action.py @@ -1,6 +1,6 @@ from typing import Callable, Tuple -import gym +import gymnasium as gym import numpy as np from smarts.core.controllers import ActionSpaceType diff --git a/examples/rl/platoon/inference/contrib_policy/frame_stack.py b/examples/rl/platoon/inference/contrib_policy/frame_stack.py index 9da17cd193..8f7d8d7193 100644 --- a/examples/rl/platoon/inference/contrib_policy/frame_stack.py +++ b/examples/rl/platoon/inference/contrib_policy/frame_stack.py @@ -1,7 +1,7 @@ import copy from collections import deque -import gym +import gymnasium as gym import numpy as np diff --git a/examples/rl/platoon/inference/contrib_policy/make_dict.py b/examples/rl/platoon/inference/contrib_policy/make_dict.py index f995c93afc..599a79719f 100644 --- a/examples/rl/platoon/inference/contrib_policy/make_dict.py +++ b/examples/rl/platoon/inference/contrib_policy/make_dict.py @@ -1,6 +1,6 @@ from typing import Dict -import gym +import gymnasium as gym import numpy as np diff --git a/examples/rl/platoon/inference/contrib_policy/network.py b/examples/rl/platoon/inference/contrib_policy/network.py index 4ed514d1df..b1ae6bd98d 100644 --- a/examples/rl/platoon/inference/contrib_policy/network.py +++ b/examples/rl/platoon/inference/contrib_policy/network.py @@ -1,4 +1,4 @@ -import gym +import gymnasium as gym import torch as th import torch.nn as nn from stable_baselines3.common.preprocessing import get_flattened_obs_dim diff --git a/examples/rl/platoon/inference/setup.cfg b/examples/rl/platoon/inference/setup.cfg index a5962664aa..3a48464cd7 100644 --- a/examples/rl/platoon/inference/setup.cfg +++ b/examples/rl/platoon/inference/setup.cfg @@ -17,7 +17,7 @@ zip_safe = True python_requires = == 3.8.* install_requires = setuptools==65.4.0 - stable-baselines3==1.7.0 + stable-baselines3==2.0.0 tensorboard==2.12.0 torch==1.13.1 torchinfo==1.7.2 diff --git a/examples/rl/platoon/train/env.py b/examples/rl/platoon/train/env.py index a4dc0fea12..7f5469d259 100644 --- a/examples/rl/platoon/train/env.py +++ b/examples/rl/platoon/train/env.py @@ -27,7 +27,6 @@ def make_env(env_id, scenario, agent_spec: AgentSpec, config, seed): ) env = Reward(env=env, crop=agent_spec.agent_params["crop"]) env = SingleAgent(env=env) - env = Api021Reversion(env=env) env = Preprocess(env=env, agent_spec=agent_spec) env = Monitor(env) diff --git a/examples/rl/platoon/train/preprocess.py b/examples/rl/platoon/train/preprocess.py index 5a79d78960..7cce9dc48e 100644 --- a/examples/rl/platoon/train/preprocess.py +++ b/examples/rl/platoon/train/preprocess.py @@ -1,4 +1,4 @@ -import gym +import gymnasium as gym from contrib_policy.filter_obs import FilterObs from contrib_policy.format_action import FormatAction from contrib_policy.frame_stack import FrameStack @@ -40,14 +40,14 @@ def _process(self, obs): def step(self, action): """Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data.""" formatted_action = self._format_action.format(action) - obs, reward, done, info = self.env.step(formatted_action) + obs, reward, terminated, truncated, info = self.env.step(formatted_action) obs = self._process(obs) - return obs, reward, done, info + return obs, reward, terminated, truncated, info - def reset(self): + def reset(self, *, seed=None, options=None): """Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data.""" self._frame_stack.reset() - obs = self.env.reset() + obs, info = self.env.reset(seed=seed, options=options) obs = self._process(obs) - return obs + return obs, info diff --git a/examples/rl/platoon/train/requirements.txt b/examples/rl/platoon/train/requirements.txt index 90d5d558db..e9412269ee 100644 --- a/examples/rl/platoon/train/requirements.txt +++ b/examples/rl/platoon/train/requirements.txt @@ -1,5 +1,5 @@ absl-py==1.4.0 -argcomplete==3.0.8 +argcomplete==3.1.1 attrs==23.1.0 Automat==22.10.0 av==10.0.0 @@ -11,36 +11,36 @@ click==8.1.3 cloudpickle==1.6.0 colorlog==6.7.0 constantly==15.1.0 -contourpy==1.0.7 +contourpy==1.1.0 cycler==0.11.0 distlib==0.3.6 -eclipse-sumo==1.17.0 -filelock==3.12.0 -fonttools==4.39.4 +eclipse-sumo==1.18.0 +Farama-Notifications==0.0.4 +filelock==3.12.2 +fonttools==4.40.0 future==0.18.3 -google-auth==2.19.1 +google-auth==2.21.0 google-auth-oauthlib==0.4.6 -grpcio==1.55.0 -gym==0.21.0 -gymnasium==0.27.0 -gymnasium-notices==0.0.1 +grpcio==1.56.0 +gym==0.19.0 +gymnasium==0.28.1 hyperlink==21.0.0 idna==3.4 -ijson==3.2.0.post0 -importlib-metadata==4.13.0 +ijson==3.2.2 +importlib-metadata==6.7.0 importlib-resources==5.12.0 incremental==22.10.0 jax-jumpy==1.0.0 -joblib==1.2.0 +joblib==1.3.1 kiwisolver==1.4.4 -llvmlite==0.40.0 +llvmlite==0.40.1 Markdown==3.4.3 -markdown-it-py==2.2.0 +markdown-it-py==3.0.0 MarkupSafe==2.1.3 matplotlib==3.7.1 mdurl==0.1.2 nox==2023.4.22 -numba==0.57.0 +numba==0.57.1 numpy==1.23.5 nvidia-cublas-cu11==11.10.3.66 nvidia-cuda-nvrtc-cu11==11.7.99 @@ -52,47 +52,46 @@ packaging==23.1 Panda3D==1.10.9 panda3d-gltf==0.13 panda3d-simplepbr==0.10 -pandas==2.0.2 +pandas==2.0.3 Pillow==9.5.0 -platformdirs==3.5.1 -protobuf==3.20.3 +platformdirs==3.8.0 +protobuf==4.23.3 psutil==5.9.5 -pyarrow==12.0.0 +pyarrow==12.0.1 pyasn1==0.5.0 pyasn1-modules==0.3.0 pybullet==3.2.5 Pygments==2.15.1 -pyparsing==3.0.9 +pyparsing==3.1.0 pyproj==3.5.0 python-dateutil==2.8.2 pytz==2023.3 PyYAML==6.0 requests==2.31.0 requests-oauthlib==1.3.1 -rich==13.4.1 +rich==13.4.2 rsa==4.9 Rtree==1.0.1 scipy==1.10.1 shapely==2.0.1 -Shimmy==0.2.1 six==1.16.0 -stable-baselines3==1.7.0 +stable-baselines3==2.0.0 tableprint==0.9.1 tensorboard==2.12.0 -tensorboard-data-server==0.7.0 +tensorboard-data-server==0.7.1 tensorboard-plugin-wit==1.8.1 torch==1.13.1 torchinfo==1.7.2 tornado==6.3.2 trimesh==3.9.29 Twisted==22.10.0 -typing_extensions==4.6.3 +typing_extensions==4.7.0 tzdata==2023.3 urllib3==1.26.16 -virtualenv==20.23.0 +virtualenv==20.23.1 wcwidth==0.2.6 -websocket-client==1.5.2 -Werkzeug==2.3.5 +websocket-client==1.6.1 +Werkzeug==2.3.6 yattag==1.15.1 zipp==3.15.0 zope.interface==6.0 diff --git a/examples/rl/platoon/train/run.py b/examples/rl/platoon/train/run.py index 6ffcc140c8..fe4ac7a473 100644 --- a/examples/rl/platoon/train/run.py +++ b/examples/rl/platoon/train/run.py @@ -13,7 +13,7 @@ from itertools import cycle, islice from typing import Any, Dict -import gym +import gymnasium as gym # Load inference module to register agent import inference diff --git a/smarts/env/gymnasium/wrappers/single_agent.py b/smarts/env/gymnasium/wrappers/single_agent.py index 37f879dc0e..dc5fdbc355 100644 --- a/smarts/env/gymnasium/wrappers/single_agent.py +++ b/smarts/env/gymnasium/wrappers/single_agent.py @@ -69,11 +69,11 @@ def step(self, action: Any) -> Tuple[Any, float, bool, bool, Any]: info[self._agent_id], ) - def reset(self) -> Tuple[Any, Any]: + def reset(self, *, seed=None, options=None) -> Tuple[Any, Any]: """Resets a single-agent SMARTS environment. Returns: Tuple[Any, Any]: Agent's observation and info """ - obs, info = self.env.reset() + obs, info = self.env.reset(seed=seed, options=options) return obs[self._agent_id], info[self._agent_id]