Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor refactor: Make creation of Experiment objects more explicit #968

Merged
merged 6 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 70 additions & 35 deletions src/orion/core/io/experiment_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
import logging
import pprint
import sys
import warnings
from typing import TypeVar

import orion.core
import orion.core.utils.backward as backward # pylint:disable=consider-using-from-import
Expand All @@ -99,7 +101,7 @@
NoNameError,
RaceCondition,
)
from orion.core.worker.experiment import Experiment
from orion.core.worker.experiment import Experiment, Mode
from orion.core.worker.primary_algo import create_algo
from orion.storage.base import get_storage, setup_storage

Expand All @@ -111,7 +113,9 @@
##


def build(name, version=None, branching=None, **config):
def build(
name: str, version: int | None = None, branching: dict | None = None, **config
) -> Experiment:
"""Build an experiment object

If new, ``space`` argument must be provided, else all arguments are fetched from the database
Expand Down Expand Up @@ -202,6 +206,7 @@ def build(name, version=None, branching=None, **config):
log.debug(f"Experiment {config['name']}-v{config['version']} already existed.")

conflicts = _get_conflicts(experiment, branching)
assert branching is not None
must_branch = len(conflicts.get()) > 1 or branching.get("branch_to")

if must_branch and branching.get("enable", orion.core.config.evc.enable):
Expand All @@ -218,7 +223,7 @@ def build(name, version=None, branching=None, **config):
return experiment


def clean_config(name, config, branching):
def clean_config(name: str, config: dict, branching: dict | None):
"""Clean configuration from hidden fields (ex: ``_id``) and update branching if necessary"""
log.debug("Cleaning config")

Expand Down Expand Up @@ -246,7 +251,7 @@ def clean_config(name, config, branching):
return name, config, branching


def consolidate_config(name, version, config):
def consolidate_config(name: str, version: int | None, config: dict):
"""Merge together given configuration with db configuration matching
for experiment (``name``, ``version``)
"""
Expand Down Expand Up @@ -278,7 +283,7 @@ def consolidate_config(name, version, config):
return config


def merge_algorithm_config(config, new_config):
def merge_algorithm_config(config: dict, new_config: dict) -> None:
"""Merge given algorithm configuration with db config"""
# TODO: Find a better solution
if isinstance(config.get("algorithms"), dict) and len(config["algorithms"]) > 1:
Expand All @@ -289,7 +294,7 @@ def merge_algorithm_config(config, new_config):


# TODO: Remove for v0.4
def merge_producer_config(config, new_config):
def merge_producer_config(config: dict, new_config: dict) -> None:
"""Merge given producer configuration with db config"""
if (
isinstance(config.get("producer", {}).get("strategy"), dict)
Expand All @@ -312,7 +317,7 @@ def build_view(name, version=None):
return load(name, version=version, mode="r")


def load(name, version=None, mode="r"):
def load(name: str, version: int | None = None, mode: Mode = "r") -> Experiment:
"""Load experiment from database

An experiment view provides all reading operations of standard experiment but prevents the
Expand Down Expand Up @@ -351,7 +356,23 @@ def load(name, version=None, mode="r"):
return create_experiment(mode=mode, **db_config)


