Skip to content

Commit

Permalink
Update examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gamenot committed Jun 19, 2023
1 parent 27e91c9 commit 9bf091a
Show file tree
Hide file tree
Showing 8 changed files with 486 additions and 248 deletions.
4 changes: 2 additions & 2 deletions docs/ecosystem/rllib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ many docs about ``Ray`` and ``RLlib``. We recommend to read the following pages
Resume training
---------------

With respect to ``SMARTS/examples/rl/rllib`` example, if you want to continue an aborted experiment, you can set ``resume=True`` in ``tune.run``. But note that ``resume=True`` will continue to use the same configuration as was set in the original experiment.
With respect to ``SMARTS/examples/rl/rllib`` examples, if you want to continue an aborted experiment, you can set ``resume_training=True``. But note that ``resume_training=True`` will continue to use the same configuration as was set in the original experiment.
To make changes to a started experiment, you can edit the latest experiment file in ``./results``.

Or if you want to start a new experiment but train from an existing checkpoint, you can set ``restore=checkpoint_path`` in ``tune.run``.
Or if you want to start a new experiment but train from an existing checkpoint, you will need to look into `How to Save and Load Trial Checkpoints <https://docs.ray.io/en/latest/tune/tutorials/tune-trial-checkpoints>`_.
149 changes: 149 additions & 0 deletions examples/rl/rllib/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import argparse
import multiprocessing
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union

try:
from ray.rllib.algorithms.algorithm import AlgorithmConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.pg import PGConfig
from ray.tune.search.sample import Integer as IntegerDomain
except Exception as e:
from smarts.core.utils.custom_exceptions import RayException

raise RayException.required_to("rllib.py")


def gen_pg_config(
scenario,
envision,
rollout_fragment_length,
train_batch_size,
num_workers,
log_level: Literal["DEBUG", "INFO", "WARN", "ERROR"],
seed: Union[int, IntegerDomain],
rllib_policies: Dict[str, Any],
agent_specs: Dict[str, Any],
callbacks: Optional[DefaultCallbacks],
) -> AlgorithmConfig:
assert len(set(rllib_policies.keys()).difference(agent_specs)) == 0
algo_config = (
PGConfig()
.environment(
env="rllib_hiway-v0",
env_config={
"seed": seed,
"scenarios": [str(Path(scenario).expanduser().resolve().absolute())],
"headless": not envision,
"agent_specs": agent_specs,
"observation_options": "multi_agent",
},
disable_env_checking=True,
)
.framework(framework="tf2", eager_tracing=True)
.rollouts(
rollout_fragment_length=rollout_fragment_length,
num_rollout_workers=num_workers,
num_envs_per_worker=1,
enable_tf1_exec_eagerly=True,
)
.training(
lr_schedule=[(0, 1e-3), (1e3, 5e-4), (1e5, 1e-4), (1e7, 5e-5), (1e8, 1e-5)],
train_batch_size=train_batch_size,
)
.multi_agent(
policies=rllib_policies,
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: f"{agent_id}",
)
.callbacks(callbacks_class=callbacks)
.debugging(log_level=log_level)
)
return algo_config


