Skip to content

Commit

Permalink
Put model fit data in gen_metadata (#2511)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2511

Differential Revision: D58261582
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Jun 12, 2024
1 parent 0dce67c commit 945d83d
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 79 deletions.
14 changes: 12 additions & 2 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.core.types import TCandidateMetadata, TModelCov, TModelMean, TModelPredict
from ax.core.types import (
TCandidateMetadata,
TGenMetadata,
TModelCov,
TModelMean,
TModelPredict,
)
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.cast import Cast
Expand Down Expand Up @@ -762,6 +768,7 @@ def gen(
pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None,
fixed_features: Optional[ObservationFeatures] = None,
model_gen_options: Optional[TConfig] = None,
extra_gen_metadata: Optional[TGenMetadata] = None,
) -> GeneratorRun:
"""
Generate new points from the underlying model according to
Expand Down Expand Up @@ -876,7 +883,10 @@ def gen(
model_key=self._model_key,
model_kwargs=self._model_kwargs,
bridge_kwargs=self._bridge_kwargs,
gen_metadata=gen_results.gen_metadata,
gen_metadata={
**gen_results.gen_metadata,
**(extra_gen_metadata or {}),
},
model_state_after_gen=self._get_serialized_model_state(),
candidate_metadata_by_arm_signature=candidate_metadata,
)
Expand Down
73 changes: 73 additions & 0 deletions ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
Expand All @@ -16,6 +17,7 @@
from logging import Logger
from numbers import Number
from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple
from warnings import warn

import numpy as np
from ax.core.observation import Observation, ObservationData, recombine_observations
Expand Down Expand Up @@ -492,6 +494,45 @@ def best_diagnostic(self, diagnostics: List[CVDiagnostics]) -> int:
"""


def get_fit_and_std_quality_and_generalization_dict(
fitted_model_bridge: ModelBridge,
) -> Dict[str, Optional[float]]:
"""
Get stats and gen from a fitted ModelBridge for analytics purposes.
"""
try:
model_fit_dict = compute_model_fit_metrics_from_modelbridge(
model_bridge=fitted_model_bridge,
generalization=False,
untransform=False,
)
# similar for uncertainty quantification, but distance from 1 matters
std = list(model_fit_dict["std_of_the_standardized_error"].values())

# generalization metrics
model_gen_dict = compute_model_fit_metrics_from_modelbridge(
model_bridge=fitted_model_bridge,
generalization=True,
untransform=False,
)
gen_std = list(model_gen_dict["std_of_the_standardized_error"].values())
return {
"model_fit_quality": _model_fit_metric(model_fit_dict),
"model_std_quality": _model_std_quality(np.array(std)),
"model_fit_generalization": _model_fit_metric(model_gen_dict),
"model_std_generalization": _model_std_quality(np.array(gen_std)),
}

except Exception as e:
warn("Encountered exception in computing model fit quality: " + str(e))
return {
"model_fit_quality": None,
"model_std_quality": None,
"model_fit_generalization": None,
"model_std_generalization": None,
}


def compute_model_fit_metrics_from_modelbridge(
model_bridge: ModelBridge,
fit_metrics_dict: Optional[Dict[str, ModelFitMetricProtocol]] = None,
Expand Down Expand Up @@ -550,6 +591,38 @@ def compute_model_fit_metrics_from_modelbridge(
)


def _model_fit_metric(metric_dict: Dict[str, Dict[str, float]]) -> float:
# We'd ideally log the entire `model_fit_dict` as a single model fit metric
# can't capture the nuances of multiple experimental metrics, but this might
# lead to database performance issues. So instead, we take the worst
# coefficient of determination as model fit quality and store the full data
# in Manifold (TODO).
return min(metric_dict["coefficient_of_determination"].values())


def _model_std_quality(std: np.ndarray) -> float:
"""Quantifies quality of the model uncertainty. A value of one means the
uncertainty is perfectly predictive of the true standard deviation of the error.
Values larger than one indicate over-estimation and negative values indicate
under-estimation of the true standard deviation of the error. In particular, a value
of 2 (resp. 1 / 2) represents an over-estimation (resp. under-estimation) of the
true standard deviation of the error by a factor of 2.
Args:
std: The standard deviation of the standardized error.
Returns:
The factor corresponding to the worst over- or under-estimation factor of the
standard deviation of the error among all experimentally observed metrics.
"""
max_std, min_std = np.max(std), np.min(std)
# comparing worst over-estimation factor with worst under-estimation factor
inv_model_std_quality = max_std if max_std > 1 / min_std else min_std
# reciprocal so that values greater than one indicate over-estimation and
# values smaller than indicate underestimation of the uncertainty.
return 1 / inv_model_std_quality


def _predict_on_training_data(
model_bridge: ModelBridge,
untransform: bool = False,
Expand Down
11 changes: 10 additions & 1 deletion ax/modelbridge/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
cross_validate,
CVDiagnostics,
CVResult,
get_fit_and_std_quality_and_generalization_dict,
)
from ax.modelbridge.registry import ModelRegistryBase
from ax.utils.common.base import SortableBase
Expand Down Expand Up @@ -218,7 +219,15 @@ def gen(self, **model_gen_kwargs: Any) -> GeneratorRun:
],
keywords=get_function_argument_names(fitted_model.gen),
)
return fitted_model.gen(**model_gen_kwargs)
fit_and_std_quality_and_generalization_dict = (
get_fit_and_std_quality_and_generalization_dict(
fitted_model_bridge=self.fitted_model,
)
)
return fitted_model.gen(
**model_gen_kwargs,
extra_gen_metadata=fit_and_std_quality_and_generalization_dict,
)

