Skip to content

Commit

Permalink
Add Decorator for GenerationStep only functions in GenerationStrategy
Browse files Browse the repository at this point in the history
Summary:
This diff does the following:
Add decorator for functions that are only supported in steps -- this prevents some code redundancy


upcoming:
(1) update the storage to include nodes independently (and not just as part of step)
(2) delete now unused GenStep functions
(3) final pass on all the doc strings and variables -- lots to clean up here
(4) add transition criterion to the repr string + some of the other fields that havent made it yet on GeneratinoNode
(5) Do a final pass of the generationStrategy/GenerationNode files to see what else can be migrated/condensed
(6) rename transiton criterion to action criterion
(7) remove conditionals for legacy usecase
( clean up any lingering todos

Differential Revision: D51816513
  • Loading branch information
Mia Garrard authored and facebook-github-bot committed Dec 5, 2023
1 parent a08c76d commit 621f017
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 12 deletions.
40 changes: 30 additions & 10 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from __future__ import annotations

from copy import deepcopy
from functools import wraps

from logging import Logger
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar

import pandas as pd
from ax.core.data import Data
Expand All @@ -21,7 +22,7 @@
extend_pending_observations,
get_pending_observation_features_based_on_trial_status,
)
from ax.exceptions.core import DataRequiredError, UserInputError
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
from ax.exceptions.generation_strategy import (
GenerationStrategyCompleted,
GenerationStrategyMisconfiguredException,
Expand All @@ -44,6 +45,28 @@
"generate a unique parameterization. This indicates that the search space has "
"likely been fully explored, or that the sweep has converged."
)
T = TypeVar("T")


def step_based_gs_only(f: Callable[..., T]) -> Callable[..., T]:
"""
For use as a decorator on functions only implemented for GenerationStep based
GenerationStrategies. Mainly useful for older GenerationStrategies.
"""

@wraps(f)
def impl(
self: "GenerationStrategy", *args: List[Any], **kwargs: Dict[str, Any]
) -> T:

if self.node_based_strategy(self._nodes):
raise UnsupportedError(
f"{f.__name__} is not supported for GenerationNode based"
" GenerationStrategies."
)
return f(self, *args, **kwargs)

return impl


class GenerationStrategy(GenerationStrategyInterface):
Expand Down Expand Up @@ -97,7 +120,6 @@ def __init__(
node_based_strategy = self.node_based_strategy(nodes=self._nodes)

if isinstance(steps, list) and not node_based_strategy:
# pyre-ignore[6]
self._validate_and_set_step_sequence(steps=self._nodes)
elif isinstance(nodes, list) and node_based_strategy:
self._validate_and_set_node_graph(nodes=nodes)
Expand All @@ -114,10 +136,10 @@ def __init__(
)
self._seen_trial_indices_by_status = None

@step_based_gs_only
def _validate_and_set_step_sequence(self, steps: List[GenerationStep]) -> None:
"""Initialize and validate that the sequence of steps."""
for idx, step in enumerate(steps):
assert isinstance(step, GenerationStep)
if step.num_trials == -1 and len(step.completion_criteria) < 1:
if idx < len(self._steps) - 1:
raise UserInputError(
Expand Down Expand Up @@ -206,11 +228,9 @@ def _validate_and_set_node_graph(self, nodes: List[GenerationNode]) -> None:
self._curr = nodes[0]

@property
@step_based_gs_only
def _steps(self) -> List[GenerationStep]:
"""List of generation steps."""
assert all(
isinstance(n, GenerationStep) for n in self._nodes
), "Attempting to set steps to non-GenerationStep objects."
return self._nodes # pyre-ignore[7]

def node_based_strategy(self, nodes: List[GenerationNode]) -> bool:
Expand Down Expand Up @@ -241,6 +261,7 @@ def name(self, name: str) -> None:
self._name = name

@property
@step_based_gs_only
def model_transitions(self) -> List[int]:
"""List of trial indices where a transition happened from one model to
another."""
Expand Down Expand Up @@ -280,14 +301,12 @@ def current_node(self) -> GenerationNode:
return self._curr

@property
@step_based_gs_only
def current_step_index(self) -> int:
"""Returns the index of the current generation step. This attribute
is replaced by node_name in newer GenerationStrategies but surfaced here
for backward compatibility.
"""
assert isinstance(
self._curr, GenerationStep
), "current_step_index only works with GenerationStep"
node_names_for_all_steps = [step._node_name for step in self._nodes]
assert (
self._curr.node_name in node_names_for_all_steps
Expand Down Expand Up @@ -333,6 +352,7 @@ def uses_non_registered_models(self) -> bool:
return not self._uses_registered_models

@property
@step_based_gs_only
def trials_as_df(self) -> Optional[pd.DataFrame]:
"""Puts information on individual trials into a data frame for easy
viewing. For example:
Expand Down
27 changes: 25 additions & 2 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ax.core.utils import (
get_pending_observation_features_based_on_trial_status as get_pending,
)
from ax.exceptions.core import DataRequiredError, UserInputError
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
from ax.exceptions.generation_strategy import (
GenerationStrategyCompleted,
GenerationStrategyMisconfiguredException,
Expand Down Expand Up @@ -1081,7 +1081,6 @@ def test_gs_with_generation_nodes(self) -> None:
)
exp = get_branin_experiment()
self.assertEqual(sobol_GPEI_GS_nodes.name, "Sobol+GPEI_Nodes")
self.assertEqual(sobol_GPEI_GS_nodes.model_transitions, [5])

for i in range(7):
g = sobol_GPEI_GS_nodes.gen(exp)
Expand Down Expand Up @@ -1135,6 +1134,30 @@ def test_gs_with_generation_nodes(self) -> None:
del ms["generated_points"]
self.assertEqual(ms, {"init_position": i + 1})

def test_step_based_gs_only(self) -> None:
"""Test the step_based_gs_only decorator"""
sobol_model_spec = ModelSpec(
model_enum=Models.SOBOL,
model_kwargs={},
model_gen_kwargs={"n": 2},
)
gs_test = GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
model_specs=[sobol_model_spec],
),
GenerationNode(
node_name="node_2",
model_specs=[sobol_model_spec],
),
],
)
with self.assertRaisesRegex(
UnsupportedError, "is not supported for GenerationNode based"
):
gs_test.current_step_index

# ------------- Testing helpers (put tests above this line) -------------

def _run_GS_for_N_rounds(
Expand Down

0 comments on commit 621f017

Please sign in to comment.