Skip to content

Commit

Permalink
Color-coding Ax tradeoff plots
Browse files Browse the repository at this point in the history
Summary: Color-coding for plot_multiple_metrics and plot_objective_vs_constraints (for in-sample points only).

Reviewed By: Balandat

Differential Revision: D30740172

fbshipit-source-id: dbc084efedc3406d3773d767a13deded4013d477
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Sep 9, 2021
1 parent c401dca commit 9afd770
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 16 deletions.
1 change: 0 additions & 1 deletion ax/plot/bandit_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def plot_bandit_rollout(experiment: Experiment) -> AxPlotConfig:
for key in arms.keys():
data.append(arms[key])

# pyre-fixme[6]: Expected `typing.Tuple[...g.Tuple[int, int, int]`.
colors = [rgba(c) for c in MIXED_SCALE]

layout = go.Layout(
Expand Down
10 changes: 8 additions & 2 deletions ax/plot/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
# LICENSE file in the root directory of this source tree.

import enum
from numbers import Real
from typing import List, Tuple

# type aliases
TRGB = Tuple[Real, ...]


class COLORS(enum.Enum):
STEELBLUE = (128, 177, 211)
Expand Down Expand Up @@ -92,13 +96,15 @@ class COLORS(enum.Enum):
]


def rgba(rgb_tuple: Tuple[float], alpha: float = 1) -> str:
def rgba(rgb_tuple: TRGB, alpha: float = 1) -> str:
"""Convert RGB tuple to an RGBA string."""
return "rgba({},{},{},{alpha})".format(*rgb_tuple, alpha=alpha)


