Skip to content

Commit

Permalink
Fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gamenot committed Jun 20, 2023
1 parent 56301b5 commit 02da98e
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 88 deletions.
4 changes: 2 additions & 2 deletions docs/ecosystem/rllib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ deep learning frameworks.

SMARTS contains two examples using `Policy Gradients (PG) <https://docs.ray.io/en/latest/rllib-algorithms.html#policy-gradients-pg>`_.

1. rllib/pg_example.py
1. ``rllib/pg_example.py``
This example shows the basics of using RLlib with SMARTS through :class:`~smarts.env.rllib_hiway_env.RLlibHiWayEnv`.
1. rllib/pg_pbt_example.py
1. ``rllib/pg_pbt_example.py``
This example combines Policy Gradients with `Population Based Training (PBT) <https://docs.ray.io/en/latest/tune/api/doc/ray.tune.schedulers.PopulationBasedTraining.html>`_ scheduling.

Recommended reads
Expand Down
58 changes: 0 additions & 58 deletions examples/rl/rllib/configs.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,6 @@
import argparse
import multiprocessing
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union

try:
from ray.rllib.algorithms.algorithm import AlgorithmConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.pg import PGConfig
from ray.tune.search.sample import Integer as IntegerDomain
except Exception as e:
from smarts.core.utils.custom_exceptions import RayException

raise RayException.required_to("rllib.py")


def gen_pg_config(
scenario,
envision,
rollout_fragment_length,
train_batch_size,
num_workers,
log_level: Literal["DEBUG", "INFO", "WARN", "ERROR"],
seed: Union[int, IntegerDomain],
rllib_policies: Dict[str, Any],
agent_specs: Dict[str, Any],
callbacks: Optional[DefaultCallbacks],
) -> AlgorithmConfig:
assert len(set(rllib_policies.keys()).difference(agent_specs)) == 0
algo_config = (
PGConfig()
.environment(
env="rllib_hiway-v0",
env_config={
"seed": seed,
"scenarios": [str(Path(scenario).expanduser().resolve().absolute())],
"headless": not envision,
"agent_specs": agent_specs,
"observation_options": "multi_agent",
},
disable_env_checking=True,
)
.framework(framework="tf2", eager_tracing=True)
.rollouts(
rollout_fragment_length=rollout_fragment_length,
num_rollout_workers=num_workers,
num_envs_per_worker=1,
enable_tf1_exec_eagerly=True,
)
.training(
lr_schedule=[(0, 1e-3), (1e3, 5e-4), (1e5, 1e-4), (1e7, 5e-5), (1e8, 1e-5)],
train_batch_size=train_batch_size,
)
.multi_agent(
policies=rllib_policies,
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: f"{agent_id}",
)
.callbacks(callbacks_class=callbacks)
.debugging(log_level=log_level)
)
return algo_config


def gen_parser(prog: str, default_result_dir: str) -> argparse.ArgumentParser:
Expand Down
51 changes: 37 additions & 14 deletions examples/rl/rllib/pg_example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from pathlib import Path
from pprint import pprint as print
from typing import Dict, Literal, Optional, Union

import numpy as np

try:
from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.pg import PGConfig
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
Expand All @@ -18,13 +20,14 @@
raise RayException.required_to("rllib.py")

import smarts
from smarts.env.rllib_hiway_env import RLlibHiWayEnv
from smarts.sstudio.scenario_construction import build_scenario

if __name__ == "__main__":
from configs import gen_parser, gen_pg_config
from configs import gen_parser
from rllib_agent import TrainingModel, rllib_agent
else:
from .configs import gen_parser, gen_pg_config
from .configs import gen_parser
from .rllib_agent import TrainingModel, rllib_agent

