diff --git a/CHANGELOG.md b/CHANGELOG.md index b2bf074d13..a8962068c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,11 +8,12 @@ * Add `num_chains` and `pred_dims` arguments to io_pyro #1090 * Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079) * Allow xarray.Dataarray input for plots.(#1120) +* Revamped the `hpd` function to make it work with mutidimensional arrays, InferenceData and xarray objects (#1117) * Skip test for optional/extra dependencies when not installed (#1113) ### Maintenance and fixes * Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115) * Fixed hist kind of `plot_dist` with multidimensional input (#1115) -* Fixed `TypeError` in `transform` argument of `plot_density` and `plot_forest` when `InferenceData is a list or tuple (#1121)` +* Fixed `TypeError` in `transform` argument of `plot_density` and `plot_forest` when `InferenceData` is a list or tuple (#1121) ### Deprecation ### Documentation diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index d29fff70e4..cec2e3f836 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -12,7 +12,7 @@ from scipy.optimize import minimize import xarray as xr -from ..plots.plot_utils import _fast_kde, get_bins +from ..plots.plot_utils import _fast_kde, get_bins, get_coords from ..data import convert_to_inference_data, convert_to_dataset, InferenceData, CoordSpec, DimSpec from .diagnostics import _multichain_statistics, _mc_error, ess from .stats_utils import ( @@ -305,7 +305,18 @@ def _ic_matrix(ics, ic_i): return rows, cols, ic_i_val -def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=False): +def hpd( + ary, + credible_interval=None, + circular=False, + multimodal=False, + skipna=False, + group="posterior", + var_names=None, + coords=None, + max_modes=10, + **kwargs +): """ Calculate highest posterior density (HPD) of array for given credible_interval. @@ -313,8 +324,10 @@ def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=Fa Parameters ---------- - ary : Numpy array - An array containing posterior samples + ary : obj + object containing posterior samples. + Any object that can be converted to an az.InferenceData object. + Refer to documentation of az.convert_to_dataset for details. credible_interval : float, optional Credible interval to compute. Defaults to 0.94. circular : bool, optional @@ -326,10 +339,22 @@ def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=Fa modes are well separated. skipna : bool If true ignores nan values when computing the hpd interval. Defaults to false. + group : str, optional + Specifies which InferenceData group should be used to calculate hpd. + Defaults to 'posterior' + var_names : list, optional + Names of variables to include in the hpd report + coords: mapping, optional + Specifies the subset over to calculate hpd. + max_modes: int, optional + Specifies the maximume number of modes for multimodal case. + kwargs : dict, optional + Additional keywords passed to `wrap_xarray_ufunc`. + See the docstring of :obj:`wrap_xarray_ufunc method `. Returns ------- - np.ndarray + np.ndarray or xarray.Dataset, depending upon input lower(s) and upper(s) values of the interval(s). Examples @@ -342,6 +367,34 @@ def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=Fa ...: import numpy as np ...: data = np.random.normal(size=2000) ...: az.hpd(data, credible_interval=.68) + + Calculate the hpd of a dataset: + + .. ipython:: + + In [1]: import arviz as az + ...: data = az.load_arviz_data('centered_eight') + ...: az.hpd(data) + + We can also calculate the hpd of some of the variables of dataset: + + .. ipython:: + + In [1]: az.hpd(data, var_names=["mu", "theta"]) + + If we want to calculate the hpd over specified dimension of dataset, + we can pass `input_core_dims` by kwargs: + + .. ipython:: + + In [1]: az.hpd(data, input_core_dims = [["chain"]]) + + We can also calculate the hpd over a particular selection over all groups: + + .. ipython:: + + In [1]: az.hpd(data, coords={"chain":[0, 1, 3]}, input_core_dims = [["draw"]]) + """ if credible_interval is None: credible_interval = rcParams["stats.credible_interval"] @@ -349,84 +402,113 @@ def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=Fa if not 1 >= credible_interval > 0: raise ValueError("The value of credible_interval should be in the interval (0, 1]") - if ary.ndim > 1: - hpd_array = np.array( - [ - hpd( - row, - credible_interval=credible_interval, - circular=circular, - multimodal=multimodal, - ) - for row in ary.T - ] - ) - return hpd_array + func_kwargs = { + "credible_interval": credible_interval, + "skipna": skipna, + "out_shape": (max_modes, 2,) if multimodal else (2,), + } + kwargs.setdefault("output_core_dims", [["hpd", "mode"] if multimodal else ["hpd"]]) + if not multimodal: + func_kwargs["circular"] = circular + else: + func_kwargs["max_modes"] = max_modes - if multimodal: - if skipna: - ary = ary[~np.isnan(ary)] + func = _hpd_multimodal if multimodal else _hpd - if ary.dtype.kind == "f": - density, lower, upper = _fast_kde(ary) - range_x = upper - lower - dx = range_x / len(density) - bins = np.linspace(lower, upper, len(density)) - else: - bins = get_bins(ary) - _, density, _ = histogram(ary, bins=bins) - dx = np.diff(bins)[0] + isarray = isinstance(ary, np.ndarray) + if isarray and ary.ndim <= 1: + func_kwargs.pop("out_shape") + hpd_data = func(ary, **func_kwargs) # pylint: disable=unexpected-keyword-arg + return hpd_data[~np.isnan(hpd_data).all(axis=1), :] if multimodal else hpd_data - density *= dx + if isarray and ary.ndim == 2: + kwargs.setdefault("input_core_dims", [["chain"]]) - idx = np.argsort(-density) - intervals = bins[idx][density[idx].cumsum() <= credible_interval] - intervals.sort() + ary = convert_to_dataset(ary, group=group) + if coords is not None: + ary = get_coords(ary, coords) + var_names = _var_names(var_names, ary) + ary = ary[var_names] if var_names else ary + + hpd_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs) + hpd_data = hpd_data.dropna("mode", how="all") if multimodal else hpd_data + return hpd_data.x.values if isarray else hpd_data + + +def _hpd(ary, credible_interval, circular, skipna): + """Compute hpd over the flattened array.""" + ary = ary.flatten() + if skipna: + nans = np.isnan(ary) + if not nans.all(): + ary = ary[~nans] + n = len(ary) - intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1) + if circular: + mean = st.circmean(ary, high=np.pi, low=-np.pi) + ary = ary - mean + ary = np.arctan2(np.sin(ary), np.cos(ary)) - hpd_intervals = [] - for interval in intervals_splitted: - if interval.size == 0: - hpd_intervals.append((bins[0], bins[0])) - else: - hpd_intervals.append((interval[0], interval[-1])) + ary = np.sort(ary) + interval_idx_inc = int(np.floor(credible_interval * n)) + n_intervals = n - interval_idx_inc + interval_width = ary[interval_idx_inc:] - ary[:n_intervals] - hpd_intervals = np.array(hpd_intervals) + if len(interval_width) == 0: + raise ValueError("Too few elements for interval calculation. ") - else: - if skipna: - nans = np.isnan(ary) - if not nans.all(): - ary = ary[~nans] - n = len(ary) + min_idx = np.argmin(interval_width) + hdi_min = ary[min_idx] + hdi_max = ary[min_idx + interval_idx_inc] - if circular: - mean = st.circmean(ary, high=np.pi, low=-np.pi) - ary = ary - mean - ary = np.arctan2(np.sin(ary), np.cos(ary)) + if circular: + hdi_min = hdi_min + mean + hdi_max = hdi_max + mean + hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min)) + hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max)) - ary = np.sort(ary) - interval_idx_inc = int(np.floor(credible_interval * n)) - n_intervals = n - interval_idx_inc - interval_width = ary[interval_idx_inc:] - ary[:n_intervals] + hpd_intervals = np.array([hdi_min, hdi_max]) - if len(interval_width) == 0: - raise ValueError("Too few elements for interval calculation. ") + return hpd_intervals - min_idx = np.argmin(interval_width) - hdi_min = ary[min_idx] - hdi_max = ary[min_idx + interval_idx_inc] - if circular: - hdi_min = hdi_min + mean - hdi_max = hdi_max + mean - hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min)) - hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max)) +def _hpd_multimodal(ary, credible_interval, skipna, max_modes): + """Compute hpd if the distribution is multimodal.""" + ary = ary.flatten() + if skipna: + ary = ary[~np.isnan(ary)] - hpd_intervals = np.array([hdi_min, hdi_max]) + if ary.dtype.kind == "f": + density, lower, upper = _fast_kde(ary) + range_x = upper - lower + dx = range_x / len(density) + bins = np.linspace(lower, upper, len(density)) + else: + bins = get_bins(ary) + _, density, _ = histogram(ary, bins=bins) + dx = np.diff(bins)[0] - return hpd_intervals + density *= dx + + idx = np.argsort(-density) + intervals = bins[idx][density[idx].cumsum() <= credible_interval] + intervals.sort() + + intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1) + + hpd_intervals = np.full((max_modes, 2,), np.nan,) + for i, interval in enumerate(intervals_splitted): + if i == max_modes: + warnings.warn( + "found more modes than {0}, returning only the first {0} modes", max_modes + ) + break + if interval.size == 0: + hpd_intervals[i] = np.asarray([bins[0], bins[0]]) + else: + hpd_intervals[i] = np.asarray([interval[0], interval[-1]]) + + return np.array(hpd_intervals) def loo(data, pointwise=False, reff=None, scale=None): diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index b1cebd47ff..6ee7811847 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -45,6 +45,52 @@ def test_hpd(): assert_array_almost_equal(interval, [-1.88, 1.88], 2) +def test_hpd_2darray(): + normal_sample = np.random.randn(12000, 5) + result = hpd(normal_sample) + assert result.shape == (5, 2,) + + +def test_hpd_multidimension(): + normal_sample = np.random.randn(12000, 10, 3) + result = hpd(normal_sample) + assert result.shape == (3, 2,) + + +def test_hpd_idata(centered_eight): + data = centered_eight.posterior + result = hpd(data) + assert isinstance(result, Dataset) + assert result.dims == {"school": 8, "hpd": 2} + + result = hpd(data, input_core_dims=[["chain"]]) + assert isinstance(result, Dataset) + assert result.dims == {"draw": 500, "hpd": 2, "school": 8} + + +def test_hpd_idata_varnames(centered_eight): + data = centered_eight.posterior + result = hpd(data, var_names=["mu", "theta"]) + assert isinstance(result, Dataset) + assert result.dims == {"hpd": 2, "school": 8} + assert list(result.data_vars.keys()) == ["mu", "theta"] + + +def test_hpd_idata_group(centered_eight): + result_posterior = hpd(centered_eight, group="posterior", var_names="mu") + result_prior = hpd(centered_eight, group="prior", var_names="mu") + assert result_prior.dims == {"hpd": 2} + range_posterior = result_posterior.mu.values[1] - result_posterior.mu.values[0] + range_prior = result_prior.mu.values[1] - result_prior.mu.values[0] + assert range_posterior < range_prior + + +def test_hpd_coords(centered_eight): + data = centered_eight.posterior + result = hpd(data, coords={"chain": [0, 1, 3]}, input_core_dims=[["draw"]]) + assert_array_equal(result.coords["chain"], [0, 1, 3]) + + def test_hpd_multimodal(): normal_sample = np.concatenate( (np.random.normal(-4, 1, 2500000), np.random.normal(2, 0.5, 2500000))