Skip to content

Commit

Permalink
Make construction of Experiments more explicit
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jul 11, 2022
1 parent e809cf6 commit 9acbc6e
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 97 deletions.
61 changes: 33 additions & 28 deletions src/orion/core/io/experiment_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
NoNameError,
RaceCondition,
)
from orion.core.worker.experiment import Experiment
from orion.core.worker.experiment import Experiment, Mode
from orion.core.worker.primary_algo import create_algo
from orion.storage.base import get_storage, setup_storage

Expand Down Expand Up @@ -312,7 +312,7 @@ def build_view(name, version=None):
return load(name, version=version, mode="r")


def load(name, version=None, mode="r"):
def load(name: str, version: int | None = None, mode: Mode = "r"):
"""Load experiment from database
An experiment view provides all reading operations of standard experiment but prevents the
Expand Down Expand Up @@ -351,7 +351,9 @@ def load(name, version=None, mode="r"):
return create_experiment(mode=mode, **db_config)


def create_experiment(name, version, mode, space, **kwargs):
def create_experiment(
name: str, version: int, mode: Mode, space: Space | dict[str, str], **kwargs
) -> Experiment:
"""Instantiate the experiment and its attribute objects
All unspecified arguments will be replaced by system's defaults (orion.core.config.*).
Expand Down Expand Up @@ -382,36 +384,39 @@ def create_experiment(name, version, mode, space, **kwargs):
Configuration of the storage backend.
"""
experiment = Experiment(name=name, version=version, mode=mode)
experiment._id = kwargs.get("_id", None) # pylint:disable=protected-access
experiment.max_trials = kwargs.get(
"max_trials", orion.core.config.experiment.max_trials
)
experiment.max_broken = kwargs.get(
"max_broken", orion.core.config.experiment.max_broken
)
experiment.space = _instantiate_space(space)
experiment.algorithms = _instantiate_algo(
experiment.space,
experiment.max_trials,
kwargs.get("algorithms"),
space = _instantiate_space(space)
_id = kwargs.get("_id", None)
max_trials = kwargs.pop("max_trials", orion.core.config.experiment.max_trials)
max_broken = kwargs.pop("max_broken", orion.core.config.experiment.max_broken)
working_dir = kwargs.pop("working_dir", orion.core.config.experiment.working_dir)
algo_config = kwargs.pop("algorithms", None)
algorithms = _instantiate_algo(
space=space,
max_trials=max_trials,
config=algo_config,
ignore_unavailable=mode != "x",
)
# TODO: Remove for v0.4
_instantiate_strategy(kwargs.get("producer", {}).get("strategy"))
experiment.working_dir = kwargs.get(
"working_dir", orion.core.config.experiment.working_dir
)
experiment.metadata = kwargs.get(
"metadata", {"user": kwargs.get("user", getpass.getuser())}
)
experiment.refers = kwargs.get(
metadata = kwargs.pop("metadata", {"user": kwargs.pop("user", getpass.getuser())})
refers: dict = kwargs.pop(
"refers", {"parent_id": None, "root_id": None, "adapter": []}
)
experiment.refers["adapter"] = _instantiate_adapters(
experiment.refers.get("adapter", [])
refers["adapter"] = _instantiate_adapters(refers.get("adapter", []))
# TODO: Remove for v0.4
strategy_config: dict | None = kwargs.pop("producer", {}).get("strategy")
_instantiate_strategy(strategy_config)
experiment = Experiment(
name=name,
version=version,
mode=mode,
space=space,
_id=_id,
max_trials=max_trials,
max_broken=max_broken,
algorithms=algorithms,
working_dir=working_dir,
metadata=metadata,
refers=refers,
)

log.debug(
"Created experiment with config:\n%s", pprint.pformat(experiment.configuration)
)
Expand Down
47 changes: 35 additions & 12 deletions src/orion/core/worker/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Manage history of trials corresponding to a black box process.
"""
from __future__ import annotations

import contextlib
import copy
import datetime
Expand All @@ -14,16 +16,21 @@
from dataclasses import dataclass, field

import pandas
from typing_extensions import Literal

from orion.algo.base import BaseAlgorithm
from orion.algo.space import Space
from orion.core.evc.adapters import BaseAdapter
from orion.core.evc.experiment import ExperimentNode
from orion.core.io.database import DuplicateKeyError
from orion.core.io.space_builder import SpaceBuilder
from orion.core.utils.exceptions import UnsupportedOperation
from orion.core.utils.flatten import flatten
from orion.core.utils.singleton import update_singletons
from orion.storage.base import FailedUpdate, get_storage

log = logging.getLogger(__name__)
Mode = Literal["r", "w", "x"]


@dataclass
Expand Down Expand Up @@ -133,22 +140,36 @@ class Experiment:
)
non_branching_attrs = ("max_trials", "max_broken")

def __init__(self, name, version=None, mode="r"):
self._id = None
def __init__(
self,
name: str,
space: Space | dict[str, str],
version: int | None = 1,
mode: Mode = "r",
_id: str | int | None = None,
max_trials: int | None = None,
max_broken: int | None = None,
algorithms: BaseAlgorithm | None = None,
working_dir: str | None = None,
metadata: dict | None = None,
refers: dict | None = None,
):
self._id = _id
self.name = name
self.space: Space = (
space if isinstance(space, Space) else SpaceBuilder().build(space)
)
self.version = version if version else 1
self._mode = mode
self._node = None
self.refers = {}
self.metadata = {}
self.max_trials = None
self.max_broken = None
self.space = None
self.algorithms = None
self.working_dir = None
self.refers = refers or {}
self.metadata = metadata or {}
self.max_trials = max_trials
self.max_broken = max_broken

self._storage = get_storage()
self.algorithms = algorithms
self.working_dir = working_dir

self._storage = get_storage()
self._node = ExperimentNode(self.name, self.version, experiment=self)

def _check_if_writable(self):
Expand Down Expand Up @@ -386,7 +407,9 @@ def register_trial(self, trial, status="new"):
self._storage.register_trial(trial)

@contextlib.contextmanager
def acquire_algorithm_lock(self, timeout=60, retry_interval=1):
def acquire_algorithm_lock(
self, timeout: int | float = 60, retry_interval: int | float = 1
):
"""Acquire lock on algorithm
This method should be called using a ``with``-clause.
Expand Down
Loading

0 comments on commit 9acbc6e

Please sign in to comment.