Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PBT implementation #705

Merged
merged 28 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
bbdba37
Set and create exp working dir inside workon
bouthilx Nov 24, 2021
cddda16
Infer trial working dir based on exp.working_dir
bouthilx Nov 24, 2021
842cba8
Add parent attribute to Trial
bouthilx Nov 24, 2021
5ebcf0a
Move tree.py to utils
bouthilx Nov 29, 2021
480f653
Add Tree.node_depth and Tree.get_nodes_at_depth
bouthilx Nov 30, 2021
5c142b0
Add Lineage for PBT and tests
bouthilx Dec 1, 2021
e12cd52
Adding tests for Lineages WIP
bouthilx Dec 1, 2021
357a762
Add TreeNode.leafs
bouthilx Dec 1, 2021
95c1126
Modularize PBT
bouthilx Dec 7, 2021
d0c2870
Add tests for Explore module
bouthilx Dec 7, 2021
12b34ff
Add tests for PBT fidelity budgets
bouthilx Dec 7, 2021
84d2de7
Rename PopulationBasedTraining to PBT
bouthilx Dec 7, 2021
72b2946
Remove old base PBT modules
bouthilx Dec 7, 2021
fadc818
Add logging and some fixes for exploit/explore
bouthilx Dec 8, 2021
6ae63f4
Add documentation for PBT
bouthilx Dec 15, 2021
2297afa
Add generic tests for PBT
bouthilx Dec 15, 2021
61e927b
isort
bouthilx Dec 15, 2021
2ba4946
Handle Trial.parents for previous versions of Oríon
bouthilx Dec 15, 2021
013e8fe
Add PBT rst doc file
bouthilx Dec 15, 2021
b21281b
Add backward.ensure_trial_working_dir
bouthilx Dec 15, 2021
31fae92
Add Trial.__eq__
bouthilx Dec 15, 2021
68eab2b
Fix exploit & explore arg docs
bouthilx Jan 5, 2022
6eed9dc
Add missing SPACE_ERROR
bouthilx Jan 5, 2022
02a2150
Update src/orion/algo/pbt/pbt.py
bouthilx Jan 5, 2022
e65cb4e
Clarify PBT model weights saving in doc
bouthilx Jan 11, 2022
95eef3a
Rename Lineage to LineageNode
bouthilx Jan 18, 2022
4917a50
Adapt Lineage -> LineageNode in docs
bouthilx Jan 19, 2022
4b2a088
Clarify PBT doc on trial.working_dir
bouthilx Jan 26, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/src/code/algo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
Algorithm modules
*****************

TODO

.. automodule:: orion.algo
:members:

Expand All @@ -17,5 +15,6 @@ TODO
algo/gridsearch
algo/hyperband
algo/asha
algo/pbt
algo/tpe
algo/parallel_strategy
7 changes: 2 additions & 5 deletions docs/src/code/algo/asha.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
Asynchronous Successive Halving Algorithm
=========================================

Can't build documentation because of import order.
Sphinx is loading ``orion.algo.asha`` before ``orion.algo`` and therefore
there is a cycle between the definition of ``BaseAlgorithm`` and
``ASHA`` as the meta-class ``Factory`` is trying to import ``ASHA``.
`PR #135 <https://github.com/Epistimio/orion/pull/135/files>`_ should get rid of this problem.
.. automodule:: orion.algo.asha
:members:
2 changes: 0 additions & 2 deletions docs/src/code/algo/base.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,3 @@ Base definition of algorithms

.. autoclass:: orion.algo.base.BaseAlgorithm
:members:


89 changes: 89 additions & 0 deletions docs/src/code/algo/pbt.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
Population Based Training
=========================

.. contents::
:depth: 3
:local:

.. role:: hidden
:class: hidden-section

Population Based Training
-------------------------

.. autoclass:: orion.algo.pbt.pbt.PBT
:members:

LineageNode
-----------

.. autoclass:: orion.algo.pbt.pbt.LineageNode
:members:

Lineages
--------

.. autoclass:: orion.algo.pbt.pbt.Lineages
:members:

Exploit classes for Population Based Training
---------------------------------------------

BaseExploit
~~~~~~~~~~~

.. autoclass:: orion.algo.pbt.exploit.BaseExploit
:members:


PipelineExploit
~~~~~~~~~~~~~~~

.. autoclass:: orion.algo.pbt.exploit.PipelineExploit
:members:


