From 90c314fe6013bc11398bc117168df8f8467a0402 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Wed, 6 Dec 2023 11:50:49 -0800 Subject: [PATCH] Call "get_improvement_over_baseline" when making scheduler log (#2050) Summary: Integrating SchedulerCompletedRecord.from_scheduler with `scheduler.get_improvement_over_baseline` Reviewed By: mpolson64 Differential Revision: D51858443 --- ax/telemetry/scheduler.py | 11 +++++++++- ax/telemetry/tests/test_scheduler.py | 30 +++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/ax/telemetry/scheduler.py b/ax/telemetry/scheduler.py index 4e16ee905e0..59e04557c97 100644 --- a/ax/telemetry/scheduler.py +++ b/ax/telemetry/scheduler.py @@ -134,6 +134,15 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCompletedRecord: model_fit_quality = float("-inf") model_std_quality = float("-inf") + try: + improvement_over_baseline = scheduler.get_improvement_over_baseline() + except Exception as e: + warn( + "Encountered exception in computing improvement over baseline: " + + str(e) + ) + improvement_over_baseline = float("-inf") + return cls( experiment_completed_record=ExperimentCompletedRecord.from_experiment( experiment=scheduler.experiment @@ -143,7 +152,7 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCompletedRecord: model_std_quality=model_std_quality, model_fit_generalization=float("-inf"), # TODO by cross_validate_by_trial model_std_generalization=float("-inf"), - improvement_over_baseline=float("-inf"), # TODO extract improvement result + improvement_over_baseline=improvement_over_baseline, num_metric_fetch_e_encountered=scheduler._num_metric_fetch_e_encountered, num_trials_bad_due_to_err=scheduler._num_trials_bad_due_to_err, ) diff --git a/ax/telemetry/tests/test_scheduler.py b/ax/telemetry/tests/test_scheduler.py index e07e1d0a612..fe98d525a58 100644 --- a/ax/telemetry/tests/test_scheduler.py +++ b/ax/telemetry/tests/test_scheduler.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict +from unittest import mock from ax.core.experiment import Experiment from ax.core.objective import Objective @@ -85,7 +86,10 @@ def test_scheduler_completed_record_from_scheduler(self) -> None: ), ) - record = SchedulerCompletedRecord.from_scheduler(scheduler=scheduler) + with mock.patch.object( + scheduler, "get_improvement_over_baseline", return_value=5.0 + ): + record = SchedulerCompletedRecord.from_scheduler(scheduler=scheduler) expected = SchedulerCompletedRecord( experiment_completed_record=ExperimentCompletedRecord.from_experiment( experiment=scheduler.experiment @@ -95,7 +99,7 @@ def test_scheduler_completed_record_from_scheduler(self) -> None: model_std_quality=float("-inf"), model_fit_generalization=float("-inf"), model_std_generalization=float("-inf"), - improvement_over_baseline=float("-inf"), + improvement_over_baseline=5.0, num_metric_fetch_e_encountered=0, num_trials_bad_due_to_err=0, ) @@ -111,12 +115,32 @@ def test_scheduler_completed_record_from_scheduler(self) -> None: "model_std_quality": float("-inf"), "model_fit_generalization": float("-inf"), "model_std_generalization": float("-inf"), - "improvement_over_baseline": float("-inf"), + "improvement_over_baseline": 5.0, "num_metric_fetch_e_encountered": 0, "num_trials_bad_due_to_err": 0, } self.assertEqual(flat, expected_dict) + def test_scheduler_raise_exceptions(self) -> None: + scheduler = Scheduler( + experiment=get_branin_experiment(), + generation_strategy=get_generation_strategy(), + options=SchedulerOptions( + total_trials=0, + tolerated_trial_failure_rate=0.2, + init_seconds_between_polls=10, + ), + ) + + with mock.patch.object( + scheduler, + "get_improvement_over_baseline", + side_effect=Exception("test_exception"), + ): + record = SchedulerCompletedRecord.from_scheduler(scheduler=scheduler) + flat = record.flatten() + self.assertEqual(flat["improvement_over_baseline"], float("-inf")) + def test_scheduler_model_fit_metrics_logging(self) -> None: # set up for model fit metrics branin_experiment = Experiment(