diff --git a/src/orion/client/experiment.py b/src/orion/client/experiment.py index 7996a983e..d2d2c35cf 100644 --- a/src/orion/client/experiment.py +++ b/src/orion/client/experiment.py @@ -9,6 +9,7 @@ import inspect import logging +import numbers import typing from contextlib import contextmanager from typing import Callable @@ -594,7 +595,12 @@ def suggest(self, pool_size=0): self._maintain_reservation(trial) return TrialCM(self, trial) - def observe(self, trial, results): + def observe( + self, + trial: Trial, + results: list[dict] | float, + name: str = "objective", + ) -> None: """Observe trial results Experiment must be in executable ('x') mode. @@ -603,10 +609,13 @@ def observe(self, trial, results): ---------- trial: `orion.core.worker.trial.Trial` Reserved trial to observe. - results: list + results: list or float Results to be set for the new trial. Results must have the format {name: : type: <'objective', 'constraint' or 'gradient'>, value=} otherwise a ValueError will be raised. If the results are invalid, the trial will not be released. + If `results` is a float, the result type will be 'objective'. + name: str + Name of the result if `results` is a float. Default: 'objective'. Returns ------- @@ -628,6 +637,9 @@ def observe(self, trial, results): """ self._check_if_executable() + if isinstance(results, numbers.Number): + results = [dict(value=results, name=name, type="objective")] + trial.results += [Trial.Result(**result) for result in results] raise_if_unreserved = True try: diff --git a/tests/unittests/client/test_experiment_client.py b/tests/unittests/client/test_experiment_client.py index 57e3dc641..77aff7fad 100644 --- a/tests/unittests/client/test_experiment_client.py +++ b/tests/unittests/client/test_experiment_client.py @@ -906,6 +906,40 @@ def test_observe_under_with(self): assert trial.status == "completed" # Still completed after __exit__ + def test_observe_with_float(self): + with create_experiment(config, base_trial) as (cfg, experiment, client): + trial = Trial(**cfg.trials[1]) + client.reserve(trial) + + client.observe(trial, 10.0) + assert trial.status == "completed" + assert trial.objective.name == "objective" + assert trial.objective.type == "objective" + assert not client._pacemakers + + def test_observe_with_float_and_name(self): + with create_experiment(config, base_trial) as (cfg, experiment, client): + trial = Trial(**cfg.trials[1]) + client.reserve(trial) + + client.observe(trial, 10.0, name="custom_objective") + assert trial.status == "completed" + assert trial.objective.name == "custom_objective" + assert trial.objective.type == "objective" + assert not client._pacemakers + + def test_observe_with_invalid_type(self): + with create_experiment(config, base_trial) as (cfg, experiment, client): + trial = Trial(**cfg.trials[1]) + client.reserve(trial) + + with pytest.raises(TypeError): + client.observe(trial, "invalid") + assert trial.status == "reserved" + assert trial.objective is None + assert client._pacemakers[trial.id].is_alive() + client._pacemakers.pop(trial.id).stop() + def test_executor_receives_correct_worker_count(): """Check that the client forwards the current number count to the executor"""