Skip to content

Commit

Permalink
Remove unnecessary special-casing for qEHVI and qMES from `optimizer …
Browse files Browse the repository at this point in the history
…argparse`

Summary:
* Remove unnecessary special-casing for qEHVI and qMES from `optimizer argparse`. This wasn't doing anything, since the base case also supports the `sequential` argument.
* Name repeated default values.

Differential Revision: D59548819
  • Loading branch information
esantorella authored and facebook-github-bot committed Jul 9, 2024
1 parent 7aed6b0 commit a1ee4c4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 105 deletions.
63 changes: 18 additions & 45 deletions ax/models/torch/botorch_modular/optimizer_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@
from ax.utils.common.typeutils import _argparse_type_encoder
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy
from botorch.acquisition.multi_objective.monte_carlo import (
qExpectedHypervolumeImprovement,
)
from botorch.optim.initializers import gen_one_shot_kg_initial_conditions
from botorch.utils.dispatcher import Dispatcher

T = TypeVar("T")
MaybeType = Union[T, Type[T]] # Annotation for a type or instance thereof

# Acquisition defaults
NUM_RESTARTS = 20
RAW_SAMPLES = 1024
INIT_BATCH_LIMIT = 32
BATCH_LIMIT = 5


optimizer_argparse = Dispatcher(
name="optimizer_argparse", encoder=_argparse_type_encoder
Expand All @@ -35,10 +37,10 @@
def _argparse_base(
acqf: MaybeType[AcquisitionFunction],
sequential: bool = True,
num_restarts: int = 20,
raw_samples: int = 1024,
init_batch_limit: int = 32,
batch_limit: int = 5,
num_restarts: int = NUM_RESTARTS,
raw_samples: int = RAW_SAMPLES,
init_batch_limit: int = INIT_BATCH_LIMIT,
batch_limit: int = BATCH_LIMIT,
optimizer_options: Optional[Dict[str, Any]] = None,
optimizer_is_discrete: bool = False,
**ignore: Any,
Expand Down Expand Up @@ -92,40 +94,24 @@ def _argparse_base(
}


@optimizer_argparse.register(qExpectedHypervolumeImprovement)
def _argparse_ehvi(
acqf: MaybeType[qExpectedHypervolumeImprovement],
sequential: bool = True,
init_batch_limit: int = 32,
batch_limit: int = 5,
optimizer_options: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return {
**_argparse_base(
acqf=acqf,
init_batch_limit=init_batch_limit,
batch_limit=batch_limit,
optimizer_options=optimizer_options,
**kwargs,
),
"sequential": sequential,
}


@optimizer_argparse.register(qKnowledgeGradient)
def _argparse_kg(
acqf: qKnowledgeGradient,
q: int,
bounds: torch.Tensor,
num_restarts: int = 20,
raw_samples: int = 1024,
num_restarts: int = NUM_RESTARTS,
raw_samples: int = RAW_SAMPLES,
frac_random: float = 0.1,
optimizer_options: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""
Argument constructor for optimization with qKG, differing from the
base case in that it computes and returns initial conditions.
optimizer_options = optimizer_options or {}
To do so, it requires specifying additional arguments `q` and `bounds` and
allows for specifying `frac_random`.
"""
base_options = _argparse_base(
acqf,
num_restarts=num_restarts,
Expand All @@ -151,16 +137,3 @@ def _argparse_kg(
**base_options,
Keys.BATCH_INIT_CONDITIONS: initial_conditions,
}


@optimizer_argparse.register(qMaxValueEntropy)
def _argparse_mes(
acqf: AcquisitionFunction,
sequential: bool = True,
optimizer_options: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return {
**_argparse_base(acqf=acqf, optimizer_options=optimizer_options, **kwargs),
"sequential": sequential,
}
65 changes: 5 additions & 60 deletions ax/models/torch/tests/test_optimizer_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ax.models.torch.botorch_modular import optimizer_argparse as Argparse
from ax.models.torch.botorch_modular.optimizer_argparse import (
_argparse_base,
INIT_BATCH_LIMIT,
MaybeType,
optimizer_argparse,
)
Expand All @@ -24,14 +25,6 @@
qKnowledgeGradient,
qMultiFidelityKnowledgeGradient,
)
from botorch.acquisition.max_value_entropy_search import (
qMaxValueEntropy,
qMultiFidelityMaxValueEntropy,
)
from botorch.acquisition.multi_objective.monte_carlo import (
qExpectedHypervolumeImprovement,
qNoisyExpectedHypervolumeImprovement,
)


class DummyAcquisitionFunction(AcquisitionFunction):
Expand Down Expand Up @@ -65,17 +58,11 @@ def _argparse(acqf: MaybeType[DummyAcquisitionFunction]):
self.assertEqual(optimizer_argparse[DummyAcquisitionFunction], _argparse)

def test_optimizer_options(self) -> None:
skipped_funcs = { # These should all have bespoke tests
optimizer_argparse[acqf_class]
for acqf_class in (
qExpectedHypervolumeImprovement,
qKnowledgeGradient,
qMaxValueEntropy,
)
}
# This has a bespoke test
skipped_func = optimizer_argparse[qKnowledgeGradient]
user_options = {"foo": "bar", "num_restarts": 13}
for func in optimizer_argparse.funcs.values():
if func in skipped_funcs:
if func is skipped_func:
continue

parsed_options = func(None, optimizer_options=user_options)
Expand All @@ -89,39 +76,9 @@ def test_optimizer_options(self) -> None:
)
self.assertEqual(
parsed_options["options"],
{"batch_limit": 10, "init_batch_limit": 32, "maxiter": 20},
{"batch_limit": 10, "init_batch_limit": INIT_BATCH_LIMIT, "maxiter": 20},
)

def test_ehvi(self) -> None:
user_options = {"foo": "bar", "num_restarts": 651}
inner_options = {"init_batch_limit": 23, "batch_limit": 67}
generic_options = _argparse_base(None, optimizer_options=user_options)
generic_options.pop("options")
for acqf in (
qExpectedHypervolumeImprovement,
qNoisyExpectedHypervolumeImprovement,
):
with self.subTest(acqf=acqf):
options = optimizer_argparse(
acqf,
optimizer_options=user_options,
**inner_options,
)
self.assertEqual(options["sequential"], True)
self.assertEqual(options.pop("options"), inner_options)
self.assertEqual(options, generic_options)

# Defaults
options = optimizer_argparse(
acqf,
sequential=True,
optimizer_options=user_options,
)
self.assertEqual(
options.pop("options"), {"init_batch_limit": 32, "batch_limit": 5}
)
self.assertEqual(options, generic_options)

def test_kg(self) -> None:
with patch(
"botorch.optim.initializers.gen_one_shot_kg_initial_conditions"
Expand All @@ -141,15 +98,3 @@ def test_kg(self) -> None:
)
self.assertEqual(options.pop(Keys.BATCH_INIT_CONDITIONS), "TEST")
self.assertEqual(options, generic_options)

def test_mes(self) -> None:
user_options = {"foo": "bar", "num_restarts": 83}
generic_options = _argparse_base(None, optimizer_options=user_options)
for acqf in (qMaxValueEntropy, qMultiFidelityMaxValueEntropy):
with self.subTest(acqf=acqf):
options = optimizer_argparse(
acqf,
optimizer_options=user_options,
)
self.assertEqual(options["sequential"], True)
self.assertEqual(options, generic_options)

0 comments on commit a1ee4c4

Please sign in to comment.