6
6
# pyre-strict
7
7
8
8
import importlib
9
- from collections .abc import Iterable
9
+ from abc import ABC , abstractmethod
10
+ from dataclasses import dataclass
10
11
from typing import Any , Optional , Union
11
12
12
13
import torch
13
14
from ax .benchmark .runners .base import BenchmarkRunner
14
15
from ax .core .arm import Arm
15
- from ax .core .base_trial import BaseTrial , TrialStatus
16
+ from ax .core .types import TParameterization
16
17
from ax .utils .common .base import Base
17
18
from ax .utils .common .equality import equality_typechecker
18
19
from ax .utils .common .serialization import TClassDecoderRegistry , TDecoderRegistry
19
- from ax .utils .common .typeutils import checked_cast
20
20
from botorch .test_functions .base import BaseTestProblem , ConstrainedBaseTestProblem
21
- from botorch .test_functions .multi_objective import MultiObjectiveTestProblem
22
21
from botorch .utils .transforms import normalize , unnormalize
22
+ from pyre_extensions import assert_is_instance
23
23
from torch import Tensor
24
24
25
25
26
- class BotorchTestProblemRunner (BenchmarkRunner ):
27
- """A Runner for evaluating Botorch BaseTestProblems.
26
+ @dataclass (kw_only = True )
27
+ class ParamBasedTestProblem (ABC ):
28
+ """
29
+ Similar to a BoTorch test problem, but evaluated using an Ax
30
+ TParameterization rather than a tensor.
31
+ """
32
+
33
+ num_objectives : int
34
+ optimal_value : float
35
+ # Constraints could easily be supported similar to BoTorch test problems,
36
+ # but haven't been hooked up.
37
+ _is_constrained : bool = False
38
+ constraint_noise_std : Optional [Union [float , list [float ]]] = None
39
+ noise_std : Optional [Union [float , list [float ]]] = None
40
+ negate : bool = False
41
+
42
+ @abstractmethod
43
+ def evaluate_true (self , params : TParameterization ) -> Tensor : ...
44
+
45
+ def evaluate_slack_true (self , params : TParameterization ) -> Tensor :
46
+ raise NotImplementedError (
47
+ f"{ self .__class__ .__name__ } does not support constraints."
48
+ )
49
+
50
+ # pyre-fixme: Missing parameter annotation [2]: Parameter `other` must have
51
+ # a type other than `Any`.
52
+ def __eq__ (self , other : Any ) -> bool :
53
+ if not isinstance (other , type (self )):
54
+ return False
55
+ return self .__class__ .__name__ == other .__class__ .__name__
56
+
28
57
29
- Given a trial the Runner will evaluate the BaseTestProblem.forward method for each
30
- arm in the trial, as well as return some metadata about the underlying Botorch
31
- problem such as the noise_std. We compute the full result on the Runner (as opposed
32
- to the Metric as is typical in synthetic test problems) because the BoTorch problem
33
- computes all metrics in one stacked tensor in the MOO case, and we wish to avoid
34
- recomputation per metric.
58
+ class SyntheticProblemRunner (BenchmarkRunner , ABC ):
59
+ """A Runner for evaluating synthetic problems, either BoTorch
60
+ `BaseTestProblem`s or Ax benchmarking `ParamBasedTestProblem`s.
61
+
62
+ Given a trial, the Runner will evaluate the problem noiselessly for each
63
+ arm in the trial, as well as return some metadata about the underlying
64
+ problem such as the noise_std.
35
65
"""
36
66
37
- test_problem : BaseTestProblem
67
+ test_problem : Union [ BaseTestProblem , ParamBasedTestProblem ]
38
68
_is_constrained : bool
39
- _test_problem_class : type [BaseTestProblem ]
69
+ _test_problem_class : type [Union [ BaseTestProblem , ParamBasedTestProblem ] ]
40
70
_test_problem_kwargs : Optional [dict [str , Any ]]
41
71
42
72
def __init__ (
43
73
self ,
44
- test_problem_class : type [BaseTestProblem ],
74
+ * ,
75
+ test_problem_class : type [Union [BaseTestProblem , ParamBasedTestProblem ]],
45
76
test_problem_kwargs : dict [str , Any ],
46
77
outcome_names : list [str ],
47
78
modified_bounds : Optional [list [tuple [float , float ]]] = None ,
48
79
) -> None :
49
80
"""Initialize the test problem runner.
50
81
51
82
Args:
52
- test_problem_class: The BoTorch test problem class.
83
+ test_problem_class: A BoTorch `BaseTestProblem` class or Ax
84
+ `ParamBasedTestProblem` class.
53
85
test_problem_kwargs: The keyword arguments used for initializing the
54
86
test problem.
55
87
outcome_names: The names of the outcomes returned by the problem.
@@ -63,28 +95,27 @@ def __init__(
63
95
If modified bounds are not provided, the test problem will be
64
96
evaluated using the raw parameter values.
65
97
"""
66
-
67
98
self ._test_problem_class = test_problem_class
68
99
self ._test_problem_kwargs = test_problem_kwargs
69
-
70
- # pyre-fixme [45] : Invalid class instantiation
71
- self . test_problem = test_problem_class ( ** test_problem_kwargs ). to (
72
- dtype = torch . double
100
+ self . test_problem = (
101
+ # pyre-fixme: Invalid class instantiation [45]: Cannot instantiate
102
+ # abstract class with abstract method `evaluate_true`.
103
+ test_problem_class ( ** test_problem_kwargs )
73
104
)
105
+ if isinstance (self .test_problem , BaseTestProblem ):
106
+ self .test_problem = self .test_problem .to (dtype = torch .double )
74
107
self ._is_constrained : bool = isinstance (
75
108
self .test_problem , ConstrainedBaseTestProblem
76
109
)
77
- self ._is_moo : bool = isinstance (self .test_problem , MultiObjectiveTestProblem )
78
- self ._outcome_names = outcome_names
110
+ self ._is_moo : bool = self .test_problem .num_objectives > 1
111
+ # A `ConstrainedBaseTestProblem` is a type of `BaseTestProblem`; a
112
+ # `ParamBasedTestProblem` is never constrained.
113
+ self .outcome_names = outcome_names
79
114
self ._modified_bounds = modified_bounds
80
115
81
- @property
82
- def outcome_names (self ) -> list [str ]:
83
- return self ._outcome_names
84
-
85
116
@equality_typechecker
86
117
def __eq__ (self , other : Base ) -> bool :
87
- if not isinstance (other , BotorchTestProblemRunner ):
118
+ if not isinstance (other , type ( self ) ):
88
119
return False
89
120
90
121
return (
@@ -129,12 +160,94 @@ def get_noise_stds(self) -> Union[None, float, dict[str, float]]:
129
160
130
161
return noise_std_dict
131
162
163
+ @classmethod
164
+ # pyre-fixme [2]: Parameter `obj` must have a type other than `Any``
165
+ def serialize_init_args (cls , obj : Any ) -> dict [str , Any ]:
166
+ """Serialize the properties needed to initialize the runner.
167
+ Used for storage.
168
+ """
169
+ runner = assert_is_instance (obj , cls )
170
+
171
+ return {
172
+ "test_problem_module" : runner ._test_problem_class .__module__ ,
173
+ "test_problem_class_name" : runner ._test_problem_class .__name__ ,
174
+ "test_problem_kwargs" : runner ._test_problem_kwargs ,
175
+ "outcome_names" : runner .outcome_names ,
176
+ "modified_bounds" : runner ._modified_bounds ,
177
+ }
178
+
179
+ @classmethod
180
+ def deserialize_init_args (
181
+ cls ,
182
+ args : dict [str , Any ],
183
+ decoder_registry : Optional [TDecoderRegistry ] = None ,
184
+ class_decoder_registry : Optional [TClassDecoderRegistry ] = None ,
185
+ ) -> dict [str , Any ]:
186
+ """Given a dictionary, deserialize the properties needed to initialize the
187
+ runner. Used for storage.
188
+ """
189
+
190
+ module = importlib .import_module (args ["test_problem_module" ])
191
+
192
+ return {
193
+ "test_problem_class" : getattr (module , args ["test_problem_class_name" ]),
194
+ "test_problem_kwargs" : args ["test_problem_kwargs" ],
195
+ "outcome_names" : args ["outcome_names" ],
196
+ "modified_bounds" : args ["modified_bounds" ],
197
+ }
198
+
199
+
200
+ class BotorchTestProblemRunner (SyntheticProblemRunner ):
201
+ """
202
+ A `SyntheticProblemRunner` for BoTorch `BaseTestProblem`s.
203
+
204
+ Args:
205
+ test_problem_class: A BoTorch `BaseTestProblem` class or Ax
206
+ `ParamBasedTestProblem` class.
207
+ test_problem_kwargs: The keyword arguments used for initializing the
208
+ test problem.
209
+ outcome_names: The names of the outcomes returned by the problem.
210
+ modified_bounds: The bounds that are used by the Ax search space
211
+ while optimizing the problem. If different from the bounds of the
212
+ test problem, we project the parameters into the test problem
213
+ bounds before evaluating the test problem.
214
+ For example, if the test problem is defined on [0, 1] but the Ax
215
+ search space is integers in [0, 10], an Ax parameter value of
216
+ 5 will correspond to 0.5 while evaluating the test problem.
217
+ If modified bounds are not provided, the test problem will be
218
+ evaluated using the raw parameter values.
219
+ """
220
+
221
+ def __init__ (
222
+ self ,
223
+ * ,
224
+ test_problem_class : type [BaseTestProblem ],
225
+ test_problem_kwargs : dict [str , Any ],
226
+ outcome_names : list [str ],
227
+ modified_bounds : Optional [list [tuple [float , float ]]] = None ,
228
+ ) -> None :
229
+ super ().__init__ (
230
+ test_problem_class = test_problem_class ,
231
+ test_problem_kwargs = test_problem_kwargs ,
232
+ outcome_names = outcome_names ,
233
+ modified_bounds = modified_bounds ,
234
+ )
235
+ self .test_problem : BaseTestProblem = self .test_problem .to (dtype = torch .double )
236
+ self ._is_constrained : bool = isinstance (
237
+ self .test_problem , ConstrainedBaseTestProblem
238
+ )
239
+
132
240
def get_Y_true (self , arm : Arm ) -> Tensor :
133
- """Converts X to original bounds -- only if modified bounds were provided --
134
- and evaluates the test problem. See `__init__` docstring for details.
241
+ """
242
+ Convert the arm to a tensor and evaluate it on the base test problem.
243
+
244
+ Convert the tensor to original bounds -- only if modified bounds were
245
+ provided -- and evaluates the test problem. See the docstring for
246
+ `modified_bounds` in `BotorchTestProblemRunner.__init__` for details.
135
247
136
248
Args:
137
- X: A `batch_shape x d`-dim tensor of point(s) at which to evaluate the
249
+ arm: Arm to evaluate. It will be converted to a
250
+ `batch_shape x d`-dim tensor of point(s) at which to evaluate the
138
251
test problem.
139
252
140
253
Returns:
@@ -171,43 +284,44 @@ def get_Y_true(self, arm: Arm) -> Tensor:
171
284
172
285
return Y_true
173
286
174
- def poll_trial_status (
175
- self , trials : Iterable [BaseTrial ]
176
- ) -> dict [TrialStatus , set [int ]]:
177
- return {TrialStatus .COMPLETED : {t .index for t in trials }}
178
287
179
- @classmethod
180
- # pyre-fixme [2]: Parameter `obj` must have a type other than `Any``
181
- def serialize_init_args (cls , obj : Any ) -> dict [str , Any ]:
182
- """Serialize the properties needed to initialize the runner.
183
- Used for storage.
184
- """
185
- runner = checked_cast (BotorchTestProblemRunner , obj )
288
+ class ParamBasedTestProblemRunner (SyntheticProblemRunner ):
289
+ """
290
+ A `SyntheticProblemRunner` for `ParamBasedTestProblem`s. See
291
+ `SyntheticProblemRunner` for more information.
292
+ """
186
293
187
- return {
188
- "test_problem_module" : runner ._test_problem_class .__module__ ,
189
- "test_problem_class_name" : runner ._test_problem_class .__name__ ,
190
- "test_problem_kwargs" : runner ._test_problem_kwargs ,
191
- "outcome_names" : runner ._outcome_names ,
192
- "modified_bounds" : runner ._modified_bounds ,
193
- }
294
+ # This could easily be supported, but hasn't been hooked up
295
+ _is_constrained : bool = False
194
296
195
- @classmethod
196
- def deserialize_init_args (
197
- cls ,
198
- args : dict [str , Any ],
199
- decoder_registry : Optional [TDecoderRegistry ] = None ,
200
- class_decoder_registry : Optional [TClassDecoderRegistry ] = None ,
201
- ) -> dict [str , Any ]:
202
- """Given a dictionary, deserialize the properties needed to initialize the
203
- runner. Used for storage.
204
- """
297
+ def __init__ (
298
+ self ,
299
+ * ,
300
+ test_problem_class : type [ParamBasedTestProblem ],
301
+ test_problem_kwargs : dict [str , Any ],
302
+ outcome_names : list [str ],
303
+ modified_bounds : Optional [list [tuple [float , float ]]] = None ,
304
+ ) -> None :
305
+ if modified_bounds is not None :
306
+ raise NotImplementedError (
307
+ f"modified_bounds is not supported for { test_problem_class .__name__ } "
308
+ )
309
+ super ().__init__ (
310
+ test_problem_class = test_problem_class ,
311
+ test_problem_kwargs = test_problem_kwargs ,
312
+ outcome_names = outcome_names ,
313
+ modified_bounds = modified_bounds ,
314
+ )
315
+ self .test_problem : ParamBasedTestProblem = self .test_problem
205
316
206
- module = importlib .import_module (args ["test_problem_module" ])
317
+ def get_Y_true (self , arm : Arm ) -> Tensor :
318
+ """Evaluates the test problem.
207
319
208
- return {
209
- "test_problem_class" : getattr (module , args ["test_problem_class_name" ]),
210
- "test_problem_kwargs" : args ["test_problem_kwargs" ],
211
- "outcome_names" : args ["outcome_names" ],
212
- "modified_bounds" : args ["modified_bounds" ],
213
- }
320
+ Returns:
321
+ A `batch_shape x m`-dim tensor of ground truth (noiseless) evaluations.
322
+ """
323
+ Y_true = self .test_problem .evaluate_true (arm .parameters ).view (- 1 )
324
+ # `BaseTestProblem.evaluate_true()` does not negate the outcome
325
+ if self .test_problem .negate :
326
+ Y_true = - Y_true
327
+ return Y_true
0 commit comments