Skip to content

Commit

Permalink
Make ExperimentClient.observe support single value
Browse files Browse the repository at this point in the history
Why:

Formatting results as dict(name, type, value) is cumbersome and the most
common use case is observing a single result which is the objective.
This use case should be simplified.

How:

Support `client.observe(trial, value, name)` and still support
`client.observe(trial, results)`.
  • Loading branch information
bouthilx committed Jan 11, 2023
1 parent 98bc0ab commit 89a37b0
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/orion/client/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import inspect
import logging
import numbers
import typing
from contextlib import contextmanager
from typing import Callable
Expand Down Expand Up @@ -594,7 +595,12 @@ def suggest(self, pool_size=0):
self._maintain_reservation(trial)
return TrialCM(self, trial)

def observe(self, trial, results):
def observe(
self,
trial: Trial,
results: list[dict] | float,
name: str = "objective",
) -> None:
"""Observe trial results
Experiment must be in executable ('x') mode.
Expand All @@ -603,10 +609,13 @@ def observe(self, trial, results):
----------
trial: `orion.core.worker.trial.Trial`
Reserved trial to observe.
results: list
results: list or float
Results to be set for the new trial. Results must have the format
{name: <str>: type: <'objective', 'constraint' or 'gradient'>, value=<float>} otherwise
a ValueError will be raised. If the results are invalid, the trial will not be released.
If `results` is a float, the result type will be 'objective'.
name: str
Name of the result if `results` is a float. Default: 'objective'.
Returns
-------
Expand All @@ -628,6 +637,9 @@ def observe(self, trial, results):
"""
self._check_if_executable()

if isinstance(results, numbers.Number):
results = [dict(value=results, name=name, type="objective")]

trial.results += [Trial.Result(**result) for result in results]
raise_if_unreserved = True
try:
Expand Down
34 changes: 34 additions & 0 deletions tests/unittests/client/test_experiment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,40 @@ def test_observe_under_with(self):

assert trial.status == "completed" # Still completed after __exit__

def test_observe_with_float(self):
with create_experiment(config, base_trial) as (cfg, experiment, client):
trial = Trial(**cfg.trials[1])
client.reserve(trial)

client.observe(trial, 10.0)
assert trial.status == "completed"
assert trial.objective.name == "objective"
assert trial.objective.type == "objective"
assert not client._pacemakers

def test_observe_with_float_and_name(self):
with create_experiment(config, base_trial) as (cfg, experiment, client):
trial = Trial(**cfg.trials[1])
client.reserve(trial)

client.observe(trial, 10.0, name="custom_objective")
assert trial.status == "completed"
assert trial.objective.name == "custom_objective"
assert trial.objective.type == "objective"
assert not client._pacemakers

def test_observe_with_invalid_type(self):
with create_experiment(config, base_trial) as (cfg, experiment, client):
trial = Trial(**cfg.trials[1])
client.reserve(trial)

with pytest.raises(TypeError):
client.observe(trial, "invalid")
assert trial.status == "reserved"
assert trial.objective is None
assert client._pacemakers[trial.id].is_alive()
client._pacemakers.pop(trial.id).stop()


def test_executor_receives_correct_worker_count():
"""Check that the client forwards the current number count to the executor"""
Expand Down

0 comments on commit 89a37b0

Please sign in to comment.