Skip to content

Commit 013dd90

Browse files
Mia Garrardfacebook-github-bot
Mia Garrard
authored andcommitted
Support GenerationNodes and GenerationSteps in GenerationStrategy and default to GenerationNodes (#2024)
Summary: This diff does the following: Supports GenerationNodes at the level of GenerationStrategy. This is the big hurrah diff! upcoming: (0) Add decorator for functions that are only supported in steps (1) update the storage to include nodes independently (and not just as part of step) (2) delete now unused GenStep functions (3) final pass on all the doc strings and variables -- lots to clean up here (4) add transition criterion to the repr string + some of the other fields that havent made it yet on GeneratinoNode (5) Do a final pass of the generationStrategy/GenerationNode files to see what else can be migrated/condensed (6) rename transiton criterion to action criterion (7) remove conditionals for legacy usecase (8) clean up any lingering todos Reviewed By: lena-kashtelyan Differential Revision: D51120002
1 parent 5dbc654 commit 013dd90

15 files changed

+524
-72
lines changed

Diff for: ax/exceptions/generation_strategy.py

+11
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,14 @@ class GenerationStrategyRepeatedPoints(GenerationStrategyCompleted):
5555
"""
5656

5757
pass
58+
59+
60+
class GenerationStrategyMisconfiguredException(AxError):
61+
"""Special exception indicating that the generation strategy is misconfigured."""
62+
63+
def __init__(self, error_info: Optional[str]) -> None:
64+
super().__init__(
65+
"This GenerationStrategy was unable to be initialized properly. Please "
66+
+ "check the documentation, and adjust the configuration accordingly. "
67+
+ f"{error_info}"
68+
)

Diff for: ax/modelbridge/generation_node.py

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class GenerationNode:
100100

101101
# Optional specifications
102102
_model_spec_to_gen_from: Optional[ModelSpec] = None
103+
# TODO: @mgarrard should this be a dict criterion_class name -> criterion mapping?
103104
_transition_criteria: Optional[Sequence[TransitionCriterion]]
104105

105106
# [TODO] Handle experiment passing more eloquently by enforcing experiment

Diff for: ax/modelbridge/generation_strategy.py

+217-53
Large diffs are not rendered by default.

Diff for: ax/modelbridge/model_spec.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
CVResult,
2828
)
2929
from ax.modelbridge.registry import ModelRegistryBase
30-
from ax.utils.common.base import Base
30+
from ax.utils.common.base import SortableBase
3131
from ax.utils.common.kwargs import (
3232
consolidate_kwargs,
3333
filter_kwargs,
@@ -48,7 +48,7 @@ def default(self, o: Any) -> str:
4848

4949

5050
@dataclass
51-
class ModelSpec(Base):
51+
class ModelSpec(SortableBase):
5252
model_enum: ModelRegistryBase
5353
# Kwargs to pass into the `Model` + `ModelBridge` constructors in
5454
# `ModelRegistryBase.__call__`.
@@ -288,6 +288,12 @@ def __hash__(self) -> int:
288288
def __eq__(self, other: ModelSpec) -> bool:
289289
return repr(self) == repr(other)
290290

291+
@property
292+
def _unique_id(self) -> str:
293+
"""Returns the unique ID of this model spec"""
294+
# TODO @mgarrard verify that this is unique enough
295+
return str(hash(self))
296+
291297

292298
@dataclass
293299
class FactoryFunctionModelSpec(ModelSpec):

Diff for: ax/modelbridge/tests/test_completion_criterion.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_single_criterion(self) -> None:
7171
)
7272
)
7373

74-
self.assertEqual(generation_strategy._curr.model, Models.GPEI)
74+
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)
7575

7676
def test_many_criteria(self) -> None:
7777
criteria = [
@@ -145,4 +145,4 @@ def test_many_criteria(self) -> None:
145145
)
146146
)
147147

148-
self.assertEqual(generation_strategy._curr.model, Models.GPEI)
148+
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)

Diff for: ax/modelbridge/tests/test_generation_strategy.py

+256-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
from typing import cast, List
89
from unittest import mock
910
from unittest.mock import MagicMock, patch
@@ -21,16 +22,23 @@
2122
from ax.exceptions.core import DataRequiredError, UserInputError
2223
from ax.exceptions.generation_strategy import (
2324
GenerationStrategyCompleted,
25+
GenerationStrategyMisconfiguredException,
2426
GenerationStrategyRepeatedPoints,
2527
MaxParallelismReachedException,
2628
)
2729
from ax.modelbridge.discrete import DiscreteModelBridge
2830
from ax.modelbridge.factory import get_sobol
31+
from ax.modelbridge.generation_node import GenerationNode
2932
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
3033
from ax.modelbridge.model_spec import ModelSpec
3134
from ax.modelbridge.random import RandomModelBridge
3235
from ax.modelbridge.registry import Cont_X_trans, MODEL_KEY_TO_MODEL_SETUP, Models
3336
from ax.modelbridge.torch import TorchModelBridge
37+
from ax.modelbridge.transition_criterion import (
38+
MaxGenerationParallelism,
39+
MaxTrials,
40+
MinTrials,
41+
)
3442
from ax.models.random.sobol import SobolGenerator
3543
from ax.utils.common.equality import same_elements
3644
from ax.utils.common.mock import mock_patch_method_original
@@ -416,8 +424,8 @@ def test_clone_reset(self) -> None:
416424
]
417425
)
418426
ftgs._curr = ftgs._steps[1]
419-
self.assertEqual(ftgs._curr.index, 1)
420-
self.assertEqual(ftgs.clone_reset()._curr.index, 0)
427+
self.assertEqual(ftgs.current_step_index, 1)
428+
self.assertEqual(ftgs.clone_reset().current_step_index, 0)
421429

422430
def test_kwargs_passed(self) -> None:
423431
gs = GenerationStrategy(
@@ -527,10 +535,12 @@ def test_trials_as_df(self) -> None:
527535
# attach necessary trials to fill up the Generation Strategy
528536
trial = exp.new_trial(sobol_generation_strategy.gen(experiment=exp))
529537
self.assertEqual(
530-
sobol_generation_strategy.trials_as_df.head()["Generation Step"][0], 0
538+
sobol_generation_strategy.trials_as_df.head()["Generation Step"][0],
539+
"GenerationStep_0",
531540
)
532541
self.assertEqual(
533-
sobol_generation_strategy.trials_as_df.head()["Generation Step"][2], 1
542+
sobol_generation_strategy.trials_as_df.head()["Generation Step"][2],
543+
"GenerationStep_1",
534544
)
535545

536546
def test_max_parallelism_reached(self) -> None:
@@ -883,6 +893,248 @@ def test_gen_for_multiple_trials_with_multiple_models(self) -> None:
883893
for p in original_pending[m]:
884894
self.assertIn(p, pending[m])
885895

896+
# ---------- Tests for GenerationStrategies composed of GenerationNodes --------
897+
def test_gs_setup_with_nodes(self) -> None:
898+
"""Test GS initalization and validation with nodes"""
899+
sobol_model_spec = ModelSpec(
900+
model_enum=Models.SOBOL,
901+
model_kwargs={},
902+
model_gen_kwargs={"n": 2},
903+
)
904+
node_1_criterion = [
905+
MaxTrials(
906+
threshold=4,
907+
block_gen_if_met=False,
908+
transition_to="node_2",
909+
only_in_statuses=None,
910+
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
911+
),
912+
MinTrials(
913+
only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
914+
threshold=2,
915+
transition_to="node_2",
916+
),
917+
MaxGenerationParallelism(
918+
threshold=1,
919+
only_in_statuses=[TrialStatus.RUNNING],
920+
block_gen_if_met=True,
921+
block_transition_if_unmet=False,
922+
),
923+
]
924+
925+
# check error raised if node names are not unique
926+
with self.assertRaisesRegex(
927+
GenerationStrategyMisconfiguredException, "All node names"
928+
):
929+
GenerationStrategy(
930+
nodes=[
931+
GenerationNode(
932+
node_name="node_1",
933+
transition_criteria=node_1_criterion,
934+
model_specs=[sobol_model_spec],
935+
),
936+
GenerationNode(
937+
node_name="node_1",
938+
model_specs=[sobol_model_spec],
939+
),
940+
],
941+
)
942+
# check error raised if transition to arguemnt is not valid
943+
with self.assertRaisesRegex(
944+
GenerationStrategyMisconfiguredException, "`transition_to` argument"
945+
):
946+
GenerationStrategy(
947+
nodes=[
948+
GenerationNode(
949+
node_name="node_1",
950+
transition_criteria=node_1_criterion,
951+
model_specs=[sobol_model_spec],
952+
),
953+
GenerationNode(
954+
node_name="node_3",
955+
model_specs=[sobol_model_spec],
956+
),
957+
],
958+
)
959+
960+
# check error raised if provided both steps and nodes
961+
with self.assertRaisesRegex(
962+
GenerationStrategyMisconfiguredException, "either steps or nodes"
963+
):
964+
GenerationStrategy(
965+
nodes=[
966+
GenerationNode(
967+
node_name="node_1",
968+
transition_criteria=node_1_criterion,
969+
model_specs=[sobol_model_spec],
970+
),
971+
GenerationNode(
972+
node_name="node_3",
973+
model_specs=[sobol_model_spec],
974+
),
975+
],
976+
steps=[
977+
GenerationStep(
978+
model=Models.SOBOL,
979+
num_trials=5,
980+
model_kwargs=self.step_model_kwargs,
981+
),
982+
GenerationStep(
983+
model=Models.GPEI,
984+
num_trials=-1,
985+
model_kwargs=self.step_model_kwargs,
986+
),
987+
],
988+
)
989+
990+
# check error raised if provided both steps and nodes under node list
991+
with self.assertRaisesRegex(
992+
GenerationStrategyMisconfiguredException, "must either be a GenerationStep"
993+
):
994+
GenerationStrategy(
995+
nodes=[
996+
GenerationNode(
997+
node_name="node_1",
998+
transition_criteria=node_1_criterion,
999+
model_specs=[sobol_model_spec],
1000+
),
1001+
GenerationStep(
1002+
model=Models.SOBOL,
1003+
num_trials=5,
1004+
model_kwargs=self.step_model_kwargs,
1005+
),
1006+
GenerationNode(
1007+
node_name="node_2",
1008+
model_specs=[sobol_model_spec],
1009+
),
1010+
],
1011+
)
1012+
# check that warning is logged if no nodes have transition arguments
1013+
with self.assertLogs(GenerationStrategy.__module__, logging.WARNING) as logger:
1014+
warning_msg = (
1015+
"None of the nodes in this GenerationStrategy "
1016+
"contain a `transition_to` argument in their transition_criteria. "
1017+
)
1018+
GenerationStrategy(
1019+
nodes=[
1020+
GenerationNode(
1021+
node_name="node_1",
1022+
model_specs=[sobol_model_spec],
1023+
),
1024+
GenerationNode(
1025+
node_name="node_3",
1026+
model_specs=[sobol_model_spec],
1027+
),
1028+
],
1029+
)
1030+
self.assertTrue(
1031+
any(warning_msg in output for output in logger.output),
1032+
logger.output,
1033+
)
1034+
1035+
def test_gs_with_generation_nodes(self) -> None:
1036+
"Simple test of a SOBOL + GPEI GenerationStrategy composed of GenerationNodes"
1037+
sobol_criterion = [
1038+
MaxTrials(
1039+
threshold=5,
1040+
transition_to="GPEI_node",
1041+
block_gen_if_met=True,
1042+
only_in_statuses=None,
1043+
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
1044+
)
1045+
]
1046+
gpei_criterion = [
1047+
MaxTrials(
1048+
threshold=2,
1049+
transition_to=None,
1050+
block_gen_if_met=True,
1051+
only_in_statuses=None,
1052+
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
1053+
)
1054+
]
1055+
sobol_model_spec = ModelSpec(
1056+
model_enum=Models.SOBOL,
1057+
model_kwargs=self.step_model_kwargs,
1058+
model_gen_kwargs={},
1059+
)
1060+
gpei_model_spec = ModelSpec(
1061+
model_enum=Models.GPEI,
1062+
model_kwargs=self.step_model_kwargs,
1063+
model_gen_kwargs={},
1064+
)
1065+
sobol_node = GenerationNode(
1066+
node_name="sobol_node",
1067+
transition_criteria=sobol_criterion,
1068+
model_specs=[sobol_model_spec],
1069+
gen_unlimited_trials=False,
1070+
)
1071+
gpei_node = GenerationNode(
1072+
node_name="GPEI_node",
1073+
transition_criteria=gpei_criterion,
1074+
model_specs=[gpei_model_spec],
1075+
gen_unlimited_trials=False,
1076+
)
1077+
1078+
sobol_GPEI_GS_nodes = GenerationStrategy(
1079+
name="Sobol+GPEI_Nodes",
1080+
nodes=[sobol_node, gpei_node],
1081+
)
1082+
exp = get_branin_experiment()
1083+
self.assertEqual(sobol_GPEI_GS_nodes.name, "Sobol+GPEI_Nodes")
1084+
self.assertEqual(sobol_GPEI_GS_nodes.model_transitions, [5])
1085+
1086+
for i in range(7):
1087+
g = sobol_GPEI_GS_nodes.gen(exp)
1088+
exp.new_trial(generator_run=g).run()
1089+
self.assertEqual(len(sobol_GPEI_GS_nodes._generator_runs), i + 1)
1090+
if i > 4:
1091+
self.mock_torch_model_bridge.assert_called()
1092+
else:
1093+
self.assertEqual(g._model_key, "Sobol")
1094+
mkw = g._model_kwargs
1095+
self.assertIsNotNone(mkw)
1096+
if i > 0:
1097+
# Generated points are randomized, so checking that they're there.
1098+
self.assertIsNotNone(mkw.get("generated_points"))
1099+
else:
1100+
# This is the first GR, there should be no generated points yet.
1101+
self.assertIsNone(mkw.get("generated_points"))
1102+
# Remove the randomized generated points to compare the rest.
1103+
mkw = mkw.copy()
1104+
del mkw["generated_points"]
1105+
self.assertEqual(
1106+
mkw,
1107+
{
1108+
"seed": None,
1109+
"deduplicate": True,
1110+
"init_position": i,
1111+
"scramble": True,
1112+
"fallback_to_sample_polytope": False,
1113+
},
1114+
)
1115+
self.assertEqual(
1116+
g._bridge_kwargs,
1117+
{
1118+
"optimization_config": None,
1119+
"status_quo_features": None,
1120+
"status_quo_name": None,
1121+
"transform_configs": None,
1122+
"transforms": Cont_X_trans,
1123+
"fit_out_of_design": False,
1124+
"fit_abandoned": False,
1125+
"fit_tracking_metrics": True,
1126+
"fit_on_init": True,
1127+
},
1128+
)
1129+
ms = g._model_state_after_gen
1130+
self.assertIsNotNone(ms)
1131+
# Generated points are randomized, so just checking that they are there.
1132+
self.assertIn("generated_points", ms)
1133+
# Remove the randomized generated points to compare the rest.
1134+
ms = ms.copy()
1135+
del ms["generated_points"]
1136+
self.assertEqual(ms, {"init_position": i + 1})
1137+
8861138
# ------------- Testing helpers (put tests above this line) -------------
8871139

8881140
def _run_GS_for_N_rounds(

Diff for: ax/modelbridge/tests/test_transition_criterion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_minimum_preference_criterion(self) -> None:
7474
raise_data_required_error=False
7575
)
7676
)
77-
self.assertEqual(generation_strategy._curr.model, Models.GPEI)
77+
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)
7878

7979
def test_default_step_criterion_setup(self) -> None:
8080
"""This test ensures that the default completion criterion for GenerationSteps

0 commit comments

Comments
 (0)