Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New parallel coordinates plot #2590

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ax/analysis/plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

# pyre-strict

from ax.analysis.plotly.parallel_coordinates.parallel_coordinates import (
ParallelCoordinatesPlot,
)
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard

__all__ = ["PlotlyAnalysis", "PlotlyAnalysisCard"]
__all__ = ["PlotlyAnalysis", "PlotlyAnalysisCard", "ParallelCoordinatesPlot"]
176 changes: 176 additions & 0 deletions ax/analysis/plotly/parallel_coordinates/parallel_coordinates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional

import numpy as np
import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.core.experiment import Experiment
from ax.core.objective import MultiObjective, ScalarizedObjective
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.modelbridge.generation_strategy import GenerationStrategy
from plotly import graph_objects as go


class ParallelCoordinatesPlot(PlotlyAnalysis):
"""
Plotly Parcoords plot for a single metric, with one line per arm and dimensions for
each parameter in the search space. This plot is useful for understanding how
thoroughly the search space is explored as well as for identifying if there is any
clusertering for either good or bad parameterizations.

The DataFrame computed will contain one row per arm and the following columns:
- arm_name: The name of the arm
- METRIC_NAME: The observed mean of the metric specified
- **PARAMETER_NAME: The value of said parameter for the arm, for each parameter
"""

def __init__(self, metric_name: Optional[str] = None) -> None:
"""
Args:
metric_name: The name of the metric to plot. If not specified the objective
will be used. Note that the metric cannot be inferred for
multi-objective or scalarized-objective experiments.
"""

self.metric_name = metric_name

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategy] = None,
) -> PlotlyAnalysisCard:
if experiment is None:
raise UserInputError("ParallelCoordinatesPlot requires an Experiment")

metric_name = self.metric_name or _select_metric(experiment=experiment)

df = _prepare_data(experiment=experiment, metric=metric_name)
fig = _prepare_plot(df=df, metric_name=metric_name)

return PlotlyAnalysisCard(
name=self.__class__.__name__,
title=f"Parallel Coordinates for {metric_name}",
subtitle="View arm parameterizations with their respective metric values",
level=AnalysisCardLevel.HIGH,
df=df,
blob=fig,
)


def _prepare_data(experiment: Experiment, metric: str) -> pd.DataFrame:
data_df = experiment.lookup_data().df
filtered_df = data_df.loc[data_df["metric_name"] == metric]

if filtered_df.empty:
raise ValueError(f"No data found for metric {metric}")

records = [
{
"arm_name": arm.name,
**arm.parameters,
metric: _find_mean_by_arm_name(df=filtered_df, arm_name=arm.name),
}
for trial in experiment.trials.values()
for arm in trial.arms
]

return pd.DataFrame.from_records(records)


def _prepare_plot(df: pd.DataFrame, metric_name: str) -> go.Figure:

# ParCoords requires that the dimensions are specified on continuous scales, so
# ChoiceParameters and FixedParameters must be preprocessed to allow for
# appropriate plotting.
parameter_dimensions = [
_get_parameter_dimension(series=df[col])
for col in df.columns
if col != "arm_name" and col != metric_name
]

return go.Figure(
go.Parcoords(
line={
"color": df[metric_name],
"showscale": True,
},
dimensions=[
*parameter_dimensions,
{
"label": _truncate_label(label=metric_name),
"values": df[metric_name].tolist(),
},
],
# Rotate the labels to allow them to be longer withoutoverlapping
labelangle=-45,
)
)


def _select_metric(experiment: Experiment) -> str:
if experiment.optimization_config is None:
raise ValueError(
"Cannot infer metric to plot from Experiment without OptimizationConfig"
)
objective = experiment.optimization_config.objective
if isinstance(objective, MultiObjective):
raise UnsupportedError(
"Cannot infer metric to plot from MultiObjective, please "
"specify a metric"
)
if isinstance(objective, ScalarizedObjective):
raise UnsupportedError(
"Cannot infer metric to plot from ScalarizedObjective, please "
"specify a metric"
)
return experiment.optimization_config.objective.metric.name