def copy(self) -> ModelSpec:
"""`ModelSpec` is both a spec and an object that performs actions.
Expand Down
21 changes: 21 additions & 0 deletions ax/modelbridge/tests/test_model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ax.modelbridge.modelbridge_utils import extract_search_space_digest
from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import not_none
from ax.utils.testing.core_stubs import get_branin_experiment
from ax.utils.testing.mock import fast_botorch_optimize

Expand Down Expand Up @@ -148,6 +149,26 @@ def test_fixed_features(self) -> None:
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
self.assertEqual(ms.model_gen_kwargs["fixed_features"], new_features)

def test_gen_attaches_empty_model_fit_metadata_if_fit_not_applicable(self) -> None:
ms = ModelSpec(model_enum=Models.SOBOL)
ms.fit(experiment=self.experiment, data=self.data)
gr = ms.gen(n=1)
gen_metadata = not_none(gr.gen_metadata)
self.assertEqual(gen_metadata["model_fit_quality"], None)
self.assertEqual(gen_metadata["model_std_quality"], None)
self.assertEqual(gen_metadata["model_fit_generalization"], None)
self.assertEqual(gen_metadata["model_std_generalization"], None)

def test_gen_attaches_model_fit_metadata_if_applicable(self) -> None:
ms = ModelSpec(model_enum=Models.GPEI)
ms.fit(experiment=self.experiment, data=self.data)
gr = ms.gen(n=1)
gen_metadata = not_none(gr.gen_metadata)
self.assertIsInstance(gen_metadata["model_fit_quality"], float)
self.assertIsInstance(gen_metadata["model_std_quality"], float)
self.assertIsInstance(gen_metadata["model_fit_generalization"], float)
self.assertIsInstance(gen_metadata["model_std_generalization"], float)


class FactoryFunctionModelSpecTest(BaseModelSpecTest):
def test_construct(self) -> None:
Expand Down
83 changes: 18 additions & 65 deletions ax/telemetry/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from typing import Any, Dict, Optional
from warnings import warn

import numpy as np
from ax.modelbridge.cross_validation import compute_model_fit_metrics_from_modelbridge
from ax.modelbridge.cross_validation import (
get_fit_and_std_quality_and_generalization_dict,
)

from ax.service.scheduler import get_fitted_model_bridge, Scheduler
from ax.telemetry.common import _get_max_transformed_dimensionality
Expand Down Expand Up @@ -103,10 +104,10 @@ class SchedulerCompletedRecord:
experiment_completed_record: ExperimentCompletedRecord

best_point_quality: float
model_fit_quality: float
model_std_quality: float
model_fit_generalization: float
model_std_generalization: float
model_fit_quality: Optional[float]
model_std_quality: Optional[float]
model_fit_generalization: Optional[float]
model_std_generalization: Optional[float]

improvement_over_baseline: float

Expand All @@ -117,32 +118,19 @@ class SchedulerCompletedRecord:
def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCompletedRecord:
try:
model_bridge = get_fitted_model_bridge(scheduler)
model_fit_dict = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
generalization=False,
untransform=False,
)
model_fit_quality = _model_fit_metric(model_fit_dict)
# similar for uncertainty quantification, but distance from 1 matters
std = list(model_fit_dict["std_of_the_standardized_error"].values())
model_std_quality = _model_std_quality(np.array(std))

# generalization metrics
model_gen_dict = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
generalization=True,
untransform=False,
qaulity_and_generalizations_dict = (
get_fit_and_std_quality_and_generalization_dict(
fitted_model_bridge=model_bridge,
)
)
model_fit_generalization = _model_fit_metric(model_gen_dict)
gen_std = list(model_gen_dict["std_of_the_standardized_error"].values())
model_std_generalization = _model_std_quality(np.array(gen_std))

except Exception as e:
warn("Encountered exception in computing model fit quality: " + str(e))
model_fit_quality = float("nan")
model_std_quality = float("nan")
model_fit_generalization = float("nan")
model_std_generalization = float("nan")
qaulity_and_generalizations_dict = {
"model_fit_quality": None,
"model_std_quality": None,
"model_fit_generalization": None,
"model_std_generalization": None,
}

try:
improvement_over_baseline = scheduler.get_improvement_over_baseline()
Expand All @@ -158,13 +146,10 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCompletedRecord:
experiment=scheduler.experiment
),
best_point_quality=float("nan"), # TODO[T147907632]
model_fit_quality=model_fit_quality,
model_std_quality=model_std_quality,
model_fit_generalization=model_fit_generalization,
model_std_generalization=model_std_generalization,
improvement_over_baseline=improvement_over_baseline,
num_metric_fetch_e_encountered=scheduler._num_metric_fetch_e_encountered,
num_trials_bad_due_to_err=scheduler._num_trials_bad_due_to_err,
**qaulity_and_generalizations_dict,
)

def flatten(self) -> Dict[str, Any]:
Expand All @@ -179,35 +164,3 @@ def flatten(self) -> Dict[str, Any]:
**self_dict,
**experiment_completed_record_dict,
}


def _model_fit_metric(metric_dict: Dict[str, Dict[str, float]]) -> float:
# We'd ideally log the entire `model_fit_dict` as a single model fit metric
# can't capture the nuances of multiple experimental metrics, but this might
# lead to database performance issues. So instead, we take the worst
# coefficient of determination as model fit quality and store the full data
# in Manifold (TODO).
return min(metric_dict["coefficient_of_determination"].values())


def _model_std_quality(std: np.ndarray) -> float:
"""Quantifies quality of the model uncertainty. A value of one means the
uncertainty is perfectly predictive of the true standard deviation of the error.
Values larger than one indicate over-estimation and negative values indicate
under-estimation of the true standard deviation of the error. In particular, a value
of 2 (resp. 1 / 2) represents an over-estimation (resp. under-estimation) of the
true standard deviation of the error by a factor of 2.
Args:
std: The standard deviation of the standardized error.
Returns:
The factor corresponding to the worst over- or under-estimation factor of the
standard deviation of the error among all experimentally observed metrics.
"""
max_std, min_std = np.max(std), np.min(std)
# comparing worst over-estimation factor with worst under-estimation factor
inv_model_std_quality = max_std if max_std > 1 / min_std else min_std
# reciprocal so that values greater than one indicate over-estimation and
# values smaller than indicate underestimation of the uncertainty.
return 1 / inv_model_std_quality
24 changes: 13 additions & 11 deletions ax/telemetry/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def test_scheduler_completed_record_from_scheduler(self) -> None:
experiment=scheduler.experiment
),
best_point_quality=float("nan"),
model_fit_quality=float("nan"), # nan because no model has been fit
model_std_quality=float("nan"),
model_fit_generalization=float("nan"),
model_std_generalization=float("nan"),
model_fit_quality=None, # nan because no model has been fit
model_std_quality=None,
model_fit_generalization=None,
model_std_generalization=None,
improvement_over_baseline=5.0,
num_metric_fetch_e_encountered=0,
num_trials_bad_due_to_err=0,
Expand All @@ -117,10 +117,10 @@ def test_scheduler_completed_record_from_scheduler(self) -> None:
experiment=scheduler.experiment
).__dict__,
"best_point_quality": float("nan"),
"model_fit_quality": float("nan"),
"model_std_quality": float("nan"),
"model_fit_generalization": float("nan"),
"model_std_generalization": float("nan"),
"model_fit_quality": None,
"model_std_quality": None,
"model_fit_generalization": None,
"model_std_generalization": None,
"improvement_over_baseline": 5.0,
"num_metric_fetch_e_encountered": 0,
"num_trials_bad_due_to_err": 0,
Expand Down Expand Up @@ -272,7 +272,9 @@ def _compare_scheduler_completed_records(
for field in numeric_fields:
rec_field = getattr(record, field)
exp_field = getattr(expected, field)
if np.isnan(rec_field):
self.assertTrue(np.isnan(exp_field))
if rec_field is None:
self.assertIsNone(exp_field, msg=field)
elif np.isnan(rec_field):
self.assertTrue(np.isnan(exp_field), msg=field)
else:
self.assertAlmostEqual(rec_field, exp_field)
self.assertAlmostEqual(rec_field, exp_field, msg=field)

0 comments on commit 945d83d

Please sign in to comment.