Skip to content

Commit

Permalink
Merge e425c26 into 8401990
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianAment authored Apr 12, 2024
2 parents 8401990 + e425c26 commit bf31d3f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 5 deletions.
31 changes: 26 additions & 5 deletions ax/plot/feature_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import pandas as pd
import plotly.graph_objs as go
from ax.core.parameter import ChoiceParameter
from ax.exceptions.core import NoDataError
from ax.modelbridge import ModelBridge
from ax.plot.base import AxPlotConfig, AxPlotTypes
Expand Down Expand Up @@ -143,6 +144,14 @@ def plot_feature_importance_by_feature_plotly(
}
traces = []
dropdown = []
categorical_features = []
if model is not None:
categorical_features = [
name
for name, par in model.model_space.parameters.items()
if isinstance(par, ChoiceParameter)
]

for i, metric_name in enumerate(sorted(sensitivity_values.keys())):
importances = sensitivity_values[metric_name]
factor_col = "Factor"
Expand All @@ -157,7 +166,11 @@ def plot_feature_importance_by_feature_plotly(
factor_col: factor,
importance_col: np.asarray(importance)[0],
importance_col_se: np.asarray(importance)[2],
sign_col: np.sign(np.asarray(importance)[0]).astype(int),
sign_col: (
0
if factor in categorical_features
else 2 * (np.asarray(importance)[0] >= 0).astype(int) - 1
),
}
for factor, importance in importances.items()
]
Expand All @@ -172,7 +185,11 @@ def plot_feature_importance_by_feature_plotly(
{
factor_col: factor,
importance_col: importance,
sign_col: np.sign(importance).astype(int),
sign_col: (
0
if factor in categorical_features
else 2 * (importance >= 0).astype(int) - 1
),
}
for factor, importance in importances.items()
]
Expand All @@ -183,9 +200,13 @@ def plot_feature_importance_by_feature_plotly(
if relative:
df[importance_col] = df[importance_col].div(df[importance_col].sum())

colors = {-1: "darkorange", 1: "steelblue"}
names = {-1: "Decreases metric", 1: "Increases metric"}
legend_counter = {-1: 0, 1: 0}
colors = {-1: "darkorange", 0: "gray", 1: "steelblue"}
names = {
-1: "Decreases metric",
0: "Affects metric (discrete choice)",
1: "Increases metric",
}
legend_counter = {-1: 0, 0: 0, 1: 0}
all_positive = all(df[sign_col] >= 0)
for _, row in df.iterrows():
traces.append(
Expand Down
27 changes: 27 additions & 0 deletions ax/utils/sensitivity/sobol_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# pyre-strict

from copy import deepcopy
from functools import partial

from typing import Any, Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -36,6 +37,7 @@ def __init__(
second_order: bool = False,
num_bootstrap_samples: int = 1,
bootstrap_array: bool = False,
discrete_features: Optional[List[int]] = None,
) -> None:
r"""Computes three types of Sobol indices:
first order indices, total indices and second order indices (if specified ).
Expand Down Expand Up @@ -71,6 +73,16 @@ def __init__(
else:
self.A = unnormalize(torch.rand(num_mc_samples, self.dim), bounds=bounds)
self.B = unnormalize(torch.rand(num_mc_samples, self.dim), bounds=bounds)

# uniform integral distribution for discrete features
if discrete_features is not None:
all_low = bounds[0, discrete_features].to(dtype=torch.int).tolist()
all_high = (bounds[1, discrete_features]).to(dtype=torch.int).tolist()
for i, low, high in zip(discrete_features, all_low, all_high):
randint = partial(torch.randint, low=low, high=high + 1)
self.A[:, i] = randint(size=self.A.shape[:-1])
self.B[:, i] = randint(size=self.B.shape[:-1])

# pyre-fixme[4]: Attribute must be annotated.
self.A_B_ABi = self.generate_all_input_matrix().to(torch.double)

Expand Down Expand Up @@ -395,6 +407,7 @@ def __init__(
[torch.Tensor, torch.Tensor], torch.Tensor
] = GaussianLinkMean,
mini_batch_size: int = 128,
discrete_features: Optional[List[int]] = None,
) -> None:
r"""Computes three types of Sobol indices:
first order indices, total indices and second order indices (if specified ).
Expand Down Expand Up @@ -438,6 +451,7 @@ def input_function(x: Tensor) -> Tensor:
second_order=self.second_order,
input_qmc=self.input_qmc,
num_bootstrap_samples=self.num_bootstrap_samples,
discrete_features=discrete_features,
)
self.sensitivity.evalute_function()

Expand Down Expand Up @@ -486,6 +500,7 @@ def __init__(
input_qmc: bool = False,
gp_sample_qmc: bool = False,
num_bootstrap_samples: int = 1,
discrete_features: Optional[List[int]] = None,
) -> None:
r"""Computes three types of Sobol indices:
first order indices, total indices and second order indices (if specified ).
Expand Down Expand Up @@ -519,6 +534,7 @@ def __init__(
input_qmc=self.input_qmc,
num_bootstrap_samples=self.num_bootstrap_samples,
bootstrap_array=True,
discrete_features=discrete_features,
)
# TODO: Ideally, we would reduce the memory consumption here as well
# but this is a tricky since it uses joint posterior sampling.
Expand Down Expand Up @@ -717,6 +733,7 @@ def compute_sobol_indices_from_model_list(
model_list: List[Model],
bounds: Tensor,
order: str = "first",
discrete_features: Optional[List[int]] = None,
**sobol_kwargs: Any,
) -> Tensor:
"""
Expand All @@ -739,6 +756,7 @@ def compute_sobol_indices_from_model_list(
sens_class = SobolSensitivityGPMean(
model=model,
bounds=bounds,
discrete_features=discrete_features,
**sobol_kwargs,
)
indices.append(method(sens_class))
Expand Down Expand Up @@ -789,6 +807,7 @@ def ax_parameter_sens(
model_list=model_list,
bounds=bounds,
order=order,
discrete_features=digest.categorical_features + digest.ordinal_features,
**sobol_kwargs,
)
if signed:
Expand All @@ -797,6 +816,14 @@ def ax_parameter_sens(
bounds=bounds,
**sobol_kwargs,
)
# categorical features don't have a direction, so we set the derivative to 1.0
# in order not to zero our their sensitivity. We treat categorical features
# separately in the sensitivity analysis plot as well, to make clear that they
# are affecting the metric, but neither increasing nor decreasing. Note that the
# orginal variables have a well defined direction, so we do not need to treat
# them differently here.
for i in digest.categorical_features:
ind_deriv[:, i] = 1.0
ind *= torch.sign(ind_deriv)
return _array_with_string_indices_to_dict(
rows=metrics, cols=digest.feature_names, A=ind.numpy()
Expand Down

0 comments on commit bf31d3f

Please sign in to comment.