# Add custom metrics to your tensorboard using these callbacks
Expand Down Expand Up @@ -107,21 +110,41 @@ def main(
agent_specs = agent_values["agent_specs"]

smarts.core.seed(seed)
algo_config: AlgorithmConfig = gen_pg_config(
scenario=scenario,
envision=envision,
rollout_fragment_length=rollout_fragment_length,
train_batch_size=train_batch_size,
num_workers=num_workers,
seed=seed,
log_level=log_level,
rllib_policies=rllib_policies,
agent_specs=agent_specs,
callbacks=Callbacks,
assert len(set(rllib_policies.keys()).difference(agent_specs)) == 0
algo_config: AlgorithmConfig = (
PGConfig()
.environment(
env=RLlibHiWayEnv,
env_config={
"seed": seed,
"scenarios": [str(Path(scenario).expanduser().resolve().absolute())],
"headless": not envision,
"agent_specs": agent_specs,
"observation_options": "multi_agent",
},
disable_env_checking=True,
)
.framework(framework="tf2", eager_tracing=True)
.rollouts(
rollout_fragment_length=rollout_fragment_length,
num_rollout_workers=num_workers,
num_envs_per_worker=1,
enable_tf1_exec_eagerly=True,
)
.training(
lr_schedule=[(0, 1e-3), (1e3, 5e-4), (1e5, 1e-4), (1e7, 5e-5), (1e8, 1e-5)],
train_batch_size=train_batch_size,
)
.multi_agent(
policies=rllib_policies,
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: f"{agent_id}",
)
.callbacks(callbacks_class=Callbacks)
.debugging(log_level=log_level)
)

def get_checkpoint_dir(num):
checkpoint_dir = result_dir / f"checkpoint_{num}" / f"checkpoint-{num}"
checkpoint_dir = Path(result_dir) / f"checkpoint_{num}" / f"checkpoint-{num}"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
return checkpoint_dir

Expand Down
50 changes: 36 additions & 14 deletions examples/rl/rllib/pg_pbt_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# whether ray[rllib] was installed by user and raises an Exception warning the user to install it if not so.
try:
from ray import tune
from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks
from ray.rllib.algorithms.algorithm import AlgorithmConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.pg import PGConfig
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
Expand All @@ -32,10 +34,10 @@
from smarts.sstudio.scenario_construction import build_scenario

if __name__ == "__main__":
from configs import gen_parser, gen_pg_config
from configs import gen_parser
from rllib_agent import TrainingModel, rllib_agent
else:
from .configs import gen_parser, gen_pg_config
from .configs import gen_parser
from .rllib_agent import TrainingModel, rllib_agent

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -152,17 +154,37 @@ def main(
agent_specs = agent_values["agent_specs"]

smarts.core.seed(seed)
algo_config = gen_pg_config(
scenario=scenario,
envision=envision,
rollout_fragment_length=rollout_fragment_length,
train_batch_size=train_batch_size,
num_workers=max(num_workers, 1),
log_level=log_level,
seed=seed,
callbacks=make_multi_callbacks([Callbacks]),
rllib_policies=rllib_policies,
agent_specs=agent_specs,
assert len(set(rllib_policies.keys()).difference(agent_specs)) == 0
algo_config: AlgorithmConfig = (
PGConfig()
.environment(
env=RLlibHiWayEnv,
env_config={
"seed": seed,
"scenarios": [str(Path(scenario).expanduser().resolve().absolute())],
"headless": not envision,
"agent_specs": agent_specs,
"observation_options": "multi_agent",
},
disable_env_checking=True,
)
.framework(framework="tf2", eager_tracing=True)
.rollouts(
rollout_fragment_length=rollout_fragment_length,
num_rollout_workers=num_workers,
num_envs_per_worker=1,
enable_tf1_exec_eagerly=True,
)
.training(
lr_schedule=[(0, 1e-3), (1e3, 5e-4), (1e5, 1e-4), (1e7, 5e-5), (1e8, 1e-5)],
train_batch_size=train_batch_size,
)
.multi_agent(
policies=rllib_policies,
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: f"{agent_id}",
)
.callbacks(callbacks_class=Callbacks)
.debugging(log_level=log_level)
)

experiment_name = "rllib_example_multi"
Expand Down

0 comments on commit 02da98e

Please sign in to comment.