diff --git a/CHANGELOG.md b/CHANGELOG.md index b4c9f70ce9..ba8268db43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ Copy and pasting the git commit messages is __NOT__ enough. ## [Unreleased] ### Added - Added action formatting option to `hiway-v0`. +- Introduced `debug: serial: bool` option to driving smarts benchmark config. ### Changed - Moved action and observation conversions from `smarts.env.gymnasium.utils` to `smarts.env.utils`. ### Deprecated diff --git a/smarts/benchmark/driving_smarts/v0/config.yaml b/smarts/benchmark/driving_smarts/v0/config.yaml index 7c715c6605..da20af56a2 100644 --- a/smarts/benchmark/driving_smarts/v0/config.yaml +++ b/smarts/benchmark/driving_smarts/v0/config.yaml @@ -12,6 +12,8 @@ benchmark: - https://codalab.lisn.upsaclay.fr/competitions/6618 - https://smarts-project.github.io/archive/2022_nips_driving_smarts/ eval_episodes: 50 + debug: + serial: False shared_env_kwargs: seed: 42 headless: true @@ -41,4 +43,4 @@ benchmark: # kwargs: # scenario_dirs: # - "./scenarios/naturalistic/waymo" - # - "./scenarios/naturalistic/ngsim" + # - "./scenarios/naturalistic/ngsim" \ No newline at end of file diff --git a/smarts/benchmark/entrypoints/benchmark_runner_v0.py b/smarts/benchmark/entrypoints/benchmark_runner_v0.py index 3d1c81d599..4aaa8a7b0f 100644 --- a/smarts/benchmark/entrypoints/benchmark_runner_v0.py +++ b/smarts/benchmark/entrypoints/benchmark_runner_v0.py @@ -21,7 +21,7 @@ # THE SOFTWARE. import logging import os -from typing import List, Tuple +from typing import Callable, Dict, Generator, List, Tuple import gymnasium as gym import psutil @@ -39,6 +39,10 @@ @ray.remote(num_returns=1) def _eval_worker(name, env_config, episodes, agent_config, error_tolerant=False): + return _eval_worker_local(name, env_config, episodes, agent_config, error_tolerant) + + +def _eval_worker_local(name, env_config, episodes, agent_config, error_tolerant=False): import warnings warnings.filterwarnings("ignore") @@ -81,7 +85,7 @@ def _eval_worker(name, env_config, episodes, agent_config, error_tolerant=False) return name, score -def _task_iterator(env_args, benchmark_args, agent_args, log_workers): +def _parallel_task_iterator(env_args, benchmark_args, agent_args, log_workers): num_cpus = max(1, min(len(os.sched_getaffinity(0)), psutil.cpu_count(False) or 4)) with suppress_output(stdout=True): @@ -110,6 +114,19 @@ def _task_iterator(env_args, benchmark_args, agent_args, log_workers): ray.shutdown() +def _serial_task_iterator(env_args, benchmark_args, agent_args, *args, **_): + for name, env_config in env_args.items(): + print(f"Evaluating {name}...") + name, score = _eval_worker_local( + name=name, + env_config=env_config, + episodes=benchmark_args["eval_episodes"], + agent_config=agent_args, + error_tolerant=ERROR_TOLERANT, + ) + yield name, score + + def benchmark(benchmark_args, agent_args, log_workers=False): """Runs the benchmark using the following: Args: @@ -118,6 +135,7 @@ def benchmark(benchmark_args, agent_args, log_workers=False): debug_log(bool): Whether the benchmark should log to stdout. """ print(f"Starting `{benchmark_args['name']}` benchmark.") + debug = benchmark_args.get("debug", {}) message = benchmark_args.get("message") if message is not None: print(message) @@ -133,7 +151,9 @@ def benchmark(benchmark_args, agent_args, log_workers=False): ) named_scores = [] - for name, score in _task_iterator( + iterator = _serial_task_iterator if debug.get("serial") else _parallel_task_iterator + + for name, score in iterator( env_args=env_args, benchmark_args=benchmark_args, agent_args=agent_args,