def gen_parser(
prog: str, default_result_dir: str, default_save_model_path: str
) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog)
parser.add_argument(
"--scenario",
type=str,
default=str(Path(__file__).resolve().parents[3] / "scenarios/sumo/loop"),
help="Scenario to run (see scenarios/ for some samples you can use)",
)
parser.add_argument(
"--envision",
action="store_true",
help="Run simulation with Envision display.",
)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of times to sample from hyperparameter space",
)
parser.add_argument(
"--rollout_fragment_length",
type=str,
default="auto",
help="Episodes are divided into fragments of this many steps for each rollout. In this example this will be ensured to be `1=<rollout_fragment_length<=train_batch_size`",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=2000,
help="The training batch size. This value must be > 0.",
)
parser.add_argument(
"--time_total_s",
type=int,
default=1 * 60 * 60, # 1 hour
help="Total time in seconds to run the simulation for. This is a rough end time as it will be checked per training batch.",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The base random seed to use, intended to be mixed with --num_samples",
)
parser.add_argument(
"--num_agents", type=int, default=2, help="Number of agents (one per policy)"
)
parser.add_argument(
"--num_workers",
type=int,
default=(multiprocessing.cpu_count() // 2 + 1),
help="Number of workers (defaults to use all system cores)",
)
parser.add_argument(
"--resume_training",
default=False,
action="store_true",
help="Resume an errored or 'ctrl+c' cancelled training. This does not extend a fully run original experiment.",
)
parser.add_argument(
"--result_dir",
type=str,
default=default_result_dir,
help="Directory containing results",
)
parser.add_argument(
"--log_level",
type=str,
default="ERROR",
help="Log level (DEBUG|INFO|WARN|ERROR)",
)
parser.add_argument(
"--checkpoint_num", type=int, default=None, help="Checkpoint number"
)
parser.add_argument(
"--checkpoint_freq", type=int, default=3, help="Checkpoint frequency"
)

parser.add_argument(
"--save_model_path",
type=str,
default=default_save_model_path,
help="Destination path of where to copy the model when training is over",
)
return parser
190 changes: 190 additions & 0 deletions examples/rl/rllib/pg_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from pathlib import Path
from typing import Dict, Literal, Optional, Union

import numpy as np

try:
from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import PolicyID
except Exception as e:
from smarts.core.utils.custom_exceptions import RayException

raise RayException.required_to("rllib.py")

import smarts
from smarts.sstudio.scenario_construction import build_scenario

if __name__ == "__main__":
from configs import gen_parser, gen_pg_config
from rllib_agent import TrainingModel, rllib_agent
else:
from .configs import gen_parser, gen_pg_config
from .rllib_agent import TrainingModel, rllib_agent

# Add custom metrics to your tensorboard using these callbacks
# See: https://ray.readthedocs.io/en/latest/rllib-training.html#callbacks-and-custom-metrics
class Callbacks(DefaultCallbacks):
@staticmethod
def on_episode_start(
worker: RolloutWorker,
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Union[Episode, EpisodeV2],
env_index: int,
**kwargs,
):

episode.user_data["ego_reward"] = []

@staticmethod
def on_episode_step(
worker: RolloutWorker,
base_env: BaseEnv,
episode: Union[Episode, EpisodeV2],
env_index: int,
**kwargs,
):
single_agent_id = list(episode.get_agents())[0]
infos = episode._last_infos.get(single_agent_id)
if infos is not None:
episode.user_data["ego_reward"].append(infos["reward"])

@staticmethod
def on_episode_end(
worker: RolloutWorker,
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Union[Episode, EpisodeV2],
env_index: int,
**kwargs,
):

mean_ego_speed = np.mean(episode.user_data["ego_reward"])
print(
f"ep. {episode.episode_id:<12} ended;"
f" length={episode.length:<6}"
f" mean_ego_reward={mean_ego_speed:.2f}"
)
episode.custom_metrics["mean_ego_reward"] = mean_ego_speed


def main(
scenario,
envision,
time_total_s,
rollout_fragment_length,
train_batch_size,
seed,
num_samples,
num_agents,
num_workers,
resume_training,
result_dir,
checkpoint_freq: int,
checkpoint_num: Optional[int],
log_level: Literal["DEBUG", "INFO", "WARN", "ERROR"],
save_model_path,
):
agent_values = {
"agent_specs": {
f"AGENT-{i}": rllib_agent["agent_spec"] for i in range(num_agents)
},
"rllib_policies": {
f"AGENT-{i}": (
None,
rllib_agent["observation_space"],
rllib_agent["action_space"],
{"model": {"custom_model": TrainingModel.NAME}},
)
for i in range(num_agents)
},
}
rllib_policies = agent_values["rllib_policies"]
agent_specs = agent_values["agent_specs"]

smarts.core.seed(seed)
algo_config: AlgorithmConfig = gen_pg_config(
scenario=scenario,
envision=envision,
rollout_fragment_length=rollout_fragment_length,
train_batch_size=train_batch_size,
num_workers=num_workers,
seed=seed,
log_level=log_level,
rllib_policies=rllib_policies,
agent_specs=agent_specs,
callbacks=Callbacks,
)

def get_checkpoint_dir(num):
checkpoint_dir = result_dir / f"checkpoint_{num}" / f"checkpoint-{num}"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
return checkpoint_dir

if resume_training:
checkpoint = str(get_checkpoint_dir("latest"))
if checkpoint_num:
checkpoint = str(get_checkpoint_dir(checkpoint_num))
else:
checkpoint = None

print(f"======= Checkpointing at {str(result_dir)} =======")

algo = algo_config.build()
if checkpoint is not None:
Algorithm.load_checkpoint(algo, checkpoint=checkpoint)
result = {}
current_iteration = 0
checkpoint_iteration = checkpoint_num or 0

try:
while result.get("time_total_s", 0) < time_total_s:
result = algo.train()
print(f"======== Iteration {result['training_iteration']} ========")
print(result, depth=1)

if current_iteration % checkpoint_freq == 0:
checkpoint_dir = get_checkpoint_dir(checkpoint_iteration)
print(f"======= Saving checkpoint {checkpoint_iteration} =======")
algo.save_checkpoint(checkpoint_dir)
checkpoint_iteration += 1
current_iteration += 1
algo.save_checkpoint(get_checkpoint_dir(checkpoint_iteration))
finally:
algo.save_checkpoint(get_checkpoint_dir("latest"))
algo.stop()


if __name__ == "__main__":
default_save_model_path = str(
Path(__file__).expanduser().resolve().parent / "pg_model"
)
default_result_dir = str(Path(__file__).resolve().parent / "results" / "pg_results")
parser = gen_parser("rllib-example", default_result_dir, default_save_model_path)

args = parser.parse_args()
build_scenario(scenario=args.scenario, clean=False, seed=42)

main(
scenario=args.scenario,
envision=args.envision,
time_total_s=args.time_total_s,
rollout_fragment_length=args.rollout_fragment_length,
train_batch_size=args.train_batch_size,
seed=args.seed,
num_samples=args.num_samples,
num_agents=args.num_agents,
num_workers=args.num_workers,
resume_training=args.resume_training,
result_dir=args.result_dir,
checkpoint_freq=max(args.checkpoint_freq, 1),
checkpoint_num=args.checkpoint_num,
log_level=args.log_level,
save_model_path=args.save_model_path,
)
Loading

0 comments on commit 9bf091a

Please sign in to comment.