Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix: Make MBM acquisition pass the correct args to each BoTorch optimizer; cleanup #2571

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 41 additions & 26 deletions ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,19 +425,48 @@ def optimize(
_tensorize = partial(torch.tensor, dtype=self.dtype, device=self.device)
ssd = search_space_digest
bounds = _tensorize(ssd.bounds).t()
discrete_features = sorted(ssd.ordinal_features + ssd.categorical_features)
discrete_choices = mk_discrete_choices(ssd=ssd, fixed_features=fixed_features)
if (
optimizer_options is not None
and "force_use_optimize_acqf" in optimizer_options
):
force_use_optimize_acqf = optimizer_options.pop("force_use_optimize_acqf")
else:
force_use_optimize_acqf = False

if (len(discrete_features) == 0) or force_use_optimize_acqf:
optimizer = "optimize_acqf"
else:
fully_discrete = len(discrete_choices) == len(ssd.feature_names)
if fully_discrete:
total_discrete_choices = reduce(
operator.mul, [float(len(c)) for c in discrete_choices.values()]
)
if total_discrete_choices > MAX_CHOICES_ENUMERATE:
optimizer = "optimize_acqf_discrete_local_search"
else:
optimizer = "optimize_acqf_discrete"
# `raw_samples` is not supported by `optimize_acqf_discrete`.
# TODO[santorella]: Rather than manually removing it, we should
# ensure that it is never passed.
if optimizer_options is not None:
optimizer_options.pop("raw_samples")
else:
optimizer = "optimize_acqf_mixed"

# Prepare arguments for optimizer
optimizer_options_with_defaults = optimizer_argparse(
self.acqf,
bounds=bounds,
q=n,
optimizer_options=optimizer_options,
optimizer=optimizer,
)
post_processing_func = get_post_processing_func(
rounding_func=rounding_func,
optimizer_options=optimizer_options_with_defaults,
)
discrete_features = sorted(ssd.ordinal_features + ssd.categorical_features)
if fixed_features is not None:
for i in fixed_features:
if not 0 <= i < len(ssd.feature_names):
Expand All @@ -446,10 +475,7 @@ def optimize(
# customized in subclasses if necessary.
arm_weights = torch.ones(n, dtype=self.dtype)
# 1. Handle the fully continuous search space.
if (
optimizer_options_with_defaults.pop("force_use_optimize_acqf", False)
or not discrete_features
):
if optimizer == "optimize_acqf":
candidates, acqf_values = optimize_acqf(
acq_function=self.acqf,
bounds=bounds,
Expand All @@ -462,10 +488,11 @@ def optimize(
return candidates, acqf_values, arm_weights

# 2. Handle search spaces with discrete features.
discrete_choices = mk_discrete_choices(ssd=ssd, fixed_features=fixed_features)

# 2a. Handle the fully discrete search space.
if len(discrete_choices) == len(ssd.feature_names):
if optimizer in (
"optimize_acqf_discrete",
"optimize_acqf_discrete_local_search",
):
X_observed = self.X_observed
if self.X_pending is not None:
if X_observed is None:
Expand All @@ -474,10 +501,7 @@ def optimize(
X_observed = torch.cat([X_observed, self.X_pending], dim=0)

# Special handling for search spaces with a large number of choices
total_choices = reduce(
operator.mul, [float(len(c)) for c in discrete_choices.values()]
)
if total_choices > MAX_CHOICES_ENUMERATE:
if optimizer == "optimize_acqf_discrete_local_search":
discrete_choices = [
torch.tensor(c, device=self.device, dtype=self.dtype)
for c in discrete_choices.values()
Expand All @@ -492,6 +516,7 @@ def optimize(
)
return candidates, acqf_values, arm_weights

# Else, optimizer is `optimize_acqf_discrete`
# Enumerate all possible choices
all_choices = (discrete_choices[i] for i in range(len(discrete_choices)))
all_choices = _tensorize(tuple(product(*all_choices)))
Expand Down Expand Up @@ -530,26 +555,16 @@ def optimize(
)
n = num_choices

# `raw_samples` is not supported by `optimize_acqf_discrete`.
# TODO[santorella]: Rather than manually removing it, we should
# ensure that it is never passed.
if optimizer_options is not None:
optimizer_options.pop("raw_samples")
discrete_opt_options = optimizer_argparse(
self.acqf,
bounds=bounds,
q=n,
optimizer_options=optimizer_options,
optimizer_is_discrete=True,
)
candidates, acqf_values = optimize_acqf_discrete(
acq_function=self.acqf, q=n, choices=all_choices, **discrete_opt_options
acq_function=self.acqf,
q=n,
choices=all_choices,
**optimizer_options_with_defaults,
)
return candidates, acqf_values, arm_weights

# 2b. Handle mixed search spaces that have discrete and continuous features.
# Only sequential optimization is supported for `optimize_acqf_mixed`.
optimizer_options_with_defaults.pop("sequential")
candidates, acqf_values = optimize_acqf_mixed(
acq_function=self.acqf,
bounds=bounds,
Expand Down
71 changes: 51 additions & 20 deletions ax/models/torch/botorch_modular/optimizer_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@
@optimizer_argparse.register(AcquisitionFunction)
def _argparse_base(
acqf: MaybeType[AcquisitionFunction],
*,
optimizer: str,
sequential: bool = True,
num_restarts: int = NUM_RESTARTS,
raw_samples: int = RAW_SAMPLES,
init_batch_limit: int = INIT_BATCH_LIMIT,
batch_limit: int = BATCH_LIMIT,
optimizer_options: Optional[Dict[str, Any]] = None,
optimizer_is_discrete: bool = False,
**ignore: Any,
) -> Dict[str, Any]:
"""Extract the base optimizer kwargs form the given arguments.
"""Extract the kwargs to be passed to a BoTorch optimizer.

NOTE: Since `optimizer_options` is how the user would typically pass in these
options, it takes precedence over other arguments. E.g., if both `num_restarts`
Expand All @@ -54,15 +55,24 @@ def _argparse_base(

Args:
acqf: The acquisition function being optimized.
sequential: Whether we choose one candidate at a time in a sequential manner.
optimizer: one of "optimize_acqf",
"optimize_acqf_discrete_local_search", "optimize_acqf_discrete", or
"optimize_acqf_mixed". This is generally chosen by
`Acquisition.optimize`.
sequential: Whether we choose one candidate at a time in a sequential
manner. Ignored unless the optimizer is `optimize_acqf`.
num_restarts: The number of starting points for multistart acquisition
function optimization.
raw_samples: The number of samples for initialization.
function optimization. Ignored if the optimizer is
`optimize_acqf_discrete`.
raw_samples: The number of samples for initialization. Ignored if the
optimizer is `optimize_acqf_discrete`.
init_batch_limit: The size of mini-batches used to evaluate the `raw_samples`.
This helps reduce peak memory usage.
This helps reduce peak memory usage. Ignored if the optimizer is
`optimize_acqf_discrete` or `optimize_acqf_discrete_local_search`.
batch_limit: The size of mini-batches used while optimizing the `acqf`.
This helps reduce peak memory usage.
optimizer_options: An optional dictionary of optimizer options. This may
This helps reduce peak memory usage. Ignored if the optimizer is
`optimize_acqf_discrete` or `optimize_acqf_discrete_local_search`.
optimizer_options: An optional dictionary of optimizer options. This may
include overrides for the above options (some of these under an `options`
dictionary) or any other option that is accepted by the optimizer. See
the docstrings in `botorch/optim/optimize.py` for supported options.
Expand All @@ -75,23 +85,43 @@ def _argparse_base(
>>> },
>>> "retry_on_optimization_warning": False,
>>> }
optimizer_is_discrete: True if the optimizer is `optimizer_acqf_discrete`,
which supports a limited set of arguments.
"""
optimizer_options = optimizer_options or {}
if optimizer_is_discrete:
return optimizer_options
return {
"sequential": sequential,
supported_optimizers = [
"optimize_acqf",
"optimize_acqf_discrete_local_search",
"optimize_acqf_discrete",
"optimize_acqf_mixed",
"optimize_acqf_homotopy",
]
if optimizer not in supported_optimizers:
raise ValueError(
f"optimizer=`{optimizer}` is not supported. Accepted options are "
f"{supported_optimizers}"
)
provided_options = optimizer_options if optimizer_options is not None else {}

# optimize_acqf_discrete only accepts 'choices', 'max_batch_size', 'unique'
if optimizer == "optimize_acqf_discrete":
return provided_options

# construct arguments from options that are not `provided_options`
options = {
"num_restarts": num_restarts,
"raw_samples": raw_samples,
"options": {
}
# if not, 'options' will be silently ignored
if optimizer in ["optimize_acqf", "optimize_acqf_mixed", "optimize_acqf_homotopy"]:
options["options"] = {
"init_batch_limit": init_batch_limit,
"batch_limit": batch_limit,
**optimizer_options.get("options", {}),
},
**{k: v for k, v in optimizer_options.items() if k != "options"},
}
**provided_options.get("options", {}),
}

if optimizer == "optimize_acqf":
options["sequential"] = sequential

options.update(**{k: v for k, v in provided_options.items() if k != "options"})
return options


@optimizer_argparse.register(qKnowledgeGradient)
Expand All @@ -117,6 +147,7 @@ def _argparse_kg(
num_restarts=num_restarts,
raw_samples=raw_samples,
optimizer_options=optimizer_options,
optimizer="optimize_acqf",
**kwargs,
)

Expand Down
5 changes: 3 additions & 2 deletions ax/models/torch/botorch_modular/sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
surrogate = surrogates[Keys.ONLY_SURROGATE]

tkwargs: Dict[str, Any] = {"dtype": surrogate.dtype, "device": surrogate.device}
options = options or {}
options = {} if options is None else options
self.penalty_name: str = options.pop("penalty", "L0_norm")
self.target_point: Tensor = options.pop("target_point", None)
if self.target_point is None:
Expand Down Expand Up @@ -296,9 +296,10 @@ def _optimize_with_homotopy(
bounds=bounds,
q=n,
optimizer_options=optimizer_options,
optimizer="optimize_acqf_homotopy",
)

def callback(): # pyre-ignore
def callback() -> None:
if (
self.acqf.cache_pending
): # If true, pending points are concatenated with X_baseline
Expand Down
Loading