Skip to content

Commit

Permalink
Bug fix: Make MBM acquisition pass the correct args to each BoTorch o…
Browse files Browse the repository at this point in the history
…ptimizer; cleanup (#2571)

Summary:
Pull Request resolved: #2571

# Context
The current flow is
1. MBM `Acquisition.optimize` constructs arguments passed to BoTorch optimizers in `optimizer_argparse` without knowing which of the four optimizers it will be passing the arguments to. It has only a boolean flag indicating whether the optimizer is discrete.
2. `Acquisition.optimize` decides which of the four optimizers to use and does optimizer specific-logic (these two parts are not really sequential).
3. `Acquisition.optimize` calls a BoTorch optimizer, passing some arguments that may not apply to the optimizer that actually got used.
4. Prior to pytorch/botorch#2390, the inappropriate arguments would be silently ignored, but now they raise an exception.

MBM can dispatch to four different BoTorch optimizers depending on the search space. Currently, `optimize_acqf_discrete_local_search` is failing because it is passed the inappropriate argument `sequential`.

# This diff
* Makes the flow within `Acquisition.optimize` more clear and changes `optimizer_argparse` so that inappropriate arguments such as `sequential` are not passed to optimizers they don't apply to:
1. `Acquisition.optimize` determines which of the four optimizers is appropriate.
2. `Acquisition.optimize` constructs arguments based on that optimizer, only constructing needed arguments.
 3. Then it does any optimizer-specific logic.
4. Then it calls a BoTorch optimizer; there is no longer an error because only appropriate arguments were passed.
* Extends unit tests for `optimizer_argparse` to check all optimizers
* Reduces the usage and scope of mocks in test_acquisition so that `optimize_acqf` and its variants are actually run as much as possible.

Reviewed By: saitcakmak

Differential Revision: D59354709

fbshipit-source-id: 88bdb464b6222cfb98f4263855288d3e0367ccc2
  • Loading branch information
esantorella authored and facebook-github-bot committed Jul 13, 2024
1 parent 3a06169 commit 1a6faa6
Show file tree
Hide file tree
Showing 5 changed files with 384 additions and 173 deletions.
67 changes: 41 additions & 26 deletions ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,19 +425,48 @@ def optimize(
_tensorize = partial(torch.tensor, dtype=self.dtype, device=self.device)
ssd = search_space_digest
bounds = _tensorize(ssd.bounds).t()
discrete_features = sorted(ssd.ordinal_features + ssd.categorical_features)
discrete_choices = mk_discrete_choices(ssd=ssd, fixed_features=fixed_features)
if (
optimizer_options is not None
and "force_use_optimize_acqf" in optimizer_options
):
force_use_optimize_acqf = optimizer_options.pop("force_use_optimize_acqf")
else:
force_use_optimize_acqf = False

if (len(discrete_features) == 0) or force_use_optimize_acqf:
optimizer = "optimize_acqf"
else:
fully_discrete = len(discrete_choices) == len(ssd.feature_names)
if fully_discrete:
total_discrete_choices = reduce(
operator.mul, [float(len(c)) for c in discrete_choices.values()]
)
if total_discrete_choices > MAX_CHOICES_ENUMERATE:
optimizer = "optimize_acqf_discrete_local_search"
else:
optimizer = "optimize_acqf_discrete"
# `raw_samples` is not supported by `optimize_acqf_discrete`.
# TODO[santorella]: Rather than manually removing it, we should
# ensure that it is never passed.
if optimizer_options is not None:
optimizer_options.pop("raw_samples")
else:
optimizer = "optimize_acqf_mixed"

# Prepare arguments for optimizer
optimizer_options_with_defaults = optimizer_argparse(
self.acqf,
bounds=bounds,
q=n,
optimizer_options=optimizer_options,
optimizer=optimizer,
)
post_processing_func = get_post_processing_func(
rounding_func=rounding_func,
optimizer_options=optimizer_options_with_defaults,
)
discrete_features = sorted(ssd.ordinal_features + ssd.categorical_features)
if fixed_features is not None:
for i in fixed_features:
if not 0 <= i < len(ssd.feature_names):
Expand All @@ -446,10 +475,7 @@ def optimize(
# customized in subclasses if necessary.
arm_weights = torch.ones(n, dtype=self.dtype)
# 1. Handle the fully continuous search space.
if (
optimizer_options_with_defaults.pop("force_use_optimize_acqf", False)
or not discrete_features
):
if optimizer == "optimize_acqf":
candidates, acqf_values = optimize_acqf(
acq_function=self.acqf,
bounds=bounds,
Expand All @@ -462,10 +488,11 @@ def optimize(
return candidates, acqf_values, arm_weights

# 2. Handle search spaces with discrete features.
discrete_choices = mk_discrete_choices(ssd=ssd, fixed_features=fixed_features)

# 2a. Handle the fully discrete search space.
if len(discrete_choices) == len(ssd.feature_names):
if optimizer in (
"optimize_acqf_discrete",
"optimize_acqf_discrete_local_search",
):
X_observed = self.X_observed
if self.X_pending is not None:
if X_observed is None:
Expand All @@ -474,10 +501,7 @@ def optimize(
X_observed = torch.cat([X_observed, self.X_pending], dim=0)

# Special handling for search spaces with a large number of choices
total_choices = reduce(
operator.mul, [float(len(c)) for c in discrete_choices.values()]
)
if total_choices > MAX_CHOICES_ENUMERATE:
if optimizer == "optimize_acqf_discrete_local_search":
discrete_choices = [
torch.tensor(c, device=self.device, dtype=self.dtype)
for c in discrete_choices.values()
Expand All @@ -492,6 +516,7 @@ def optimize(
)
return candidates, acqf_values, arm_weights

# Else, optimizer is `optimize_acqf_discrete`
# Enumerate all possible choices
all_choices = (discrete_choices[i] for i in range(len(discrete_choices)))
all_choices = _tensorize(tuple(product(*all_choices)))
Expand Down Expand Up @@ -530,26 +555,16 @@ def optimize(
)
n = num_choices

# `raw_samples` is not supported by `optimize_acqf_discrete`.
# TODO[santorella]: Rather than manually removing it, we should
# ensure that it is never passed.
if optimizer_options is not None:
optimizer_options.pop("raw_samples")
discrete_opt_options = optimizer_argparse(
self.acqf,
bounds=bounds,
q=n,
optimizer_options=optimizer_options,
optimizer_is_discrete=True,
)
candidates, acqf_values = optimize_acqf_discrete(
acq_function=self.acqf, q=n, choices=all_choices, **discrete_opt_options
acq_function=self.acqf,
q=n,
choices=all_choices,
**optimizer_options_with_defaults,
)
return candidates, acqf_values, arm_weights

# 2b. Handle mixed search spaces that have discrete and continuous features.
# Only sequential optimization is supported for `optimize_acqf_mixed`.
optimizer_options_with_defaults.pop("sequential")
candidates, acqf_values = optimize_acqf_mixed(
acq_function=self.acqf,
bounds=bounds,
Expand Down
71 changes: 51 additions & 20 deletions ax/models/torch/botorch_modular/optimizer_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@
@optimizer_argparse.register(AcquisitionFunction)
def _argparse_base(
acqf: MaybeType[AcquisitionFunction],
*,
optimizer: str,
sequential: bool = True,
num_restarts: int = NUM_RESTARTS,
raw_samples: int = RAW_SAMPLES,
init_batch_limit: int = INIT_BATCH_LIMIT,
batch_limit: int = BATCH_LIMIT,
optimizer_options: Optional[Dict[str, Any]] = None,
optimizer_is_discrete: bool = False,
**ignore: Any,
) -> Dict[str, Any]:
"""Extract the base optimizer kwargs form the given arguments.
"""Extract the kwargs to be passed to a BoTorch optimizer.
NOTE: Since `optimizer_options` is how the user would typically pass in these
options, it takes precedence over other arguments. E.g., if both `num_restarts`
Expand All @@ -54,15 +55,24 @@ def _argparse_base(
Args:
acqf: The acquisition function being optimized.
sequential: Whether we choose one candidate at a time in a sequential manner.
optimizer: one of "optimize_acqf",
"optimize_acqf_discrete_local_search", "optimize_acqf_discrete", or
"optimize_acqf_mixed". This is generally chosen by
`Acquisition.optimize`.
sequential: Whether we choose one candidate at a time in a sequential
manner. Ignored unless the optimizer is `optimize_acqf`.
num_restarts: The number of starting points for multistart acquisition
function optimization.
raw_samples: The number of samples for initialization.
function optimization. Ignored if the optimizer is
`optimize_acqf_discrete`.
raw_samples: The number of samples for initialization. Ignored if the
optimizer is `optimize_acqf_discrete`.
init_batch_limit: The size of mini-batches used to evaluate the `raw_samples`.
This helps reduce peak memory usage.
This helps reduce peak memory usage. Ignored if the optimizer is
`optimize_acqf_discrete` or `optimize_acqf_discrete_local_search`.
batch_limit: The size of mini-batches used while optimizing the `acqf`.
This helps reduce peak memory usage.
optimizer_options: An optional dictionary of optimizer options. This may
This helps reduce peak memory usage. Ignored if the optimizer is
`optimize_acqf_discrete` or `optimize_acqf_discrete_local_search`.
optimizer_options: An optional dictionary of optimizer options. This may
include overrides for the above options (some of these under an `options`
dictionary) or any other option that is accepted by the optimizer. See
the docstrings in `botorch/optim/optimize.py` for supported options.
Expand All @@ -75,23 +85,43 @@ def _argparse_base(
>>> },
>>> "retry_on_optimization_warning": False,
>>> }
optimizer_is_discrete: True if the optimizer is `optimizer_acqf_discrete`,
which supports a limited set of arguments.
"""
optimizer_options = optimizer_options or {}
if optimizer_is_discrete:
return optimizer_options
return {
"sequential": sequential,
supported_optimizers = [
"optimize_acqf",
"optimize_acqf_discrete_local_search",
"optimize_acqf_discrete",
"optimize_acqf_mixed",
"optimize_acqf_homotopy",
]
if optimizer not in supported_optimizers:
raise ValueError(
f"optimizer=`{optimizer}` is not supported. Accepted options are "
f"{supported_optimizers}"
)
provided_options = optimizer_options if optimizer_options is not None else {}

# optimize_acqf_discrete only accepts 'choices', 'max_batch_size', 'unique'
if optimizer == "optimize_acqf_discrete":
return provided_options

# construct arguments from options that are not `provided_options`
options = {
"num_restarts": num_restarts,
"raw_samples": raw_samples,
"options": {
}
# if not, 'options' will be silently ignored
if optimizer in ["optimize_acqf", "optimize_acqf_mixed", "optimize_acqf_homotopy"]:
options["options"] = {
"init_batch_limit": init_batch_limit,
"batch_limit": batch_limit,
**optimizer_options.get("options", {}),
},
**{k: v for k, v in optimizer_options.items() if k != "options"},
}
**provided_options.get("options", {}),
}

if optimizer == "optimize_acqf":
options["sequential"] = sequential

options.update(**{k: v for k, v in provided_options.items() if k != "options"})
return options


@optimizer_argparse.register(qKnowledgeGradient)
Expand All @@ -117,6 +147,7 @@ def _argparse_kg(
num_restarts=num_restarts,
raw_samples=raw_samples,
optimizer_options=optimizer_options,
optimizer="optimize_acqf",
**kwargs,
)

Expand Down
5 changes: 3 additions & 2 deletions ax/models/torch/botorch_modular/sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
surrogate = surrogates[Keys.ONLY_SURROGATE]

tkwargs: Dict[str, Any] = {"dtype": surrogate.dtype, "device": surrogate.device}
options = options or {}
options = {} if options is None else options
self.penalty_name: str = options.pop("penalty", "L0_norm")
self.target_point: Tensor = options.pop("target_point", None)
if self.target_point is None:
Expand Down Expand Up @@ -296,9 +296,10 @@ def _optimize_with_homotopy(
bounds=bounds,
q=n,
optimizer_options=optimizer_options,
optimizer="optimize_acqf_homotopy",
)

def callback(): # pyre-ignore
def callback() -> None:
if (
self.acqf.cache_pending
): # If true, pending points are concatenated with X_baseline
Expand Down
Loading

0 comments on commit 1a6faa6

Please sign in to comment.