Skip to content

Commit

Permalink
Final updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gamenot committed Jun 20, 2023
1 parent 9bf091a commit 56301b5
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 44 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@ Copy and pasting the git commit messages is __NOT__ enough.

## [Unreleased]
### Added
- Added `rllib/pg_example.py` to demonstrate a simple integration with `RLlib` and `tensorflow` for policy training.
- Added `rllib/pg_pbt_example.py` to demonstrate integration with `ray.RLlib`, `tensorflow`, and `ray.tune` for scheduled policy training.
### Changed
- Updated `smarts[ray]` (`ray==2.2`) and `smarts[rllib]` (`ray[rllib]==1.4`) to use `ray~=2.5`.
- Introduced `tensorflow-probability` to `smarts[rllib]`.
- Updated `RLlibHiWayEnv` to use the `gymnasium` interface.
- Renamed `rllib/rllib.py` to `rllib/pg_pbt_example.py`.
### Deprecated
### Fixed
- Missing neighborhood vehicle state `'lane_id'` is now added to the `hiway-v1` formatted observations.
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ Several agent control policies and agent [action types](smarts/core/controllers/
### RL Model
1. [Drive](examples/rl/drive). See [Driving SMARTS 2023.1 & 2023.2](https://smarts.readthedocs.io/en/latest/benchmarks/driving_smarts_2023_1.html) for more info.
1. [VehicleFollowing](examples/rl/platoon). See [Driving SMARTS 2023.3](https://smarts.readthedocs.io/en/latest/benchmarks/driving_smarts_2023_3.html) for more info.
1. [PG](examples/rl/rllib/pg_example.py). See [RLlib](https://smarts.readthedocs.io/en/latest/docs/ecosystem/rllib.html) for more info.
1. [PG Population Based Training](examples/rl/rllib/pg_pbt_example.py). See [RLlib](https://smarts.readthedocs.io/en/latest/docs/ecosystem/rllib.html) for more info.

### RL Environment
1. [ULTRA](https://github.com/smarts-project/smarts-project.rl/blob/master/ultra) provides a gym-based environment built upon SMARTS to tackle intersection navigation, specifically the unprotected left turn.
Expand Down
7 changes: 7 additions & 0 deletions docs/ecosystem/rllib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ RLlib
of applications. ``RLlib`` natively supports ``TensorFlow``, ``TensorFlow Eager``, and ``PyTorch``. Most of its internals are agnostic to such
deep learning frameworks.

SMARTS contains two examples using `Policy Gradients (PG) <https://docs.ray.io/en/latest/rllib-algorithms.html#policy-gradients-pg>`_.

1. rllib/pg_example.py
This example shows the basics of using RLlib with SMARTS through :class:`~smarts.env.rllib_hiway_env.RLlibHiWayEnv`.
1. rllib/pg_pbt_example.py
This example combines Policy Gradients with `Population Based Training (PBT) <https://docs.ray.io/en/latest/tune/api/doc/ray.tune.schedulers.PopulationBasedTraining.html>`_ scheduling.

Recommended reads
-----------------

Expand Down
9 changes: 6 additions & 3 deletions docs/sim/env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@ Base environments
SMARTS environment module is defined in :mod:`~smarts.env` package. Currently SMARTS provides two kinds of training
environments, namely:

+ ``HiWayEnv`` utilizing ``gym.env`` style interface
+ ``HiWayEnv`` utilizing a ``gymnasium.Env`` interface
+ ``RLlibHiwayEnv`` customized for `RLlib <https://docs.ray.io/en/latest/rllib/index.html>`_ training

.. image:: ../_static/env.png

HiWayEnv
^^^^^^^^

``HiWayEnv`` inherits class ``gym.Env`` and supports gym APIs like ``reset``, ``step``, ``close``. An usage example is shown below.
``HiWayEnv`` inherits class ``gymnasium.Env`` and supports gym APIs like ``reset``, ``step``, ``close``. An usage example is shown below.
Refer to :class:`~smarts.env.hiway_env.HiWayEnv` for more details.

.. code-block:: python
import gymnasium as gym
# Make env
env = gym.make(
"smarts.env:hiway-v0", # Env entry name.
Expand Down Expand Up @@ -53,6 +54,7 @@ exactly matches the `env.observation_space`, and `ObservationOptions.multi_agent

.. code-block:: python
import gymnasium as gym
# Make env
env = gym.make(
"smarts.env:hiway-v1", # Env entry name.
Expand Down Expand Up @@ -81,6 +83,7 @@ This can be done with :class:`~smarts.env.gymnasium.wrappers.api_reversion.Api02

.. code-block:: python
import gymnasium as gym
# Make env
env = gym.make(
"smarts.env:hiway-v1", # Env entry name.
Expand All @@ -91,7 +94,7 @@ This can be done with :class:`~smarts.env.gymnasium.wrappers.api_reversion.Api02
RLlibHiwayEnv
^^^^^^^^^^^^^

``RLlibHiwayEnv`` inherits class ``MultiAgentEnv``, which is defined in `RLlib <https://docs.ray.io/en/latest/rllib/index.html>`_. It also supports common env APIs like ``reset``,
``RLlibHiwayEnv`` inherits class ``MultiAgentEnv``, which is defined in `RLlib <https://docs.ray.io/en/latest/rllib/index.html>`_. It also supports common environment APIs like ``reset``,
``step``, ``close``. An usage example is shown below. Refer to :class:`~smarts.env.rllib_hiway_env.RLlibHiWayEnv` for more details.

.. code-block:: python
Expand Down
20 changes: 1 addition & 19 deletions examples/rl/rllib/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def gen_pg_config(
return algo_config


def gen_parser(
prog: str, default_result_dir: str, default_save_model_path: str
) -> argparse.ArgumentParser:
def gen_parser(prog: str, default_result_dir: str) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog)
parser.add_argument(
"--scenario",
Expand All @@ -76,12 +74,6 @@ def gen_parser(
action="store_true",
help="Run simulation with Envision display.",
)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of times to sample from hyperparameter space",
)
parser.add_argument(
"--rollout_fragment_length",
type=str,
Expand Down Expand Up @@ -133,17 +125,7 @@ def gen_parser(
default="ERROR",
help="Log level (DEBUG|INFO|WARN|ERROR)",
)
parser.add_argument(
"--checkpoint_num", type=int, default=None, help="Checkpoint number"
)
parser.add_argument(
"--checkpoint_freq", type=int, default=3, help="Checkpoint frequency"
)

parser.add_argument(
"--save_model_path",
type=str,
default=default_save_model_path,
help="Destination path of where to copy the model when training is over",
)
return parser
2 changes: 1 addition & 1 deletion examples/rl/rllib/model/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
## Model Binaries

The binaries located in this directory are the components of a trained rllib model. These are related to the `examples/rl/rllib/rllib.py` example script. Results from `examples/rl/rllib/rllib.py` are loaded and written to this directory.
The binaries located in this directory are the components of a trained rllib model. These are related to the `examples/rl/rllib/pg_pbt_example.py` example script. Results from `examples/rl/rllib/pg_pbt_example.py` are loaded and written to this directory.
16 changes: 7 additions & 9 deletions examples/rl/rllib/pg_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,13 @@ def main(
rollout_fragment_length,
train_batch_size,
seed,
num_samples,
num_agents,
num_workers,
resume_training,
result_dir,
checkpoint_freq: int,
checkpoint_num: Optional[int],
log_level: Literal["DEBUG", "INFO", "WARN", "ERROR"],
save_model_path,
):
agent_values = {
"agent_specs": {
Expand Down Expand Up @@ -162,12 +160,14 @@ def get_checkpoint_dir(num):


if __name__ == "__main__":
default_save_model_path = str(
Path(__file__).expanduser().resolve().parent / "pg_model"
)
default_result_dir = str(Path(__file__).resolve().parent / "results" / "pg_results")
parser = gen_parser("rllib-example", default_result_dir, default_save_model_path)

parser = gen_parser("rllib-example", default_result_dir)
parser.add_argument(
"--checkpoint_num",
type=int,
default=None,
help="The checkpoint number to restart from.",
)
args = parser.parse_args()
build_scenario(scenario=args.scenario, clean=False, seed=42)

Expand All @@ -178,13 +178,11 @@ def get_checkpoint_dir(num):
rollout_fragment_length=args.rollout_fragment_length,
train_batch_size=args.train_batch_size,
seed=args.seed,
num_samples=args.num_samples,
num_agents=args.num_agents,
num_workers=args.num_workers,
resume_training=args.resume_training,
result_dir=args.result_dir,
checkpoint_freq=max(args.checkpoint_freq, 1),
checkpoint_num=args.checkpoint_num,
log_level=args.log_level,
save_model_path=args.save_model_path,
)
22 changes: 15 additions & 7 deletions examples/rl/rllib/rllib.py → examples/rl/rllib/pg_pbt_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def main(
resume_training,
result_dir,
checkpoint_freq: int,
checkpoint_num: Optional[int],
log_level: Literal["DEBUG", "INFO", "WARN", "ERROR"],
save_model_path,
):
Expand Down Expand Up @@ -240,12 +239,22 @@ def main(


if __name__ == "__main__":
default_save_model_path = str(
Path(__file__).expanduser().resolve().parent / "model"
default_result_dir = str(
Path(__file__).resolve().parent / "results" / "tune_pg_results"
)
parser = gen_parser("rllib-example", default_result_dir)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of times to sample from hyperparameter space",
)
parser.add_argument(
"--save_model_path",
type=str,
default=str(Path(__file__).expanduser().resolve().parent / "model"),
help="Destination path of where to copy the model when training is over",
)
default_result_dir = str(Path(__file__).resolve().parent / "results" / "tune_pg_results")
parser = gen_parser("rllib-example", default_result_dir, default_save_model_path)

args = parser.parse_args()
build_scenario(scenario=args.scenario, clean=False, seed=42)

Expand All @@ -262,7 +271,6 @@ def main(
resume_training=args.resume_training,
result_dir=args.result_dir,
checkpoint_freq=max(args.checkpoint_freq, 1),
checkpoint_num=args.checkpoint_num,
log_level=args.log_level,
save_model_path=args.save_model_path,
)
7 changes: 2 additions & 5 deletions examples/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,20 @@ def test_rllib_pg_example():
rollout_fragment_length=200,
train_batch_size=200,
seed=42,
num_samples=1,
num_agents=2,
num_workers=1,
resume_training=False,
result_dir=result_dir,
checkpoint_num=None,
checkpoint_freq=1,
save_model_path=model_dir,
log_level="WARN",
)


def test_rllib_tune_pg_example():
from examples.rl.rllib import rllib
from examples.rl.rllib import pg_pbt_example

main = rllib.main
main = pg_pbt_example.main
with tempfile.TemporaryDirectory() as result_dir, tempfile.TemporaryDirectory() as model_dir:
main(
scenario="scenarios/sumo/loop",
Expand All @@ -80,7 +78,6 @@ def test_rllib_tune_pg_example():
num_workers=1,
resume_training=False,
result_dir=result_dir,
checkpoint_num=None,
checkpoint_freq=1,
save_model_path=model_dir,
log_level="WARN",
Expand Down

0 comments on commit 56301b5

Please sign in to comment.