Skip to content

Commit d89ff88

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Entropy of observations metric (facebook#2340)
Summary: This commit introduces `entropy_of_observations` as a model fit metric. It quantifies the entropy of the outcomes `y_obs` using a kernel density estimator. This metric can be useful in detecting datasets in which the outcomes are clustered (implying a low entropy), rather than uniformly distributed in the outcome space (high entropy). Differential Revision: D55930954
1 parent c122890 commit d89ff88

File tree

4 files changed

+116
-5
lines changed

4 files changed

+116
-5
lines changed

Diff for: ax/modelbridge/cross_validation.py

-1
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,6 @@ def compute_model_fit_metrics_from_modelbridge(
535535
if generalization
536536
else _predict_on_training_data(model_bridge=model_bridge, experiment=experiment)
537537
)
538-
539538
if fit_metrics_dict is None:
540539
fit_metrics_dict = {
541540
"coefficient_of_determination": coefficient_of_determination,

Diff for: ax/modelbridge/tests/test_model_fit_metrics.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,23 @@
99
import warnings
1010
from typing import cast, Dict
1111

12+
import numpy as np
13+
1214
from ax.core.experiment import Experiment
1315
from ax.core.objective import Objective
1416
from ax.core.optimization_config import OptimizationConfig
1517
from ax.metrics.branin import BraninMetric
16-
from ax.modelbridge.cross_validation import compute_model_fit_metrics_from_modelbridge
18+
from ax.modelbridge.cross_validation import (
19+
_predict_on_cross_validation_data,
20+
compute_model_fit_metrics_from_modelbridge,
21+
)
1722
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
1823
from ax.modelbridge.registry import Models
1924
from ax.runners.synthetic import SyntheticRunner
2025
from ax.service.scheduler import get_fitted_model_bridge, Scheduler, SchedulerOptions
2126
from ax.utils.common.constants import Keys
2227
from ax.utils.common.testutils import TestCase
28+
from ax.utils.stats.model_fit_stats import _entropy_via_kde, entropy_of_observations
2329
from ax.utils.testing.core_stubs import get_branin_search_space
2430

2531
NUM_SOBOL = 5
@@ -83,6 +89,29 @@ def test_model_fit_metrics(self) -> None:
8389
std_branin = std["branin"]
8490
self.assertIsInstance(std_branin, float)
8591

92+
# checking non-default model-fit-metric
93+
untransform = False
94+
fit_metrics = compute_model_fit_metrics_from_modelbridge(
95+
model_bridge=model_bridge,
96+
experiment=scheduler.experiment,
97+
generalization=True,
98+
untransform=untransform,
99+
fit_metrics_dict={"Entropy": entropy_of_observations},
100+
)
101+
entropy = fit_metrics.get("Entropy")
102+
self.assertIsInstance(entropy, dict)
103+
entropy = cast(Dict[str, float], entropy)
104+
self.assertTrue("branin" in entropy)
105+
entropy_branin = entropy["branin"]
106+
self.assertIsInstance(entropy_branin, float)
107+
108+
y_obs, _, _ = _predict_on_cross_validation_data(
109+
model_bridge=model_bridge, untransform=untransform
110+
)
111+
y_obs_branin = np.array(y_obs["branin"])[:, np.newaxis]
112+
entropy_truth = _entropy_via_kde(y_obs_branin)
113+
self.assertAlmostEqual(entropy_branin, entropy_truth)
114+
86115
# testing with empty metrics
87116
empty_metrics = compute_model_fit_metrics_from_modelbridge(
88117
model_bridge=model_bridge,

Diff for: ax/utils/stats/model_fit_stats.py

+45
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
from scipy.stats import fisher_exact, norm, pearsonr, spearmanr
12+
from sklearn.neighbors import KernelDensity
1213

1314
"""
1415
################################ Model Fit Metrics ###############################
@@ -127,6 +128,50 @@ def std_of_the_standardized_error(
127128
return ((y_obs - y_pred) / se_pred).std()
128129

129130

131+
def entropy_of_observations(
132+
y_obs: np.ndarray,
133+
y_pred: np.ndarray,
134+
se_pred: np.ndarray,
135+
bandwidth: float = 0.1,
136+
) -> float:
137+
"""Computes the entropy of the observations y_obs using a kernel density estimator.
138+
This can be used to quantify how "clustered" the outcomes are. NOTE: y_pred and
139+
se_pred are not used, but are required for the API.
140+
141+
Args:
142+
y_obs: An array of observations for a single metric.
143+
y_pred: An array of the predicted values corresponding to y_obs.
144+
se_pred: An array of the standard errors of the predicted values.
145+
bandwidth: The kernel bandwidth. Defaults to 0.1, which is a reasonable value
146+
for standardized outcomes y_obs. The rank ordering of the results on a set
147+
of y_obs data sets is not generally sensitive to the bandwidth, if it is
148+
held fixed across the data sets. The absolute value of the results however
149+
changes significantly with the bandwidth.
150+
151+
Returns:
152+
The scalar entropy of the observations.
153+
"""
154+
if y_obs.ndim == 1:
155+
y_obs = y_obs[:, np.newaxis]
156+
return _entropy_via_kde(y_obs, bandwidth=bandwidth)
157+
158+
159+
def _entropy_via_kde(y: np.ndarray, bandwidth: float = 0.1) -> float:
160+
"""Computes the entropy of the kernel density estimate of the input data.
161+
162+
Args:
163+
y: An (n x m) array of observations.
164+
bandwidth: The kernel bandwidth.
165+
166+
Returns:
167+
The scalar entropy of the kernel density estimate.
168+
"""
169+
kde = KernelDensity(kernel="gaussian", bandwidth=bandwidth)
170+
kde.fit(y)
171+
log_p = kde.score_samples(y) # computes the log probability of each data point
172+
return -np.sum(np.exp(log_p) * log_p) # compute entropy, the negated sum of p log p
173+
174+
130175
def _mean_prediction_ci(
131176
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
132177
) -> float:

Diff for: ax/utils/stats/tests/test_model_fit_stats.py

+41-3
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,59 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import numpy as np
8+
from ax.modelbridge.cross_validation import compute_model_fit_metrics_from_modelbridge
9+
from ax.service.scheduler import Scheduler
810
from ax.utils.common.testutils import TestCase
9-
from ax.utils.stats.model_fit_stats import _fisher_exact_test_p
11+
from ax.utils.stats.model_fit_stats import _fisher_exact_test_p, entropy_of_observations
1012
from scipy.stats import fisher_exact
1113

1214

13-
class FisherExactTestTest(TestCase):
15+
class TestModelFitStats(TestCase):
16+
def test_entropy_of_observations(self) -> None:
17+
np.random.seed(1234)
18+
n = 16
19+
yc = np.ones(n)
20+
yc[: n // 2] = -1
21+
yc += np.random.randn(n) * 0.05
22+
yr = np.random.randn(n)
23+
24+
# standardize both observations
25+
yc = yc / yc.std()
26+
yr = yr / yr.std()
27+
28+
ones = np.ones(n)
29+
30+
# compute entropy of observations
31+
ec = entropy_of_observations(y_obs=yc, y_pred=ones, se_pred=ones, bandwidth=0.1)
32+
er = entropy_of_observations(y_obs=yr, y_pred=ones, se_pred=ones, bandwidth=0.1)
33+
34+
# testing that the Gaussian distributed data has a much larger entropy than
35+
# the clustered distribution
36+
self.assertTrue(er - ec > 10.0)
37+
38+
ec2 = entropy_of_observations(
39+
y_obs=yc, y_pred=ones, se_pred=ones, bandwidth=0.2
40+
)
41+
er2 = entropy_of_observations(
42+
y_obs=yr, y_pred=ones, se_pred=ones, bandwidth=0.2
43+
)
44+
# entropy increases with larger bandwidth
45+
self.assertGreater(ec2, ec)
46+
self.assertGreater(er2, er)
47+
48+
# ordering of entropies stays the same, though the difference is smaller
49+
self.assertTrue(er2 - ec2 > 3)
50+
1451
def test_contingency_table_construction(self) -> None:
1552
# Create a dummy set of observations and predictions
1653
y_obs = np.array([1, 3, 2, 5, 7, 3])
1754
y_pred = np.array([2, 4, 1, 6, 8, 2.5])
55+
se_pred = np.full(len(y_obs), np.nan) # not used for fisher exact
1856

1957
# Compute ground truth contingency table
2058
true_table = np.array([[2, 1], [1, 2]])
2159

2260
scipy_result = fisher_exact(true_table, alternative="greater")[1]
23-
ax_result = _fisher_exact_test_p(y_obs, y_pred, se_pred=None)
61+
ax_result = _fisher_exact_test_p(y_obs, y_pred, se_pred=se_pred)
2462

2563
self.assertEqual(scipy_result, ax_result)

0 commit comments

Comments
 (0)