Skip to content

Commit

Permalink
Merge pull request #310 from alan-turing-institute/show-plots
Browse files Browse the repository at this point in the history
Show plots
  • Loading branch information
mastoffel authored Feb 25, 2025
2 parents 9b9705c + 0d5b06d commit a344091
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 35 deletions.
4 changes: 2 additions & 2 deletions autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def plot_cv(
"actual_vs_predicted" for plotting observed values (y-axis) vs. the predicted values (x-axis).
"residual_vs_predicted" for plotting the residuals, i.e. difference between observed and predicted values, (y-axis) vs. the predicted values (x-axis).
n_cols : int
Number of columns in the plot grid.
Maximum number of columns in the plot grid.
figsize : tuple, optional
Overrides the default figure size, in inches, e.g. (6, 4).
output_index : int
Expand Down Expand Up @@ -512,7 +512,7 @@ def plot_eval(
"actual_vs_predicted" draws the observed values (y-axis) vs. the predicted values (x-axis) (default).
"residual_vs_predicted" draws the residuals, i.e. difference between observed and predicted values, (y-axis) vs. the predicted values (x-axis).
n_cols : int, optional
Number of columns in the plot grid for multi-output. Default is 2.
Maximum number of columns in the plot grid for multi-output. Default is 3.
output_index : list, int
Index of the output to plot. Either a single index or a list of indices.
input_index : list, int
Expand Down
75 changes: 60 additions & 15 deletions autoemulate/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,30 @@ def _predict_with_optional_std(model, X_test):
return y_test_pred, y_test_std


def _calculate_subplot_layout(n_plots, n_cols=3):
"""Calculate optimal number of rows and columns for subplots.
Parameters
----------
n_plots : int
Number of plots to display
n_cols : int, optional
Maximum number of columns allowed, default is 3
Returns
-------
tuple
(n_rows, n_cols) for the subplot layout
"""
if n_plots <= 1:
return (1, 1)

n_cols = min(n_plots, n_cols)
n_rows = (n_plots + n_cols - 1) // n_cols

return n_rows, n_cols


def _plot_single_fold(
cv_results,
X,
Expand Down Expand Up @@ -185,7 +209,7 @@ def _plot_best_fold_per_model(
y : array-like, shape (n_samples, n_outputs)
Simulation output.
n_cols : int, optional
The number of columns in the plot. Default is 3.
The maximum number of columns in the plot. Default is 3.
plot : str, optional
The type of plot to draw:
“standard" or "residual”.
Expand All @@ -198,7 +222,7 @@ def _plot_best_fold_per_model(
"""

n_models = len(cv_results)
n_rows = int(np.ceil(n_models / n_cols))
n_rows, n_cols = _calculate_subplot_layout(n_models, n_cols)

if figsize is None:
figsize = (4 * n_cols, 3 * n_rows)
Expand All @@ -225,8 +249,6 @@ def _plot_best_fold_per_model(
for j in range(i + 1, len(axs)):
axs[j].set_visible(False)
plt.tight_layout()
# prevent double plotting in notebooks
plt.close(fig)
return fig


Expand Down Expand Up @@ -254,7 +276,7 @@ def _plot_model_folds(
model_name : str
The name of the model to plot.
n_cols : int, optional
The number of columns in the plot. Default is 5.
The maximum number of columns in the plot. Default is 3.
plot : str, optional
The type of plot to draw:
“standard” or “residual”.
Expand All @@ -267,7 +289,7 @@ def _plot_model_folds(
"""

n_folds = len(cv_results[model_name]["estimator"])
n_rows = int(np.ceil(n_folds / n_cols))
n_rows, n_cols = _calculate_subplot_layout(n_folds, n_cols)

if figsize is None:
figsize = (4 * n_cols, 3 * n_rows)
Expand All @@ -293,8 +315,6 @@ def _plot_model_folds(
axs[j].set_visible(False)

plt.tight_layout()
# prevent double plotting in notebooks
plt.close(fig)
return fig


Expand Down Expand Up @@ -322,7 +342,7 @@ def _plot_cv(
model_name : (str, optional)
The name of the model to plot. If None, the best (largest R^2) fold for each model will be plotted.
n_cols : int, optional
The number of columns in the plot. Default is 3.
The maximum number of columns in the plot. Default is 3.
plot : str, optional
The type of plot to draw:
“standard” draws the observed values (y-axis) vs. the predicted values (x-axis) (default).
Expand Down Expand Up @@ -355,7 +375,7 @@ def _plot_cv(
cv_results, X, y, n_cols, style, figsize, output_index, input_index
)

return figure
return _display_figure(figure)


def _plot_model(
Expand Down Expand Up @@ -384,7 +404,7 @@ def _plot_model(
"residual" draws the residuals, i.e. difference between observed and predicted values, (y-axis) vs. the predicted values (x-axis).
"Xy" draws the input features vs. the output variables, including predictions.
n_cols : int, optional
The number of columns in the plot. Default is 2.
The maximum number of columns in the plot. Default is 3.
figsize : tuple, optional
Overrides the default figure size.
input_index : int or list of int, optional
Expand Down Expand Up @@ -426,7 +446,7 @@ def _plot_model(
n_plots = len(output_index)

# Calculate number of rows
n_rows = int(np.ceil(n_plots / n_cols))
n_rows, n_cols = _calculate_subplot_layout(n_plots, n_cols)

# Set up the figure
if figsize is None:
Expand Down Expand Up @@ -478,9 +498,7 @@ def _plot_model(
ax.set_visible(False)
plt.tight_layout()

# prevent double plotting in notebooks
plt.close(fig)
return fig
return _display_figure(fig)


def _plot_Xy(
Expand Down Expand Up @@ -577,3 +595,30 @@ def _plot_Xy(
transform=ax.transAxes,
verticalalignment="bottom",
)


def _display_figure(fig):
"""
Display a matplotlib figure appropriately based on the environment (Jupyter notebook or terminal).
Args:
fig: matplotlib figure object to display
Returns:
fig: the input figure object
"""
# Are we in Jupyter?
try:
is_jupyter = get_ipython().__class__.__name__ == "ZMQInteractiveShell"
except NameError:
is_jupyter = False

if is_jupyter:
# we don't show otherwise it will double plot
plt.close(fig)
return fig
else:
# in terminal, show the plot
plt.show()
plt.close(fig)
return fig
5 changes: 2 additions & 3 deletions autoemulate/sensitivity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from SALib.analyze.sobol import analyze
from SALib.sample.sobol import sample

from autoemulate.plotting import _display_figure
from autoemulate.utils import _ensure_2d


Expand Down Expand Up @@ -298,7 +299,5 @@ def _plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None):
)

plt.tight_layout()
# prevent double plotting in notebooks
plt.close(fig)

return fig
return _display_figure(fig)
41 changes: 26 additions & 15 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,20 @@ def test__plot_cv_input_range(ae_multi_output, monkeypatch):
_plot_cv(cv_results, X, y, input_index=2)


# ------------------------------ most important tests, does it work? ----------------
# ------------------------------ test plot_cv ----------------------------------
# # ------------------------------ most important tests, does it work? ----------------
# # ------------------------------ test plot_cv ----------------------------------


# test plots with best cv per model, Xy plot
def test_plot_cv(ae_single_output):
# # test plots with best cv per model, Xy plot
def test_plot_cv(ae_single_output, monkeypatch):
monkeypatch.setattr(plt, "show", lambda: None)
fig = ae_single_output.plot_cv(style="Xy")
assert isinstance(fig, plt.Figure)
assert len(fig.axes) == 3


def test_plot_cv_input_index(ae_single_output):
def test_plot_cv_input_index(ae_single_output, monkeypatch):
monkeypatch.setattr(plt, "show", lambda: None)
fig = ae_single_output.plot_cv(input_index=1)
assert isinstance(fig, plt.Figure)
assert len(fig.axes) == 3
Expand All @@ -285,7 +287,8 @@ def test_plot_cv_input_index_out_of_range(ae_single_output):
ae_single_output.plot_cv(input_index=2)


def test_plot_cv_output_index(ae_multi_output):
def test_plot_cv_output_index(ae_multi_output, monkeypatch):
monkeypatch.setattr(plt, "show", lambda: None)
fig = ae_multi_output.plot_cv(output_index=1)
assert isinstance(fig, plt.Figure)
assert len(fig.axes) == 3
Expand All @@ -297,13 +300,15 @@ def test_plot_cv_output_index_out_of_range(ae_multi_output):


# test plots with best cv per model, standard [;pt]
def test_plot_cv_actual_vs_predicted(ae_single_output):
def test_plot_cv_actual_vs_predicted(ae_single_output, monkeypatch):
monkeypatch.setattr(plt, "show", lambda: None)
fig = ae_single_output.plot_cv(style="actual_vs_predicted")
assert isinstance(fig, plt.Figure)
assert len(fig.axes) == 3


def test_plot_cv_output_index_actual_vs_predicted(ae_multi_output):
def test_plot_cv_output_index_actual_vs_predicted(ae_multi_output, monkeypatch):
monkeypatch.setattr(plt, "show", lambda: None)
fig = ae_multi_output.plot_cv(style="actual_vs_predicted", output_index=1)
assert isinstance(fig, plt.Figure)
assert len(fig.axes) == 3
Expand All @@ -315,19 +320,22 @@ def test_plot_cv_output_index_actual_vs_predicted_out_of_range(ae_multi_output):


# test plots with all cv folds for a single model
def test_plot_cv_model(ae_single_output):
def test_plot_cv_model(ae_single_output, monkeypatch):
monkeypatch.setattr(plt, "show", lambda: None)
fig = ae_single_output.plot_cv(model="gp")
assert isinstance(fig, plt.Figure)
assert len(fig.axes) == 6 # 5 cv folds, but three columns so 6 subplots are made


def test_plot_cv_model_input_index(ae_single_output):
def test_plot_cv_model_input_index(ae_single_output, monkeypatch):
monkeypatch.setattr(plt, "show", lambda: None)
fig = ae_single_output.plot_cv(model="gp", input_index=1)
assert isinstance(fig, plt.Figure)
assert len(fig.axes) == 6


def test_plot_cv_model_output_index(ae_multi_output):
def test_plot_cv_model_output_index(ae_multi_output, monkeypatch):
monkeypatch.setattr(plt, "show", lambda: None)
fig = ae_multi_output.plot_cv(model="gp", output_index=1)
assert isinstance(fig, plt.Figure)
assert len(fig.axes) == 6
Expand All @@ -343,8 +351,9 @@ def test_plot_cv_model_output_index_out_of_range(ae_multi_output):
ae_multi_output.plot_cv(model="gp", output_index=2)


# ------------------------------ test _plot_model ------------------------------
def test__plot_model_int(ae_single_output):
# # ------------------------------ test _plot_model ------------------------------
def test__plot_model_int(ae_single_output, monkeypatch):
monkeypatch.setattr(plt, "show", lambda: None)
fig = _plot_model(
ae_single_output.get_model(name="gp"),
ae_single_output.X,
Expand All @@ -357,7 +366,8 @@ def test__plot_model_int(ae_single_output):
assert all(term in fig.axes[0].get_title() for term in ["X", "y", "vs."])


def test__plot_model_list(ae_single_output):
def test__plot_model_list(ae_single_output, monkeypatch):
monkeypatch.setattr(plt, "show", lambda: None)
fig = _plot_model(
ae_single_output.get_model(name="gp"),
ae_single_output.X,
Expand All @@ -382,7 +392,8 @@ def test__plot_model_int_out_of_range(ae_single_output):
)


def test__plot_model_actual_vs_predicted(ae_single_output):
def test__plot_model_actual_vs_predicted(ae_single_output, monkeypatch):
monkeypatch.setattr(plt, "show", lambda: None)
fig = _plot_model(
ae_single_output.get_model(name="gp"),
ae_single_output.X,
Expand Down

0 comments on commit a344091

Please sign in to comment.