From 1139ea048d150b2f390a1eee0f1f78bfe62c94b5 Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Thu, 13 Jun 2024 14:15:53 -0700 Subject: [PATCH] Put model fit data in gen_metadata (#2511) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2511 Reviewed By: saitcakmak Differential Revision: D58261582 fbshipit-source-id: a29600fb48d3a825d2c648646a12c88702ca94b3 --- ax/modelbridge/cross_validation.py | 73 ++++++++++++++++ ax/modelbridge/model_spec.py | 17 +++- .../tests/test_model_fit_metrics.py | 78 ++++++++++++++++- ax/modelbridge/tests/test_model_spec.py | 21 +++++ ax/telemetry/scheduler.py | 83 ++++--------------- ax/telemetry/tests/test_scheduler.py | 24 +++--- 6 files changed, 218 insertions(+), 78 deletions(-) diff --git a/ax/modelbridge/cross_validation.py b/ax/modelbridge/cross_validation.py index cf01d97070f..a07ec1475f7 100644 --- a/ax/modelbridge/cross_validation.py +++ b/ax/modelbridge/cross_validation.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +from __future__ import annotations import warnings from abc import ABC, abstractmethod @@ -16,6 +17,7 @@ from logging import Logger from numbers import Number from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple +from warnings import warn import numpy as np from ax.core.observation import Observation, ObservationData, recombine_observations @@ -492,6 +494,45 @@ def best_diagnostic(self, diagnostics: List[CVDiagnostics]) -> int: """ +def get_fit_and_std_quality_and_generalization_dict( + fitted_model_bridge: ModelBridge, +) -> Dict[str, Optional[float]]: + """ + Get stats and gen from a fitted ModelBridge for analytics purposes. + """ + try: + model_fit_dict = compute_model_fit_metrics_from_modelbridge( + model_bridge=fitted_model_bridge, + generalization=False, + untransform=False, + ) + # similar for uncertainty quantification, but distance from 1 matters + std = list(model_fit_dict["std_of_the_standardized_error"].values()) + + # generalization metrics + model_gen_dict = compute_model_fit_metrics_from_modelbridge( + model_bridge=fitted_model_bridge, + generalization=True, + untransform=False, + ) + gen_std = list(model_gen_dict["std_of_the_standardized_error"].values()) + return { + "model_fit_quality": _model_fit_metric(model_fit_dict), + "model_std_quality": _model_std_quality(np.array(std)), + "model_fit_generalization": _model_fit_metric(model_gen_dict), + "model_std_generalization": _model_std_quality(np.array(gen_std)), + } + + except Exception as e: + warn("Encountered exception in computing model fit quality: " + str(e)) + return { + "model_fit_quality": None, + "model_std_quality": None, + "model_fit_generalization": None, + "model_std_generalization": None, + } + + def compute_model_fit_metrics_from_modelbridge( model_bridge: ModelBridge, fit_metrics_dict: Optional[Dict[str, ModelFitMetricProtocol]] = None, @@ -550,6 +591,38 @@ def compute_model_fit_metrics_from_modelbridge( ) +def _model_fit_metric(metric_dict: Dict[str, Dict[str, float]]) -> float: + # We'd ideally log the entire `model_fit_dict` as a single model fit metric + # can't capture the nuances of multiple experimental metrics, but this might + # lead to database performance issues. So instead, we take the worst + # coefficient of determination as model fit quality and store the full data + # in Manifold (TODO). + return min(metric_dict["coefficient_of_determination"].values()) + + +def _model_std_quality(std: np.ndarray) -> float: + """Quantifies quality of the model uncertainty. A value of one means the + uncertainty is perfectly predictive of the true standard deviation of the error. + Values larger than one indicate over-estimation and negative values indicate + under-estimation of the true standard deviation of the error. In particular, a value + of 2 (resp. 1 / 2) represents an over-estimation (resp. under-estimation) of the + true standard deviation of the error by a factor of 2. + + Args: + std: The standard deviation of the standardized error. + + Returns: + The factor corresponding to the worst over- or under-estimation factor of the + standard deviation of the error among all experimentally observed metrics. + """ + max_std, min_std = np.max(std), np.min(std) + # comparing worst over-estimation factor with worst under-estimation factor + inv_model_std_quality = max_std if max_std > 1 / min_std else min_std + # reciprocal so that values greater than one indicate over-estimation and + # values smaller than indicate underestimation of the uncertainty. + return 1 / inv_model_std_quality + + def _predict_on_training_data( model_bridge: ModelBridge, untransform: bool = False, diff --git a/ax/modelbridge/model_spec.py b/ax/modelbridge/model_spec.py index cca8bde87a3..2e86fcf46f3 100644 --- a/ax/modelbridge/model_spec.py +++ b/ax/modelbridge/model_spec.py @@ -27,6 +27,7 @@ cross_validate, CVDiagnostics, CVResult, + get_fit_and_std_quality_and_generalization_dict, ) from ax.modelbridge.registry import ModelRegistryBase from ax.utils.common.base import SortableBase @@ -218,7 +219,21 @@ def gen(self, **model_gen_kwargs: Any) -> GeneratorRun: ], keywords=get_function_argument_names(fitted_model.gen), ) - return fitted_model.gen(**model_gen_kwargs) + generator_run = fitted_model.gen( + **model_gen_kwargs, + ) + fit_and_std_quality_and_generalization_dict = ( + get_fit_and_std_quality_and_generalization_dict( + fitted_model_bridge=self.fitted_model, + ) + ) + generator_run._gen_metadata = ( + {} if generator_run.gen_metadata is None else generator_run.gen_metadata + ) + generator_run._gen_metadata.update( + **fit_and_std_quality_and_generalization_dict + ) + return generator_run def copy(self) -> ModelSpec: """`ModelSpec` is both a spec and an object that performs actions. diff --git a/ax/modelbridge/tests/test_model_fit_metrics.py b/ax/modelbridge/tests/test_model_fit_metrics.py index 1f286de6e27..98747261c8c 100644 --- a/ax/modelbridge/tests/test_model_fit_metrics.py +++ b/ax/modelbridge/tests/test_model_fit_metrics.py @@ -19,6 +19,7 @@ _predict_on_cross_validation_data, _predict_on_training_data, compute_model_fit_metrics_from_modelbridge, + get_fit_and_std_quality_and_generalization_dict, ) from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import Models @@ -27,7 +28,7 @@ from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase from ax.utils.stats.model_fit_stats import _entropy_via_kde, entropy_of_observations -from ax.utils.testing.core_stubs import get_branin_search_space +from ax.utils.testing.core_stubs import get_branin_experiment, get_branin_search_space NUM_SOBOL = 5 @@ -141,3 +142,78 @@ def test_model_fit_metrics(self) -> None: self.assertFalse( any("Input data is not standardized" in str(w.message) for w in ws) ) + + +class TestGetFitAndStdQualityAndGeneralizationDict(TestCase): + def setUp(self) -> None: + super().setUp() + self.experiment = get_branin_experiment() + self.sobol = Models.SOBOL(search_space=self.experiment.search_space) + + def test_it_returns_empty_data_for_sobol(self) -> None: + results = get_fit_and_std_quality_and_generalization_dict( + fitted_model_bridge=self.sobol, + ) + expected = { + "model_fit_quality": None, + "model_std_quality": None, + "model_fit_generalization": None, + "model_std_generalization": None, + } + self.assertDictEqual(results, expected) + + def test_it_returns_float_values_when_fit_can_be_evaluated(self) -> None: + # GIVEN we have a model whose CV can be evaluated + sobol_run = self.sobol.gen(n=20) + self.experiment.new_batch_trial().add_generator_run( + sobol_run + ).run().mark_completed() + data = self.experiment.fetch_data() + botorch_modelbridge = Models.BOTORCH_MODULAR( + experiment=self.experiment, data=data + ) + + # WHEN we call get_fit_and_std_quality_and_generalization_dict + results = get_fit_and_std_quality_and_generalization_dict( + fitted_model_bridge=botorch_modelbridge, + ) + + # THEN we get expected results + # CALCULATE EXPECTED RESULTS + fit_metrics = compute_model_fit_metrics_from_modelbridge( + model_bridge=botorch_modelbridge, + generalization=False, + untransform=False, + ) + # checking fit metrics + r2 = fit_metrics.get("coefficient_of_determination") + r2 = cast(Dict[str, float], r2) + + std = fit_metrics.get("std_of_the_standardized_error") + std = cast(Dict[str, float], std) + std_branin = std["branin"] + + model_std_quality = 1 / std_branin + + # check generalization metrics + gen_metrics = compute_model_fit_metrics_from_modelbridge( + model_bridge=botorch_modelbridge, + generalization=True, + untransform=False, + ) + r2_gen = gen_metrics.get("coefficient_of_determination") + r2_gen = cast(Dict[str, float], r2_gen) + gen_std = gen_metrics.get("std_of_the_standardized_error") + gen_std = cast(Dict[str, float], gen_std) + gen_std_branin = gen_std["branin"] + model_std_generalization = 1 / gen_std_branin + + expected = { + "model_fit_quality": min(r2.values()), + "model_std_quality": model_std_quality, + "model_fit_generalization": min(r2_gen.values()), + "model_std_generalization": model_std_generalization, + } + # END CALCULATE EXPECTED RESULTS + + self.assertDictsAlmostEqual(results, expected) diff --git a/ax/modelbridge/tests/test_model_spec.py b/ax/modelbridge/tests/test_model_spec.py index 90c6703f612..2730c951df8 100644 --- a/ax/modelbridge/tests/test_model_spec.py +++ b/ax/modelbridge/tests/test_model_spec.py @@ -17,6 +17,7 @@ from ax.modelbridge.modelbridge_utils import extract_search_space_digest from ax.modelbridge.registry import Models from ax.utils.common.testutils import TestCase +from ax.utils.common.typeutils import not_none from ax.utils.testing.core_stubs import get_branin_experiment from ax.utils.testing.mock import fast_botorch_optimize @@ -148,6 +149,26 @@ def test_fixed_features(self) -> None: # pyre-fixme[16]: Optional type has no attribute `__getitem__`. self.assertEqual(ms.model_gen_kwargs["fixed_features"], new_features) + def test_gen_attaches_empty_model_fit_metadata_if_fit_not_applicable(self) -> None: + ms = ModelSpec(model_enum=Models.SOBOL) + ms.fit(experiment=self.experiment, data=self.data) + gr = ms.gen(n=1) + gen_metadata = not_none(gr.gen_metadata) + self.assertEqual(gen_metadata["model_fit_quality"], None) + self.assertEqual(gen_metadata["model_std_quality"], None) + self.assertEqual(gen_metadata["model_fit_generalization"], None) + self.assertEqual(gen_metadata["model_std_generalization"], None) + + def test_gen_attaches_model_fit_metadata_if_applicable(self) -> None: + ms = ModelSpec(model_enum=Models.GPEI) + ms.fit(experiment=self.experiment, data=self.data) + gr = ms.gen(n=1) + gen_metadata = not_none(gr.gen_metadata) + self.assertIsInstance(gen_metadata["model_fit_quality"], float) + self.assertIsInstance(gen_metadata["model_std_quality"], float) + self.assertIsInstance(gen_metadata["model_fit_generalization"], float) + self.assertIsInstance(gen_metadata["model_std_generalization"], float) + class FactoryFunctionModelSpecTest(BaseModelSpecTest): def test_construct(self) -> None: diff --git a/ax/telemetry/scheduler.py b/ax/telemetry/scheduler.py index 86e649fe2af..018d38e9af6 100644 --- a/ax/telemetry/scheduler.py +++ b/ax/telemetry/scheduler.py @@ -11,8 +11,9 @@ from typing import Any, Dict, Optional from warnings import warn -import numpy as np -from ax.modelbridge.cross_validation import compute_model_fit_metrics_from_modelbridge +from ax.modelbridge.cross_validation import ( + get_fit_and_std_quality_and_generalization_dict, +) from ax.service.scheduler import get_fitted_model_bridge, Scheduler from ax.telemetry.common import _get_max_transformed_dimensionality @@ -103,10 +104,10 @@ class SchedulerCompletedRecord: experiment_completed_record: ExperimentCompletedRecord best_point_quality: float - model_fit_quality: float - model_std_quality: float - model_fit_generalization: float - model_std_generalization: float + model_fit_quality: Optional[float] + model_std_quality: Optional[float] + model_fit_generalization: Optional[float] + model_std_generalization: Optional[float] improvement_over_baseline: float @@ -117,32 +118,19 @@ class SchedulerCompletedRecord: def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCompletedRecord: try: model_bridge = get_fitted_model_bridge(scheduler) - model_fit_dict = compute_model_fit_metrics_from_modelbridge( - model_bridge=model_bridge, - generalization=False, - untransform=False, - ) - model_fit_quality = _model_fit_metric(model_fit_dict) - # similar for uncertainty quantification, but distance from 1 matters - std = list(model_fit_dict["std_of_the_standardized_error"].values()) - model_std_quality = _model_std_quality(np.array(std)) - - # generalization metrics - model_gen_dict = compute_model_fit_metrics_from_modelbridge( - model_bridge=model_bridge, - generalization=True, - untransform=False, + quality_and_generalizations_dict = ( + get_fit_and_std_quality_and_generalization_dict( + fitted_model_bridge=model_bridge, + ) ) - model_fit_generalization = _model_fit_metric(model_gen_dict) - gen_std = list(model_gen_dict["std_of_the_standardized_error"].values()) - model_std_generalization = _model_std_quality(np.array(gen_std)) - except Exception as e: warn("Encountered exception in computing model fit quality: " + str(e)) - model_fit_quality = float("nan") - model_std_quality = float("nan") - model_fit_generalization = float("nan") - model_std_generalization = float("nan") + quality_and_generalizations_dict = { + "model_fit_quality": None, + "model_std_quality": None, + "model_fit_generalization": None, + "model_std_generalization": None, + } try: improvement_over_baseline = scheduler.get_improvement_over_baseline() @@ -158,13 +146,10 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCompletedRecord: experiment=scheduler.experiment ), best_point_quality=float("nan"), # TODO[T147907632] - model_fit_quality=model_fit_quality, - model_std_quality=model_std_quality, - model_fit_generalization=model_fit_generalization, - model_std_generalization=model_std_generalization, improvement_over_baseline=improvement_over_baseline, num_metric_fetch_e_encountered=scheduler._num_metric_fetch_e_encountered, num_trials_bad_due_to_err=scheduler._num_trials_bad_due_to_err, + **quality_and_generalizations_dict, ) def flatten(self) -> Dict[str, Any]: @@ -179,35 +164,3 @@ def flatten(self) -> Dict[str, Any]: **self_dict, **experiment_completed_record_dict, } - - -def _model_fit_metric(metric_dict: Dict[str, Dict[str, float]]) -> float: - # We'd ideally log the entire `model_fit_dict` as a single model fit metric - # can't capture the nuances of multiple experimental metrics, but this might - # lead to database performance issues. So instead, we take the worst - # coefficient of determination as model fit quality and store the full data - # in Manifold (TODO). - return min(metric_dict["coefficient_of_determination"].values()) - - -def _model_std_quality(std: np.ndarray) -> float: - """Quantifies quality of the model uncertainty. A value of one means the - uncertainty is perfectly predictive of the true standard deviation of the error. - Values larger than one indicate over-estimation and negative values indicate - under-estimation of the true standard deviation of the error. In particular, a value - of 2 (resp. 1 / 2) represents an over-estimation (resp. under-estimation) of the - true standard deviation of the error by a factor of 2. - - Args: - std: The standard deviation of the standardized error. - - Returns: - The factor corresponding to the worst over- or under-estimation factor of the - standard deviation of the error among all experimentally observed metrics. - """ - max_std, min_std = np.max(std), np.min(std) - # comparing worst over-estimation factor with worst under-estimation factor - inv_model_std_quality = max_std if max_std > 1 / min_std else min_std - # reciprocal so that values greater than one indicate over-estimation and - # values smaller than indicate underestimation of the uncertainty. - return 1 / inv_model_std_quality diff --git a/ax/telemetry/tests/test_scheduler.py b/ax/telemetry/tests/test_scheduler.py index 0ff8b31bd56..a808b49e767 100644 --- a/ax/telemetry/tests/test_scheduler.py +++ b/ax/telemetry/tests/test_scheduler.py @@ -101,10 +101,10 @@ def test_scheduler_completed_record_from_scheduler(self) -> None: experiment=scheduler.experiment ), best_point_quality=float("nan"), - model_fit_quality=float("nan"), # nan because no model has been fit - model_std_quality=float("nan"), - model_fit_generalization=float("nan"), - model_std_generalization=float("nan"), + model_fit_quality=None, # nan because no model has been fit + model_std_quality=None, + model_fit_generalization=None, + model_std_generalization=None, improvement_over_baseline=5.0, num_metric_fetch_e_encountered=0, num_trials_bad_due_to_err=0, @@ -117,10 +117,10 @@ def test_scheduler_completed_record_from_scheduler(self) -> None: experiment=scheduler.experiment ).__dict__, "best_point_quality": float("nan"), - "model_fit_quality": float("nan"), - "model_std_quality": float("nan"), - "model_fit_generalization": float("nan"), - "model_std_generalization": float("nan"), + "model_fit_quality": None, + "model_std_quality": None, + "model_fit_generalization": None, + "model_std_generalization": None, "improvement_over_baseline": 5.0, "num_metric_fetch_e_encountered": 0, "num_trials_bad_due_to_err": 0, @@ -272,7 +272,9 @@ def _compare_scheduler_completed_records( for field in numeric_fields: rec_field = getattr(record, field) exp_field = getattr(expected, field) - if np.isnan(rec_field): - self.assertTrue(np.isnan(exp_field)) + if rec_field is None: + self.assertIsNone(exp_field, msg=field) + elif np.isnan(rec_field): + self.assertTrue(np.isnan(exp_field), msg=field) else: - self.assertAlmostEqual(rec_field, exp_field) + self.assertAlmostEqual(rec_field, exp_field, msg=field)