def plotly_color_scale(
list_of_rgb_tuples: List[Tuple[float]], reverse: bool = False, alpha: float = 1
list_of_rgb_tuples: List[TRGB],
reverse: bool = False,
alpha: float = 1,
) -> List[Tuple[float, str]]:
"""Convert list of RGB tuples to list of tuples, where each tuple is
break in [0, 1] and stringified RGBA color.
Expand Down
110 changes: 97 additions & 13 deletions ax/plot/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
PlotOutOfSampleArm,
Z,
)
from ax.plot.color import COLORS, DISCRETE_COLOR_SCALE, rgba
from ax.plot.color import (
COLORS,
DISCRETE_COLOR_SCALE,
BLUE_SCALE,
rgba,
)
from ax.plot.helper import (
TNullableGeneratorRunsDict,
_format_CI,
Expand Down Expand Up @@ -113,6 +118,8 @@ def _error_scatter_trace(
show_arm_details_on_hover: bool = True,
show_context: bool = False,
arm_noun: str = "arm",
color_parameter: Optional[str] = None,
color_metric: Optional[str] = None,
) -> Dict[str, Any]:
"""Plot scatterplot with error bars.
Expand Down Expand Up @@ -147,14 +154,29 @@ def _error_scatter_trace(
show_context: if True and show_arm_details_on_hover,
context will be included in the hover.
arm_noun: noun to use instead of "arm" (e.g. group)
color_parameter: color points according to the specified parameter,
cannot be used together with color_metric.
color_metric: color points according to the specified metric,
cannot be used together with color_parameter.
"""
if color_metric and color_parameter:
raise RuntimeError(
"color_metric and color_parameter cannot be used at the same time!"
)

if (color_metric or color_parameter) and not all(
isinstance(arm, PlotInSampleArm) for arm in arms
):
raise RuntimeError("Color coding currently only works with in-sample arms!")

x, x_se, y, y_se = _error_scatter_data(
arms=arms,
y_axis_var=y_axis_var,
x_axis_var=x_axis_var,
status_quo_arm=status_quo_arm,
)
labels = []
colors = []

arm_names = [a.name for a in arms]

Expand Down Expand Up @@ -191,6 +213,13 @@ def _error_scatter_trace(
else ""
)

if color_parameter:
colors.append(arms[i].parameters[color_parameter])
elif color_metric:
# Must be PlotInSampleArm here if no error raised previously
# pyre-ignore[16]: `PlotOutOfSampleArm` has no attribute `y`
colors.append(arms[i].y[color_metric])

context = (
# Expected `Dict[str, Optional[Union[bool, float, str]]]` for 1st anonymous
# parameter to call `ax.plot.helper._format_dict` but got
Expand All @@ -213,10 +242,22 @@ def _error_scatter_trace(
)
)
i += 1

if color_metric or color_parameter:
rgba_blue_scale = [rgba(c) for c in BLUE_SCALE]
marker = {
"color": colors,
"colorscale": rgba_blue_scale,
"colorbar": {"title": color_metric or color_parameter},
"showscale": True,
}
else:
marker = {"color": rgba(color)}

trace = go.Scatter(
x=x,
y=y,
marker={"color": rgba(color)},
marker=marker,
mode="markers",
name=name,
text=labels,
Expand Down Expand Up @@ -258,6 +299,8 @@ def _multiple_metric_traces(
rel_y: bool,
fixed_features: Optional[ObservationFeatures] = None,
data_selector: Optional[Callable[[Observation], bool]] = None,
color_parameter: Optional[str] = None,
color_metric: Optional[str] = None,
) -> Traces:
"""Plot traces for multiple metrics given a model and metrics.
Expand All @@ -271,11 +314,19 @@ def _multiple_metric_traces(
rel_y: if True, use relative effects on metric_y.
fixed_features: Fixed features to use when making model predictions.
data_selector: Function for selecting observations for plotting.
color_parameter: color points according to the specified parameter,
cannot be used together with color_metric.
color_metric: color points according to the specified metric,
cannot be used together with color_parameter.
"""
metric_names = {metric_x, metric_y}
if color_metric is not None:
metric_names.add(color_metric)

plot_data, _, _ = get_plot_data(
model,
generator_runs_dict if generator_runs_dict is not None else {},
{metric_x, metric_y},
metric_names,
fixed_features=fixed_features,
data_selector=data_selector,
)
Expand All @@ -298,6 +349,8 @@ def _multiple_metric_traces(
y_axis_var=PlotMetric(metric_y, pred=False, rel=rel_y),
status_quo_arm=status_quo_arm,
visible=False,
color_parameter=color_parameter,
color_metric=color_metric,
),
_error_scatter_trace(
# Expected `List[Union[PlotInSampleArm, PlotOutOfSampleArm]]`
Expand All @@ -310,9 +363,12 @@ def _multiple_metric_traces(
y_axis_var=PlotMetric(metric_y, pred=True, rel=rel_y),
status_quo_arm=status_quo_arm,
visible=True,
color_parameter=color_parameter,
color_metric=color_metric,
),
]

# TODO: Figure out if there's a better way to color code out-of-sample points
for i, (generator_run_name, cand_arms) in enumerate(
(plot_data.out_of_sample or {}).items(), start=1
):
Expand All @@ -338,6 +394,8 @@ def plot_multiple_metrics(
rel_y: bool = True,
fixed_features: Optional[ObservationFeatures] = None,
data_selector: Optional[Callable[[Observation], bool]] = None,
color_parameter: Optional[str] = None,
color_metric: Optional[str] = None,
**kwargs: Any,
) -> AxPlotConfig:
"""Plot raw values or predictions of two metrics for arms.
Expand All @@ -354,7 +412,15 @@ def plot_multiple_metrics(
rel_x: if True, use relative effects on metric_x.
rel_y: if True, use relative effects on metric_y.
data_selector: Function for selecting observations for plotting.
color_parameter: color points according to the specified parameter,
cannot be used together with color_metric.
color_metric: color points according to the specified metric,
cannot be used together with color_parameter.
"""
if color_parameter or color_metric:
layout_offset_x = 0.15
else:
layout_offset_x = 0
rel = checked_cast_optional(bool, kwargs.get("rel"))
if rel is not None:
warnings.warn("Use `rel_x` and `rel_y` instead of `rel`.", DeprecationWarning)
Expand All @@ -369,14 +435,16 @@ def plot_multiple_metrics(
rel_y=rel_y,
fixed_features=fixed_features,
data_selector=data_selector,
color_parameter=color_parameter,
color_metric=color_metric,
)
num_cand_traces = len(generator_runs_dict) if generator_runs_dict is not None else 0
layout = go.Layout(
title="Objective Tradeoffs",
hovermode="closest",
updatemenus=[
{
"x": 1.25,
"x": 1.25 + layout_offset_x,
"y": 0.67,
"buttons": [
{
Expand Down Expand Up @@ -408,7 +476,7 @@ def plot_multiple_metrics(
"xanchor": "left",
},
{
"x": 1.25,
"x": 1.25 + layout_offset_x,
"y": 0.57,
"buttons": [
{
Expand All @@ -432,7 +500,7 @@ def plot_multiple_metrics(
],
annotations=[
{
"x": 1.18,
"x": 1.18 + layout_offset_x,
"y": 0.7,
"xref": "paper",
"yref": "paper",
Expand All @@ -441,7 +509,7 @@ def plot_multiple_metrics(
"yanchor": "middle",
},
{
"x": 1.18,
"x": 1.18 + layout_offset_x,
"y": 0.6,
"xref": "paper",
"yref": "paper",
Expand All @@ -463,6 +531,7 @@ def plot_multiple_metrics(
width=800,
height=600,
font={"size": 10},
legend={"x": 1 + layout_offset_x},
)
fig = go.Figure(data=traces, layout=layout)
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
Expand All @@ -477,6 +546,8 @@ def plot_objective_vs_constraints(
infer_relative_constraints: Optional[bool] = False,
fixed_features: Optional[ObservationFeatures] = None,
data_selector: Optional[Callable[[Observation], bool]] = None,
color_parameter: Optional[str] = None,
color_metric: Optional[str] = None,
) -> AxPlotConfig:
"""Plot the tradeoff between an objetive and all other metrics in a model.
Expand All @@ -501,7 +572,15 @@ def plot_objective_vs_constraints(
Metrics that are not constraints will be relativized.
fixed_features: Fixed features to use when making model predictions.
data_selector: Function for selecting observations for plotting.
color_parameter: color points according to the specified parameter,
cannot be used together with color_metric.
color_metric: color points according to the specified metric,
cannot be used together with color_parameter.
"""
if color_parameter or color_metric:
layout_offset_x = 0.15
else:
layout_offset_x = 0
if subset_metrics is not None:
metrics = subset_metrics
else:
Expand Down Expand Up @@ -533,6 +612,8 @@ def plot_objective_vs_constraints(
rel_y=rels[metrics[0]],
fixed_features=fixed_features,
data_selector=data_selector,
color_parameter=color_parameter,
color_metric=color_metric,
)

for metric in metrics:
Expand All @@ -545,6 +626,8 @@ def plot_objective_vs_constraints(
rel_y=rels[metric],
fixed_features=fixed_features,
data_selector=data_selector,
color_parameter=color_parameter,
color_metric=color_metric,
)

# Current version of Plotly does not allow updating the yaxis label
Expand All @@ -571,7 +654,7 @@ def plot_objective_vs_constraints(
hovermode="closest",
updatemenus=[
{
"x": 1.25,
"x": 1.25 + layout_offset_x,
"y": 0.62,
"buttons": [
{
Expand Down Expand Up @@ -603,7 +686,7 @@ def plot_objective_vs_constraints(
"xanchor": "left",
},
{
"x": 1.25,
"x": 1.25 + layout_offset_x,
"y": 0.52,
"buttons": [
{
Expand All @@ -625,7 +708,7 @@ def plot_objective_vs_constraints(
"xanchor": "left",
},
{
"x": 1.25,
"x": 1.25 + layout_offset_x,
"y": 0.72,
"yanchor": "middle",
"xanchor": "left",
Expand All @@ -634,7 +717,7 @@ def plot_objective_vs_constraints(
],
annotations=[
{
"x": 1.18,
"x": 1.18 + layout_offset_x,
"y": 0.72,
"xref": "paper",
"yref": "paper",
Expand All @@ -643,7 +726,7 @@ def plot_objective_vs_constraints(
"yanchor": "middle",
},
{
"x": 1.18,
"x": 1.18 + layout_offset_x,
"y": 0.62,
"xref": "paper",
"yref": "paper",
Expand All @@ -652,7 +735,7 @@ def plot_objective_vs_constraints(
"yanchor": "middle",
},
{
"x": 1.18,
"x": 1.18 + layout_offset_x,
"y": 0.52,
"xref": "paper",
"yref": "paper",
Expand All @@ -674,6 +757,7 @@ def plot_objective_vs_constraints(
width=900,
height=600,
font={"size": 10},
legend={"x": 1 + layout_offset_x},
)

fig = go.Figure(data=plot_data, layout=layout)
Expand Down

0 comments on commit 9afd770

Please sign in to comment.