6
6
# pyre-strict
7
7
8
8
import importlib
9
- from collections .abc import Iterable
9
+ from abc import ABC , abstractmethod
10
+ from dataclasses import dataclass
10
11
from typing import Any , Optional , Union
11
12
12
13
import torch
13
14
from ax .benchmark .runners .base import BenchmarkRunner
14
15
from ax .core .arm import Arm
15
- from ax .core .base_trial import BaseTrial , TrialStatus
16
+ from ax .core .types import TParameterization
16
17
from ax .utils .common .base import Base
17
18
from ax .utils .common .equality import equality_typechecker
18
19
from ax .utils .common .serialization import TClassDecoderRegistry , TDecoderRegistry
19
- from ax .utils .common .typeutils import checked_cast
20
- from botorch .test_functions .base import BaseTestProblem , ConstrainedBaseTestProblem
21
- from botorch .test_functions .multi_objective import MultiObjectiveTestProblem
20
+ from botorch .test_functions .synthetic import (
21
+ ConstrainedSyntheticTestFunction ,
22
+ SyntheticTestFunction ,
23
+ )
22
24
from botorch .utils .transforms import normalize , unnormalize
25
+ from pyre_extensions import assert_is_instance
23
26
from torch import Tensor
24
27
25
28
26
- class BotorchTestProblemRunner (BenchmarkRunner ):
27
- """A Runner for evaluating Botorch BaseTestProblems.
29
+ @dataclass (kw_only = True )
30
+ class ParamBasedTestProblem (ABC ):
31
+ """
32
+ Similar to a BoTorch test problem, but evaluated using an Ax
33
+ TParameterization rather than a tensor.
34
+ """
35
+
36
+ num_objectives : int
37
+ optimal_value : float
38
+ # Constraints could easily be supported similar to BoTorch test problems,
39
+ # but haven't been hooked up.
40
+ _is_constrained : bool = False
41
+ constraint_noise_std : Optional [Union [float , list [float ]]] = None
42
+ noise_std : Optional [Union [float , list [float ]]] = None
43
+ negate : bool = False
44
+
45
+ @abstractmethod
46
+ def evaluate_true (self , params : TParameterization ) -> Tensor : ...
47
+
48
+ def evaluate_slack_true (self , params : TParameterization ) -> Tensor :
49
+ raise NotImplementedError (
50
+ f"{ self .__class__ .__name__ } does not support constraints."
51
+ )
52
+
53
+ # pyre-fixme: Missing parameter annotation [2]: Parameter `other` must have
54
+ # a type other than `Any`.
55
+ def __eq__ (self , other : Any ) -> bool :
56
+ if not isinstance (other , type (self )):
57
+ return False
58
+ return self .__class__ .__name__ == other .__class__ .__name__
59
+
28
60
29
- Given a trial the Runner will evaluate the BaseTestProblem.forward method for each
30
- arm in the trial, as well as return some metadata about the underlying Botorch
31
- problem such as the noise_std. We compute the full result on the Runner (as opposed
32
- to the Metric as is typical in synthetic test problems) because the BoTorch problem
33
- computes all metrics in one stacked tensor in the MOO case, and we wish to avoid
34
- recomputation per metric.
61
+ class SyntheticProblemRunner (BenchmarkRunner , ABC ):
62
+ """A Runner for evaluating synthetic problems, either BoTorch
63
+ `SyntheticTestFunction`s or Ax benchmarking `ParamBasedTestProblem`s.
64
+
65
+ Given a trial, the Runner will evaluate the problem noiselessly for each
66
+ arm in the trial, as well as return some metadata about the underlying
67
+ problem such as the noise_std.
35
68
"""
36
69
37
- test_problem : BaseTestProblem
70
+ test_problem : Union [ SyntheticTestFunction , ParamBasedTestProblem ]
38
71
_is_constrained : bool
39
- _test_problem_class : type [BaseTestProblem ]
72
+ _test_problem_class : type [Union [ SyntheticTestFunction , ParamBasedTestProblem ] ]
40
73
_test_problem_kwargs : Optional [dict [str , Any ]]
41
74
42
75
def __init__ (
43
76
self ,
44
- test_problem_class : type [BaseTestProblem ],
77
+ * ,
78
+ test_problem_class : type [Union [SyntheticTestFunction , ParamBasedTestProblem ]],
45
79
test_problem_kwargs : dict [str , Any ],
46
80
outcome_names : list [str ],
47
81
modified_bounds : Optional [list [tuple [float , float ]]] = None ,
48
82
) -> None :
49
83
"""Initialize the test problem runner.
50
84
51
85
Args:
52
- test_problem_class: The BoTorch test problem class.
86
+ test_problem_class: A BoTorch `SyntheticTestFunction` class or Ax
87
+ `ParamBasedTestProblem` class.
53
88
test_problem_kwargs: The keyword arguments used for initializing the
54
89
test problem.
55
90
outcome_names: The names of the outcomes returned by the problem.
@@ -63,28 +98,27 @@ def __init__(
63
98
If modified bounds are not provided, the test problem will be
64
99
evaluated using the raw parameter values.
65
100
"""
66
-
67
101
self ._test_problem_class = test_problem_class
68
102
self ._test_problem_kwargs = test_problem_kwargs
69
-
70
- # pyre-fixme [45] : Invalid class instantiation
71
- self . test_problem = test_problem_class ( ** test_problem_kwargs ). to (
72
- dtype = torch . double
103
+ self . test_problem = (
104
+ # pyre-fixme: Invalid class instantiation [45]: Cannot instantiate
105
+ # abstract class with abstract method `evaluate_true`.
106
+ test_problem_class ( ** test_problem_kwargs )
73
107
)
108
+ if isinstance (self .test_problem , SyntheticTestFunction ):
109
+ self .test_problem = self .test_problem .to (dtype = torch .double )
110
+ # A `ConstrainedSyntheticTestFunction` is a type of `SyntheticTestFunction`; a
111
+ # `ParamBasedTestProblem` is never constrained.
74
112
self ._is_constrained : bool = isinstance (
75
- self .test_problem , ConstrainedBaseTestProblem
113
+ self .test_problem , ConstrainedSyntheticTestFunction
76
114
)
77
- self ._is_moo : bool = isinstance ( self .test_problem , MultiObjectiveTestProblem )
78
- self ._outcome_names = outcome_names
115
+ self ._is_moo : bool = self .test_problem . num_objectives > 1
116
+ self .outcome_names = outcome_names
79
117
self ._modified_bounds = modified_bounds
80
118
81
- @property
82
- def outcome_names (self ) -> list [str ]:
83
- return self ._outcome_names
84
-
85
119
@equality_typechecker
86
120
def __eq__ (self , other : Base ) -> bool :
87
- if not isinstance (other , BotorchTestProblemRunner ):
121
+ if not isinstance (other , type ( self ) ):
88
122
return False
89
123
90
124
return (
@@ -129,12 +163,95 @@ def get_noise_stds(self) -> Union[None, float, dict[str, float]]:
129
163
130
164
return noise_std_dict
131
165
166
+ @classmethod
167
+ # pyre-fixme [2]: Parameter `obj` must have a type other than `Any``
168
+ def serialize_init_args (cls , obj : Any ) -> dict [str , Any ]:
169
+ """Serialize the properties needed to initialize the runner.
170
+ Used for storage.
171
+ """
172
+ runner = assert_is_instance (obj , cls )
173
+
174
+ return {
175
+ "test_problem_module" : runner ._test_problem_class .__module__ ,
176
+ "test_problem_class_name" : runner ._test_problem_class .__name__ ,
177
+ "test_problem_kwargs" : runner ._test_problem_kwargs ,
178
+ "outcome_names" : runner .outcome_names ,
179
+ "modified_bounds" : runner ._modified_bounds ,
180
+ }
181
+
182
+ @classmethod
183
+ def deserialize_init_args (
184
+ cls ,
185
+ args : dict [str , Any ],
186
+ decoder_registry : Optional [TDecoderRegistry ] = None ,
187
+ class_decoder_registry : Optional [TClassDecoderRegistry ] = None ,
188
+ ) -> dict [str , Any ]:
189
+ """Given a dictionary, deserialize the properties needed to initialize the
190
+ runner. Used for storage.
191
+ """
192
+
193
+ module = importlib .import_module (args ["test_problem_module" ])
194
+
195
+ return {
196
+ "test_problem_class" : getattr (module , args ["test_problem_class_name" ]),
197
+ "test_problem_kwargs" : args ["test_problem_kwargs" ],
198
+ "outcome_names" : args ["outcome_names" ],
199
+ "modified_bounds" : args ["modified_bounds" ],
200
+ }
201
+
202
+
203
+ class BotorchTestProblemRunner (SyntheticProblemRunner ):
204
+ """
205
+ A `SyntheticProblemRunner` for BoTorch `SyntheticTestFunction`s.
206
+
207
+ Args:
208
+ test_problem_class: A BoTorch `SyntheticTestFunction` class.
209
+ test_problem_kwargs: The keyword arguments used for initializing the
210
+ test problem.
211
+ outcome_names: The names of the outcomes returned by the problem.
212
+ modified_bounds: The bounds that are used by the Ax search space
213
+ while optimizing the problem. If different from the bounds of the
214
+ test problem, we project the parameters into the test problem
215
+ bounds before evaluating the test problem.
216
+ For example, if the test problem is defined on [0, 1] but the Ax
217
+ search space is integers in [0, 10], an Ax parameter value of
218
+ 5 will correspond to 0.5 while evaluating the test problem.
219
+ If modified bounds are not provided, the test problem will be
220
+ evaluated using the raw parameter values.
221
+ """
222
+
223
+ def __init__ (
224
+ self ,
225
+ * ,
226
+ test_problem_class : type [SyntheticTestFunction ],
227
+ test_problem_kwargs : dict [str , Any ],
228
+ outcome_names : list [str ],
229
+ modified_bounds : Optional [list [tuple [float , float ]]] = None ,
230
+ ) -> None :
231
+ super ().__init__ (
232
+ test_problem_class = test_problem_class ,
233
+ test_problem_kwargs = test_problem_kwargs ,
234
+ outcome_names = outcome_names ,
235
+ modified_bounds = modified_bounds ,
236
+ )
237
+ self .test_problem : SyntheticTestFunction = self .test_problem .to (
238
+ dtype = torch .double
239
+ )
240
+ self ._is_constrained : bool = isinstance (
241
+ self .test_problem , ConstrainedSyntheticTestFunction
242
+ )
243
+
132
244
def get_Y_true (self , arm : Arm ) -> Tensor :
133
- """Converts X to original bounds -- only if modified bounds were provided --
134
- and evaluates the test problem. See `__init__` docstring for details.
245
+ """
246
+ Convert the arm to a tensor and evaluate it on the base test problem.
247
+
248
+ Convert the tensor to original bounds -- only if modified bounds were
249
+ provided -- and evaluates the test problem. See the docstring for
250
+ `modified_bounds` in `BotorchTestProblemRunner.__init__` for details.
135
251
136
252
Args:
137
- X: A `batch_shape x d`-dim tensor of point(s) at which to evaluate the
253
+ arm: Arm to evaluate. It will be converted to a
254
+ `batch_shape x d`-dim tensor of point(s) at which to evaluate the
138
255
test problem.
139
256
140
257
Returns:
@@ -157,7 +274,7 @@ def get_Y_true(self, arm: Arm) -> Tensor:
157
274
X = unnormalize (unit_X , self .test_problem .bounds )
158
275
159
276
Y_true = self .test_problem .evaluate_true (X ).view (- 1 )
160
- # `BaseTestProblem .evaluate_true()` does not negate the outcome
277
+ # `SyntheticTestFunction .evaluate_true()` does not negate the outcome
161
278
if self .test_problem .negate :
162
279
Y_true = - Y_true
163
280
@@ -171,43 +288,44 @@ def get_Y_true(self, arm: Arm) -> Tensor:
171
288
172
289
return Y_true
173
290
174
- def poll_trial_status (
175
- self , trials : Iterable [BaseTrial ]
176
- ) -> dict [TrialStatus , set [int ]]:
177
- return {TrialStatus .COMPLETED : {t .index for t in trials }}
178
291
179
- @classmethod
180
- # pyre-fixme [2]: Parameter `obj` must have a type other than `Any``
181
- def serialize_init_args (cls , obj : Any ) -> dict [str , Any ]:
182
- """Serialize the properties needed to initialize the runner.
183
- Used for storage.
184
- """
185
- runner = checked_cast (BotorchTestProblemRunner , obj )
292
+ class ParamBasedTestProblemRunner (SyntheticProblemRunner ):
293
+ """
294
+ A `SyntheticProblemRunner` for `ParamBasedTestProblem`s. See
295
+ `SyntheticProblemRunner` for more information.
296
+ """
186
297
187
- return {
188
- "test_problem_module" : runner ._test_problem_class .__module__ ,
189
- "test_problem_class_name" : runner ._test_problem_class .__name__ ,
190
- "test_problem_kwargs" : runner ._test_problem_kwargs ,
191
- "outcome_names" : runner ._outcome_names ,
192
- "modified_bounds" : runner ._modified_bounds ,
193
- }
298
+ # This could easily be supported, but hasn't been hooked up
299
+ _is_constrained : bool = False
194
300
195
- @classmethod
196
- def deserialize_init_args (
197
- cls ,
198
- args : dict [str , Any ],
199
- decoder_registry : Optional [TDecoderRegistry ] = None ,
200
- class_decoder_registry : Optional [TClassDecoderRegistry ] = None ,
201
- ) -> dict [str , Any ]:
202
- """Given a dictionary, deserialize the properties needed to initialize the
203
- runner. Used for storage.
204
- """
301
+ def __init__ (
302
+ self ,
303
+ * ,
304
+ test_problem_class : type [ParamBasedTestProblem ],
305
+ test_problem_kwargs : dict [str , Any ],
306
+ outcome_names : list [str ],
307
+ modified_bounds : Optional [list [tuple [float , float ]]] = None ,
308
+ ) -> None :
309
+ if modified_bounds is not None :
310
+ raise NotImplementedError (
311
+ f"modified_bounds is not supported for { test_problem_class .__name__ } "
312
+ )
313
+ super ().__init__ (
314
+ test_problem_class = test_problem_class ,
315
+ test_problem_kwargs = test_problem_kwargs ,
316
+ outcome_names = outcome_names ,
317
+ modified_bounds = modified_bounds ,
318
+ )
319
+ self .test_problem : ParamBasedTestProblem = self .test_problem
205
320
206
- module = importlib .import_module (args ["test_problem_module" ])
321
+ def get_Y_true (self , arm : Arm ) -> Tensor :
322
+ """Evaluates the test problem.
207
323
208
- return {
209
- "test_problem_class" : getattr (module , args ["test_problem_class_name" ]),
210
- "test_problem_kwargs" : args ["test_problem_kwargs" ],
211
- "outcome_names" : args ["outcome_names" ],
212
- "modified_bounds" : args ["modified_bounds" ],
213
- }
324
+ Returns:
325
+ A `batch_shape x m`-dim tensor of ground truth (noiseless) evaluations.
326
+ """
327
+ Y_true = self .test_problem .evaluate_true (arm .parameters ).view (- 1 )
328
+ # `ParamBasedTestProblem.evaluate_true()` does not negate the outcome
329
+ if self .test_problem .negate :
330
+ Y_true = - Y_true
331
+ return Y_true
0 commit comments