diff --git a/examples/sensitivity_analysis/action_sensitivity_analysis.py b/examples/sensitivity_analysis/action_sensitivity_analysis.py index d979cee65..c7cc4f433 100644 --- a/examples/sensitivity_analysis/action_sensitivity_analysis.py +++ b/examples/sensitivity_analysis/action_sensitivity_analysis.py @@ -26,7 +26,6 @@ """ import random from concurrent.futures import ThreadPoolExecutor -from multiprocessing import cpu_count from pathlib import Path from typing import List, Optional @@ -36,10 +35,10 @@ SensitivityAnalysisResult, run_sensitivity_analysis, ) - from compiler_gym.envs import CompilerEnv from compiler_gym.util.flags.benchmark_from_flags import benchmark_from_flags from compiler_gym.util.flags.env_from_flags import env_from_flags +from compiler_gym.util.flags import nproc # noqa from compiler_gym.util.gym_type_hints import ActionType from compiler_gym.util.logs import create_logging_dir from compiler_gym.util.timer import Timer @@ -130,7 +129,7 @@ def run_action_sensitivity_analysis( reward_space: str, num_trials: int, max_warmup_steps: int, - nproc: int = cpu_count(), + nproc: int, max_attempts_multiplier: int = 5, ): """Estimate the immediate reward of a given list of actions.""" @@ -181,7 +180,7 @@ def main(argv): rewards_path=rewards_path, runtimes_path=runtimes_path, actions=actions, - reward=FLAGS.reward, + reward_space=FLAGS.reward, num_trials=FLAGS.num_action_sensitivity_trials, max_warmup_steps=FLAGS.max_warmup_steps, nproc=FLAGS.nproc,