From 50790cc5935cb9eafe6ea7b68d6cecaedc6c317e Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Wed, 22 Nov 2023 07:19:01 -0800 Subject: [PATCH] Allow batch trial to be constructed with a list of GRs (#1995) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/1995 This makes creating trials from `GS.gen_for_multiple_trial_from_multiple()` models easier in D51307866 and is closer to the way `BatchTrial`s actually work. Reviewed By: lena-kashtelyan Differential Revision: D51211147 fbshipit-source-id: eb57c2c77c06959bdeb3bbaa37bb0c98b8c85699 --- ax/core/batch_trial.py | 14 ++++++++++++- ax/core/experiment.py | 6 ++++++ ax/core/tests/test_experiment.py | 36 ++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) diff --git a/ax/core/batch_trial.py b/ax/core/batch_trial.py index 41d97312c36..a83393f9b33 100644 --- a/ax/core/batch_trial.py +++ b/ax/core/batch_trial.py @@ -35,7 +35,7 @@ TEvaluationOutcome, validate_evaluation_outcome, ) -from ax.exceptions.core import AxError, UserInputError +from ax.exceptions.core import AxError, UnsupportedError, UserInputError from ax.utils.common.base import SortableBase from ax.utils.common.docutils import copy_doc from ax.utils.common.equality import datetime_equals, equality_typechecker @@ -117,6 +117,10 @@ class BatchTrial(BaseTrial): generator_run: GeneratorRun, associated with this trial. This can a also be set later through `add_arm` or `add_generator_run`, but a trial's associated generator run is immutable once set. + generator_runs: GeneratorRuns, associated with this trial. This can a + also be set later through `add_arm` or `add_generator_run`, but a + trial's associated generator run is immutable once set. This cannot + be combined with the `generator_run` argument. trial_type: Type of this trial, if used in MultiTypeExperiment. optimize_for_power: Whether to optimize the weights of arms in this trial such that the experiment's power to detect effects of @@ -140,6 +144,7 @@ def __init__( self, experiment: core.experiment.Experiment, generator_run: Optional[GeneratorRun] = None, + generator_runs: Optional[List[GeneratorRun]] = None, trial_type: Optional[str] = None, optimize_for_power: Optional[bool] = False, ttl_seconds: Optional[int] = None, @@ -158,7 +163,14 @@ def __init__( self._status_quo: Optional[Arm] = None self._status_quo_weight_override: Optional[float] = None if generator_run is not None: + if generator_runs is not None: + raise UnsupportedError( + "Cannot specify both `generator_run` and `generator_runs`." + ) self.add_generator_run(generator_run=generator_run) + elif generator_runs is not None: + for gr in generator_runs: + self.add_generator_run(generator_run=gr) self.optimize_for_power = optimize_for_power status_quo = experiment.status_quo diff --git a/ax/core/experiment.py b/ax/core/experiment.py index ad2122790ce..28adb1c11fe 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -1059,6 +1059,7 @@ def new_trial( def new_batch_trial( self, generator_run: Optional[GeneratorRun] = None, + generator_runs: Optional[List[GeneratorRun]] = None, trial_type: Optional[str] = None, optimize_for_power: Optional[bool] = False, ttl_seconds: Optional[int] = None, @@ -1070,6 +1071,10 @@ def new_batch_trial( generator_run: GeneratorRun, associated with this trial. This can a also be set later through `add_arm` or `add_generator_run`, but a trial's associated generator run is immutable once set. + generator_runs: GeneratorRuns, associated with this trial. This can a + also be set later through `add_arm` or `add_generator_run`, but a + trial's associated generator run is immutable once set. This cannot + be combined with the `generator_run` argument. trial_type: Type of this trial, if used in MultiTypeExperiment. optimize_for_power: Whether to optimize the weights of arms in this trial such that the experiment's power to detect effects of @@ -1090,6 +1095,7 @@ def new_batch_trial( experiment=self, trial_type=trial_type, generator_run=generator_run, + generator_runs=generator_runs, optimize_for_power=optimize_for_power, ttl_seconds=ttl_seconds, lifecycle_stage=lifecycle_stage, diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 3d2eff3a852..fb56c53eb48 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -50,6 +50,7 @@ get_status_quo, get_test_map_data_experiment, ) +from ax.utils.testing.mock import fast_botorch_optimize DUMMY_RUN_METADATA_KEY = "test_run_metadata_key" DUMMY_RUN_METADATA_VALUE = "test_run_metadata_value" @@ -1254,3 +1255,38 @@ def test_WarmStartMapData(self) -> None: old_df.drop(["arm_name", "trial_index"], axis=1), new_df.drop(["arm_name", "trial_index"], axis=1), ) + + @fast_botorch_optimize + def test_batch_with_multiple_generator_runs(self) -> None: + exp = get_branin_experiment() + sobol = Models.SOBOL(experiment=exp, search_space=exp.search_space) + exp.new_batch_trial(generator_runs=[sobol.gen(n=7)]).run().complete() + + data = exp.fetch_data() + gp = Models.BOTORCH_MODULAR( + experiment=exp, search_space=exp.search_space, data=data + ) + ts = Models.EMPIRICAL_BAYES_THOMPSON( + experiment=exp, search_space=exp.search_space, data=data + ) + exp.new_batch_trial(generator_runs=[gp.gen(n=3), ts.gen(n=1)]).run().complete() + + self.assertEqual(len(exp.trials), 2) + self.assertEqual(len(exp.trials[0].generator_runs), 1) + self.assertEqual(len(exp.trials[0].arms), 7) + self.assertEqual(len(exp.trials[1].generator_runs), 2) + self.assertEqual(len(exp.trials[1].arms), 4) + + def test_it_does_not_take_both_single_and_multiple_gr_ars(self) -> None: + exp = get_branin_experiment() + sobol = Models.SOBOL(experiment=exp, search_space=exp.search_space) + gr1 = sobol.gen(n=7) + gr2 = sobol.gen(n=7) + with self.assertRaisesRegex( + UnsupportedError, + "Cannot specify both `generator_run` and `generator_runs`.", + ): + exp.new_batch_trial( + generator_run=gr1, + generator_runs=[gr2], + )