Skip to content

Commit 4a1358f

Browse files
esantorellafacebook-github-bot
authored andcommitted
Introduce ParamBasedTestProblem for benchmarking (#2675)
Summary: Pull Request resolved: #2675 Context: In a future refactor that will enable more flexible and powerful best-point functionality, every BenchmarkProblem's runner will be able to produce an "oracle" value (possibly the ground truth) for any arm, in-sample or not, with a function like `BenchmarkRunner.evaluate_oracle(arm=arm)`, with the problem handling computation and the runner formatting results. However, the current `BenchmarkRunner` and `BenchmarkMetric` setup currently doesn't cover every benchmark. Consolidating on `BenchmarkRunner` and `BenchmarkMetric` will enable the refactor, make it easier to universalize functionality like handling of constraints, noise, and inference regret, and will also allow for deleting some LOC for more custom problems. Current `BenchmarkRunner`s only handle problems that can consume tensor-valued arguments: BoTorch synthetic problems and surrogate problems. This isn't a good fit for problems like Jenatton that have a hierarchical search space and can have some parameters not passed. Because Ax always passes parameters and only sometimes represents them as tensors, a `TParameterization` is a more natural abstraction to handle parameters than a tensor. This PR: - Introduces `ParamBasedTestProblem`, which is like a BoTorch synthetic test problem but consumes a `TParameterization` rather than a tensor - Added `ParamBasedProblemRunner`, which shares a base class `SyntheticProblemRunner` and most functionality with `BotorchTestProblemRunner` (so it is a `BenchmarkRunner` and supports both observed and unboserved noise). Differential Revision: D60996475
1 parent 631b89c commit 4a1358f

File tree

7 files changed

+341
-118
lines changed

7 files changed

+341
-118
lines changed

ax/benchmark/runners/base.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55

66
# pyre-strict
77

8-
from abc import ABC, abstractmethod, abstractproperty
8+
from abc import ABC, abstractmethod
9+
from collections.abc import Iterable
910
from math import sqrt
1011
from typing import Any, Union
1112

1213
import torch
1314
from ax.core.arm import Arm
14-
from ax.core.base_trial import BaseTrial
15+
16+
from ax.core.base_trial import BaseTrial, TrialStatus
1517
from ax.core.batch_trial import BatchTrial
1618
from ax.core.runner import Runner
1719
from ax.core.trial import Trial
@@ -39,10 +41,7 @@ class BenchmarkRunner(Runner, ABC):
3941
not over-engineer for that before such a use case arrives.
4042
"""
4143

42-
@abstractproperty
43-
def outcome_names(self) -> list[str]:
44-
"""The names of the outcomes of the problem (in the order of the outcomes)."""
45-
pass # pragma: no cover
44+
outcome_names: list[str]
4645

4746
def get_Y_true(self, arm: Arm) -> Tensor:
4847
"""
@@ -132,3 +131,9 @@ def run(self, trial: BaseTrial) -> dict[str, Any]:
132131
"Ys_true": Ys_true,
133132
}
134133
return run_metadata
134+
135+
# This will need to be udpated once asynchronous benchmarks are supported.
136+
def poll_trial_status(
137+
self, trials: Iterable[BaseTrial]
138+
) -> dict[TrialStatus, set[int]]:
139+
return {TrialStatus.COMPLETED: {t.index for t in trials}}

ax/benchmark/runners/botorch_test.py

+187-69
Original file line numberDiff line numberDiff line change
@@ -6,50 +6,85 @@
66
# pyre-strict
77

88
import importlib
9-
from collections.abc import Iterable
9+
from abc import ABC, abstractmethod
10+
from dataclasses import dataclass
1011
from typing import Any, Optional, Union
1112

1213
import torch
1314
from ax.benchmark.runners.base import BenchmarkRunner
1415
from ax.core.arm import Arm
15-
from ax.core.base_trial import BaseTrial, TrialStatus
16+
from ax.core.types import TParameterization
1617
from ax.utils.common.base import Base
1718
from ax.utils.common.equality import equality_typechecker
1819
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry
19-
from ax.utils.common.typeutils import checked_cast
20-
from botorch.test_functions.base import BaseTestProblem, ConstrainedBaseTestProblem
21-
from botorch.test_functions.multi_objective import MultiObjectiveTestProblem
20+
from botorch.test_functions.synthetic import (
21+
ConstrainedSyntheticTestFunction,
22+
SyntheticTestFunction,
23+
)
2224
from botorch.utils.transforms import normalize, unnormalize
25+
from pyre_extensions import assert_is_instance
2326
from torch import Tensor
2427