TruncateExploit
~~~~~~~~~~~~~~~

.. autoclass:: orion.algo.pbt.exploit.TruncateExploit
:members:

BacktrackExploit
~~~~~~~~~~~~~~~~

.. autoclass:: orion.algo.pbt.exploit.BacktrackExploit
:members:

Explore classes for Population Based Training
---------------------------------------------

BaseExplore
~~~~~~~~~~~

.. autoclass:: orion.algo.pbt.explore.BaseExplore
:members:


PipelineExplore
~~~~~~~~~~~~~~~

.. autoclass:: orion.algo.pbt.explore.PipelineExplore
:members:


PerturbExplore
~~~~~~~~~~~~~~

.. autoclass:: orion.algo.pbt.explore.PerturbExplore
:members:

ResampleExplore
~~~~~~~~~~~~~~~

.. autoclass:: orion.algo.pbt.explore.ResampleExplore
:members:




1 change: 0 additions & 1 deletion docs/src/code/core/evc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ Experiment Version Control
:maxdepth: 1
:caption: Modules

evc/tree
evc/experiment
evc/adapters
evc/conflicts
1 change: 1 addition & 0 deletions docs/src/code/core/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Utilities
utils/format_trials
utils/format_terminal
utils/singleton
utils/tree

.. automodule:: orion.core.utils
:members:
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Generic Tree
============

.. automodule:: orion.core.evc.tree
.. automodule:: orion.core.utils.tree
:members:
63 changes: 63 additions & 0 deletions docs/src/user/algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,69 @@ Configuration
executed_times, compute_bracket_idx


.. _PBT:

Population Based Training (PBT)
-------------------------------

Population based training is an evolutionary algorithm that evolve trials
from low fidelity levels to high fidelity levels (ex: number of epochs), reusing
the model's parameters along the way. This has the effect of creating hyperparameter
schedules through the fidelity levels.

See documentation below for more information on the algorithm and how to use it.

.. note::

Current implementation does not support more than one fidelity dimension.

Configuration
~~~~~~~~~~~~~

.. code-block:: yaml

experiment:

strategy: StubParallelStrategy

algorithms:
pbt:
population_size: 50
generations: 10
fork_timeout: 60
exploit:
of_type: PipelineExploit
exploit_configs:
- of_type: BacktrackExploit
min_forking_population: 5
truncation_quantile: 0.9
candidate_pool_ratio: 0.2
- of_type: TruncateExploit
min_forking_population: 5
truncation_quantile: 0.8
candidate_pool_ratio: 0.2
explore:
of_type: PipelineExplore
explore_configs:
- of_type: ResampleExplore
probability: 0.2
- of_type: PerturbExplore
factor: 1.2
volatility: 0.0001



.. note::
Notice the additional ``strategy`` in configuration which is not mandatory for most other
algorithms. See :ref:`StubParallelStrategy` for more information.


.. autoclass:: orion.algo.pbt.pbt.PBT
:noindex:
:exclude-members: space, state_dict, set_state, suggest, observe, is_done, seed_rng,
configuration, requires_type, rng, register



.. _tpe-algorithm:

Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

packages = [ # Packages must be sorted alphabetically to ease maintenance and merges.
"orion.algo",
"orion.algo.pbt",
"orion.analysis",
"orion.benchmark",
"orion.client",
Expand Down Expand Up @@ -53,6 +54,7 @@
"hyperband = orion.algo.hyperband:Hyperband",
"tpe = orion.algo.tpe:TPE",
"EvolutionES = orion.algo.evolution_es:EvolutionES",
"pbt = orion.algo.pbt.pbt:PBT",
],
"Database": [
"ephemeraldb = orion.core.io.database.ephemeraldb:EphemeralDB",
Expand Down
11 changes: 8 additions & 3 deletions src/orion/algo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def format_trial(self, trial):

return trial

def get_id(self, trial, ignore_fidelity=False):
def get_id(self, trial, ignore_fidelity=False, ignore_parent=False):
"""Return unique hash for a trials based on params

The trial is assumed to be in the transformed space if the algorithm is working in a
Expand All @@ -174,6 +174,10 @@ def get_id(self, trial, ignore_fidelity=False):
ignore_fidelity: bool, optional
If True, the fidelity dimension is ignored when computing a unique hash for
the trial. Defaults to False.
ignore_parent: bool, optional
If True, the parent id is ignored when computing a unique hash for
the trial. Defaults to False.

"""

# Apply transforms and reverse to see data as it would come from DB
Expand All @@ -188,6 +192,7 @@ def get_id(self, trial, ignore_fidelity=False):
ignore_fidelity=ignore_fidelity,
ignore_experiment=True,
ignore_lie=True,
ignore_parent=ignore_parent,
)

