From 02da98ed6670d556f267a0b2987995fb7d9092dc Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Tue, 20 Jun 2023 15:16:09 +0000 Subject: [PATCH] Fix tests. --- docs/ecosystem/rllib.rst | 4 +- examples/rl/rllib/configs.py | 58 ----------------------------- examples/rl/rllib/pg_example.py | 51 ++++++++++++++++++------- examples/rl/rllib/pg_pbt_example.py | 50 ++++++++++++++++++------- 4 files changed, 75 insertions(+), 88 deletions(-) diff --git a/docs/ecosystem/rllib.rst b/docs/ecosystem/rllib.rst index 7ae554cb99..0b09585d4a 100644 --- a/docs/ecosystem/rllib.rst +++ b/docs/ecosystem/rllib.rst @@ -10,9 +10,9 @@ deep learning frameworks. SMARTS contains two examples using `Policy Gradients (PG) `_. -1. rllib/pg_example.py +1. ``rllib/pg_example.py`` This example shows the basics of using RLlib with SMARTS through :class:`~smarts.env.rllib_hiway_env.RLlibHiWayEnv`. -1. rllib/pg_pbt_example.py +1. ``rllib/pg_pbt_example.py`` This example combines Policy Gradients with `Population Based Training (PBT) `_ scheduling. Recommended reads diff --git a/examples/rl/rllib/configs.py b/examples/rl/rllib/configs.py index 266e39d841..48ab2853de 100644 --- a/examples/rl/rllib/configs.py +++ b/examples/rl/rllib/configs.py @@ -1,64 +1,6 @@ import argparse import multiprocessing from pathlib import Path -from typing import Any, Dict, Literal, Optional, Union - -try: - from ray.rllib.algorithms.algorithm import AlgorithmConfig - from ray.rllib.algorithms.callbacks import DefaultCallbacks - from ray.rllib.algorithms.pg import PGConfig - from ray.tune.search.sample import Integer as IntegerDomain -except Exception as e: - from smarts.core.utils.custom_exceptions import RayException - - raise RayException.required_to("rllib.py") - - -def gen_pg_config( - scenario, - envision, - rollout_fragment_length, - train_batch_size, - num_workers, - log_level: Literal["DEBUG", "INFO", "WARN", "ERROR"], - seed: Union[int, IntegerDomain], - rllib_policies: Dict[str, Any], - agent_specs: Dict[str, Any], - callbacks: Optional[DefaultCallbacks], -) -> AlgorithmConfig: - assert len(set(rllib_policies.keys()).difference(agent_specs)) == 0 - algo_config = ( - PGConfig() - .environment( - env="rllib_hiway-v0", - env_config={ - "seed": seed, - "scenarios": [str(Path(scenario).expanduser().resolve().absolute())], - "headless": not envision, - "agent_specs": agent_specs, - "observation_options": "multi_agent", - }, - disable_env_checking=True, - ) - .framework(framework="tf2", eager_tracing=True) - .rollouts( - rollout_fragment_length=rollout_fragment_length, - num_rollout_workers=num_workers, - num_envs_per_worker=1, - enable_tf1_exec_eagerly=True, - ) - .training( - lr_schedule=[(0, 1e-3), (1e3, 5e-4), (1e5, 1e-4), (1e7, 5e-5), (1e8, 1e-5)], - train_batch_size=train_batch_size, - ) - .multi_agent( - policies=rllib_policies, - policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: f"{agent_id}", - ) - .callbacks(callbacks_class=callbacks) - .debugging(log_level=log_level) - ) - return algo_config def gen_parser(prog: str, default_result_dir: str) -> argparse.ArgumentParser: diff --git a/examples/rl/rllib/pg_example.py b/examples/rl/rllib/pg_example.py index ec340bba05..3d82678028 100644 --- a/examples/rl/rllib/pg_example.py +++ b/examples/rl/rllib/pg_example.py @@ -1,4 +1,5 @@ from pathlib import Path +from pprint import pprint as print from typing import Dict, Literal, Optional, Union import numpy as np @@ -6,6 +7,7 @@ try: from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig from ray.rllib.algorithms.callbacks import DefaultCallbacks + from ray.rllib.algorithms.pg import PGConfig from ray.rllib.env.base_env import BaseEnv from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.episode_v2 import EpisodeV2 @@ -18,13 +20,14 @@ raise RayException.required_to("rllib.py") import smarts +from smarts.env.rllib_hiway_env import RLlibHiWayEnv from smarts.sstudio.scenario_construction import build_scenario if __name__ == "__main__": - from configs import gen_parser, gen_pg_config + from configs import gen_parser from rllib_agent import TrainingModel, rllib_agent else: - from .configs import gen_parser, gen_pg_config + from .configs import gen_parser from .rllib_agent import TrainingModel, rllib_agent # Add custom metrics to your tensorboard using these callbacks @@ -107,21 +110,41 @@ def main( agent_specs = agent_values["agent_specs"] smarts.core.seed(seed) - algo_config: AlgorithmConfig = gen_pg_config( - scenario=scenario, - envision=envision, - rollout_fragment_length=rollout_fragment_length, - train_batch_size=train_batch_size, - num_workers=num_workers, - seed=seed, - log_level=log_level, - rllib_policies=rllib_policies, - agent_specs=agent_specs, - callbacks=Callbacks, + assert len(set(rllib_policies.keys()).difference(agent_specs)) == 0 + algo_config: AlgorithmConfig = ( + PGConfig() + .environment( + env=RLlibHiWayEnv, + env_config={ + "seed": seed, + "scenarios": [str(Path(scenario).expanduser().resolve().absolute())], + "headless": not envision, + "agent_specs": agent_specs, + "observation_options": "multi_agent", + }, + disable_env_checking=True, + ) + .framework(framework="tf2", eager_tracing=True) + .rollouts( + rollout_fragment_length=rollout_fragment_length, + num_rollout_workers=num_workers, + num_envs_per_worker=1, + enable_tf1_exec_eagerly=True, + ) + .training( + lr_schedule=[(0, 1e-3), (1e3, 5e-4), (1e5, 1e-4), (1e7, 5e-5), (1e8, 1e-5)], + train_batch_size=train_batch_size, + ) + .multi_agent( + policies=rllib_policies, + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: f"{agent_id}", + ) + .callbacks(callbacks_class=Callbacks) + .debugging(log_level=log_level) ) def get_checkpoint_dir(num): - checkpoint_dir = result_dir / f"checkpoint_{num}" / f"checkpoint-{num}" + checkpoint_dir = Path(result_dir) / f"checkpoint_{num}" / f"checkpoint-{num}" checkpoint_dir.mkdir(parents=True, exist_ok=True) return checkpoint_dir diff --git a/examples/rl/rllib/pg_pbt_example.py b/examples/rl/rllib/pg_pbt_example.py index 4fa50c53d0..b9b49669e2 100644 --- a/examples/rl/rllib/pg_pbt_example.py +++ b/examples/rl/rllib/pg_pbt_example.py @@ -11,7 +11,9 @@ # whether ray[rllib] was installed by user and raises an Exception warning the user to install it if not so. try: from ray import tune - from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks + from ray.rllib.algorithms.algorithm import AlgorithmConfig + from ray.rllib.algorithms.callbacks import DefaultCallbacks + from ray.rllib.algorithms.pg import PGConfig from ray.rllib.env.base_env import BaseEnv from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.episode_v2 import EpisodeV2 @@ -32,10 +34,10 @@ from smarts.sstudio.scenario_construction import build_scenario if __name__ == "__main__": - from configs import gen_parser, gen_pg_config + from configs import gen_parser from rllib_agent import TrainingModel, rllib_agent else: - from .configs import gen_parser, gen_pg_config + from .configs import gen_parser from .rllib_agent import TrainingModel, rllib_agent logging.basicConfig(level=logging.INFO) @@ -152,17 +154,37 @@ def main( agent_specs = agent_values["agent_specs"] smarts.core.seed(seed) - algo_config = gen_pg_config( - scenario=scenario, - envision=envision, - rollout_fragment_length=rollout_fragment_length, - train_batch_size=train_batch_size, - num_workers=max(num_workers, 1), - log_level=log_level, - seed=seed, - callbacks=make_multi_callbacks([Callbacks]), - rllib_policies=rllib_policies, - agent_specs=agent_specs, + assert len(set(rllib_policies.keys()).difference(agent_specs)) == 0 + algo_config: AlgorithmConfig = ( + PGConfig() + .environment( + env=RLlibHiWayEnv, + env_config={ + "seed": seed, + "scenarios": [str(Path(scenario).expanduser().resolve().absolute())], + "headless": not envision, + "agent_specs": agent_specs, + "observation_options": "multi_agent", + }, + disable_env_checking=True, + ) + .framework(framework="tf2", eager_tracing=True) + .rollouts( + rollout_fragment_length=rollout_fragment_length, + num_rollout_workers=num_workers, + num_envs_per_worker=1, + enable_tf1_exec_eagerly=True, + ) + .training( + lr_schedule=[(0, 1e-3), (1e3, 5e-4), (1e5, 1e-4), (1e7, 5e-5), (1e8, 1e-5)], + train_batch_size=train_batch_size, + ) + .multi_agent( + policies=rllib_policies, + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: f"{agent_id}", + ) + .callbacks(callbacks_class=Callbacks) + .debugging(log_level=log_level) ) experiment_name = "rllib_example_multi"