def create_experiment(name, version, mode, space, **kwargs):
# pylint: disable=too-many-arguments
def create_experiment(
name: str,
version: int,
mode: Mode,
space: Space | dict[str, str],
algorithms: str | dict | None = None,
max_trials: int | None = None,
max_broken: int | None = None,
working_dir: str | None = None,
metadata: dict | None = None,
refers: dict | None = None,
producer: dict | None = None,
user: str | None = None,
_id: int | str | None = None,
**kwargs,
) -> Experiment:
"""Instantiate the experiment and its attribute objects

All unspecified arguments will be replaced by system's defaults (orion.core.config.*).
Expand Down Expand Up @@ -382,39 +403,53 @@ def create_experiment(name, version, mode, space, **kwargs):
Configuration of the storage backend.

"""
experiment = Experiment(name=name, version=version, mode=mode)
experiment._id = kwargs.get("_id", None) # pylint:disable=protected-access
experiment.max_trials = kwargs.get(
"max_trials", orion.core.config.experiment.max_trials
)
experiment.max_broken = kwargs.get(
"max_broken", orion.core.config.experiment.max_broken
)
experiment.space = _instantiate_space(space)
experiment.algorithms = _instantiate_algo(
experiment.space,
experiment.max_trials,
kwargs.get("algorithms"),

T = TypeVar("T")
V = TypeVar("V")

def _default(v: T | None, default: V) -> T | V:
return v if v is not None else default

space = _instantiate_space(space)
max_trials = _default(max_trials, orion.core.config.experiment.max_trials)
instantiated_algorithm = _instantiate_algo(
space=space,
max_trials=max_trials,
config=algorithms,
ignore_unavailable=mode != "x",
)

max_broken = _default(max_broken, orion.core.config.experiment.max_broken)
working_dir = _default(working_dir, orion.core.config.experiment.working_dir)
metadata = _default(metadata, {"user": _default(user, getpass.getuser())})
refers = _default(refers, dict(parent_id=None, root_id=None, adapter=[]))
refers["adapter"] = _instantiate_adapters(refers.get("adapter", [])) # type: ignore

# TODO: Remove for v0.4
_instantiate_strategy(kwargs.get("producer", {}).get("strategy"))
experiment.working_dir = kwargs.get(
"working_dir", orion.core.config.experiment.working_dir
_instantiate_strategy((producer or {}).get("strategy"))

experiment = Experiment(
name=name,
version=version,
mode=mode,
space=space,
_id=_id,
max_trials=max_trials,
algorithms=instantiated_algorithm,
max_broken=max_broken,
working_dir=working_dir,
metadata=metadata,
refers=refers,
)
experiment.metadata = kwargs.get(
"metadata", {"user": kwargs.get("user", getpass.getuser())}
)
experiment.refers = kwargs.get(
"refers", {"parent_id": None, "root_id": None, "adapter": []}
)
experiment.refers["adapter"] = _instantiate_adapters(
experiment.refers.get("adapter", [])
)

log.debug(
"Created experiment with config:\n%s", pprint.pformat(experiment.configuration)
)
if kwargs:
warnings.warn(
UserWarning(
f"create_experiment received some extra unused arguments: {kwargs}"
)
)

return experiment

Expand Down Expand Up @@ -568,7 +603,7 @@ def _register_experiment(experiment):
)


def _update_experiment(experiment):
def _update_experiment(experiment: Experiment) -> None:
"""Update experiment configuration in database"""
log.debug("Updating experiment (name: %s)", experiment.name)
config = experiment.configuration
Expand Down
45 changes: 33 additions & 12 deletions src/orion/core/worker/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Manage history of trials corresponding to a black box process.

"""
from __future__ import annotations

import contextlib
import copy
import datetime
Expand All @@ -14,7 +16,10 @@
from dataclasses import dataclass, field

import pandas
from typing_extensions import Literal

from orion.algo.base import BaseAlgorithm
from orion.algo.space import Space
from orion.core.evc.adapters import BaseAdapter
from orion.core.evc.experiment import ExperimentNode
from orion.core.io.database import DuplicateKeyError
Expand All @@ -24,6 +29,7 @@
from orion.storage.base import FailedUpdate, get_storage

log = logging.getLogger(__name__)
Mode = Literal["r", "w", "x"]


@dataclass
Expand Down Expand Up @@ -133,22 +139,35 @@ class Experiment:
)
non_branching_attrs = ("max_trials", "max_broken")

def __init__(self, name, version=None, mode="r"):
self._id = None
# pylint: disable=too-many-arguments
def __init__(
self,
name: str,
space: Space,
version: int | None = 1,
mode: Mode = "r",
_id: str | int | None = None,
max_trials: int | None = None,
max_broken: int | None = None,
algorithms: BaseAlgorithm | None = None,
working_dir: str | None = None,
metadata: dict | None = None,
refers: dict | None = None,
):
self._id = _id
self.name = name
self.space: Space = space
self.version = version if version else 1
self._mode = mode
self._node = None
self.refers = {}
self.metadata = {}
self.max_trials = None
self.max_broken = None
self.space = None
self.algorithms = None
self.working_dir = None
self.refers = refers or {}
self.metadata = metadata or {}
self.max_trials = max_trials
self.max_broken = max_broken

self._storage = get_storage()
self.algorithms = algorithms
self.working_dir = working_dir

self._storage = get_storage()
self._node = ExperimentNode(self.name, self.version, experiment=self)

def _check_if_writable(self):
Expand Down Expand Up @@ -386,7 +405,9 @@ def register_trial(self, trial, status="new"):
self._storage.register_trial(trial)

@contextlib.contextmanager
def acquire_algorithm_lock(self, timeout=60, retry_interval=1):
def acquire_algorithm_lock(
self, timeout: int | float = 60, retry_interval: int | float = 1
):
"""Acquire lock on algorithm

This method should be called using a ``with``-clause.
Expand Down
Loading