diff --git a/docs/ecosystem/rllib.rst b/docs/ecosystem/rllib.rst index 2a49341ae8..8e4b031f1c 100644 --- a/docs/ecosystem/rllib.rst +++ b/docs/ecosystem/rllib.rst @@ -28,7 +28,7 @@ many docs about ``Ray`` and ``RLlib``. We recommend to read the following pages Resume training --------------- -With respect to ``SMARTS/examples/rl/rllib`` example, if you want to continue an aborted experiment, you can set ``resume=True`` in ``tune.run``. But note that ``resume=True`` will continue to use the same configuration as was set in the original experiment. +With respect to ``SMARTS/examples/rl/rllib`` examples, if you want to continue an aborted experiment, you can set ``resume_training=True``. But note that ``resume_training=True`` will continue to use the same configuration as was set in the original experiment. To make changes to a started experiment, you can edit the latest experiment file in ``./results``. -Or if you want to start a new experiment but train from an existing checkpoint, you can set ``restore=checkpoint_path`` in ``tune.run``. +Or if you want to start a new experiment but train from an existing checkpoint, you will need to look into `How to Save and Load Trial Checkpoints `_. diff --git a/examples/rl/rllib/configs.py b/examples/rl/rllib/configs.py new file mode 100644 index 0000000000..841b8f1cdc --- /dev/null +++ b/examples/rl/rllib/configs.py @@ -0,0 +1,149 @@ +import argparse +import multiprocessing +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Union + +try: + from ray.rllib.algorithms.algorithm import AlgorithmConfig + from ray.rllib.algorithms.callbacks import DefaultCallbacks + from ray.rllib.algorithms.pg import PGConfig + from ray.tune.search.sample import Integer as IntegerDomain +except Exception as e: + from smarts.core.utils.custom_exceptions import RayException + + raise RayException.required_to("rllib.py") + + +def gen_pg_config( + scenario, + envision, + rollout_fragment_length, + train_batch_size, + num_workers, + log_level: Literal["DEBUG", "INFO", "WARN", "ERROR"], + seed: Union[int, IntegerDomain], + rllib_policies: Dict[str, Any], + agent_specs: Dict[str, Any], + callbacks: Optional[DefaultCallbacks], +) -> AlgorithmConfig: + assert len(set(rllib_policies.keys()).difference(agent_specs)) == 0 + algo_config = ( + PGConfig() + .environment( + env="rllib_hiway-v0", + env_config={ + "seed": seed, + "scenarios": [str(Path(scenario).expanduser().resolve().absolute())], + "headless": not envision, + "agent_specs": agent_specs, + "observation_options": "multi_agent", + }, + disable_env_checking=True, + ) + .framework(framework="tf2", eager_tracing=True) + .rollouts( + rollout_fragment_length=rollout_fragment_length, + num_rollout_workers=num_workers, + num_envs_per_worker=1, + enable_tf1_exec_eagerly=True, + ) + .training( + lr_schedule=[(0, 1e-3), (1e3, 5e-4), (1e5, 1e-4), (1e7, 5e-5), (1e8, 1e-5)], + train_batch_size=train_batch_size, + ) + .multi_agent( + policies=rllib_policies, + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: f"{agent_id}", + ) + .callbacks(callbacks_class=callbacks) + .debugging(log_level=log_level) + ) + return algo_config + + +def gen_parser( + prog: str, default_result_dir: str, default_save_model_path: str +) -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(prog) + parser.add_argument( + "--scenario", + type=str, + default=str(Path(__file__).resolve().parents[3] / "scenarios/sumo/loop"), + help="Scenario to run (see scenarios/ for some samples you can use)", + ) + parser.add_argument( + "--envision", + action="store_true", + help="Run simulation with Envision display.", + ) + parser.add_argument( + "--num_samples", + type=int, + default=1, + help="Number of times to sample from hyperparameter space", + ) + parser.add_argument( + "--rollout_fragment_length", + type=str, + default="auto", + help="Episodes are divided into fragments of this many steps for each rollout. In this example this will be ensured to be `1=