Skip to content

Commit

Permalink
Bug fix: Make MBM acquisition pass the correct args to each BoTorch o…
Browse files Browse the repository at this point in the history
…ptimizer; cleanup (#2571)

Summary:
Pull Request resolved: #2571

Motivation: MBM can dispatch to four different BoTorch optimizers depending on the search space. Currently, `optimize_acqf_discrete_local_search` is failing because it is passed the inappropriate argument `sequential`.

This diff
* Makes the logic more clear: First, Acquisition.optimize determines which of the four optimizers is appropriate. then it constructs arguments based on that optimizer. Then it constructs arguments for the optimizers using `optimizer_argparse`. Then it does any optimizer_specific logic and calls the optimizer.
* Fixes optimizer_argparse so that inappropriate arguments such as `sequential` are not passed to optimizers they don't apply to.
* Removes special-casing for qNEHVI and qMES in optimizer_argparse that wasn't actually doing anything
* Extends unit tests for `optimizer_argparse` to check all optimizers
* Reduces the usage and scope of mocks in test_acquisition so that `optimize_acqf` and its variants are actually run as much as possible.

Differential Revision: D59354709
  • Loading branch information
esantorella authored and facebook-github-bot committed Jul 9, 2024
1 parent a1ee4c4 commit 6b0ad55
Show file tree
Hide file tree
Showing 4 changed files with 372 additions and 164 deletions.
67 changes: 41 additions & 26 deletions ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,19 +425,48 @@ def optimize(
_tensorize = partial(torch.tensor, dtype=self.dtype, device=self.device)
ssd = search_space_digest
bounds = _tensorize(ssd.bounds).t()
discrete_features = sorted(ssd.ordinal_features + ssd.categorical_features)
discrete_choices = mk_discrete_choices(ssd=ssd, fixed_features=fixed_features)
if (
optimizer_options is not None
and "force_use_optimize_acqf" in optimizer_options
):
force_use_optimize_acqf = optimizer_options.pop("force_use_optimize_acqf")
else:
force_use_optimize_acqf = False

if (len(discrete_features) == 0) or force_use_optimize_acqf:
optimizer = "optimize_acqf"
else:
fully_discrete = len(discrete_choices) == len(ssd.feature_names)
if fully_discrete:
total_discrete_choices = reduce(
operator.mul, [float(len(c)) for c in discrete_choices.values()]
)
if total_discrete_choices > MAX_CHOICES_ENUMERATE:
optimizer = "optimize_acqf_discrete_local_search"
else:
optimizer = "optimize_acqf_discrete"
# `raw_samples` is not supported by `optimize_acqf_discrete`.
# TODO[santorella]: Rather than manually removing it, we should
# ensure that it is never passed.
if optimizer_options is not None:
optimizer_options.pop("raw_samples")
else:
optimizer = "optimize_acqf_mixed"

# Prepare arguments for optimizer
optimizer_options_with_defaults = optimizer_argparse(
self.acqf,
bounds=bounds,
q=n,
optimizer_options=optimizer_options,
optimizer=optimizer,
)
post_processing_func = get_post_processing_func(
rounding_func=rounding_func,
optimizer_options=optimizer_options_with_defaults,
)
discrete_features = sorted(ssd.ordinal_features + ssd.categorical_features)
if fixed_features is not None:
for i in fixed_features:
if not 0 <= i < len(ssd.feature_names):
Expand All @@ -446,10 +475,7 @@ def optimize(
# customized in subclasses if necessary.
arm_weights = torch.ones(n, dtype=self.dtype)
# 1. Handle the fully continuous search space.
if (
optimizer_options_with_defaults.pop("force_use_optimize_acqf", False)
or not discrete_features
):
if optimizer == "optimize_acqf":
candidates, acqf_values = optimize_acqf(
acq_function=self.acqf,
bounds=bounds,
Expand All @@ -462,10 +488,11 @@ def optimize(
return candidates, acqf_values, arm_weights

# 2. Handle search spaces with discrete features.
discrete_choices = mk_discrete_choices(ssd=ssd, fixed_features=fixed_features)

# 2a. Handle the fully discrete search space.
if len(discrete_choices) == len(ssd.feature_names):
if optimizer in (
"optimize_acqf_discrete",
"optimize_acqf_discrete_local_search",
):
X_observed = self.X_observed
if self.X_pending is not None:
if X_observed is None:
Expand All @@ -474,10 +501,7 @@ def optimize(
X_observed = torch.cat([X_observed, self.X_pending], dim=0)

# Special handling for search spaces with a large number of choices
total_choices = reduce(
operator.mul, [float(len(c)) for c in discrete_choices.values()]
)
if total_choices > MAX_CHOICES_ENUMERATE:
if optimizer == "optimize_acqf_discrete_local_search":
discrete_choices = [
torch.tensor(c, device=self.device, dtype=self.dtype)
for c in discrete_choices.values()
Expand All @@ -492,6 +516,7 @@ def optimize(
)
return candidates, acqf_values, arm_weights

# Else, optimizer is `optimize_acqf_discrete`
# Enumerate all possible choices
all_choices = (discrete_choices[i] for i in range(len(discrete_choices)))
all_choices = _tensorize(tuple(product(*all_choices)))
Expand Down Expand Up @@ -530,26 +555,16 @@ def optimize(
)
n = num_choices

# `raw_samples` is not supported by `optimize_acqf_discrete`.
# TODO[santorella]: Rather than manually removing it, we should
# ensure that it is never passed.
if optimizer_options is not None:
optimizer_options.pop("raw_samples")
discrete_opt_options = optimizer_argparse(
self.acqf,
bounds=bounds,
q=n,
optimizer_options=optimizer_options,
optimizer_is_discrete=True,
)
candidates, acqf_values = optimize_acqf_discrete(
acq_function=self.acqf, q=n, choices=all_choices, **discrete_opt_options
acq_function=self.acqf,
q=n,
choices=all_choices,
**optimizer_options_with_defaults,
)
return candidates, acqf_values, arm_weights

# 2b. Handle mixed search spaces that have discrete and continuous features.
# Only sequential optimization is supported for `optimize_acqf_mixed`.
optimizer_options_with_defaults.pop("sequential")
candidates, acqf_values = optimize_acqf_mixed(
acq_function=self.acqf,
bounds=bounds,
Expand Down
59 changes: 45 additions & 14 deletions ax/models/torch/botorch_modular/optimizer_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,23 @@ def _argparse_base(
init_batch_limit: int = INIT_BATCH_LIMIT,
batch_limit: int = BATCH_LIMIT,
optimizer_options: Optional[Dict[str, Any]] = None,
optimizer_is_discrete: bool = False,
*,
optimizer: str = "optimize_acqf",
**ignore: Any,
) -> Dict[str, Any]:
"""Extract the base optimizer kwargs form the given arguments.
"""Extract the kwargs to be passed to a BoTorch optimizer.
NOTE: Since `optimizer_options` is how the user would typically pass in these
options, it takes precedence over other arguments. E.g., if both `num_restarts`
and `optimizer_options["num_restarts"]` are provided, this will use
`num_restarts` from `optimizer_options`.
NOTE: Arguments in `**ignore` are ignored; in addition, any arguments
specified in `optimizer_options` that are supported for some optimizers and
not others can be silently ignored. For example, `sequential` will not be
passed when the `optimizer` is `optimize_acqf_discrete`, since
`optimize_acqf_discrete` does not support it.
Args:
acqf: The acquisition function being optimized.
sequential: Whether we choose one candidate at a time in a sequential manner.
Expand All @@ -75,23 +82,46 @@ def _argparse_base(
>>> },
>>> "retry_on_optimization_warning": False,
>>> }
optimizer_is_discrete: True if the optimizer is `optimizer_acqf_discrete`,
which supports a limited set of arguments.
optimizer: one of "optimize_acqf",
"optimize_acqf_discrete_local_search", "optimize_acqf_discrete", or
"optimize_acqf_mixed". This is generally chosen by
`Acquisition.optimize`.
"""
optimizer_options = optimizer_options or {}
if optimizer_is_discrete:
return optimizer_options
return {
"sequential": sequential,
supported_optimizers = [
"optimize_acqf",
"optimize_acqf_discrete_local_search",
"optimize_acqf_discrete",
"optimize_acqf_mixed",
]
if optimizer not in supported_optimizers:
raise ValueError(
f"optimizer=`{optimizer}` is not supported. Accepted options are "
f"{supported_optimizers}"
)
provided_options = optimizer_options if optimizer_options is not None else {}

# construct arguments from options that are not `provided_options`
options = {
"num_restarts": num_restarts,
"raw_samples": raw_samples,
"options": {
}
# if not, 'options' will be silently ignored
if optimizer in ["optimize_acqf", "optimize_acqf_mixed"]:
options["options"] = {
"init_batch_limit": init_batch_limit,
"batch_limit": batch_limit,
**optimizer_options.get("options", {}),
},
**{k: v for k, v in optimizer_options.items() if k != "options"},
}
**provided_options.get("options", {}),
}

if optimizer == "optimize_acqf":
options["sequential"] = sequential

# optimize_acqf_discrete only accepts 'choices', 'max_batch_size', 'unique'
if optimizer == "optimize_acqf_discrete":
return provided_options

options.update(**{k: v for k, v in provided_options.items() if k != "options"})
return options


@optimizer_argparse.register(qKnowledgeGradient)
Expand All @@ -117,6 +147,7 @@ def _argparse_kg(
num_restarts=num_restarts,
raw_samples=raw_samples,
optimizer_options=optimizer_options,
optimizer="optimize_acqf",
**kwargs,
)

Expand Down
Loading

0 comments on commit 6b0ad55

Please sign in to comment.