Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option for serial run of benchmark. #1831

Merged
merged 2 commits into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Copy and pasting the git commit messages is __NOT__ enough.
## [Unreleased]
### Added
- Added action formatting option to `hiway-v0`.
- Introduced `debug: serial: bool` option to driving smarts benchmark config.
### Changed
- Moved action and observation conversions from `smarts.env.gymnasium.utils` to `smarts.env.utils`.
### Deprecated
Expand Down
4 changes: 3 additions & 1 deletion smarts/benchmark/driving_smarts/v0/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ benchmark:
- https://codalab.lisn.upsaclay.fr/competitions/6618
- https://smarts-project.github.io/archive/2022_nips_driving_smarts/
eval_episodes: 50
debug:
serial: False
shared_env_kwargs:
seed: 42
headless: true
Expand Down Expand Up @@ -41,4 +43,4 @@ benchmark:
# kwargs:
# scenario_dirs:
# - "./scenarios/naturalistic/waymo"
# - "./scenarios/naturalistic/ngsim"
# - "./scenarios/naturalistic/ngsim"
28 changes: 25 additions & 3 deletions smarts/benchmark/entrypoints/benchmark_runner_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# THE SOFTWARE.
import logging
import os
from typing import List, Tuple
from typing import Generator, List, Tuple

import gymnasium as gym
import psutil
Expand All @@ -39,6 +39,10 @@

@ray.remote(num_returns=1)
def _eval_worker(name, env_config, episodes, agent_config, error_tolerant=False):
return _eval_worker_local(name, env_config, episodes, agent_config, error_tolerant)


def _eval_worker_local(name, env_config, episodes, agent_config, error_tolerant=False):
import warnings

warnings.filterwarnings("ignore")
Expand Down Expand Up @@ -81,7 +85,7 @@ def _eval_worker(name, env_config, episodes, agent_config, error_tolerant=False)
return name, score


def _task_iterator(env_args, benchmark_args, agent_args, log_workers):
def _parallel_task_iterator(env_args, benchmark_args, agent_args, log_workers):
num_cpus = max(1, min(len(os.sched_getaffinity(0)), psutil.cpu_count(False) or 4))

with suppress_output(stdout=True):
Expand Down Expand Up @@ -110,6 +114,19 @@ def _task_iterator(env_args, benchmark_args, agent_args, log_workers):
ray.shutdown()


def _serial_task_iterator(env_args, benchmark_args, agent_args, *args, **_):
for name, env_config in env_args.items():
print(f"Evaluating {name}...")
name, score = _eval_worker_local(
name=name,
env_config=env_config,
episodes=benchmark_args["eval_episodes"],
agent_config=agent_args,
error_tolerant=ERROR_TOLERANT,
)
yield name, score


def benchmark(benchmark_args, agent_args, log_workers=False):
"""Runs the benchmark using the following:
Args:
Expand All @@ -118,6 +135,7 @@ def benchmark(benchmark_args, agent_args, log_workers=False):
debug_log(bool): Whether the benchmark should log to stdout.
"""
print(f"Starting `{benchmark_args['name']}` benchmark.")
debug = benchmark_args.get("debug", {})
message = benchmark_args.get("message")
if message is not None:
print(message)
Expand All @@ -133,7 +151,11 @@ def benchmark(benchmark_args, agent_args, log_workers=False):
)
named_scores = []

for name, score in _task_iterator(
iterator: Generator[tuple, None, None] = (
_serial_task_iterator if debug.get("serial") else _parallel_task_iterator
)

for name, score in iterator(
env_args=env_args,
benchmark_args=benchmark_args,
agent_args=agent_args,
Expand Down