diff --git a/examples/history_vehicles_replacement_for_imitation_learning.py b/examples/history_vehicles_replacement_for_imitation_learning.py index 55dea60f58..cfae2277ef 100644 --- a/examples/history_vehicles_replacement_for_imitation_learning.py +++ b/examples/history_vehicles_replacement_for_imitation_learning.py @@ -20,6 +20,12 @@ def act(self, obs): def main(scenarios, headless, seed): scenarios_iterator = Scenario.scenario_variations(scenarios, []) + smarts = SMARTS( + agent_interfaces={}, + traffic_sim=SumoTrafficSimulation(headless=True, auto_start=True), + envision=Envision(), + ) + for _ in scenarios: scenario = next(scenarios_iterator) agent_missions = scenario.discover_missions_of_traffic_histories() @@ -33,14 +39,10 @@ def main(scenarios, headless, seed): ), agent_builder=KeepLaneAgent, ) - agent = agent_spec.build_agent() - smarts = SMARTS( - agent_interfaces={agent_id: agent_spec.interface}, - traffic_sim=SumoTrafficSimulation(headless=True, auto_start=True), - envision=Envision(), - ) + smarts.switch_ego_agent({agent_id: agent_spec.interface}) + observations = smarts.reset(scenario) dones = {agent_id: False} @@ -52,7 +54,7 @@ def main(scenarios, headless, seed): {agent_id: agent_action} ) - smarts.destroy() + smarts.destroy() if __name__ == "__main__": diff --git a/examples/rllib.py b/examples/rllib.py index fe24592304..541a763a25 100644 --- a/examples/rllib.py +++ b/examples/rllib.py @@ -1,12 +1,21 @@ import argparse +from datetime import timedelta import logging import multiprocessing +from os import stat import random from pathlib import Path +from typing import Dict import numpy as np from ray import tune +from ray.rllib.env.base_env import BaseEnv +from ray.rllib.evaluation.episode import MultiAgentEpisode +from ray.rllib.evaluation.rollout_worker import RolloutWorker +from ray.rllib.policy.policy import Policy +from ray.rllib.utils.typing import PolicyID from ray.tune.schedulers import PopulationBasedTraining +from ray.rllib.agents.callbacks import DefaultCallbacks import smarts from smarts.core.utils.file import copy_tree @@ -18,28 +27,50 @@ # Add custom metrics to your tensorboard using these callbacks -# see: https://ray.readthedocs.io/en/latest/rllib-training.html#callbacks-and-custom-metrics -def on_episode_start(info): - episode = info["episode"] - episode.user_data["ego_speed"] = [] +# See: https://ray.readthedocs.io/en/latest/rllib-training.html#callbacks-and-custom-metrics +class Callbacks(DefaultCallbacks): + @staticmethod + def on_episode_start( + worker: RolloutWorker, + base_env: BaseEnv, + policies: Dict[PolicyID, Policy], + episode: MultiAgentEpisode, + env_index: int, + **kwargs, + ): + episode.user_data["ego_speed"] = [] -def on_episode_step(info): - episode = info["episode"] - single_agent_id = list(episode._agent_to_last_obs)[0] - obs = episode.last_raw_obs_for(single_agent_id) - episode.user_data["ego_speed"].append(obs["speed"]) + @staticmethod + def on_episode_step( + worker: RolloutWorker, + base_env: BaseEnv, + episode: MultiAgentEpisode, + env_index: int, + **kwargs, + ): + single_agent_id = list(episode._agent_to_last_obs)[0] + obs = episode.last_raw_obs_for(single_agent_id) + episode.user_data["ego_speed"].append(obs["speed"]) -def on_episode_end(info): - episode = info["episode"] - mean_ego_speed = np.mean(episode.user_data["ego_speed"]) - print( - f"ep. {episode.episode_id:<12} ended;" - f" length={episode.length:<6}" - f" mean_ego_speed={mean_ego_speed:.2f}" - ) - episode.custom_metrics["mean_ego_speed"] = mean_ego_speed + @staticmethod + def on_episode_end( + worker: RolloutWorker, + base_env: BaseEnv, + policies: Dict[PolicyID, Policy], + episode: MultiAgentEpisode, + env_index: int, + **kwargs, + ): + + mean_ego_speed = np.mean(episode.user_data["ego_speed"]) + print( + f"ep. {episode.episode_id:<12} ended;" + f" length={episode.length:<6}" + f" mean_ego_speed={mean_ego_speed:.2f}" + ) + episode.custom_metrics["mean_ego_speed"] = mean_ego_speed def explore(config): @@ -53,6 +84,8 @@ def main( scenario, headless, time_total_s, + rollout_fragment_length, + train_batch_size, seed, num_samples, num_agents, @@ -62,6 +95,10 @@ def main( checkpoint_num, save_model_path, ): + assert train_batch_size > 0, f"{train_batch_size.__name__} cannot be less than 1." + if rollout_fragment_length > train_batch_size: + rollout_fragment_length = train_batch_size + pbt = PopulationBasedTraining( time_attr="time_total_s", metric="episode_reward_mean", @@ -69,10 +106,11 @@ def main( perturbation_interval=300, resample_probability=0.25, # Specifies the mutations of these hyperparams + # See: `ray.rllib.agents.trainer.COMMON_CONFIG` for common hyperparams hyperparam_mutations={ "lr": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5], - "rollout_fragment_length": lambda: random.randint(128, 16384), - "train_batch_size": lambda: random.randint(2000, 160000), + "rollout_fragment_length": lambda: rollout_fragment_length, + "train_batch_size": lambda: train_batch_size, }, # Specifies additional mutations after hyperparam_mutations is applied custom_explore_fn=explore, @@ -104,11 +142,7 @@ def main( }, }, "multiagent": {"policies": rllib_policies}, - "callbacks": { - "on_episode_start": on_episode_start, - "on_episode_step": on_episode_step, - "on_episode_end": on_episode_end, - }, + "callbacks": Callbacks, } experiment_name = "rllib_example_multi" @@ -139,7 +173,7 @@ def main( print(analysis.dataframe().head()) - best_logdir = Path(analysis.get_best_logdir("episode_reward_max")) + best_logdir = Path(analysis.get_best_logdir("episode_reward_max", mode="max")) model_path = best_logdir / "model" copy_tree(str(model_path), save_model_path, overwrite=True) @@ -165,11 +199,23 @@ def main( default=1, help="Number of times to sample from hyperparameter space", ) + parser.add_argument( + "--rollout_fragment_length", + type=int, + default=200, + help="Episodes are divided into fragments of this many steps for each rollout. In this example this will be ensured to be `1= -