@property
Expand Down Expand Up @@ -357,8 +362,8 @@ def judge(self, trial, measurements): # pylint:disable=no-self-use,unused-argum
trial: ``orion.core.worker.trial.Trial``
Trial object to retrieve from the database

Notes:
------
Notes
-----

Calling algorithm to `judge` a `point` based on its online `measurements` will effectively
change a state in the algorithm (like a reinforcement learning agent's hidden state or an
Expand Down
27 changes: 18 additions & 9 deletions src/orion/algo/hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,10 @@ def sample_from_bracket(self, bracket, num):
params={self.fidelity_index: bracket.rungs[0]["resources"]}
)

full_id = self.get_id(trial, ignore_fidelity=False)
id_wo_fidelity = self.get_id(trial, ignore_fidelity=True)
full_id = self.get_id(trial, ignore_fidelity=False, ignore_parent=False)
id_wo_fidelity = self.get_id(
trial, ignore_fidelity=True, ignore_parent=True
)

bracket_id = self.trial_to_brackets.get(id_wo_fidelity, None)
if bracket_id is not None:
Expand Down Expand Up @@ -262,7 +264,7 @@ def set_state(self, state_dict):

def register_samples(self, bracket, samples):
for sample in samples:
full_id = self.get_id(sample, ignore_fidelity=False)
full_id = self.get_id(sample, ignore_fidelity=False, ignore_parent=False)
if self.has_observed(sample):
raise RuntimeError(
"Hyperband resampling a trial that was already completed. "
Expand All @@ -273,9 +275,12 @@ def register_samples(self, bracket, samples):
self.register(sample)
bracket.register(sample)

if self.get_id(sample, ignore_fidelity=True) not in self.trial_to_brackets:
if (
self.get_id(sample, ignore_fidelity=True, ignore_parent=True)
not in self.trial_to_brackets
):
self.trial_to_brackets[
self.get_id(sample, ignore_fidelity=True)
self.get_id(sample, ignore_fidelity=True, ignore_parent=True)
] = self.brackets.index(bracket)

def promote(self, num):
Expand Down Expand Up @@ -384,7 +389,7 @@ def create_brackets(self):

def _get_bracket(self, trial):
"""Get the bracket of a trial"""
_id_wo_fidelity = self.get_id(trial, ignore_fidelity=True)
_id_wo_fidelity = self.get_id(trial, ignore_fidelity=True, ignore_parent=True)
return self.brackets[self.trial_to_brackets[_id_wo_fidelity]]

def observe(self, trials):
Expand Down Expand Up @@ -474,7 +479,9 @@ def is_filled(self):
def get_trial_max_resource(self, trial):
"""Return the max resource value that has been tried for a trial"""
max_resource = 0
_id_wo_fidelity = self.hyperband.get_id(trial, ignore_fidelity=True)
_id_wo_fidelity = self.hyperband.get_id(
trial, ignore_fidelity=True, ignore_parent=True
)
for rung in self.rungs:
if _id_wo_fidelity in rung["results"]:
max_resource = rung["resources"]
Expand Down Expand Up @@ -511,7 +518,9 @@ def sample(self, num):

def register(self, trial):
"""Register a trial in the corresponding rung"""
self._get_results(trial)[self.hyperband.get_id(trial, ignore_fidelity=True)] = (
self._get_results(trial)[
self.hyperband.get_id(trial, ignore_fidelity=True, ignore_parent=True)
] = (
trial.objective.value if trial.objective else None,
copy.deepcopy(trial),
)
Expand Down Expand Up @@ -562,7 +571,7 @@ def get_candidates(self, rung_id):
while len(trials) + len(next_rung) < should_have_n_trials:
objective, trial = rung[i]
assert objective is not None
_id = self.hyperband.get_id(trial, ignore_fidelity=True)
_id = self.hyperband.get_id(trial, ignore_fidelity=True, ignore_parent=True)
if _id not in next_rung:
trials.append(trial)
i += 1
Expand Down
Loading