Skip to content

Commit 06e6be9

Browse files
lebricebouthilx
andauthored
Knowledge Base + Multi-Task Warm-starting (#933)
* Import files from old PR, start fresh Signed-off-by: Fabrice Normandin <[email protected]> * Cleaning up algo_wrapper.py Signed-off-by: Fabrice Normandin <[email protected]> * Minor typing improvements in utils/__init__.py Signed-off-by: Fabrice Normandin <[email protected]> * add transform-related methods to AlgoWrapper Signed-off-by: Fabrice Normandin <[email protected]> * Clean up multi_task_algo.py Signed-off-by: Fabrice Normandin <[email protected]> * [big] move wrappers to folder, simplify wrappers Signed-off-by: Fabrice Normandin <[email protected]> * Add tests for the wrappers Signed-off-by: Fabrice Normandin <[email protected]> * Move algo wrappers to folder, split up tests a bit Signed-off-by: Fabrice Normandin <[email protected]> * Fix cyclical import issue with knowledbe base Signed-off-by: Fabrice Normandin <[email protected]> * Remove accidentally added .pre-commit-config.yaml Signed-off-by: Fabrice Normandin <[email protected]> * "algorithms.algorithm" -> "algorithms.unwrapped" Signed-off-by: Fabrice Normandin <[email protected]> * Have wrappers return the algo's config, not theirs Signed-off-by: Fabrice Normandin <[email protected]> * Remove duplicated code from SpaceTransform Signed-off-by: Fabrice Normandin <[email protected]> * Rename test files for algo wrappers Signed-off-by: Fabrice Normandin <[email protected]> * Fix bug in test_gridsearch.py Signed-off-by: Fabrice Normandin <[email protected]> * Remove more duplicated code from SpaceTransform Signed-off-by: Fabrice Normandin <[email protected]> * Husk of unit tests for multi-task wrapper Signed-off-by: Fabrice Normandin <[email protected]> * Rename multi_task_wrapper.py -> multi_task.py Signed-off-by: Fabrice Normandin <[email protected]> * Start to write body of the test Signed-off-by: Fabrice Normandin <[email protected]> * Remove other duplicate methods of SpaceTransform Signed-off-by: Fabrice Normandin <[email protected]> * Add dataclasses for various objects Signed-off-by: Fabrice Normandin <[email protected]> * Adapting previous changes from config_dataclasses Signed-off-by: Fabrice Normandin <[email protected]> * Adapting previous changes from config_dataclasses Signed-off-by: Fabrice Normandin <[email protected]> * Fix / reformat docstrings Signed-off-by: Fabrice Normandin <[email protected]> * Fix typo in log.debug call Co-authored-by: Xavier Bouthillier <[email protected]> * Fix outdated docstrings Signed-off-by: Fabrice Normandin <[email protected]> * Revert changes to ugly part of testing/__init__.py Signed-off-by: Fabrice Normandin <[email protected]> * Fix isort issues Signed-off-by: Fabrice Normandin <[email protected]> * Fix isort / flake8 issues Signed-off-by: Fabrice Normandin <[email protected]> * Add an intermediate TransformWrapper ABC Signed-off-by: Fabrice Normandin <[email protected]> * Add optional `max_trials` property to algorithm Signed-off-by: Fabrice Normandin <[email protected]> * Remove outdated/misleading comments Signed-off-by: Fabrice Normandin <[email protected]> * Fix various pylint/flake8 errors Signed-off-by: Fabrice Normandin <[email protected]> * Touchups on the types of primary_algo.py Signed-off-by: Fabrice Normandin <[email protected]> * Slight refactoring in serializable.py Signed-off-by: Fabrice Normandin <[email protected]> * Rename AlgoType TypeVar to AlgoT Signed-off-by: Fabrice Normandin <[email protected]> * Start to add tests for ExperimentInfo dataclass Signed-off-by: Fabrice Normandin <[email protected]> * Allow passing Algo type to create_experiment Signed-off-by: Fabrice Normandin <[email protected]> * Minor tweak in AlgoWrapper.__repr__ Signed-off-by: Fabrice Normandin <[email protected]> * Adding (failing) tests for warm-starting. Signed-off-by: Fabrice Normandin <[email protected]> * Add Knowledge Base arg to Experiment and Producer Signed-off-by: Fabrice Normandin <[email protected]> * Add KnowledgeBase as argument to workon functions Signed-off-by: Fabrice Normandin <[email protected]> * Move warm-start method to right wrapper class Signed-off-by: Fabrice Normandin <[email protected]> * Add a simple test for warm starting Signed-off-by: Fabrice Normandin <[email protected]> * Improve __repr__ of Registry Signed-off-by: Fabrice Normandin <[email protected]> * Type the `space` property of ExperimentClient Signed-off-by: Fabrice Normandin <[email protected]> * Remove line that caused bug in experiment_builder Signed-off-by: Fabrice Normandin <[email protected]> * Add tests for KB + Warm-Starting Signed-off-by: Fabrice Normandin <[email protected]> * (big, ugly commit) Add tests, clarify, refactor Signed-off-by: Fabrice Normandin <[email protected]> * Add repr for RegistryMapping Signed-off-by: Fabrice Normandin <[email protected]> * Type the _results and _params attributes of Trial Signed-off-by: Fabrice Normandin <[email protected]> * Move and Rename algo for unit tests Signed-off-by: Fabrice Normandin <[email protected]> * Add test for Multi-Task wrapper collisions Signed-off-by: Fabrice Normandin <[email protected]> * Fix bug in AlgoWrapper (see desc.) Subclasses of AlgoWrapper (in particular the InsistSuggest wrapper) did not register the trials into their registry when suggesting or observing. This caused an issue where the registry mapping wasn't working properly in the MultiTaskWrapper. - Made the `suggest` and `observe` of AlgoWrapper use `self.register` so that the registry mapping and collision detection now works properly. Signed-off-by: Fabrice Normandin <[email protected]> * Add test for suggest to always give task_id=0 Signed-off-by: Fabrice Normandin <[email protected]> * Fix small bug in gridsearch.py Signed-off-by: Fabrice Normandin <[email protected]> * Simplification: kb is only attr of Experiment Signed-off-by: Fabrice Normandin <[email protected]> * Moved "functional" tests to different file Signed-off-by: Fabrice Normandin <[email protected]> * Fix pylint errors Signed-off-by: Fabrice Normandin <[email protected]> * Add unwrap convenience method on AlgoWrapper Signed-off-by: Fabrice Normandin <[email protected]> * Misc changes (test cleanup, copy status) Signed-off-by: Fabrice Normandin <[email protected]> * Remove unused warm_start_mode context manager Signed-off-by: Fabrice Normandin <[email protected]> * Use the new max_trials property on algo Signed-off-by: Fabrice Normandin <[email protected]> * Add tests for setting max_trials and n_observed Signed-off-by: Fabrice Normandin <[email protected]> * Fix renaming of SpaceTransform algo wrapper Signed-off-by: Fabrice Normandin <[email protected]> * Don't register trials from other tasks in algo Signed-off-by: Fabrice Normandin <[email protected]> * Fix randomness of flaky-ish test Signed-off-by: Fabrice Normandin <[email protected]> * Remove dict[k, v] type annotation for python < 3.9 Signed-off-by: Fabrice Normandin <[email protected]> * Move / Simplify the Config classes -> TypedDicts Signed-off-by: Fabrice Normandin <[email protected]> * Type out the Random algo, fix a test Signed-off-by: Fabrice Normandin <[email protected]> * Fixing some broken tests Signed-off-by: Fabrice Normandin <[email protected]> * Fix experiment_builder tests Signed-off-by: Fabrice Normandin <[email protected]> * Fix broken test in test_experiment.py Signed-off-by: Fabrice Normandin <[email protected]> * Add test for ExperimentConfig fields Signed-off-by: Fabrice Normandin <[email protected]> * Minor touchups in experiment_config.py Signed-off-by: Fabrice Normandin <[email protected]> * Removed unused typeddicts in experiment_config.py Signed-off-by: Fabrice Normandin <[email protected]> * Remove outdated todo Signed-off-by: Fabrice Normandin <[email protected]> * [breaking] Knowledge Base implementation Signed-off-by: Fabrice Normandin <[email protected]> * Remove unused knowledge_base argument to workon Signed-off-by: Fabrice Normandin <[email protected]> * Fix instantiation of KB in exp builder Signed-off-by: Fabrice Normandin <[email protected]> * [nit] Fix typing of ExperimentStats fields Signed-off-by: Fabrice Normandin <[email protected]> * [optional] Type out storage/base.py and misc types Signed-off-by: Fabrice Normandin <[email protected]> * Adapting the MultiTask wrapper tests, add stubs Signed-off-by: Fabrice Normandin <[email protected]> * Move test_experiment_config to reflect src Signed-off-by: Fabrice Normandin <[email protected]> * Add test for KB Signed-off-by: Fabrice Normandin <[email protected]> * Add more tests for the KnowledgeBase Signed-off-by: Fabrice Normandin <[email protected]> * Fix incorrect type for experiment id in docstrings Signed-off-by: Fabrice Normandin <[email protected]> * Fix error in fixture, adjust docstrings Signed-off-by: Fabrice Normandin <[email protected]> * Pass experiment config instead of experiment obj Signed-off-by: Fabrice Normandin <[email protected]> * Fix docstrings, use Unpack[] to type **kwargs Signed-off-by: Fabrice Normandin <[email protected]> * Move and update functional tests Signed-off-by: Fabrice Normandin <[email protected]> * Remove unused code block in test Signed-off-by: Fabrice Normandin <[email protected]> * Add note about the Unpack[] annotation Signed-off-by: Fabrice Normandin <[email protected]> * Add PartialExperimentConfig typeddict Signed-off-by: Fabrice Normandin <[email protected]> * Trying to make functional tests pass... Signed-off-by: Fabrice Normandin <[email protected]> * Fix pylint error in experiment.py Signed-off-by: Fabrice Normandin <[email protected]> * Fix bugs in test_knowledge_base Signed-off-by: Fabrice Normandin <[email protected]> * Pass kb to instantiate_algo, minor typing stuffs Signed-off-by: Fabrice Normandin <[email protected]> * Clarify potential issue in register, touchups Signed-off-by: Fabrice Normandin <[email protected]> * Minor typing touchups Signed-off-by: Fabrice Normandin <[email protected]> * Simplify functional tests for warm_starting Signed-off-by: Fabrice Normandin <[email protected]> * Add tests for how to pass the algorithm Signed-off-by: Fabrice Normandin <[email protected]> * Minor typing improvements to experiment_builder.py Signed-off-by: Fabrice Normandin <[email protected]> * Fix assignment of max_trials in exp client Signed-off-by: Fabrice Normandin <[email protected]> * Use KnowledgeBase in functional test, remove todos Signed-off-by: Fabrice Normandin <[email protected]> * Fix pylint error Signed-off-by: Fabrice Normandin <[email protected]> * Fix import error for TypedDict Signed-off-by: Fabrice Normandin <[email protected]> * Re-introduce fix from #964 (?) Signed-off-by: Fabrice Normandin <[email protected]> * Fix bug in algo creation logic Signed-off-by: Fabrice Normandin <[email protected]> * Add missing types in BaseAlgorithm Signed-off-by: Fabrice Normandin <[email protected]> * Add some of the missing types in exp builder Signed-off-by: Fabrice Normandin <[email protected]> * Remove leftover todo in BaseAlgorithm Signed-off-by: Fabrice Normandin <[email protected]> * Fix type annotation on create_experiment Signed-off-by: Fabrice Normandin <[email protected]> * Remove extra type annotation on create_experiment Signed-off-by: Fabrice Normandin <[email protected]> * Fix bug in test_tpe (added wrapper) Signed-off-by: Fabrice Normandin <[email protected]> * Fix value in DumbAlgo Signed-off-by: Fabrice Normandin <[email protected]> * Fix test condition Signed-off-by: Fabrice Normandin <[email protected]> * Remove fixme comment Signed-off-by: Fabrice Normandin <[email protected]> * Fix tests Signed-off-by: Fabrice Normandin <[email protected]> * Fix most PBT tests Signed-off-by: Fabrice Normandin <[email protected]> * Misc changes Signed-off-by: Fabrice Normandin <[email protected]> * Misc typing changes Signed-off-by: Fabrice Normandin <[email protected]> * Debugging the PBT Errors Signed-off-by: Fabrice Normandin <[email protected]> * Fix broken AlgoWrapper tests Signed-off-by: Fabrice Normandin <[email protected]> * Add missing register method on AlgoWrapper Signed-off-by: Fabrice Normandin <[email protected]> * Remove unused `original_space` property Signed-off-by: Fabrice Normandin <[email protected]> * Fix docstring and warning in BaseAlgorithm.get_id Signed-off-by: Fabrice Normandin <[email protected]> * Fix bug introduced previously Signed-off-by: Fabrice Normandin <[email protected]> * Minor improvements to TransformWrapper.suggest Signed-off-by: Fabrice Normandin <[email protected]> * Use self.reverse_transform in get_original_parent Signed-off-by: Fabrice Normandin <[email protected]> * Ugly temporary fix to bug that affected PBT (desc) There is a weird bug that was happening in PBT: - The trials that PBT suggest have a parent id. - When the SpaceTransform wrapper calls observe, the completed trials don't have the same parent as they did when they were suggested. Signed-off-by: Fabrice Normandin <[email protected]> * Remove hacky fix, fix bug source (hopefully) Signed-off-by: Fabrice Normandin <[email protected]> * Make _get_original_parent "static" again Signed-off-by: Fabrice Normandin <[email protected]> * Add todo for later (copying attributes explicitly) Signed-off-by: Fabrice Normandin <[email protected]> * Fix minor bug in DumbAlgo-related test Signed-off-by: Fabrice Normandin <[email protected]> * Greatly reduce number of warnings Signed-off-by: Fabrice Normandin <[email protected]> * Remove Registry from InsistSuggestWrapper Signed-off-by: Fabrice Normandin <[email protected]> * Add tests to increase coverage of Registry class Signed-off-by: Fabrice Normandin <[email protected]> * Add more coverage for _instantiate_knowledge_base Signed-off-by: Fabrice Normandin <[email protected]> * Add a bit more coverage for AlgoWrapper Signed-off-by: Fabrice Normandin <[email protected]> * Fix bug in _instantiate_knowledge_base Signed-off-by: Fabrice Normandin <[email protected]> * Add tests for _instantiate_algo Signed-off-by: Fabrice Normandin <[email protected]> * Add generic test class for AlgoWrappers Signed-off-by: Fabrice Normandin <[email protected]> * Remove redundant fixture in AlgoWrapper tests Signed-off-by: Fabrice Normandin <[email protected]> * Clean / minor changes to test_knowledge_base.py Signed-off-by: Fabrice Normandin <[email protected]> * Use asserts to avoid writing useless tests Signed-off-by: Fabrice Normandin <[email protected]> * Add test case for no compatible trials found Signed-off-by: Fabrice Normandin <[email protected]> * Clean up testing utility function a bit Signed-off-by: Fabrice Normandin <[email protected]> * Add test for not warm-starting twice Signed-off-by: Fabrice Normandin <[email protected]> * Add test for "space already has task_id" error Signed-off-by: Fabrice Normandin <[email protected]> * Add tests for is_warmstarteable function Signed-off-by: Fabrice Normandin <[email protected]> * Standardize the imports of ExperimentConfig Signed-off-by: Fabrice Normandin <[email protected]> * Add more tests for _instantiate_kb Signed-off-by: Fabrice Normandin <[email protected]> * Remove redundante else clause in _instantiate_algo Signed-off-by: Fabrice Normandin <[email protected]> Signed-off-by: Fabrice Normandin <[email protected]> Signed-off-by: Fabrice Normandin <[email protected]> Co-authored-by: Fabrice Normandin <[email protected]> Co-authored-by: Xavier Bouthillier <[email protected]>
1 parent dcb422a commit 06e6be9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+3394
-833
lines changed

.pre-commit-config.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
---
22
repos:
3+
- repo: https://github.com/pre-commit/pre-commit-hooks
4+
rev: v4.3.0 # Use the ref you want to point at
5+
hooks:
6+
- id: check-merge-conflict
37
- repo: https://github.com/python/black
48
rev: 22.6.0
59
hooks:

src/orion/algo/base.py

+32-29
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
import inspect
1919
import logging
2020
from abc import abstractmethod
21+
from typing import Any
2122

2223
from orion.algo.registry import Registry
24+
from orion.algo.space import Space
2325
from orion.core.utils import GenericFactory
2426
from orion.core.worker.trial import Trial
2527

@@ -101,7 +103,9 @@ def observe(self, points, results):
101103
requires_shape = None
102104
requires_dist = None
103105

104-
def __init__(self, space, **kwargs):
106+
max_trials: int | None = None
107+
108+
def __init__(self, space: Space, **kwargs):
105109
log.debug(
106110
"Creating Algorithm object of %s type with parameters:\n%s",
107111
type(self).__name__,
@@ -142,18 +146,19 @@ def state_dict(self):
142146
"""Return a state dict that can be used to reset the state of the algorithm."""
143147
return {"registry": self.registry.state_dict}
144148

145-
def set_state(self, state_dict):
149+
def set_state(self, state_dict: dict):
146150
"""Reset the state of the algorithm based on the given state_dict
147151
148152
:param state_dict: Dictionary representing state of an algorithm
149153
"""
150154
self.registry.set_state(state_dict["registry"])
151155

152-
def get_id(self, trial, ignore_fidelity=False, ignore_parent=False):
156+
def get_id(
157+
self, trial: Trial, ignore_fidelity: bool = False, ignore_parent: bool = False
158+
) -> str:
153159
"""Return unique hash for a trials based on params
154160
155-
The trial is assumed to be in the transformed space if the algorithm is working in a
156-
transformed space.
161+
The trial is assumed to be in the optimization space of the algorithm.
157162
158163
Parameters
159164
----------
@@ -170,13 +175,12 @@ def get_id(self, trial, ignore_fidelity=False, ignore_parent=False):
170175
return trial.compute_trial_hash(
171176
trial,
172177
ignore_fidelity=ignore_fidelity,
173-
ignore_experiment=True,
174178
ignore_lie=True,
175179
ignore_parent=ignore_parent,
176180
)
177181

178182
@property
179-
def fidelity_index(self):
183+
def fidelity_index(self) -> str | None:
180184
"""Returns the name of the first fidelity dimension if there is one, otherwise `None`."""
181185
fidelity_dims = [dim for dim in self.space.values() if dim.type == "fidelity"]
182186
if fidelity_dims:
@@ -209,7 +213,7 @@ def suggest(self, num: int) -> list[Trial]:
209213
has suggested/observed, and for the auto-generated unit-tests to pass.
210214
"""
211215

212-
def observe(self, trials):
216+
def observe(self, trials: list[Trial]) -> None:
213217
"""Observe the `results` of the evaluation of the `trials` in the
214218
process defined in user's script.
215219
@@ -223,7 +227,7 @@ def observe(self, trials):
223227
if not self.has_observed(trial):
224228
self.register(trial)
225229

226-
def register(self, trial):
230+
def register(self, trial: Trial) -> None:
227231
"""Save the trial as one suggested or observed by the algorithm.
228232
229233
Parameters
@@ -234,16 +238,16 @@ def register(self, trial):
234238
self.registry.register(trial)
235239

236240
@property
237-
def n_suggested(self):
241+
def n_suggested(self) -> int:
238242
"""Number of trials suggested by the algorithm"""
239243
return len(self.registry)
240244

241245
@property
242-
def n_observed(self):
246+
def n_observed(self) -> int:
243247
"""Number of completed trials observed by the algorithm."""
244248
return sum(self.has_observed(trial) for trial in self.registry)
245249

246-
def has_suggested(self, trial):
250+
def has_suggested(self, trial: Trial) -> bool:
247251
"""Whether the algorithm has suggested a given point.
248252
249253
Parameters
@@ -259,7 +263,7 @@ def has_suggested(self, trial):
259263
"""
260264
return self.registry.has_suggested(trial)
261265

262-
def has_observed(self, trial):
266+
def has_observed(self, trial: Trial) -> bool:
263267
"""Whether the algorithm has observed a given point objective.
264268
265269
This only counts observed completed trials.
@@ -312,15 +316,11 @@ def has_completed_max_trials(self) -> bool:
312316
"""Returns True if the algorithm has a `max_trials` attribute, and has completed more trials
313317
than its value.
314318
"""
315-
if not hasattr(self, "max_trials"):
316-
return False
317-
max_trials = getattr(self, "max_trials")
318-
if max_trials is None:
319+
if self.max_trials is None:
319320
return False
320321

321322
fidelity_index = self.fidelity_index
322323
max_fidelity_value = None
323-
324324
# When a fidelity dimension is present, we only count trials that have the maximum value.
325325
if fidelity_index is not None:
326326
_, max_fidelity_value = self.space[fidelity_index].interval()
@@ -333,10 +333,11 @@ def _is_completed(trial: Trial) -> bool:
333333
and trial.params[fidelity_index] >= max_fidelity_value
334334
)
335335

336-
return sum(map(_is_completed, self.registry)) >= max_trials
336+
return sum(map(_is_completed, self.registry)) >= self.max_trials
337337

338-
def score(self, trial): # pylint:disable=no-self-use,unused-argument
339-
"""Allow algorithm to evaluate `point` based on a prediction about
338+
# pylint:disable=no-self-use,unused-argument
339+
def score(self, trial: Trial) -> float:
340+
"""Allow algorithm to evaluate `trial` based on a prediction about
340341
this parameter set's performance.
341342
342343
By default, return the same score any parameter (no preference).
@@ -353,7 +354,8 @@ def score(self, trial): # pylint:disable=no-self-use,unused-argument
353354
"""
354355
return 0
355356

356-
def judge(self, trial, measurements): # pylint:disable=no-self-use,unused-argument
357+
# pylint:disable=no-self-use,unused-argument
358+
def judge(self, trial: Trial, measurements: Any) -> dict | None:
357359
"""Inform an algorithm about online `measurements` of a running trial.
358360
359361
This method is to be used as a callback in a client-server communication
@@ -381,7 +383,7 @@ def judge(self, trial, measurements): # pylint:disable=no-self-use,unused-argum
381383
"""
382384
return None
383385

384-
def should_suspend(self, trial):
386+
def should_suspend(self, trial: Trial) -> bool:
385387
"""Allow algorithm to decide whether a particular running trial is still
386388
worth to complete its evaluation, based on information provided by the
387389
`judge` method.
@@ -390,10 +392,12 @@ def should_suspend(self, trial):
390392
return False
391393

392394
@property
393-
def configuration(self):
395+
def configuration(self) -> dict[str, Any]:
394396
"""Return tunable elements of this algorithm in a dictionary form
395397
appropriate for saving.
396398
399+
By default, returns a dictionary containing the attributes of `self` which are also
400+
constructor arguments.
397401
"""
398402
dict_form = dict()
399403
for attrname in self._param_names:
@@ -404,14 +408,13 @@ def configuration(self):
404408
return {self.__class__.__name__.lower(): dict_form}
405409

406410
@property
407-
def space(self):
411+
def space(self) -> Space:
408412
"""Domain of problem associated with this algorithm's instance."""
409413
return self._space
410414

411-
@space.setter
412-
def space(self, space):
413-
"""Set space."""
414-
self._space = space
415+
@property
416+
def unwrapped(self):
417+
return self
415418

416419

417420
algo_factory = GenericFactory(BaseAlgorithm)

src/orion/algo/gridsearch.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,11 @@ def __init__(
133133
if not isinstance(n_values, dict)
134134
else n_values
135135
)
136+
max_trials = 10_000 if self.max_trials is None else self.max_trials
136137
self.grid = self.build_grid(
137-
self.space, n_values_dict, getattr(self, "max_trials", 10000)
138+
self.space,
139+
n_values_dict,
140+
max_trials=max_trials,
138141
)
139142
self.index = 0
140143

src/orion/algo/pbt/pb2.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from orion.algo.pbt.pbt import PBT
1818
from orion.core.utils.flatten import flatten
1919
from orion.core.utils.random_state import RandomState, control_randomness
20+
from orion.core.worker.transformer import ReshapedSpace, TransformedSpace
2021
from orion.core.worker.trial import Trial
2122

2223
logger = logging.getLogger(__name__)
@@ -172,6 +173,7 @@ def _generate_offspring(self, trial):
172173
]
173174

174175
new_trial = trial_to_branch.branch(params=new_params)
176+
assert isinstance(self.space, (TransformedSpace, ReshapedSpace))
175177
new_trial = self.space.transform(self.space.reverse(new_trial))
176178

177179
logger.debug("Attempt %s - Creating new trial %s", attempts, new_trial)

src/orion/algo/pbt/pbt.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def __init__(
162162
explore: dict | None = None,
163163
fork_timeout: int = 60,
164164
):
165+
super().__init__(space)
165166
if exploit is None:
166167
exploit = {
167168
"of_type": "PipelineExploit",
@@ -217,7 +218,6 @@ def __init__(
217218

218219
self.lineages = Lineages()
219220

220-
super().__init__(space)
221221
self.seed = seed
222222
self.population_size = population_size
223223
self.generations = generations
@@ -318,14 +318,15 @@ def suggest(self, num: int) -> list[Trial]:
318318
A list of trials representing values suggested by the algorithm.
319319
320320
"""
321-
322-
# Sample points until num is met, or population_size
323-
num_random_samples = min(max(self.population_size - self._num_root, 0), num)
321+
assert num > 0
324322
logger.debug(
325323
"PBT has %s pending or completed trials at root, %s broken trials.",
326324
self._num_root,
327325
len(self.lineages) - self._num_root,
328326
)
327+
328+
# Sample points until num is met, or population_size
329+
num_random_samples = min(max(self.population_size - self._num_root, 0), num)
329330
logger.debug("Sampling %s new trials", num_random_samples)
330331
trials = self._sample(num_random_samples)
331332
logger.debug("Sampled %s new trials", len(trials))

src/orion/algo/random.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55
Draw and deliver samples from prior defined in problem's domain.
66
77
"""
8+
from __future__ import annotations
9+
10+
from typing import Sequence
11+
812
import numpy
913

1014
from orion.algo.base import BaseAlgorithm
15+
from orion.algo.space import Space
1116

1217

1318
class Random(BaseAlgorithm):
@@ -23,10 +28,12 @@ class Random(BaseAlgorithm):
2328
2429
"""
2530

26-
def __init__(self, space, seed=None):
27-
super().__init__(space, seed=seed)
31+
def __init__(self, space: Space, seed: int | Sequence[int] | None = None):
32+
super().__init__(space)
33+
self.seed = seed
34+
self.seed_rng(seed)
2835

29-
def seed_rng(self, seed):
36+
def seed_rng(self, seed: int | Sequence[int] | None):
3037
"""Seed the state of the random number generator.
3138
3239
:param seed: Integer seed for the random number generator.

src/orion/algo/registry.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,27 @@
44
import copy
55
from collections import defaultdict
66
from logging import getLogger as get_logger
7-
from typing import Any, Container, Iterator, Mapping
7+
from typing import Any, Container, Iterable, Iterator, Mapping
88

99
from orion.core.worker.trial import Trial, TrialCM
1010

1111
logger = get_logger(__name__)
1212

1313

1414
class Registry(Container[Trial]):
15-
"""In-memory container for the trials that the algorithm suggests/observes/etc."""
15+
"""In-memory container for the trials that the algorithm suggests/observes/etc.
1616
17-
def __init__(self):
17+
This behaves a bit like a managed dictionary, but the "keys" are trials ids, which
18+
(at the time of writing) can vary depending on how we chose to compute them.
19+
"""
20+
21+
def __init__(self, trials: Iterable[Trial] = ()):
1822
self._trials: dict[str, Trial] = {}
23+
for trial in trials:
24+
self.register(trial)
25+
26+
def __repr__(self) -> str:
27+
return f"{type(self).__qualname__}({list(iter(self))})"
1928

2029
def __contains__(self, trial_or_id: str | Trial | Any) -> bool:
2130
if isinstance(trial_or_id, TrialCM):
@@ -94,7 +103,7 @@ def get_existing(self, trial: Trial) -> Trial:
94103
class RegistryMapping(Mapping[Trial, "list[Trial]"]):
95104
"""A map between the original and transformed registries.
96105
97-
This object is used in the `SpaceTransformAlgoWrapper` to check if a trial in the original space
106+
This object is used in the `SpaceTransform` to check if a trial in the original space
98107
has equivalent trials in the transformed space.
99108
100109
The goal is to make it so the algorithms don't have to care about the transforms/etc.
@@ -123,10 +132,15 @@ def set_state(self, statedict: dict):
123132
self._mapping = copy.deepcopy(statedict["_mapping"])
124133

125134
def __iter__(self) -> Iterator[Trial]:
135+
"""Iterate over the trials in the original registry."""
126136
for trial_id in self._mapping:
127137
yield self.original_registry[trial_id]
128138

129139
def __len__(self) -> int:
140+
"""Give the number of trials in the mapping.
141+
142+
This should be the same as the number of trials in the original registry.
143+
"""
130144
return len(self._mapping)
131145

132146
def __contains__(self, trial: Trial):
@@ -160,6 +174,9 @@ def register(self, original_trial: Trial, transformed_trial: Trial) -> str:
160174
self._mapping[original_trial_id].add(transformed_trial_id)
161175
return original_trial_id
162176

177+
def __repr__(self) -> str:
178+
return f"{type(self).__qualname__}({list((trial, self.get_trials(trial)) for trial in self)})"
179+
163180

164181
def _get_id(trial: Trial) -> str:
165182
"""Returns the unique identifier to be used to store the trial.

0 commit comments

Comments
 (0)