Skip to content

Commit

Permalink
Merge 013dd90 into 5dbc654
Browse files Browse the repository at this point in the history
  • Loading branch information
mgarrard authored Dec 7, 2023
2 parents 5dbc654 + 013dd90 commit 58b13b2
Show file tree
Hide file tree
Showing 15 changed files with 524 additions and 72 deletions.
11 changes: 11 additions & 0 deletions ax/exceptions/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,14 @@ class GenerationStrategyRepeatedPoints(GenerationStrategyCompleted):
"""

pass


class GenerationStrategyMisconfiguredException(AxError):
"""Special exception indicating that the generation strategy is misconfigured."""

def __init__(self, error_info: Optional[str]) -> None:
super().__init__(
"This GenerationStrategy was unable to be initialized properly. Please "
+ "check the documentation, and adjust the configuration accordingly. "
+ f"{error_info}"
)
1 change: 1 addition & 0 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class GenerationNode:

# Optional specifications
_model_spec_to_gen_from: Optional[ModelSpec] = None
# TODO: @mgarrard should this be a dict criterion_class name -> criterion mapping?
_transition_criteria: Optional[Sequence[TransitionCriterion]]

# [TODO] Handle experiment passing more eloquently by enforcing experiment
Expand Down
270 changes: 217 additions & 53 deletions ax/modelbridge/generation_strategy.py

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions ax/modelbridge/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
CVResult,
)
from ax.modelbridge.registry import ModelRegistryBase
from ax.utils.common.base import Base
from ax.utils.common.base import SortableBase
from ax.utils.common.kwargs import (
consolidate_kwargs,
filter_kwargs,
Expand All @@ -48,7 +48,7 @@ def default(self, o: Any) -> str:


@dataclass
class ModelSpec(Base):
class ModelSpec(SortableBase):
model_enum: ModelRegistryBase
# Kwargs to pass into the `Model` + `ModelBridge` constructors in
# `ModelRegistryBase.__call__`.
Expand Down Expand Up @@ -288,6 +288,12 @@ def __hash__(self) -> int:
def __eq__(self, other: ModelSpec) -> bool:
return repr(self) == repr(other)

@property
def _unique_id(self) -> str:
"""Returns the unique ID of this model spec"""
# TODO @mgarrard verify that this is unique enough
return str(hash(self))


@dataclass
class FactoryFunctionModelSpec(ModelSpec):
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/tests/test_completion_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_single_criterion(self) -> None:
)
)

self.assertEqual(generation_strategy._curr.model, Models.GPEI)
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)

def test_many_criteria(self) -> None:
criteria = [
Expand Down Expand Up @@ -145,4 +145,4 @@ def test_many_criteria(self) -> None:
)
)

self.assertEqual(generation_strategy._curr.model, Models.GPEI)
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)
260 changes: 256 additions & 4 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import cast, List
from unittest import mock
from unittest.mock import MagicMock, patch
Expand All @@ -21,16 +22,23 @@
from ax.exceptions.core import DataRequiredError, UserInputError
from ax.exceptions.generation_strategy import (
GenerationStrategyCompleted,
GenerationStrategyMisconfiguredException,
GenerationStrategyRepeatedPoints,
MaxParallelismReachedException,
)
from ax.modelbridge.discrete import DiscreteModelBridge
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.random import RandomModelBridge
from ax.modelbridge.registry import Cont_X_trans, MODEL_KEY_TO_MODEL_SETUP, Models
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.transition_criterion import (
MaxGenerationParallelism,
MaxTrials,
MinTrials,
)
from ax.models.random.sobol import SobolGenerator
from ax.utils.common.equality import same_elements
from ax.utils.common.mock import mock_patch_method_original
Expand Down Expand Up @@ -416,8 +424,8 @@ def test_clone_reset(self) -> None:
]
)
ftgs._curr = ftgs._steps[1]
self.assertEqual(ftgs._curr.index, 1)
self.assertEqual(ftgs.clone_reset()._curr.index, 0)
self.assertEqual(ftgs.current_step_index, 1)
self.assertEqual(ftgs.clone_reset().current_step_index, 0)

def test_kwargs_passed(self) -> None:
gs = GenerationStrategy(
Expand Down Expand Up @@ -527,10 +535,12 @@ def test_trials_as_df(self) -> None:
# attach necessary trials to fill up the Generation Strategy
trial = exp.new_trial(sobol_generation_strategy.gen(experiment=exp))
self.assertEqual(
sobol_generation_strategy.trials_as_df.head()["Generation Step"][0], 0
sobol_generation_strategy.trials_as_df.head()["Generation Step"][0],
"GenerationStep_0",
)
self.assertEqual(
sobol_generation_strategy.trials_as_df.head()["Generation Step"][2], 1
sobol_generation_strategy.trials_as_df.head()["Generation Step"][2],
"GenerationStep_1",
)

def test_max_parallelism_reached(self) -> None:
Expand Down Expand Up @@ -883,6 +893,248 @@ def test_gen_for_multiple_trials_with_multiple_models(self) -> None:
for p in original_pending[m]:
self.assertIn(p, pending[m])

# ---------- Tests for GenerationStrategies composed of GenerationNodes --------
def test_gs_setup_with_nodes(self) -> None:
"""Test GS initalization and validation with nodes"""
sobol_model_spec = ModelSpec(
model_enum=Models.SOBOL,
model_kwargs={},
model_gen_kwargs={"n": 2},
)
node_1_criterion = [
MaxTrials(
threshold=4,
block_gen_if_met=False,
transition_to="node_2",
only_in_statuses=None,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
),
MinTrials(
only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
threshold=2,
transition_to="node_2",
),
MaxGenerationParallelism(
threshold=1,
only_in_statuses=[TrialStatus.RUNNING],
block_gen_if_met=True,
block_transition_if_unmet=False,
),
]

# check error raised if node names are not unique
with self.assertRaisesRegex(
GenerationStrategyMisconfiguredException, "All node names"
):
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
transition_criteria=node_1_criterion,
model_specs=[sobol_model_spec],
),
GenerationNode(
node_name="node_1",
model_specs=[sobol_model_spec],
),
],
)
# check error raised if transition to arguemnt is not valid
with self.assertRaisesRegex(
GenerationStrategyMisconfiguredException, "`transition_to` argument"
):
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
transition_criteria=node_1_criterion,
model_specs=[sobol_model_spec],
),
GenerationNode(
node_name="node_3",
model_specs=[sobol_model_spec],
),
],
)

# check error raised if provided both steps and nodes
with self.assertRaisesRegex(
GenerationStrategyMisconfiguredException, "either steps or nodes"
):
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
transition_criteria=node_1_criterion,
model_specs=[sobol_model_spec],
),
GenerationNode(
node_name="node_3",
model_specs=[sobol_model_spec],
),
],
steps=[
GenerationStep(
model=Models.SOBOL,
num_trials=5,
model_kwargs=self.step_model_kwargs,
),
GenerationStep(
model=Models.GPEI,
num_trials=-1,
model_kwargs=self.step_model_kwargs,
),
],
)

# check error raised if provided both steps and nodes under node list
with self.assertRaisesRegex(
GenerationStrategyMisconfiguredException, "must either be a GenerationStep"
):
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
transition_criteria=node_1_criterion,
model_specs=[sobol_model_spec],
),
GenerationStep(
model=Models.SOBOL,
num_trials=5,
model_kwargs=self.step_model_kwargs,
),
GenerationNode(
node_name="node_2",
model_specs=[sobol_model_spec],
),
],
)
# check that warning is logged if no nodes have transition arguments
with self.assertLogs(GenerationStrategy.__module__, logging.WARNING) as logger:
warning_msg = (
"None of the nodes in this GenerationStrategy "
"contain a `transition_to` argument in their transition_criteria. "
)
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
model_specs=[sobol_model_spec],
),
GenerationNode(
node_name="node_3",
model_specs=[sobol_model_spec],
),
],
)
self.assertTrue(
any(warning_msg in output for output in logger.output),
logger.output,
)

def test_gs_with_generation_nodes(self) -> None:
"Simple test of a SOBOL + GPEI GenerationStrategy composed of GenerationNodes"
sobol_criterion = [
MaxTrials(
threshold=5,
transition_to="GPEI_node",
block_gen_if_met=True,
only_in_statuses=None,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
)
]
gpei_criterion = [
MaxTrials(
threshold=2,
transition_to=None,
block_gen_if_met=True,
only_in_statuses=None,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
)
]
sobol_model_spec = ModelSpec(
model_enum=Models.SOBOL,
model_kwargs=self.step_model_kwargs,
model_gen_kwargs={},
)
gpei_model_spec = ModelSpec(
model_enum=Models.GPEI,
model_kwargs=self.step_model_kwargs,
model_gen_kwargs={},
)
sobol_node = GenerationNode(
node_name="sobol_node",
transition_criteria=sobol_criterion,
model_specs=[sobol_model_spec],
gen_unlimited_trials=False,
)
gpei_node = GenerationNode(
node_name="GPEI_node",
transition_criteria=gpei_criterion,
model_specs=[gpei_model_spec],
gen_unlimited_trials=False,
)

sobol_GPEI_GS_nodes = GenerationStrategy(
name="Sobol+GPEI_Nodes",
nodes=[sobol_node, gpei_node],
)
exp = get_branin_experiment()
self.assertEqual(sobol_GPEI_GS_nodes.name, "Sobol+GPEI_Nodes")
self.assertEqual(sobol_GPEI_GS_nodes.model_transitions, [5])

for i in range(7):
g = sobol_GPEI_GS_nodes.gen(exp)
exp.new_trial(generator_run=g).run()
self.assertEqual(len(sobol_GPEI_GS_nodes._generator_runs), i + 1)
if i > 4:
self.mock_torch_model_bridge.assert_called()
else:
self.assertEqual(g._model_key, "Sobol")
mkw = g._model_kwargs
self.assertIsNotNone(mkw)
if i > 0:
# Generated points are randomized, so checking that they're there.
self.assertIsNotNone(mkw.get("generated_points"))
else:
# This is the first GR, there should be no generated points yet.
self.assertIsNone(mkw.get("generated_points"))
# Remove the randomized generated points to compare the rest.
mkw = mkw.copy()
del mkw["generated_points"]
self.assertEqual(
mkw,
{
"seed": None,
"deduplicate": True,
"init_position": i,
"scramble": True,
"fallback_to_sample_polytope": False,
},
)
self.assertEqual(
g._bridge_kwargs,
{
"optimization_config": None,
"status_quo_features": None,
"status_quo_name": None,
"transform_configs": None,
"transforms": Cont_X_trans,
"fit_out_of_design": False,
"fit_abandoned": False,
"fit_tracking_metrics": True,
"fit_on_init": True,
},
)
ms = g._model_state_after_gen
self.assertIsNotNone(ms)
# Generated points are randomized, so just checking that they are there.
self.assertIn("generated_points", ms)
# Remove the randomized generated points to compare the rest.
ms = ms.copy()
del ms["generated_points"]
self.assertEqual(ms, {"init_position": i + 1})

# ------------- Testing helpers (put tests above this line) -------------

def _run_GS_for_N_rounds(
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_minimum_preference_criterion(self) -> None:
raise_data_required_error=False
)
)
self.assertEqual(generation_strategy._curr.model, Models.GPEI)
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)

def test_default_step_criterion_setup(self) -> None:
"""This test ensures that the default completion criterion for GenerationSteps
Expand Down
Loading

0 comments on commit 58b13b2

Please sign in to comment.