2528

26-
class BotorchTestProblemRunner(BenchmarkRunner):
27-
"""A Runner for evaluating Botorch BaseTestProblems.
29+
@dataclass(kw_only=True)
30+
class ParamBasedTestProblem(ABC):
31+
"""
32+
Similar to a BoTorch test problem, but evaluated using an Ax
33+
TParameterization rather than a tensor.
34+
"""
35+
36+
num_objectives: int
37+
optimal_value: float
38+
# Constraints could easily be supported similar to BoTorch test problems,
39+
# but haven't been hooked up.
40+
_is_constrained: bool = False
41+
constraint_noise_std: Optional[Union[float, list[float]]] = None
42+
noise_std: Optional[Union[float, list[float]]] = None
43+
negate: bool = False
44+
45+
@abstractmethod
46+
def evaluate_true(self, params: TParameterization) -> Tensor: ...
47+
48+
def evaluate_slack_true(self, params: TParameterization) -> Tensor:
49+
raise NotImplementedError(
50+
f"{self.__class__.__name__} does not support constraints."
51+
)
52+
53+
# pyre-fixme: Missing parameter annotation [2]: Parameter `other` must have
54+
# a type other than `Any`.
55+
def __eq__(self, other: Any) -> bool:
56+
if not isinstance(other, type(self)):
57+
return False
58+
return self.__class__.__name__ == other.__class__.__name__
59+
2860

29-
Given a trial the Runner will evaluate the BaseTestProblem.forward method for each
30-
arm in the trial, as well as return some metadata about the underlying Botorch
31-
problem such as the noise_std. We compute the full result on the Runner (as opposed
32-
to the Metric as is typical in synthetic test problems) because the BoTorch problem
33-
computes all metrics in one stacked tensor in the MOO case, and we wish to avoid
34-
recomputation per metric.
61+
class SyntheticProblemRunner(BenchmarkRunner, ABC):
62+
"""A Runner for evaluating synthetic problems, either BoTorch
63+
`SyntheticTestFunction`s or Ax benchmarking `ParamBasedTestProblem`s.
64+
65+
Given a trial, the Runner will evaluate the problem noiselessly for each
66+
arm in the trial, as well as return some metadata about the underlying
67+
problem such as the noise_std.
3568
"""
3669

37-
test_problem: BaseTestProblem
70+
test_problem: Union[SyntheticTestFunction, ParamBasedTestProblem]
3871
_is_constrained: bool
39-
_test_problem_class: type[BaseTestProblem]
72+
_test_problem_class: type[Union[SyntheticTestFunction, ParamBasedTestProblem]]
4073
_test_problem_kwargs: Optional[dict[str, Any]]
4174

4275
def __init__(
4376
self,
44-
test_problem_class: type[BaseTestProblem],
77+
*,
78+
test_problem_class: type[Union[SyntheticTestFunction, ParamBasedTestProblem]],
4579
test_problem_kwargs: dict[str, Any],
4680
outcome_names: list[str],
4781
modified_bounds: Optional[list[tuple[float, float]]] = None,
4882
) -> None:
4983
"""Initialize the test problem runner.
5084
5185
Args:
52-
test_problem_class: The BoTorch test problem class.
86+
test_problem_class: A BoTorch `SyntheticTestFunction` class or Ax
87+
`ParamBasedTestProblem` class.
5388
test_problem_kwargs: The keyword arguments used for initializing the
5489
test problem.
5590
outcome_names: The names of the outcomes returned by the problem.
@@ -63,28 +98,27 @@ def __init__(
6398
If modified bounds are not provided, the test problem will be
6499
evaluated using the raw parameter values.
65100
"""
66-
67101
self._test_problem_class = test_problem_class
68102
self._test_problem_kwargs = test_problem_kwargs
69-
70-
# pyre-fixme [45]: Invalid class instantiation
71-
self.test_problem = test_problem_class(**test_problem_kwargs).to(
72-
dtype=torch.double
103+
self.test_problem = (
104+
# pyre-fixme: Invalid class instantiation [45]: Cannot instantiate
105+
# abstract class with abstract method `evaluate_true`.
106+
test_problem_class(**test_problem_kwargs)
73107
)
108+
if isinstance(self.test_problem, SyntheticTestFunction):
109+
self.test_problem = self.test_problem.to(dtype=torch.double)
110+
# A `ConstrainedSyntheticTestFunction` is a type of `SyntheticTestFunction`; a
111+
# `ParamBasedTestProblem` is never constrained.
74112
self._is_constrained: bool = isinstance(
75-
self.test_problem, ConstrainedBaseTestProblem
113+
self.test_problem, ConstrainedSyntheticTestFunction
76114
)
77-
self._is_moo: bool = isinstance(self.test_problem, MultiObjectiveTestProblem)
78-
self._outcome_names = outcome_names
115+
self._is_moo: bool = self.test_problem.num_objectives > 1
116+
self.outcome_names = outcome_names
79117
self._modified_bounds = modified_bounds
80118

81-
@property
82-
def outcome_names(self) -> list[str]:
83-
return self._outcome_names
84-
85119
@equality_typechecker
86120
def __eq__(self, other: Base) -> bool:
87-
if not isinstance(other, BotorchTestProblemRunner):
121+
if not isinstance(other, type(self)):
88122
return False
89123

90124
return (
@@ -129,12 +163,95 @@ def get_noise_stds(self) -> Union[None, float, dict[str, float]]:
129163

130164
return noise_std_dict
131165

166+
@classmethod
167+
# pyre-fixme [2]: Parameter `obj` must have a type other than `Any``
168+
def serialize_init_args(cls, obj: Any) -> dict[str, Any]:
169+
"""Serialize the properties needed to initialize the runner.
170+
Used for storage.
171+
"""
172+
runner = assert_is_instance(obj, cls)
173+
174+
return {
175+
"test_problem_module": runner._test_problem_class.__module__,
176+
"test_problem_class_name": runner._test_problem_class.__name__,
177+
"test_problem_kwargs": runner._test_problem_kwargs,
178+
"outcome_names": runner.outcome_names,
179+
"modified_bounds": runner._modified_bounds,
180+
}
181+
182+
@classmethod
183+
def deserialize_init_args(
184+
cls,
185+
args: dict[str, Any],
186+
decoder_registry: Optional[TDecoderRegistry] = None,
187+
class_decoder_registry: Optional[TClassDecoderRegistry] = None,
188+
) -> dict[str, Any]:
189+
"""Given a dictionary, deserialize the properties needed to initialize the
190+
runner. Used for storage.
191+
"""
192+
193+
module = importlib.import_module(args["test_problem_module"])
194+
195+
return {
196+
"test_problem_class": getattr(module, args["test_problem_class_name"]),
197+
"test_problem_kwargs": args["test_problem_kwargs"],
198+
"outcome_names": args["outcome_names"],
199+
"modified_bounds": args["modified_bounds"],
200+
}
201+
202+
203+
class BotorchTestProblemRunner(SyntheticProblemRunner):
204+
"""
205+
A `SyntheticProblemRunner` for BoTorch `SyntheticTestFunction`s.
206+
207+
Args:
208+
test_problem_class: A BoTorch `SyntheticTestFunction` class.
209+
test_problem_kwargs: The keyword arguments used for initializing the
210+
test problem.
211+
outcome_names: The names of the outcomes returned by the problem.
212+
modified_bounds: The bounds that are used by the Ax search space
213+
while optimizing the problem. If different from the bounds of the
214+
test problem, we project the parameters into the test problem
215+
bounds before evaluating the test problem.
216+
For example, if the test problem is defined on [0, 1] but the Ax
217+
search space is integers in [0, 10], an Ax parameter value of
218+
5 will correspond to 0.5 while evaluating the test problem.
219+
If modified bounds are not provided, the test problem will be
220+
evaluated using the raw parameter values.
221+
"""
222+
223+
def __init__(
224+
self,
225+
*,
226+
test_problem_class: type[SyntheticTestFunction],
227+
test_problem_kwargs: dict[str, Any],
228+
outcome_names: list[str],
229+
modified_bounds: Optional[list[tuple[float, float]]] = None,
230+
) -> None:
231+
super().__init__(
232+
test_problem_class=test_problem_class,
233+
test_problem_kwargs=test_problem_kwargs,
234+
outcome_names=outcome_names,
235+
modified_bounds=modified_bounds,
236+
)
237+
self.test_problem: SyntheticTestFunction = self.test_problem.to(
238+
dtype=torch.double
239+
)
240+
self._is_constrained: bool = isinstance(
241+
self.test_problem, ConstrainedSyntheticTestFunction
242+
)
243+
132244
def get_Y_true(self, arm: Arm) -> Tensor:
133-
"""Converts X to original bounds -- only if modified bounds were provided --
134-
and evaluates the test problem. See `__init__` docstring for details.
245+
"""
246+
Convert the arm to a tensor and evaluate it on the base test problem.
247+
248+
Convert the tensor to original bounds -- only if modified bounds were
249+
provided -- and evaluates the test problem. See the docstring for
250+
`modified_bounds` in `BotorchTestProblemRunner.__init__` for details.
135251
136252
Args:
137-
X: A `batch_shape x d`-dim tensor of point(s) at which to evaluate the
253+
arm: Arm to evaluate. It will be converted to a
254+
`batch_shape x d`-dim tensor of point(s) at which to evaluate the
138255
test problem.
139256
140257
Returns:
@@ -157,7 +274,7 @@ def get_Y_true(self, arm: Arm) -> Tensor:
157274
X = unnormalize(unit_X, self.test_problem.bounds)
158275

159276
Y_true = self.test_problem.evaluate_true(X).view(-1)
160-
# `BaseTestProblem.evaluate_true()` does not negate the outcome
277+
# `SyntheticTestFunction.evaluate_true()` does not negate the outcome
161278
if self.test_problem.negate:
162279
Y_true = -Y_true
163280

@@ -171,43 +288,44 @@ def get_Y_true(self, arm: Arm) -> Tensor:
171288

172289
return Y_true
173290

174-
def poll_trial_status(
175-
self, trials: Iterable[BaseTrial]
176-
) -> dict[TrialStatus, set[int]]:
177-
return {TrialStatus.COMPLETED: {t.index for t in trials}}
178291

179-
@classmethod
180-
# pyre-fixme [2]: Parameter `obj` must have a type other than `Any``
181-
def serialize_init_args(cls, obj: Any) -> dict[str, Any]:
182-
"""Serialize the properties needed to initialize the runner.
183-
Used for storage.
184-
"""
185-
runner = checked_cast(BotorchTestProblemRunner, obj)
292+
class ParamBasedTestProblemRunner(SyntheticProblemRunner):
293+
"""
294+
A `SyntheticProblemRunner` for `ParamBasedTestProblem`s. See
295+
`SyntheticProblemRunner` for more information.
296+
"""
186297

187-
return {
188-
"test_problem_module": runner._test_problem_class.__module__,
189-
"test_problem_class_name": runner._test_problem_class.__name__,
190-
"test_problem_kwargs": runner._test_problem_kwargs,
191-
"outcome_names": runner._outcome_names,
192-
"modified_bounds": runner._modified_bounds,
193-
}
298+
# This could easily be supported, but hasn't been hooked up
299+
_is_constrained: bool = False
194300

195-
@classmethod
196-
def deserialize_init_args(
197-
cls,
198-
args: dict[str, Any],
199-
decoder_registry: Optional[TDecoderRegistry] = None,
200-
class_decoder_registry: Optional[TClassDecoderRegistry] = None,
201-
) -> dict[str, Any]:
202-
"""Given a dictionary, deserialize the properties needed to initialize the
203-
runner. Used for storage.
204-
"""
301+
def __init__(
302+
self,
303+
*,
304+
test_problem_class: type[ParamBasedTestProblem],
305+
test_problem_kwargs: dict[str, Any],
306+
outcome_names: list[str],
307+
modified_bounds: Optional[list[tuple[float, float]]] = None,
308+
) -> None:
309+
if modified_bounds is not None:
310+
raise NotImplementedError(
311+
f"modified_bounds is not supported for {test_problem_class.__name__}"
312+
)
313+
super().__init__(
314+
test_problem_class=test_problem_class,
315+
test_problem_kwargs=test_problem_kwargs,
316+
outcome_names=outcome_names,
317+
modified_bounds=modified_bounds,
318+
)
319+
self.test_problem: ParamBasedTestProblem = self.test_problem
205320

206-
module = importlib.import_module(args["test_problem_module"])
321+
def get_Y_true(self, arm: Arm) -> Tensor:
322+
"""Evaluates the test problem.
207323
208-
return {
209-
"test_problem_class": getattr(module, args["test_problem_class_name"]),
210-
"test_problem_kwargs": args["test_problem_kwargs"],
211-
"outcome_names": args["outcome_names"],
212-
"modified_bounds": args["modified_bounds"],
213-
}
324+
Returns:
325+
A `batch_shape x m`-dim tensor of ground truth (noiseless) evaluations.
326+
"""
327+
Y_true = self.test_problem.evaluate_true(arm.parameters).view(-1)
328+
# `ParamBasedTestProblem.evaluate_true()` does not negate the outcome
329+
if self.test_problem.negate:
330+
Y_true = -Y_true
331+
return Y_true

0 commit comments

Comments
 (0)