-
Notifications
You must be signed in to change notification settings - Fork 190
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
486 additions
and
248 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import argparse | ||
import multiprocessing | ||
from pathlib import Path | ||
from typing import Any, Dict, Literal, Optional, Union | ||
|
||
try: | ||
from ray.rllib.algorithms.algorithm import AlgorithmConfig | ||
from ray.rllib.algorithms.callbacks import DefaultCallbacks | ||
from ray.rllib.algorithms.pg import PGConfig | ||
from ray.tune.search.sample import Integer as IntegerDomain | ||
except Exception as e: | ||
from smarts.core.utils.custom_exceptions import RayException | ||
|
||
raise RayException.required_to("rllib.py") | ||
|
||
|
||
def gen_pg_config( | ||
scenario, | ||
envision, | ||
rollout_fragment_length, | ||
train_batch_size, | ||
num_workers, | ||
log_level: Literal["DEBUG", "INFO", "WARN", "ERROR"], | ||
seed: Union[int, IntegerDomain], | ||
rllib_policies: Dict[str, Any], | ||
agent_specs: Dict[str, Any], | ||
callbacks: Optional[DefaultCallbacks], | ||
) -> AlgorithmConfig: | ||
assert len(set(rllib_policies.keys()).difference(agent_specs)) == 0 | ||
algo_config = ( | ||
PGConfig() | ||
.environment( | ||
env="rllib_hiway-v0", | ||
env_config={ | ||
"seed": seed, | ||
"scenarios": [str(Path(scenario).expanduser().resolve().absolute())], | ||
"headless": not envision, | ||
"agent_specs": agent_specs, | ||
"observation_options": "multi_agent", | ||
}, | ||
disable_env_checking=True, | ||
) | ||
.framework(framework="tf2", eager_tracing=True) | ||
.rollouts( | ||
rollout_fragment_length=rollout_fragment_length, | ||
num_rollout_workers=num_workers, | ||
num_envs_per_worker=1, | ||
enable_tf1_exec_eagerly=True, | ||
) | ||
.training( | ||
lr_schedule=[(0, 1e-3), (1e3, 5e-4), (1e5, 1e-4), (1e7, 5e-5), (1e8, 1e-5)], | ||
train_batch_size=train_batch_size, | ||
) | ||
.multi_agent( | ||
policies=rllib_policies, | ||
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: f"{agent_id}", | ||
) | ||
.callbacks(callbacks_class=callbacks) | ||
.debugging(log_level=log_level) | ||
) | ||
return algo_config | ||
|
||
|
||
def gen_parser( | ||
prog: str, default_result_dir: str, default_save_model_path: str | ||
) -> argparse.ArgumentParser: | ||
parser = argparse.ArgumentParser(prog) | ||
parser.add_argument( | ||
"--scenario", | ||
type=str, | ||
default=str(Path(__file__).resolve().parents[3] / "scenarios/sumo/loop"), | ||
help="Scenario to run (see scenarios/ for some samples you can use)", | ||
) | ||
parser.add_argument( | ||
"--envision", | ||
action="store_true", | ||
help="Run simulation with Envision display.", | ||
) | ||
parser.add_argument( | ||
"--num_samples", | ||
type=int, | ||
default=1, | ||
help="Number of times to sample from hyperparameter space", | ||
) | ||
parser.add_argument( | ||
"--rollout_fragment_length", | ||
type=str, | ||
default="auto", | ||
help="Episodes are divided into fragments of this many steps for each rollout. In this example this will be ensured to be `1=<rollout_fragment_length<=train_batch_size`", | ||
) | ||
parser.add_argument( | ||
"--train_batch_size", | ||
type=int, | ||
default=2000, | ||
help="The training batch size. This value must be > 0.", | ||
) | ||
parser.add_argument( | ||
"--time_total_s", | ||
type=int, | ||
default=1 * 60 * 60, # 1 hour | ||
help="Total time in seconds to run the simulation for. This is a rough end time as it will be checked per training batch.", | ||
) | ||
parser.add_argument( | ||
"--seed", | ||
type=int, | ||
default=42, | ||
help="The base random seed to use, intended to be mixed with --num_samples", | ||
) | ||
parser.add_argument( | ||
"--num_agents", type=int, default=2, help="Number of agents (one per policy)" | ||
) | ||
parser.add_argument( | ||
"--num_workers", | ||
type=int, | ||
default=(multiprocessing.cpu_count() // 2 + 1), | ||
help="Number of workers (defaults to use all system cores)", | ||
) | ||
parser.add_argument( | ||
"--resume_training", | ||
default=False, | ||
action="store_true", | ||
help="Resume an errored or 'ctrl+c' cancelled training. This does not extend a fully run original experiment.", | ||
) | ||
parser.add_argument( | ||
"--result_dir", | ||
type=str, | ||
default=default_result_dir, | ||
help="Directory containing results", | ||
) | ||
parser.add_argument( | ||
"--log_level", | ||
type=str, | ||
default="ERROR", | ||
help="Log level (DEBUG|INFO|WARN|ERROR)", | ||
) | ||
parser.add_argument( | ||
"--checkpoint_num", type=int, default=None, help="Checkpoint number" | ||
) | ||
parser.add_argument( | ||
"--checkpoint_freq", type=int, default=3, help="Checkpoint frequency" | ||
) | ||
|
||
parser.add_argument( | ||
"--save_model_path", | ||
type=str, | ||
default=default_save_model_path, | ||
help="Destination path of where to copy the model when training is over", | ||
) | ||
return parser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
from pathlib import Path | ||
from typing import Dict, Literal, Optional, Union | ||
|
||
import numpy as np | ||
|
||
try: | ||
from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig | ||
from ray.rllib.algorithms.callbacks import DefaultCallbacks | ||
from ray.rllib.env.base_env import BaseEnv | ||
from ray.rllib.evaluation.episode import Episode | ||
from ray.rllib.evaluation.episode_v2 import EpisodeV2 | ||
from ray.rllib.evaluation.rollout_worker import RolloutWorker | ||
from ray.rllib.policy.policy import Policy | ||
from ray.rllib.utils.typing import PolicyID | ||
except Exception as e: | ||
from smarts.core.utils.custom_exceptions import RayException | ||
|
||
raise RayException.required_to("rllib.py") | ||
|
||
import smarts | ||
from smarts.sstudio.scenario_construction import build_scenario | ||
|
||
if __name__ == "__main__": | ||
from configs import gen_parser, gen_pg_config | ||
from rllib_agent import TrainingModel, rllib_agent | ||
else: | ||
from .configs import gen_parser, gen_pg_config | ||
from .rllib_agent import TrainingModel, rllib_agent | ||
|
||
# Add custom metrics to your tensorboard using these callbacks | ||
# See: https://ray.readthedocs.io/en/latest/rllib-training.html#callbacks-and-custom-metrics | ||
class Callbacks(DefaultCallbacks): | ||
@staticmethod | ||
def on_episode_start( | ||
worker: RolloutWorker, | ||
base_env: BaseEnv, | ||
policies: Dict[PolicyID, Policy], | ||
episode: Union[Episode, EpisodeV2], | ||
env_index: int, | ||
**kwargs, | ||
): | ||
|
||
episode.user_data["ego_reward"] = [] | ||
|
||
@staticmethod | ||
def on_episode_step( | ||
worker: RolloutWorker, | ||
base_env: BaseEnv, | ||
episode: Union[Episode, EpisodeV2], | ||
env_index: int, | ||
**kwargs, | ||
): | ||
single_agent_id = list(episode.get_agents())[0] | ||
infos = episode._last_infos.get(single_agent_id) | ||
if infos is not None: | ||
episode.user_data["ego_reward"].append(infos["reward"]) | ||
|
||
@staticmethod | ||
def on_episode_end( | ||
worker: RolloutWorker, | ||
base_env: BaseEnv, | ||
policies: Dict[PolicyID, Policy], | ||
episode: Union[Episode, EpisodeV2], | ||
env_index: int, | ||
**kwargs, | ||
): | ||
|
||
mean_ego_speed = np.mean(episode.user_data["ego_reward"]) | ||
print( | ||
f"ep. {episode.episode_id:<12} ended;" | ||
f" length={episode.length:<6}" | ||
f" mean_ego_reward={mean_ego_speed:.2f}" | ||
) | ||
episode.custom_metrics["mean_ego_reward"] = mean_ego_speed | ||
|
||
|
||
def main( | ||
scenario, | ||
envision, | ||
time_total_s, | ||
rollout_fragment_length, | ||
train_batch_size, | ||
seed, | ||
num_samples, | ||
num_agents, | ||
num_workers, | ||
resume_training, | ||
result_dir, | ||
checkpoint_freq: int, | ||
checkpoint_num: Optional[int], | ||
log_level: Literal["DEBUG", "INFO", "WARN", "ERROR"], | ||
save_model_path, | ||
): | ||
agent_values = { | ||
"agent_specs": { | ||
f"AGENT-{i}": rllib_agent["agent_spec"] for i in range(num_agents) | ||
}, | ||
"rllib_policies": { | ||
f"AGENT-{i}": ( | ||
None, | ||
rllib_agent["observation_space"], | ||
rllib_agent["action_space"], | ||
{"model": {"custom_model": TrainingModel.NAME}}, | ||
) | ||
for i in range(num_agents) | ||
}, | ||
} | ||
rllib_policies = agent_values["rllib_policies"] | ||
agent_specs = agent_values["agent_specs"] | ||
|
||
smarts.core.seed(seed) | ||
algo_config: AlgorithmConfig = gen_pg_config( | ||
scenario=scenario, | ||
envision=envision, | ||
rollout_fragment_length=rollout_fragment_length, | ||
train_batch_size=train_batch_size, | ||
num_workers=num_workers, | ||
seed=seed, | ||
log_level=log_level, | ||
rllib_policies=rllib_policies, | ||
agent_specs=agent_specs, | ||
callbacks=Callbacks, | ||
) | ||
|
||
def get_checkpoint_dir(num): | ||
checkpoint_dir = result_dir / f"checkpoint_{num}" / f"checkpoint-{num}" | ||
checkpoint_dir.mkdir(parents=True, exist_ok=True) | ||
return checkpoint_dir | ||
|
||
if resume_training: | ||
checkpoint = str(get_checkpoint_dir("latest")) | ||
if checkpoint_num: | ||
checkpoint = str(get_checkpoint_dir(checkpoint_num)) | ||
else: | ||
checkpoint = None | ||
|
||
print(f"======= Checkpointing at {str(result_dir)} =======") | ||
|
||
algo = algo_config.build() | ||
if checkpoint is not None: | ||
Algorithm.load_checkpoint(algo, checkpoint=checkpoint) | ||
result = {} | ||
current_iteration = 0 | ||
checkpoint_iteration = checkpoint_num or 0 | ||
|
||
try: | ||
while result.get("time_total_s", 0) < time_total_s: | ||
result = algo.train() | ||
print(f"======== Iteration {result['training_iteration']} ========") | ||
print(result, depth=1) | ||
|
||
if current_iteration % checkpoint_freq == 0: | ||
checkpoint_dir = get_checkpoint_dir(checkpoint_iteration) | ||
print(f"======= Saving checkpoint {checkpoint_iteration} =======") | ||
algo.save_checkpoint(checkpoint_dir) | ||
checkpoint_iteration += 1 | ||
current_iteration += 1 | ||
algo.save_checkpoint(get_checkpoint_dir(checkpoint_iteration)) | ||
finally: | ||
algo.save_checkpoint(get_checkpoint_dir("latest")) | ||
algo.stop() | ||
|
||
|
||
if __name__ == "__main__": | ||
default_save_model_path = str( | ||
Path(__file__).expanduser().resolve().parent / "pg_model" | ||
) | ||
default_result_dir = str(Path(__file__).resolve().parent / "results" / "pg_results") | ||
parser = gen_parser("rllib-example", default_result_dir, default_save_model_path) | ||
|
||
args = parser.parse_args() | ||
build_scenario(scenario=args.scenario, clean=False, seed=42) | ||
|
||
main( | ||
scenario=args.scenario, | ||
envision=args.envision, | ||
time_total_s=args.time_total_s, | ||
rollout_fragment_length=args.rollout_fragment_length, | ||
train_batch_size=args.train_batch_size, | ||
seed=args.seed, | ||
num_samples=args.num_samples, | ||
num_agents=args.num_agents, | ||
num_workers=args.num_workers, | ||
resume_training=args.resume_training, | ||
result_dir=args.result_dir, | ||
checkpoint_freq=max(args.checkpoint_freq, 1), | ||
checkpoint_num=args.checkpoint_num, | ||
log_level=args.log_level, | ||
save_model_path=args.save_model_path, | ||
) |
Oops, something went wrong.