Skip to content

Commit

Permalink
Merge d3906b0 into 5dbc654
Browse files Browse the repository at this point in the history
  • Loading branch information
mgarrard authored Dec 7, 2023
2 parents 5dbc654 + d3906b0 commit b3614d4
Show file tree
Hide file tree
Showing 15 changed files with 570 additions and 75 deletions.
11 changes: 11 additions & 0 deletions ax/exceptions/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,14 @@ class GenerationStrategyRepeatedPoints(GenerationStrategyCompleted):
"""

pass


class GenerationStrategyMisconfiguredException(AxError):
"""Special exception indicating that the generation strategy is misconfigured."""

def __init__(self, error_info: Optional[str]) -> None:
super().__init__(
"This GenerationStrategy was unable to be initialized properly. Please "
+ "check the documentation, and adjust the configuration accordingly. "
+ f"{error_info}"
)
1 change: 1 addition & 0 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class GenerationNode:

# Optional specifications
_model_spec_to_gen_from: Optional[ModelSpec] = None
# TODO: @mgarrard should this be a dict criterion_class name -> criterion mapping?
_transition_criteria: Optional[Sequence[TransitionCriterion]]

# [TODO] Handle experiment passing more eloquently by enforcing experiment
Expand Down
294 changes: 239 additions & 55 deletions ax/modelbridge/generation_strategy.py

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions ax/modelbridge/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
CVResult,
)
from ax.modelbridge.registry import ModelRegistryBase
from ax.utils.common.base import Base
from ax.utils.common.base import SortableBase
from ax.utils.common.kwargs import (
consolidate_kwargs,
filter_kwargs,
Expand All @@ -48,7 +48,7 @@ def default(self, o: Any) -> str:


@dataclass
class ModelSpec(Base):
class ModelSpec(SortableBase):
model_enum: ModelRegistryBase
# Kwargs to pass into the `Model` + `ModelBridge` constructors in
# `ModelRegistryBase.__call__`.
Expand Down Expand Up @@ -288,6 +288,12 @@ def __hash__(self) -> int:
def __eq__(self, other: ModelSpec) -> bool:
return repr(self) == repr(other)

@property
def _unique_id(self) -> str:
"""Returns the unique ID of this model spec"""
# TODO @mgarrard verify that this is unique enough
return str(hash(self))


@dataclass
class FactoryFunctionModelSpec(ModelSpec):
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/tests/test_completion_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_single_criterion(self) -> None:
)
)

self.assertEqual(generation_strategy._curr.model, Models.GPEI)
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)

def test_many_criteria(self) -> None:
criteria = [
Expand Down Expand Up @@ -145,4 +145,4 @@ def test_many_criteria(self) -> None:
)
)

self.assertEqual(generation_strategy._curr.model, Models.GPEI)
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)
Loading

0 comments on commit b3614d4

Please sign in to comment.