diff --git a/src/gluonts/lab/__init__.py b/src/gluonts/lab/__init__.py new file mode 100644 index 0000000000..ad0b91c549 --- /dev/null +++ b/src/gluonts/lab/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from .viz import ( + plot_time_series, + plot_univariate_forecast, + plot_multivariate_forecast, + plot_forecast_comparison, +) + +__all__ = [ + "plot_time_series", + "plot_univariate_forecast", + "plot_multivariate_forecast", + "plot_forecast_comparison", +] diff --git a/src/gluonts/lab/helpers.py b/src/gluonts/lab/helpers.py new file mode 100644 index 0000000000..e8f34c6c78 --- /dev/null +++ b/src/gluonts/lab/helpers.py @@ -0,0 +1,116 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import List, Optional, Collection, Union + +import pandas as pd +from matplotlib import pyplot as plt + +from gluonts.model import Forecast +from gluonts.exceptions import GluonTSUserError + + +def get_percentiles(prediction_intervals): + percentiles_list = [50.0] + [ + 50.0 + sign * percentile / 2.0 + for percentile in prediction_intervals + for sign in [-1.0, +1.0] + ] + return sorted(set(percentiles_list)) + + +def plot_forecast( + forecast: Forecast, + axis: plt.axis, + prediction_intervals: Collection[float], + plot_mean: bool, + color: str, + marker: Optional[str] = None, + label_prefix: str = "", + label_suffix: str = "", +): + """Helper function for plotting a single forecast + + Parameters + ---------- + axis + plt.axis to plot on + forecast + Forecast to plot + ... + """ + + interval_count = len(prediction_intervals) + + percentiles = get_percentiles(prediction_intervals) + predictions = [forecast.quantile(p / 100.0) for p in percentiles] + + if plot_mean: + axis.plot( + forecast.index.to_timestamp(), + forecast.mean, + color=color, + ls=":", + label=f"{label_prefix}mean prediction{label_suffix}", + marker=marker, + ) + + # median prediction + p50_data = predictions[interval_count] + p50_series = pd.Series(data=p50_data, index=forecast.index.to_timestamp()) + label = f"{label_prefix}median prediction" + axis.plot( + p50_series, label=label, linestyle="--", color=color, marker=marker + ) + + # percentile prediction intervals + alphas_lower_half = [(p / 100.0) ** 0.3 for p in percentiles] + alphas = alphas_lower_half + alphas_lower_half[::-1] + for interval_idx in range(interval_count): + p = 100 - percentiles[interval_idx] * 2 + label = f"{label_prefix}{p}% prediction interval" + + # plot lower and upper half of median individually to keep colors true + area_info = [ + {"label": label, "idx": interval_idx}, # give label only once + {"label": None, "idx": interval_count * 2 - interval_idx - 1}, + ] + for info in area_info: + axis.fill_between( + forecast.index.to_timestamp(), + predictions[info["idx"]], + predictions[info["idx"] + 1], + facecolor=color, + alpha=alphas[interval_idx], + interpolate=True, + label=info["label"], + ) + + +def read_input_for_marker_or_color( + value: Union[str, List[str]], entry_count: int, entity_name: str +): + """normalize input for marker/color into list of length `entry_count`""" + result = value + if isinstance(result, str): + result = [result] + if len(result) == 0: + raise GluonTSUserError(f"'{entity_name}' can't be empty list") + + # repeat if necesarry to match entry_count + result_idx = 0 + while len(result) < entry_count: + result.append(result[result_idx]) + result_idx += 1 + + return result diff --git a/src/gluonts/lab/usage_example.py b/src/gluonts/lab/usage_example.py new file mode 100644 index 0000000000..8e6c0306c6 --- /dev/null +++ b/src/gluonts/lab/usage_example.py @@ -0,0 +1,144 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from matplotlib.dates import DateFormatter + +from gluonts.mx import SimpleFeedForwardEstimator, Trainer, DeepVAREstimator +from gluonts.mx.distribution import MultivariateGaussianOutput +from gluonts.dataset.repository.datasets import get_dataset +from gluonts.dataset.artificial import default_synthetic +from gluonts.dataset.common import TrainDatasets +from gluonts.dataset.multivariate_grouper import MultivariateGrouper +from gluonts.evaluation import make_evaluation_predictions +from gluonts.lab.viz import ( + plot_forecast_comparison, + plot_time_series, + plot_univariate_forecast, + plot_multivariate_forecast, +) + + +def get_m4_ts_and_forecast(): + dataset = get_dataset("m4_hourly") + estimator = SimpleFeedForwardEstimator( + num_hidden_dimensions=[10], + prediction_length=dataset.metadata.prediction_length, + context_length=100, + trainer=Trainer( + ctx="cpu", epochs=5, learning_rate=1e-3, num_batches_per_epoch=100 + ), + ) + predictor = estimator.train(dataset.train) + + forecast_it, ts_it = make_evaluation_predictions( + dataset=dataset.test, + predictor=predictor, + num_samples=100, + ) + forecast_entry = next(iter(forecast_it)) + ts_entry = next(iter(ts_it)) + + return forecast_entry, ts_entry + + +def get_multivariate_ts_and_forecast(): + def load_multivariate_synthetic_dataset(): + dataset_info, train_ds, test_ds = default_synthetic() + grouper_train = MultivariateGrouper(max_target_dim=10) + grouper_test = MultivariateGrouper(num_test_dates=1, max_target_dim=10) + metadata = dataset_info.metadata + metadata.prediction_length = dataset_info.prediction_length + return TrainDatasets( + metadata=dataset_info.metadata, + train=grouper_train(train_ds), + test=grouper_test(test_ds), + ) + + dataset = load_multivariate_synthetic_dataset() + target_dim = int(dataset.metadata.feat_static_cat[0].cardinality) + metadata = dataset.metadata + + estimator = DeepVAREstimator( + num_cells=20, + num_layers=1, + pick_incomplete=True, + prediction_length=metadata.prediction_length, + target_dim=target_dim, + freq=metadata.freq, + distr_output=MultivariateGaussianOutput(dim=target_dim), + scaling=False, + trainer=Trainer( + epochs=10, learning_rate=1e-10, num_batches_per_epoch=1 + ), + ) + + predictor = estimator.train(training_data=dataset.train) + forecast_it, ts_it = make_evaluation_predictions(dataset.test, predictor) + + forecast_entry = next(iter(forecast_it)) + ts_it = next(iter(ts_it)) + return forecast_entry, ts_it + + +def ts_example(): + _, ts_entry = get_m4_ts_and_forecast() + fig, _ = plot_time_series(ts_entry) + + +def univariate_example(): + forecast_entry, ts_entry = get_m4_ts_and_forecast() + + fig, ax = plot_univariate_forecast( + forecast=forecast_entry, + time_series=ts_entry[-100:], + plot_mean=True, + label_prefix="first entry - ", + show_plot=False, + ) + + # do more custom things before showing the plot + ax.tick_params(axis="x", labelrotation=45) + ax.xaxis.set_major_formatter(DateFormatter("%d-%m-%Y %H:%M")) + ax.legend(loc="upper left") + fig.show() + + +def multivariate_example(): + forecast_entry, ts_entry = get_multivariate_ts_and_forecast() + + fig, ax = plot_multivariate_forecast( + forecast=forecast_entry, + time_series=ts_entry, + variates_to_plot=[0, 1, 4], + color=["g", "r", "b"], + marker=["^", "v"], + ) + + +def comparison_example(): + forecast_entry, ts_entry = get_m4_ts_and_forecast() + forecast_entry2, ts_entry2 = get_m4_ts_and_forecast() + forecasts = [forecast_entry, forecast_entry2] + + fig, ax = plot_forecast_comparison( + forecasts=forecasts, + time_series=ts_entry, + show_plot=False, + use_subplots=False, + ) + + +ts_example() +univariate_example() +multivariate_example() +comparison_example() diff --git a/src/gluonts/lab/viz.py b/src/gluonts/lab/viz.py new file mode 100644 index 0000000000..533e937dbc --- /dev/null +++ b/src/gluonts/lab/viz.py @@ -0,0 +1,330 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from os import PathLike +from typing import List, Optional, Collection, Tuple, Union + +import pandas as pd +from matplotlib import pyplot as plt +from .helpers import plot_forecast, read_input_for_marker_or_color + +from gluonts.model import Forecast +from gluonts.exceptions import GluonTSUserError + + +def plot_time_series( + time_series: Union[pd.Series, pd.DataFrame], + train_test_separator: Optional[pd.Timestamp] = None, + label: str = "target", + color="black", + plt_context: Optional[Tuple[plt.figure, plt.axis]] = None, + save_path: Optional[Union[str, bytes, PathLike]] = None, + show_plot: bool = True, +): + """Plots a univariate or multivariate time series + + Parameters + ---------- + time_series + data points to plot, in Pandas format + train_test_separator + if provided a time stamp, draws a vertical line at this position + label + label to use in call to plot + color + line color, defaults to "black" + plt_context + tuple (fig, ax) to use; if None, new plot will be created + save_path + specifies where to save the plot + show_plot + wether or not to show the plot using matplolibs `show`, defaults to True + """ + if plt_context is None: + plt_context = plt.subplots() + fig, axis = plt_context + + if isinstance(time_series, pd.Series): + time_series = time_series.to_frame() + + axis.plot( + time_series.to_timestamp().index, + time_series.values, + label=label, + color=color, + ) + + if train_test_separator is not None: + axis.axvline(train_test_separator, color="r") + + if save_path: + fig.savefig(save_path) + if show_plot: + fig.show() + + return fig, axis + + +def plot_univariate_forecast( + forecast: Optional[Forecast], + time_series: Optional[Union[pd.DataFrame, pd.Series]] = None, + prediction_intervals: Collection[float] = (50.0, 90.0), + plot_mean: bool = False, + figsize: Tuple[int] = (10, 10), + xlabel: str = "Time", + ylabel: str = "Value", + label_prefix: str = "", + label_suffix: str = "", + color: str = "g", + marker: Optional[str] = None, + legend_location: str = "upper left", + plot_grid: bool = True, + show_plot: bool = True, + save_path: Optional[Union[str, bytes, PathLike]] = None, +) -> Tuple[plt.figure, plt.axis]: + """Plots prediction intervals of a single probabilistic forecast + + Parameters + ---------- + forecast + Forecast to plot + time_series + ground truth for comparison, can start at earlier time than forecast + prediction_intervals + a collection of numbers between 0 and 100 specifying what prediction + intervals to plot - the larger a value, the fainter the color + plot_mean + wether or not to plot the forecast mean, defaults to False + ... + """ + + for c in prediction_intervals: + if not 0.0 <= c <= 100.0: + raise GluonTSUserError( + f"Prediction interval {c} is not between 0 and 100" + ) + + dim_count = forecast.dim() + assert ( + dim_count == 1 + ), f"Expected univariate forecast but got {dim_count} dimensions(s)" + + fig, axis = plt.subplots(1, 1, figsize=figsize) + + if time_series is not None: + plot_time_series( + time_series=time_series, + label=f"{label_prefix}target{label_suffix}", + plt_context=(fig, axis), + ) + + plot_forecast( + forecast=forecast, + prediction_intervals=prediction_intervals, + axis=axis, + plot_mean=plot_mean, + color=color, + marker=marker, + label_prefix=label_prefix, + ) + + axis.legend(loc=legend_location) + axis.set_xlabel(xlabel) + axis.set_ylabel(ylabel) + if plot_grid: + axis.grid(which="both") + + if save_path: + fig.savefig(save_path) + if show_plot: + fig.show() + + return fig, axis + + +def plot_multivariate_forecast( + forecast: Forecast, + time_series: Optional[Union[pd.DataFrame, pd.Series]] = None, + prediction_intervals: Collection[float] = (50.0, 90.0), + variates_to_plot: Optional[List[int]] = None, + plot_mean: bool = False, + figsize: Tuple[int] = (10, 10), + xlabel: str = "Time", + ylabel: str = "Value", + label_prefix: Optional[str] = None, + legend_location: str = "upper left", + plot_grid: bool = True, + color: Union[str, List[str]] = "g", + marker: Union[str, List[str]] = "o", + use_subplots: bool = True, + show_plot: bool = True, + save_path: Optional[Union[str, bytes, PathLike]] = None, +): + def dim_suffix(dim: int) -> str: + return f" (dim {dim})" + + for c in prediction_intervals: + if not 0.0 <= c <= 100.0: + raise GluonTSUserError( + f"Prediction interval {c} is not between 0 and 100" + ) + + dim_count = forecast.dim() + assert ( + dim_count > 1 + ), f"Expected multivariate forecast but got {dim_count} dimension(s)" + + dim_count = forecast.dim() + if variates_to_plot is None: + variates_to_plot = list(range(dim_count)) + variates_to_plot = sorted(set(variates_to_plot)) + + if not all(0 <= dim < dim_count for dim in variates_to_plot): + raise GluonTSUserError( + "Each dim in variates_to_plot must be in range " + "0 <= dim < forecast.dim()" + ) + + color = read_input_for_marker_or_color(color, dim_count, "color") + plot_markers = marker is not None + if plot_markers: + marker = read_input_for_marker_or_color(marker, dim_count, "marker") + + label_prefix = "" if label_prefix is None else label_prefix + + subplot_count = len(variates_to_plot) + if use_subplots: + fig, axes = plt.subplots( + subplot_count, 1, figsize=figsize, sharex=True, sharey=True + ) + else: + fig, axis = plt.subplots(1, 1, figsize=figsize) + axes = [axis] * subplot_count # always use the same axis to draw on + + if time_series is not None: + for axis, dim in zip(axes, variates_to_plot): + plot_time_series( + time_series=time_series[dim], + label=f"{label_prefix}target{dim_suffix(dim)}", + plt_context=(fig, axis), + ) + + for axis, dim in zip(axes, variates_to_plot): + label_suffix = dim_suffix(dim) + plot_forecast( + forecast=forecast.copy_dim(dim), + prediction_intervals=prediction_intervals, + axis=axis, + plot_mean=plot_mean, + color=color[dim], + label_prefix=label_prefix, + label_suffix=label_suffix, + marker=marker[dim] if plot_markers else None, + ) + + fig.supxlabel(xlabel) + fig.supylabel(ylabel) + + for axis in axes: + handles, labels = axis.get_legend_handles_labels() + axis.legend(handles, labels, loc=legend_location) + if plot_grid: + axis.grid(which="both") + + if save_path: + fig.savefig(save_path) + if show_plot: + fig.show() + + return fig, axes + + +def plot_forecast_comparison( + forecasts: Collection[Forecast], + time_series: Optional[Union[pd.DataFrame, pd.Series]] = None, + prediction_intervals: Collection[float] = (50.0, 90.0), + plot_mean: bool = False, + figsize: Tuple[int] = (10, 10), + xlabel: str = "Time", + ylabel: str = "Value", + label_prefix: Optional[str] = None, + legend_location: str = "upper left", + plot_grid: bool = True, + color: Union[str, List[str]] = "g", + marker: Optional[List[str]] = None, + use_subplots: bool = True, + show_plot: bool = True, + save_path: Optional[Union[str, bytes, PathLike]] = None, +): + for c in prediction_intervals: + if not 0.0 <= c <= 100.0: + raise GluonTSUserError( + f"Prediction interval {c} is not between 0 and 100" + ) + + forecast_count = len(forecasts) + + color = read_input_for_marker_or_color(color, forecast_count, "color") + + plot_markers = marker is not None + if plot_markers: + marker = read_input_for_marker_or_color( + marker, forecast_count, "marker" + ) + + label_prefix = "" if label_prefix is None else label_prefix + + if use_subplots: + fig, axes = plt.subplots( + forecast_count, 1, figsize=figsize, sharex=True, sharey=True + ) + else: + fig, axis = plt.subplots(1, 1, figsize=figsize) + axes = [axis] * forecast_count # always use the same axis to draw on + + if time_series is not None: + for axis, forecast_id in zip(axes, range(forecast_count)): + plot_time_series( + time_series=time_series, + label=f"{label_prefix}target", + plt_context=(fig, axis), + ) + + if forecasts is not None: + for axis, forecast_id in zip(axes, range(forecast_count)): + plot_forecast( + forecast=forecasts[forecast_id], + prediction_intervals=prediction_intervals, + axis=axis, + plot_mean=plot_mean, + color=color[forecast_id], + label_prefix=label_prefix, + label_suffix=str(forecast_id), + marker=marker[forecast_id] if plot_markers else None, + ) + + fig.supxlabel(xlabel) + fig.supylabel(ylabel) + + for axis in axes: + handles, labels = axis.get_legend_handles_labels() + axis.legend(handles, labels, loc=legend_location) + if plot_grid: + axis.grid(which="both") + + if save_path: + fig.savefig(save_path) + if show_plot: + fig.show() + + return fig, axes