From bee22e03ec7571872be1b00d09b0ca6bbd2aafff Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Wed, 6 Dec 2023 14:28:18 -0800 Subject: [PATCH] Add Decorator for GenerationStep only functions in GenerationStrategy (#2045) Summary: This diff does the following: Add decorator for functions that are only supported in steps -- this prevents some code redundancy upcoming: (1) update the storage to include nodes independently (and not just as part of step) (2) delete now unused GenStep functions (3) final pass on all the doc strings and variables -- lots to clean up here (4) add transition criterion to the repr string + some of the other fields that havent made it yet on GeneratinoNode (5) Do a final pass of the generationStrategy/GenerationNode files to see what else can be migrated/condensed (6) rename transiton criterion to action criterion (7) remove conditionals for legacy usecase ( clean up any lingering todos Reviewed By: lena-kashtelyan Differential Revision: D51816513 --- ax/modelbridge/generation_strategy.py | 40 ++++++++++++++----- .../tests/test_generation_strategy.py | 27 ++++++++++++- 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index cd342f573e6..4c2ab7f519e 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -7,9 +7,10 @@ from __future__ import annotations from copy import deepcopy +from functools import wraps from logging import Logger -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar import pandas as pd from ax.core.data import Data @@ -21,7 +22,7 @@ extend_pending_observations, get_pending_observation_features_based_on_trial_status, ) -from ax.exceptions.core import DataRequiredError, UserInputError +from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError from ax.exceptions.generation_strategy import ( GenerationStrategyCompleted, GenerationStrategyMisconfiguredException, @@ -45,6 +46,28 @@ "generate a unique parameterization. This indicates that the search space has " "likely been fully explored, or that the sweep has converged." ) +T = TypeVar("T") + + +def step_based_gs_only(f: Callable[..., T]) -> Callable[..., T]: + """ + For use as a decorator on functions only implemented for GenerationStep based + GenerationStrategies. Mainly useful for older GenerationStrategies. + """ + + @wraps(f) + def impl( + self: "GenerationStrategy", *args: List[Any], **kwargs: Dict[str, Any] + ) -> T: + + if self.is_node_based(self._nodes): + raise UnsupportedError( + f"{f.__name__} is not supported for GenerationNode based" + " GenerationStrategies." + ) + return f(self, *args, **kwargs) + + return impl class GenerationStrategy(GenerationStrategyInterface): @@ -98,7 +121,6 @@ def __init__( node_based_strategy = self.is_node_based(nodes=self._nodes) if isinstance(steps, list) and not node_based_strategy: - # pyre-ignore[6] self._validate_and_set_step_sequence(steps=self._nodes) elif isinstance(nodes, list) and node_based_strategy: self._validate_and_set_node_graph(nodes=nodes) @@ -119,6 +141,7 @@ def __init__( ) self._seen_trial_indices_by_status = None + @step_based_gs_only def _validate_and_set_step_sequence(self, steps: List[GenerationStep]) -> None: """Initialize and validate the steps provided to this GenerationStrategy. @@ -133,7 +156,6 @@ def _validate_and_set_step_sequence(self, steps: List[GenerationStep]) -> None: underlying GenerationNode objects. """ for idx, step in enumerate(steps): - assert isinstance(step, GenerationStep) if step.num_trials == -1 and len(step.completion_criteria) < 1: if idx < len(self._steps) - 1: raise UserInputError( @@ -218,11 +240,9 @@ def _validate_and_set_node_graph(self, nodes: List[GenerationNode]) -> None: self._curr = nodes[0] @property + @step_based_gs_only def _steps(self) -> List[GenerationStep]: """List of generation steps.""" - assert all( - isinstance(n, GenerationStep) for n in self._nodes - ), "Attempting to set steps to non-GenerationStep objects." return self._nodes # pyre-ignore[7] def is_node_based(self, nodes: List[GenerationNode]) -> bool: @@ -253,6 +273,7 @@ def name(self, name: str) -> None: self._name = name @property + @step_based_gs_only def model_transitions(self) -> List[int]: """List of trial indices where a transition happened from one model to another.""" @@ -294,14 +315,12 @@ def current_node(self) -> GenerationNode: return self._curr @property + @step_based_gs_only def current_step_index(self) -> int: """Returns the index of the current generation step. This attribute is replaced by node_name in newer GenerationStrategies but surfaced here for backward compatibility. """ - assert isinstance( - self._curr, GenerationStep - ), "current_step_index only works with GenerationStep" node_names_for_all_steps = [step._node_name for step in self._nodes] assert ( self._curr.node_name in node_names_for_all_steps @@ -347,6 +366,7 @@ def uses_non_registered_models(self) -> bool: return not self._uses_registered_models @property + @step_based_gs_only def trials_as_df(self) -> Optional[pd.DataFrame]: """Puts information on individual trials into a data frame for easy viewing. For example: diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index 06948642363..6dcfe4b6a80 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -19,7 +19,7 @@ from ax.core.utils import ( get_pending_observation_features_based_on_trial_status as get_pending, ) -from ax.exceptions.core import DataRequiredError, UserInputError +from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError from ax.exceptions.generation_strategy import ( GenerationStrategyCompleted, GenerationStrategyMisconfiguredException, @@ -1081,7 +1081,6 @@ def test_gs_with_generation_nodes(self) -> None: ) exp = get_branin_experiment() self.assertEqual(sobol_GPEI_GS_nodes.name, "Sobol+GPEI_Nodes") - self.assertEqual(sobol_GPEI_GS_nodes.model_transitions, [5]) for i in range(7): g = sobol_GPEI_GS_nodes.gen(exp) @@ -1135,6 +1134,30 @@ def test_gs_with_generation_nodes(self) -> None: del ms["generated_points"] self.assertEqual(ms, {"init_position": i + 1}) + def test_step_based_gs_only(self) -> None: + """Test the step_based_gs_only decorator""" + sobol_model_spec = ModelSpec( + model_enum=Models.SOBOL, + model_kwargs={}, + model_gen_kwargs={"n": 2}, + ) + gs_test = GenerationStrategy( + nodes=[ + GenerationNode( + node_name="node_1", + model_specs=[sobol_model_spec], + ), + GenerationNode( + node_name="node_2", + model_specs=[sobol_model_spec], + ), + ], + ) + with self.assertRaisesRegex( + UnsupportedError, "is not supported for GenerationNode based" + ): + gs_test.current_step_index + # ------------- Testing helpers (put tests above this line) ------------- def _run_GS_for_N_rounds(