def _find_mean_by_arm_name(
df: pd.DataFrame,
arm_name: str,
) -> float:
# Given a dataframe with arm_name and mean columns, find the mean for a given
# arm_name. If an arm_name is not found (as can happen if the arm is still running
# or has failed) return NaN.
series = df.loc[df["arm_name"] == arm_name]["mean"]

if series.empty:
return np.nan

return series.item()


def _get_parameter_dimension(series: pd.Series) -> Dict[str, Any]:
# For numeric parameters allow Plotly to infer tick attributes. Note: booleans are
# considered numeric, but in this case we want to treat them as categorical.
if pd.api.types.is_numeric_dtype(series) and not pd.api.types.is_bool_dtype(series):
return {
"tickvals": None,
"ticktext": None,
"label": _truncate_label(label=str(series.name)),
"values": series.tolist(),
}

# For non-numeric parameters, sort, map onto an integer scale, and provide
# corresponding tick attributes
mapping = {v: k for k, v in enumerate(sorted(series.unique()))}

return {
"tickvals": [_truncate_label(label=str(val)) for val in mapping.values()],
"ticktext": [_truncate_label(label=str(key)) for key in mapping.keys()],
"label": _truncate_label(label=str(series.name)),
"values": series.map(mapping).tolist(),
}


def _truncate_label(label: str, n: int = 18) -> str:
if len(label) > n:
return label[:n] + "..."
return label
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.plotly.parallel_coordinates.parallel_coordinates import (
_get_parameter_dimension,
_select_metric,
ParallelCoordinatesPlot,
)
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_branin_experiment,
get_experiment_with_multi_objective,
get_experiment_with_scalarized_objective_and_outcome_constraint,
)


class TestParallelCoordinatesPlot(TestCase):
def test_compute(self) -> None:
analysis = ParallelCoordinatesPlot("branin")
experiment = get_branin_experiment(with_completed_trial=True)

with self.assertRaisesRegex(UserInputError, "requires an Experiment"):
analysis.compute()

card = analysis.compute(experiment=experiment)
self.assertEqual(card.name, "ParallelCoordinatesPlot")
self.assertEqual(card.title, "Parallel Coordinates for branin")
self.assertEqual(
card.subtitle,
"View arm parameterizations with their respective metric values",
)
self.assertEqual(card.level, AnalysisCardLevel.HIGH)
self.assertEqual({*card.df.columns}, {"arm_name", "branin", "x1", "x2"})
self.assertIsNotNone(card.blob)
self.assertEqual(card.blob_annotation, "plotly")

analysis_no_metric = ParallelCoordinatesPlot()
_ = analysis_no_metric.compute(experiment=experiment)

def test_select_metric(self) -> None:
experiment = get_branin_experiment()
experiment_no_optimization_config = get_branin_experiment(
has_optimization_config=False
)
experiment_multi_objective = get_experiment_with_multi_objective()
experiment_scalarized_objective = (
get_experiment_with_scalarized_objective_and_outcome_constraint()
)

self.assertEqual(_select_metric(experiment=experiment), "branin")

with self.assertRaisesRegex(ValueError, "OptimizationConfig"):
_select_metric(experiment=experiment_no_optimization_config)

with self.assertRaisesRegex(UnsupportedError, "MultiObjective"):
_select_metric(experiment=experiment_multi_objective)

with self.assertRaisesRegex(UnsupportedError, "ScalarizedObjective"):
_select_metric(experiment=experiment_scalarized_objective)

def test_get_parameter_dimension(self) -> None:
range_series = pd.Series([0, 1, 2, 3], name="range")
range_dimension = _get_parameter_dimension(series=range_series)
self.assertEqual(
range_dimension,
{
"tickvals": None,
"ticktext": None,
"label": "range",
"values": range_series.tolist(),
},
)

choice_series = pd.Series(["foo", "bar", "baz"], name="choice")
choice_dimension = _get_parameter_dimension(series=choice_series)
self.assertEqual(
choice_dimension,
{
"tickvals": ["0", "1", "2"],
"ticktext": ["bar", "baz", "foo"],
"label": "choice",
"values": [2, 0, 1],
},
)