diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index cd342f573e6..4c2ab7f519e 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -7,9 +7,10 @@ from __future__ import annotations from copy import deepcopy +from functools import wraps from logging import Logger -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar import pandas as pd from ax.core.data import Data @@ -21,7 +22,7 @@ extend_pending_observations, get_pending_observation_features_based_on_trial_status, ) -from ax.exceptions.core import DataRequiredError, UserInputError +from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError from ax.exceptions.generation_strategy import ( GenerationStrategyCompleted, GenerationStrategyMisconfiguredException, @@ -45,6 +46,28 @@ "generate a unique parameterization. This indicates that the search space has " "likely been fully explored, or that the sweep has converged." ) +T = TypeVar("T") + + +def step_based_gs_only(f: Callable[..., T]) -> Callable[..., T]: + """ + For use as a decorator on functions only implemented for GenerationStep based + GenerationStrategies. Mainly useful for older GenerationStrategies. + """ + + @wraps(f) + def impl( + self: "GenerationStrategy", *args: List[Any], **kwargs: Dict[str, Any] + ) -> T: + + if self.is_node_based(self._nodes): + raise UnsupportedError( + f"{f.__name__} is not supported for GenerationNode based" + " GenerationStrategies." + ) + return f(self, *args, **kwargs) + + return impl class GenerationStrategy(GenerationStrategyInterface): @@ -98,7 +121,6 @@ def __init__( node_based_strategy = self.is_node_based(nodes=self._nodes) if isinstance(steps, list) and not node_based_strategy: - # pyre-ignore[6] self._validate_and_set_step_sequence(steps=self._nodes) elif isinstance(nodes, list) and node_based_strategy: self._validate_and_set_node_graph(nodes=nodes) @@ -119,6 +141,7 @@ def __init__( ) self._seen_trial_indices_by_status = None + @step_based_gs_only def _validate_and_set_step_sequence(self, steps: List[GenerationStep]) -> None: """Initialize and validate the steps provided to this GenerationStrategy. @@ -133,7 +156,6 @@ def _validate_and_set_step_sequence(self, steps: List[GenerationStep]) -> None: underlying GenerationNode objects. """ for idx, step in enumerate(steps): - assert isinstance(step, GenerationStep) if step.num_trials == -1 and len(step.completion_criteria) < 1: if idx < len(self._steps) - 1: raise UserInputError( @@ -218,11 +240,9 @@ def _validate_and_set_node_graph(self, nodes: List[GenerationNode]) -> None: self._curr = nodes[0] @property + @step_based_gs_only def _steps(self) -> List[GenerationStep]: """List of generation steps.""" - assert all( - isinstance(n, GenerationStep) for n in self._nodes - ), "Attempting to set steps to non-GenerationStep objects." return self._nodes # pyre-ignore[7] def is_node_based(self, nodes: List[GenerationNode]) -> bool: @@ -253,6 +273,7 @@ def name(self, name: str) -> None: self._name = name @property + @step_based_gs_only def model_transitions(self) -> List[int]: """List of trial indices where a transition happened from one model to another.""" @@ -294,14 +315,12 @@ def current_node(self) -> GenerationNode: return self._curr @property + @step_based_gs_only def current_step_index(self) -> int: """Returns the index of the current generation step. This attribute is replaced by node_name in newer GenerationStrategies but surfaced here for backward compatibility. """ - assert isinstance( - self._curr, GenerationStep - ), "current_step_index only works with GenerationStep" node_names_for_all_steps = [step._node_name for step in self._nodes] assert ( self._curr.node_name in node_names_for_all_steps @@ -347,6 +366,7 @@ def uses_non_registered_models(self) -> bool: return not self._uses_registered_models @property + @step_based_gs_only def trials_as_df(self) -> Optional[pd.DataFrame]: """Puts information on individual trials into a data frame for easy viewing. For example: diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index 06948642363..6dcfe4b6a80 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -19,7 +19,7 @@ from ax.core.utils import ( get_pending_observation_features_based_on_trial_status as get_pending, ) -from ax.exceptions.core import DataRequiredError, UserInputError +from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError from ax.exceptions.generation_strategy import ( GenerationStrategyCompleted, GenerationStrategyMisconfiguredException, @@ -1081,7 +1081,6 @@ def test_gs_with_generation_nodes(self) -> None: ) exp = get_branin_experiment() self.assertEqual(sobol_GPEI_GS_nodes.name, "Sobol+GPEI_Nodes") - self.assertEqual(sobol_GPEI_GS_nodes.model_transitions, [5]) for i in range(7): g = sobol_GPEI_GS_nodes.gen(exp) @@ -1135,6 +1134,30 @@ def test_gs_with_generation_nodes(self) -> None: del ms["generated_points"] self.assertEqual(ms, {"init_position": i + 1}) + def test_step_based_gs_only(self) -> None: + """Test the step_based_gs_only decorator""" + sobol_model_spec = ModelSpec( + model_enum=Models.SOBOL, + model_kwargs={}, + model_gen_kwargs={"n": 2}, + ) + gs_test = GenerationStrategy( + nodes=[ + GenerationNode( + node_name="node_1", + model_specs=[sobol_model_spec], + ), + GenerationNode( + node_name="node_2", + model_specs=[sobol_model_spec], + ), + ], + ) + with self.assertRaisesRegex( + UnsupportedError, "is not supported for GenerationNode based" + ): + gs_test.current_step_index + # ------------- Testing helpers (put tests above this line) ------------- def _run_GS_for_N_rounds(