Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify interface and data handling of model training and generalization metrics #2367

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np
from ax.core.experiment import Experiment
from ax.core.observation import Observation, ObservationData, observations_from_data
from ax.core.observation import Observation, ObservationData, recombine_observations
from ax.core.optimization_config import OptimizationConfig
from ax.modelbridge.base import ModelBridge, unwrap_observation_data
from ax.utils.common.logger import get_logger
Expand Down Expand Up @@ -515,6 +515,7 @@ def compute_model_fit_metrics_from_modelbridge(
before calcualting the model fit metrics. False by default as models
are trained in transformed space and model fit should be
evaluated in transformed space.

Returns:
A nested dictionary mapping from the *model fit* metric names and the
*experimental metric* names to the values of the model fit metrics.
Expand All @@ -528,12 +529,13 @@ def compute_model_fit_metrics_from_modelbridge(
`coefficient of determination of the test error predictions`
```
"""
y_obs, y_pred, se_pred = (
_predict_on_cross_validation_data(
model_bridge=model_bridge, untransform=untransform
)
predict_func = (
_predict_on_cross_validation_data
if generalization
else _predict_on_training_data(model_bridge=model_bridge, experiment=experiment)
else _predict_on_training_data
)
y_obs, y_pred, se_pred = predict_func(
model_bridge=model_bridge, untransform=untransform
)
if fit_metrics_dict is None:
fit_metrics_dict = {
Expand All @@ -552,7 +554,7 @@ def compute_model_fit_metrics_from_modelbridge(

def _predict_on_training_data(
model_bridge: ModelBridge,
experiment: Experiment,
untransform: bool = False,
) -> Tuple[
Dict[str, np.ndarray],
Dict[str, np.ndarray],
Expand All @@ -566,16 +568,17 @@ def _predict_on_training_data(

Args:
model_bridge: A ModelBridge object with which to make predictions.
experiment: The experiment with whose data to compute the model fit metrics.
untransform: Boolean indicating whether to untransform model predictions.

Returns:
A tuple containing three dictionaries for 1) observed metric values, and the
model's associated 2) predictive means and 3) predictive standard deviations.
"""
data = experiment.lookup_data()
observations = observations_from_data(
experiment=experiment, data=data
) # List[Observation]
observations = model_bridge.get_training_data() # List[Observation]

# NOTE: the following up to the end of the untransform block could be replaced
# with model_bridge's public predict / private _batch_predict method, if the
# latter had a boolean untransform flag.

# Transform observations -- this will transform both obs data and features
for t in model_bridge.transforms.values():
Expand All @@ -586,12 +589,23 @@ def _predict_on_training_data(
# Make predictions in transformed space
observation_data_pred = model_bridge._predict(observation_features)

if untransform:
# Apply reverse transforms, in reverse order
pred_observations = recombine_observations(
observation_features=observation_features,
observation_data=observation_data_pred,
)
for t in reversed(list(model_bridge.transforms.values())):
pred_observations = t.untransform_observations(pred_observations)

observation_data_pred = [obs.data for obs in pred_observations]

mean_predicted, cov_predicted = unwrap_observation_data(observation_data_pred)
mean_observed = [
obs.data.means_dict for obs in observations
] # List[Dict[str, float]]

metric_names = list(data.metric_names)
metric_names = observations[0].data.metric_names
mean_observed = _list_of_dicts_to_dict_of_lists(
list_of_dicts=mean_observed, keys=metric_names
)
Expand Down
1 change: 0 additions & 1 deletion ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,6 @@ def _fit_current_model(self, data: Optional[Data]) -> None:
# model state from last generator run and pass it to the model
# being instantiated in this function.
model_state_on_lgr = self._get_model_state_from_last_generator_run()

if not data.df.empty:
trial_indices_in_data = sorted(data.df["trial_index"].unique())
logger.debug(f"Fitting model with data for trials: {trial_indices_in_data}")
Expand Down
95 changes: 54 additions & 41 deletions ax/modelbridge/tests/test_model_fit_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

import warnings
from itertools import product
from typing import cast, Dict

import numpy as np
Expand All @@ -17,6 +18,7 @@
from ax.metrics.branin import BraninMetric
from ax.modelbridge.cross_validation import (
_predict_on_cross_validation_data,
_predict_on_training_data,
compute_model_fit_metrics_from_modelbridge,
)
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
Expand Down Expand Up @@ -67,7 +69,12 @@ def test_model_fit_metrics(self) -> None:
)
# need to run some trials to initialize the ModelBridge
scheduler.run_n_trials(max_trials=NUM_SOBOL + 1)

model_bridge = get_fitted_model_bridge(scheduler)
self.assertEqual(len(model_bridge.get_training_data()), NUM_SOBOL)

model_bridge = get_fitted_model_bridge(scheduler, force_refit=True)
self.assertEqual(len(model_bridge.get_training_data()), NUM_SOBOL + 1)

# testing compute_model_fit_metrics_from_modelbridge with default metrics
fit_metrics = compute_model_fit_metrics_from_modelbridge(
Expand All @@ -90,45 +97,51 @@ def test_model_fit_metrics(self) -> None:
self.assertIsInstance(std_branin, float)

# checking non-default model-fit-metric
untransform = False
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=scheduler.experiment,
generalization=True,
untransform=untransform,
fit_metrics_dict={"Entropy": entropy_of_observations},
)
entropy = fit_metrics.get("Entropy")
self.assertIsInstance(entropy, dict)
entropy = cast(Dict[str, float], entropy)
self.assertTrue("branin" in entropy)
entropy_branin = entropy["branin"]
self.assertIsInstance(entropy_branin, float)

y_obs, _, _ = _predict_on_cross_validation_data(
model_bridge=model_bridge, untransform=untransform
)
y_obs_branin = np.array(y_obs["branin"])[:, np.newaxis]
entropy_truth = _entropy_via_kde(y_obs_branin)
self.assertAlmostEqual(entropy_branin, entropy_truth)
for untransform, generalization in product([True, False], [True, False]):
with self.subTest(untransform=untransform):
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=scheduler.experiment,
generalization=generalization,
untransform=untransform,
fit_metrics_dict={"Entropy": entropy_of_observations},
)
entropy = fit_metrics.get("Entropy")
self.assertIsInstance(entropy, dict)
entropy = cast(Dict[str, float], entropy)
self.assertTrue("branin" in entropy)
entropy_branin = entropy["branin"]
self.assertIsInstance(entropy_branin, float)

# testing with empty metrics
empty_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=self.branin_experiment,
fit_metrics_dict={},
)
self.assertIsInstance(empty_metrics, dict)
self.assertTrue(len(empty_metrics) == 0)

# testing log filtering
with warnings.catch_warnings(record=True) as ws:
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=self.branin_experiment,
untransform=False,
generalization=True,
)
self.assertFalse(
any("Input data is not standardized" in str(w.message) for w in ws)
)
predict = (
_predict_on_cross_validation_data
if generalization
else _predict_on_training_data
)
y_obs, _, _ = predict(
model_bridge=model_bridge, untransform=untransform
)
y_obs_branin = np.array(y_obs["branin"])[:, np.newaxis]
entropy_truth = _entropy_via_kde(y_obs_branin)
self.assertAlmostEqual(entropy_branin, entropy_truth)

# testing with empty metrics
empty_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=self.branin_experiment,
fit_metrics_dict={},
)
self.assertIsInstance(empty_metrics, dict)
self.assertTrue(len(empty_metrics) == 0)

# testing log filtering
with warnings.catch_warnings(record=True) as ws:
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=self.branin_experiment,
untransform=untransform,
generalization=generalization,
)
self.assertFalse(
any("Input data is not standardized" in str(w.message) for w in ws)
)
9 changes: 6 additions & 3 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2112,19 +2112,22 @@ def _get_failure_rate_exceeded_error(
)


def get_fitted_model_bridge(scheduler: Scheduler) -> ModelBridge:
def get_fitted_model_bridge(
scheduler: Scheduler, force_refit: bool = False
) -> ModelBridge:
"""Returns a fitted ModelBridge object. If the model is fit already, directly
returns the already fitted model. Otherwise, fits and returns a new one.

Args:
scheduler: The scheduler object from which to get the fitted model.
force_refit: If True, will force a data lookup and a refit of the model.

Returns:
A ModelBridge object fitted to the observations of the scheduler's experiment.
"""
gs = scheduler.standard_generation_strategy
model_bridge = gs.model # Optional[ModelBridge]
if model_bridge is None: # Need to re-fit the model.
gs._fit_current_model(data=None) # Will lookup_data if it none is provided.
if model_bridge is None or force_refit: # Need to re-fit the model.
gs._fit_current_model(data=None) # Will lookup_data if none is provided.
model_bridge = cast(ModelBridge, gs.model)
return model_bridge