From d29cefbbd7879d340e61012abd443ec19e35570b Mon Sep 17 00:00:00 2001 From: percygautam Date: Thu, 12 Mar 2020 01:34:06 +0530 Subject: [PATCH 01/21] revamped hpd function --- arviz/stats/stats.py | 146 +++++++++++++++++++++---------------------- 1 file changed, 71 insertions(+), 75 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index d29fff70e4..743e53953e 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -305,7 +305,7 @@ def _ic_matrix(ics, ic_i): return rows, cols, ic_i_val -def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=False): +def hpd(ary, *, credible_interval=None, circular=False, multimodal=False, skipna=False): """ Calculate highest posterior density (HPD) of array for given credible_interval. @@ -343,91 +343,87 @@ def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=Fa ...: data = np.random.normal(size=2000) ...: az.hpd(data, credible_interval=.68) """ - if credible_interval is None: - credible_interval = rcParams["stats.credible_interval"] - else: - if not 1 >= credible_interval > 0: - raise ValueError("The value of credible_interval should be in the interval (0, 1]") - if ary.ndim > 1: - hpd_array = np.array( - [ - hpd( - row, - credible_interval=credible_interval, - circular=circular, - multimodal=multimodal, - ) - for row in ary.T - ] - ) - return hpd_array + if isinstance(ary, np.ndarray): + ary = np.atleast_2d(ary) + ary = convert_to_dataset(ary, group="posterior") - if multimodal: - if skipna: - ary = ary[~np.isnan(ary)] + def _hpd(ary=ary, credible_interval=credible_interval, circular=circular, multimodal=multimodal, skipna=skipna): - if ary.dtype.kind == "f": - density, lower, upper = _fast_kde(ary) - range_x = upper - lower - dx = range_x / len(density) - bins = np.linspace(lower, upper, len(density)) + if credible_interval is None: + credible_interval = rcParams["stats.credible_interval"] else: - bins = get_bins(ary) - _, density, _ = histogram(ary, bins=bins) - dx = np.diff(bins)[0] - - density *= dx - - idx = np.argsort(-density) - intervals = bins[idx][density[idx].cumsum() <= credible_interval] - intervals.sort() - - intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1) - - hpd_intervals = [] - for interval in intervals_splitted: - if interval.size == 0: - hpd_intervals.append((bins[0], bins[0])) + if not 1 >= credible_interval > 0: + raise ValueError("The value of credible_interval should be in the interval (0, 1]") + + if multimodal: + if skipna: + ary = ary[~np.isnan(ary)] + + if ary.dtype.kind == "f": + density, lower, upper = _fast_kde(ary) + range_x = upper - lower + dx = range_x / len(density) + bins = np.linspace(lower, upper, len(density)) else: - hpd_intervals.append((interval[0], interval[-1])) + bins = get_bins(ary) + _, density, _ = histogram(ary, bins=bins) + dx = np.diff(bins)[0] - hpd_intervals = np.array(hpd_intervals) - - else: - if skipna: - nans = np.isnan(ary) - if not nans.all(): - ary = ary[~nans] - n = len(ary) + density *= dx - if circular: - mean = st.circmean(ary, high=np.pi, low=-np.pi) - ary = ary - mean - ary = np.arctan2(np.sin(ary), np.cos(ary)) + idx = np.argsort(-density) + intervals = bins[idx][density[idx].cumsum() <= credible_interval] + intervals.sort() - ary = np.sort(ary) - interval_idx_inc = int(np.floor(credible_interval * n)) - n_intervals = n - interval_idx_inc - interval_width = ary[interval_idx_inc:] - ary[:n_intervals] + intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1) - if len(interval_width) == 0: - raise ValueError("Too few elements for interval calculation. ") + hpd_intervals = [] + for interval in intervals_splitted: + if interval.size == 0: + hpd_intervals.append((bins[0], bins[0])) + else: + hpd_intervals.append((interval[0], interval[-1])) - min_idx = np.argmin(interval_width) - hdi_min = ary[min_idx] - hdi_max = ary[min_idx + interval_idx_inc] - - if circular: - hdi_min = hdi_min + mean - hdi_max = hdi_max + mean - hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min)) - hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max)) - - hpd_intervals = np.array([hdi_min, hdi_max]) - - return hpd_intervals + hpd_intervals = np.array(hpd_intervals) + else: + if skipna: + nans = np.isnan(ary) + if not nans.all(): + ary = ary[~nans] + n = len(ary) + + if circular: + mean = st.circmean(ary, high=np.pi, low=-np.pi) + ary = ary - mean + ary = np.arctan2(np.sin(ary), np.cos(ary)) + + ary = np.sort(ary) + interval_idx_inc = int(np.floor(credible_interval * n)) + n_intervals = n - interval_idx_inc + interval_width = ary[interval_idx_inc:] - ary[:n_intervals] + + if len(interval_width) == 0: + raise ValueError("Too few elements for interval calculation. ") + + min_idx = np.argmin(interval_width) + hdi_min = ary[min_idx] + hdi_max = ary[min_idx + interval_idx_inc] + + if circular: + hdi_min = hdi_min + mean + hdi_max = hdi_max + mean + hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min)) + hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max)) + + hpd_intervals = np.array([hdi_min, hdi_max]) + + return hpd_intervals + + func_kwargs = {"out_shape": (2,)} + kwargs = {"input_core_dims": [["chain"]], "output_core_dims": [['hpd']]} + return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs).to_array().values[0] def loo(data, pointwise=False, reff=None, scale=None): """Pareto-smoothed importance sampling leave-one-out cross-validation. From 2bb5646389640407f33a8167e0fedf815bdb456e Mon Sep 17 00:00:00 2001 From: percygautam Date: Sat, 14 Mar 2020 01:35:18 +0530 Subject: [PATCH 02/21] corrected hpd errors --- arviz/stats/stats.py | 151 +++++++++++++++++++++++-------------------- 1 file changed, 80 insertions(+), 71 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 743e53953e..c5f7edbf10 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -305,7 +305,7 @@ def _ic_matrix(ics, ic_i): return rows, cols, ic_i_val -def hpd(ary, *, credible_interval=None, circular=False, multimodal=False, skipna=False): +def hpd(ary, *, group =None, credible_interval=None, circular=False, multimodal=False, skipna=False): """ Calculate highest posterior density (HPD) of array for given credible_interval. @@ -343,87 +343,96 @@ def hpd(ary, *, credible_interval=None, circular=False, multimodal=False, skipna ...: data = np.random.normal(size=2000) ...: az.hpd(data, credible_interval=.68) """ - + if group is None: + group = ['chain', 'draw'] if isinstance(ary, np.ndarray): - ary = np.atleast_2d(ary) - ary = convert_to_dataset(ary, group="posterior") + data = convert_to_dataset(ary) + + kwargs = {"input_core_dims": [group], "output_core_dims": [['hpd']]} + func_kwargs = { + "credible_interval": credible_interval, + "circular": circular, + "multimodal": multimodal, + "skipna": skipna, + "out_shape": (2,), + } + return _wrap_xarray_ufunc(_hpd, data, func_kwargs=func_kwargs, **kwargs) + + +def _hpd(ary, credible_interval, circular, multimodal, skipna): + """Compute hpd over the flattened array.""" + ary = ary.flatten() + if credible_interval is None: + credible_interval = rcParams["stats.credible_interval"] + else: + if not 1 >= credible_interval > 0: + raise ValueError("The value of credible_interval should be in the interval (0, 1]") - def _hpd(ary=ary, credible_interval=credible_interval, circular=circular, multimodal=multimodal, skipna=skipna): + if multimodal: + if skipna: + ary = ary[~np.isnan(ary)] - if credible_interval is None: - credible_interval = rcParams["stats.credible_interval"] + if ary.dtype.kind == "f": + density, lower, upper = _fast_kde(ary) + range_x = upper - lower + dx = range_x / len(density) + bins = np.linspace(lower, upper, len(density)) else: - if not 1 >= credible_interval > 0: - raise ValueError("The value of credible_interval should be in the interval (0, 1]") - - if multimodal: - if skipna: - ary = ary[~np.isnan(ary)] - - if ary.dtype.kind == "f": - density, lower, upper = _fast_kde(ary) - range_x = upper - lower - dx = range_x / len(density) - bins = np.linspace(lower, upper, len(density)) + bins = get_bins(ary) + _, density, _ = histogram(ary, bins=bins) + dx = np.diff(bins)[0] + + density *= dx + + idx = np.argsort(-density) + intervals = bins[idx][density[idx].cumsum() <= credible_interval] + intervals.sort() + + intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1) + + hpd_intervals = [] + for interval in intervals_splitted: + if interval.size == 0: + hpd_intervals.append((bins[0], bins[0])) else: - bins = get_bins(ary) - _, density, _ = histogram(ary, bins=bins) - dx = np.diff(bins)[0] + hpd_intervals.append((interval[0], interval[-1])) - density *= dx + hpd_intervals = np.array(hpd_intervals) - idx = np.argsort(-density) - intervals = bins[idx][density[idx].cumsum() <= credible_interval] - intervals.sort() + else: + if skipna: + nans = np.isnan(ary) + if not nans.all(): + ary = ary[~nans] + n = len(ary) - intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1) + if circular: + mean = st.circmean(ary, high=np.pi, low=-np.pi) + ary = ary - mean + ary = np.arctan2(np.sin(ary), np.cos(ary)) - hpd_intervals = [] - for interval in intervals_splitted: - if interval.size == 0: - hpd_intervals.append((bins[0], bins[0])) - else: - hpd_intervals.append((interval[0], interval[-1])) + ary = np.sort(ary) + interval_idx_inc = int(np.floor(credible_interval * n)) + n_intervals = n - interval_idx_inc + interval_width = ary[interval_idx_inc:] - ary[:n_intervals] - hpd_intervals = np.array(hpd_intervals) + if len(interval_width) == 0: + raise ValueError("Too few elements for interval calculation. ") + + min_idx = np.argmin(interval_width) + hdi_min = ary[min_idx] + hdi_max = ary[min_idx + interval_idx_inc] + + if circular: + hdi_min = hdi_min + mean + hdi_max = hdi_max + mean + hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min)) + hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max)) + + hpd_intervals = np.array([hdi_min, hdi_max]) + + return hpd_intervals - else: - if skipna: - nans = np.isnan(ary) - if not nans.all(): - ary = ary[~nans] - n = len(ary) - - if circular: - mean = st.circmean(ary, high=np.pi, low=-np.pi) - ary = ary - mean - ary = np.arctan2(np.sin(ary), np.cos(ary)) - - ary = np.sort(ary) - interval_idx_inc = int(np.floor(credible_interval * n)) - n_intervals = n - interval_idx_inc - interval_width = ary[interval_idx_inc:] - ary[:n_intervals] - - if len(interval_width) == 0: - raise ValueError("Too few elements for interval calculation. ") - - min_idx = np.argmin(interval_width) - hdi_min = ary[min_idx] - hdi_max = ary[min_idx + interval_idx_inc] - - if circular: - hdi_min = hdi_min + mean - hdi_max = hdi_max + mean - hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min)) - hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max)) - - hpd_intervals = np.array([hdi_min, hdi_max]) - - return hpd_intervals - - func_kwargs = {"out_shape": (2,)} - kwargs = {"input_core_dims": [["chain"]], "output_core_dims": [['hpd']]} - return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs).to_array().values[0] def loo(data, pointwise=False, reff=None, scale=None): """Pareto-smoothed importance sampling leave-one-out cross-validation. From 7de2518e7453e949565774a6a707406b7378db14 Mon Sep 17 00:00:00 2001 From: percygautam Date: Sat, 14 Mar 2020 01:38:08 +0530 Subject: [PATCH 03/21] add docstring --- arviz/stats/stats.py | 68 +++++++++++++++++++++++--------------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index c5f7edbf10..cdd568dd4b 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -44,7 +44,7 @@ def compare( - dataset_dict, ic=None, method="BB-pseudo-BMA", b_samples=1000, alpha=1, seed=None, scale=None + dataset_dict, ic=None, method="BB-pseudo-BMA", b_samples=1000, alpha=1, seed=None, scale=None ): r"""Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation. @@ -305,7 +305,7 @@ def _ic_matrix(ics, ic_i): return rows, cols, ic_i_val -def hpd(ary, *, group =None, credible_interval=None, circular=False, multimodal=False, skipna=False): +def hpd(ary, *, group=None, credible_interval=None, circular=False, multimodal=False, skipna=False): """ Calculate highest posterior density (HPD) of array for given credible_interval. @@ -315,6 +315,8 @@ def hpd(ary, *, group =None, credible_interval=None, circular=False, multimodal= ---------- ary : Numpy array An array containing posterior samples + group : List + An list containing the dimensions to compute hpd credible_interval : float, optional Credible interval to compute. Defaults to 0.94. circular : bool, optional @@ -525,7 +527,7 @@ def loo(data, pointwise=False, reff=None, scale=None): ess_p = ess(posterior, method="mean") # this mean is over all data variables reff = ( - np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples + np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples ) log_weights, pareto_shape = psislw(-log_likelihood, reff) @@ -811,20 +813,20 @@ def r2_score(y_true, y_pred): def summary( - data, - var_names: Optional[List[str]] = None, - fmt: str = "wide", - kind: str = "all", - round_to=None, - include_circ=None, - stat_funcs=None, - extend=True, - credible_interval=None, - order="C", - index_origin=None, - skipna=False, - coords: Optional[CoordSpec] = None, - dims: Optional[DimSpec] = None, + data, + var_names: Optional[List[str]] = None, + fmt: str = "wide", + kind: str = "all", + round_to=None, + include_circ=None, + stat_funcs=None, + extend=True, + credible_interval=None, + order="C", + index_origin=None, + skipna=False, + coords: Optional[CoordSpec] = None, + dims: Optional[DimSpec] = None, ) -> Union[pd.DataFrame, xr.Dataset]: """Create a data frame with summary statistics. @@ -1407,22 +1409,22 @@ def _loo_pit(y, y_hat, log_weights): def apply_test_function( - idata, - func, - group="both", - var_names=None, - pointwise=False, - out_data_shape=None, - out_pp_shape=None, - out_name_data="T", - out_name_pp=None, - func_args=None, - func_kwargs=None, - ufunc_kwargs=None, - wrap_data_kwargs=None, - wrap_pp_kwargs=None, - inplace=True, - overwrite=None, + idata, + func, + group="both", + var_names=None, + pointwise=False, + out_data_shape=None, + out_pp_shape=None, + out_name_data="T", + out_name_pp=None, + func_args=None, + func_kwargs=None, + ufunc_kwargs=None, + wrap_data_kwargs=None, + wrap_pp_kwargs=None, + inplace=True, + overwrite=None, ): """Apply a Bayesian test function to an InferenceData object. From e4d83180851a61ae282face829e62361e58d912a Mon Sep 17 00:00:00 2001 From: percygautam Date: Sat, 14 Mar 2020 01:40:16 +0530 Subject: [PATCH 04/21] linting change --- arviz/stats/stats.py | 68 ++++++++++++++++++++++---------------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index cdd568dd4b..c99079bab3 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -44,7 +44,7 @@ def compare( - dataset_dict, ic=None, method="BB-pseudo-BMA", b_samples=1000, alpha=1, seed=None, scale=None + dataset_dict, ic=None, method="BB-pseudo-BMA", b_samples=1000, alpha=1, seed=None, scale=None ): r"""Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation. @@ -346,11 +346,11 @@ def hpd(ary, *, group=None, credible_interval=None, circular=False, multimodal=F ...: az.hpd(data, credible_interval=.68) """ if group is None: - group = ['chain', 'draw'] + group = ["chain", "draw"] if isinstance(ary, np.ndarray): data = convert_to_dataset(ary) - kwargs = {"input_core_dims": [group], "output_core_dims": [['hpd']]} + kwargs = {"input_core_dims": [group], "output_core_dims": [["hpd"]]} func_kwargs = { "credible_interval": credible_interval, "circular": circular, @@ -527,7 +527,7 @@ def loo(data, pointwise=False, reff=None, scale=None): ess_p = ess(posterior, method="mean") # this mean is over all data variables reff = ( - np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples + np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples ) log_weights, pareto_shape = psislw(-log_likelihood, reff) @@ -813,20 +813,20 @@ def r2_score(y_true, y_pred): def summary( - data, - var_names: Optional[List[str]] = None, - fmt: str = "wide", - kind: str = "all", - round_to=None, - include_circ=None, - stat_funcs=None, - extend=True, - credible_interval=None, - order="C", - index_origin=None, - skipna=False, - coords: Optional[CoordSpec] = None, - dims: Optional[DimSpec] = None, + data, + var_names: Optional[List[str]] = None, + fmt: str = "wide", + kind: str = "all", + round_to=None, + include_circ=None, + stat_funcs=None, + extend=True, + credible_interval=None, + order="C", + index_origin=None, + skipna=False, + coords: Optional[CoordSpec] = None, + dims: Optional[DimSpec] = None, ) -> Union[pd.DataFrame, xr.Dataset]: """Create a data frame with summary statistics. @@ -1409,22 +1409,22 @@ def _loo_pit(y, y_hat, log_weights): def apply_test_function( - idata, - func, - group="both", - var_names=None, - pointwise=False, - out_data_shape=None, - out_pp_shape=None, - out_name_data="T", - out_name_pp=None, - func_args=None, - func_kwargs=None, - ufunc_kwargs=None, - wrap_data_kwargs=None, - wrap_pp_kwargs=None, - inplace=True, - overwrite=None, + idata, + func, + group="both", + var_names=None, + pointwise=False, + out_data_shape=None, + out_pp_shape=None, + out_name_data="T", + out_name_pp=None, + func_args=None, + func_kwargs=None, + ufunc_kwargs=None, + wrap_data_kwargs=None, + wrap_pp_kwargs=None, + inplace=True, + overwrite=None, ): """Apply a Bayesian test function to an InferenceData object. From ae2e59ded65b408549dc1078e08aa29ab9a94fe3 Mon Sep 17 00:00:00 2001 From: percygautam Date: Sat, 14 Mar 2020 17:38:02 +0530 Subject: [PATCH 05/21] minor nits --- arviz/stats/stats.py | 131 ++++++++++++++------------- arviz/tests/base_tests/test_stats.py | 6 +- 2 files changed, 70 insertions(+), 67 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index c99079bab3..9cc7ed122b 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -305,7 +305,7 @@ def _ic_matrix(ics, ic_i): return rows, cols, ic_i_val -def hpd(ary, *, group=None, credible_interval=None, circular=False, multimodal=False, skipna=False): +def hpd(ary, *, credible_interval=None, circular=False, multimodal=False, skipna=False, **kwargs): """ Calculate highest posterior density (HPD) of array for given credible_interval. @@ -345,95 +345,98 @@ def hpd(ary, *, group=None, credible_interval=None, circular=False, multimodal=F ...: data = np.random.normal(size=2000) ...: az.hpd(data, credible_interval=.68) """ - if group is None: - group = ["chain", "draw"] + if credible_interval is None: + credible_interval = rcParams["stats.credible_interval"] + else: + if not 1 >= credible_interval > 0: + raise ValueError("The value of credible_interval should be in the interval (0, 1]") + + if multimodal: + return _hpd_multimodal(ary, credible_interval, skipna) + if isinstance(ary, np.ndarray): - data = convert_to_dataset(ary) + ary = convert_to_dataset(ary) - kwargs = {"input_core_dims": [group], "output_core_dims": [["hpd"]]} + kwargs.setdefault("input_core_dims", [["chain", "draw"]]) + kwargs.setdefault("output_core_dims", [["hpd"]]) func_kwargs = { "credible_interval": credible_interval, "circular": circular, - "multimodal": multimodal, "skipna": skipna, "out_shape": (2,), } - return _wrap_xarray_ufunc(_hpd, data, func_kwargs=func_kwargs, **kwargs) + return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs) -def _hpd(ary, credible_interval, circular, multimodal, skipna): +def _hpd(ary, credible_interval, circular, skipna): """Compute hpd over the flattened array.""" ary = ary.flatten() - if credible_interval is None: - credible_interval = rcParams["stats.credible_interval"] - else: - if not 1 >= credible_interval > 0: - raise ValueError("The value of credible_interval should be in the interval (0, 1]") + if skipna: + nans = np.isnan(ary) + if not nans.all(): + ary = ary[~nans] + n = len(ary) - if multimodal: - if skipna: - ary = ary[~np.isnan(ary)] - - if ary.dtype.kind == "f": - density, lower, upper = _fast_kde(ary) - range_x = upper - lower - dx = range_x / len(density) - bins = np.linspace(lower, upper, len(density)) - else: - bins = get_bins(ary) - _, density, _ = histogram(ary, bins=bins) - dx = np.diff(bins)[0] + if circular: + mean = st.circmean(ary, high=np.pi, low=-np.pi) + ary = ary - mean + ary = np.arctan2(np.sin(ary), np.cos(ary)) - density *= dx + ary = np.sort(ary) + interval_idx_inc = int(np.floor(credible_interval * n)) + n_intervals = n - interval_idx_inc + interval_width = ary[interval_idx_inc:] - ary[:n_intervals] - idx = np.argsort(-density) - intervals = bins[idx][density[idx].cumsum() <= credible_interval] - intervals.sort() + if len(interval_width) == 0: + raise ValueError("Too few elements for interval calculation. ") - intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1) + min_idx = np.argmin(interval_width) + hdi_min = ary[min_idx] + hdi_max = ary[min_idx + interval_idx_inc] - hpd_intervals = [] - for interval in intervals_splitted: - if interval.size == 0: - hpd_intervals.append((bins[0], bins[0])) - else: - hpd_intervals.append((interval[0], interval[-1])) + if circular: + hdi_min = hdi_min + mean + hdi_max = hdi_max + mean + hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min)) + hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max)) - hpd_intervals = np.array(hpd_intervals) + hpd_intervals = np.array([hdi_min, hdi_max]) - else: - if skipna: - nans = np.isnan(ary) - if not nans.all(): - ary = ary[~nans] - n = len(ary) + return hpd_intervals - if circular: - mean = st.circmean(ary, high=np.pi, low=-np.pi) - ary = ary - mean - ary = np.arctan2(np.sin(ary), np.cos(ary)) - ary = np.sort(ary) - interval_idx_inc = int(np.floor(credible_interval * n)) - n_intervals = n - interval_idx_inc - interval_width = ary[interval_idx_inc:] - ary[:n_intervals] +def _hpd_multimodal(ary, credible_interval, skipna): + """Compute hpd if the distribution is multimodal""" - if len(interval_width) == 0: - raise ValueError("Too few elements for interval calculation. ") + if skipna: + ary = ary[~np.isnan(ary)] - min_idx = np.argmin(interval_width) - hdi_min = ary[min_idx] - hdi_max = ary[min_idx + interval_idx_inc] + if ary.dtype.kind == "f": + density, lower, upper = _fast_kde(ary) + range_x = upper - lower + dx = range_x / len(density) + bins = np.linspace(lower, upper, len(density)) + else: + bins = get_bins(ary) + _, density, _ = histogram(ary, bins=bins) + dx = np.diff(bins)[0] - if circular: - hdi_min = hdi_min + mean - hdi_max = hdi_max + mean - hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min)) - hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max)) + density *= dx - hpd_intervals = np.array([hdi_min, hdi_max]) + idx = np.argsort(-density) + intervals = bins[idx][density[idx].cumsum() <= credible_interval] + intervals.sort() - return hpd_intervals + intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1) + + hpd_intervals = [] + for interval in intervals_splitted: + if interval.size == 0: + hpd_intervals.append((bins[0], bins[0])) + else: + hpd_intervals.append((interval[0], interval[-1])) + + return np.array(hpd_intervals) def loo(data, pointwise=False, reff=None, scale=None): diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index ab38aea502..e7f6353435 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -44,7 +44,7 @@ def non_centered_eight(): def test_hpd(): normal_sample = np.random.randn(5000000) interval = hpd(normal_sample) - assert_array_almost_equal(interval, [-1.88, 1.88], 2) + assert_array_almost_equal(interval.x.values, [-1.88, 1.88], 2) def test_hpd_multimodal(): @@ -58,7 +58,7 @@ def test_hpd_multimodal(): def test_hpd_circular(): normal_sample = np.random.vonmises(np.pi, 1, 5000000) interval = hpd(normal_sample, circular=True) - assert_array_almost_equal(interval, [0.6, -0.6], 1) + assert_array_almost_equal(interval.x.values, [0.6, -0.6], 1) def test_hpd_bad_ci(): @@ -72,7 +72,7 @@ def test_hpd_skipna(): interval = hpd(normal_sample[10:]) normal_sample[:10] = np.nan interval_ = hpd(normal_sample, skipna=True) - assert_array_almost_equal(interval, interval_) + assert_array_almost_equal(interval.x.values, interval_.x.values) def test_r2_score(): From f2e67b481b7c90beaaf65924df4f16e6a7ff33b9 Mon Sep 17 00:00:00 2001 From: percygautam Date: Fri, 20 Mar 2020 01:21:13 +0530 Subject: [PATCH 06/21] allow input to be ndarray/dataset --- arviz/stats/stats.py | 46 +++++++++++++++++++++------- arviz/tests/base_tests/test_stats.py | 12 ++++++-- 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 9cc7ed122b..5afc823d4a 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -305,7 +305,17 @@ def _ic_matrix(ics, ic_i): return rows, cols, ic_i_val -def hpd(ary, *, credible_interval=None, circular=False, multimodal=False, skipna=False, **kwargs): +def hpd( + ary, + *, + group="posterior", + var_names=None, + credible_interval=None, + circular=False, + multimodal=False, + skipna=False, + **kwargs +): """ Calculate highest posterior density (HPD) of array for given credible_interval. @@ -313,10 +323,14 @@ def hpd(ary, *, credible_interval=None, circular=False, multimodal=False, skipna Parameters ---------- - ary : Numpy array - An array containing posterior samples - group : List - An list containing the dimensions to compute hpd + ary : obj + onject containing posterior samples. + Any object that can be converted to an az.InferenceData object. + Refer to documentation of az.convert_to_dataset for details. + group : str, optional + Specifies which InferenceData group should be used to calculate hpd. Defaults to 'posterior' + var_names : list + Names of variables to include in the hpd report credible_interval : float, optional Credible interval to compute. Defaults to 0.94. circular : bool, optional @@ -354,17 +368,27 @@ def hpd(ary, *, credible_interval=None, circular=False, multimodal=False, skipna if multimodal: return _hpd_multimodal(ary, credible_interval, skipna) - if isinstance(ary, np.ndarray): - ary = convert_to_dataset(ary) - - kwargs.setdefault("input_core_dims", [["chain", "draw"]]) - kwargs.setdefault("output_core_dims", [["hpd"]]) func_kwargs = { "credible_interval": credible_interval, "circular": circular, "skipna": skipna, "out_shape": (2,), } + kwargs.setdefault("output_core_dims", [["hpd"]]) + + if isinstance(ary, np.ndarray): + if len(ary.shape) == 1: + return _hpd(ary, credible_interval, circular, skipna) + ary = convert_to_dataset(ary) + kwargs.setdefault("input_core_dims", [["chain"]]) + return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs).to_array().values[0] + + ary = convert_to_dataset(ary, group=group) + var_names = _var_names(var_names, ary) + + ary = ary if var_names is None else ary[var_names] + + kwargs.setdefault("input_core_dims", [["chain", "draw"]]) return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs) @@ -407,7 +431,7 @@ def _hpd(ary, credible_interval, circular, skipna): def _hpd_multimodal(ary, credible_interval, skipna): """Compute hpd if the distribution is multimodal""" - + ary = ary.flatten() if skipna: ary = ary[~np.isnan(ary)] diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index e7f6353435..98ebdc8629 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -44,7 +44,13 @@ def non_centered_eight(): def test_hpd(): normal_sample = np.random.randn(5000000) interval = hpd(normal_sample) - assert_array_almost_equal(interval.x.values, [-1.88, 1.88], 2) + assert_array_almost_equal(interval, [-1.88, 1.88], 2) + + +def test_hpd_multidimension(): + normal_sample = np.random.randn(12000, 10, 3) + result = hpd(normal_sample) + assert result.shape == (10, 3, 2,) def test_hpd_multimodal(): @@ -58,7 +64,7 @@ def test_hpd_multimodal(): def test_hpd_circular(): normal_sample = np.random.vonmises(np.pi, 1, 5000000) interval = hpd(normal_sample, circular=True) - assert_array_almost_equal(interval.x.values, [0.6, -0.6], 1) + assert_array_almost_equal(interval, [0.6, -0.6], 1) def test_hpd_bad_ci(): @@ -72,7 +78,7 @@ def test_hpd_skipna(): interval = hpd(normal_sample[10:]) normal_sample[:10] = np.nan interval_ = hpd(normal_sample, skipna=True) - assert_array_almost_equal(interval.x.values, interval_.x.values) + assert_array_almost_equal(interval, interval_) def test_r2_score(): From 6dd3ee85e3d951630381c22cc9cdb5124b7f13f6 Mon Sep 17 00:00:00 2001 From: percygautam Date: Fri, 20 Mar 2020 23:37:24 +0530 Subject: [PATCH 07/21] add tests for unimodal case --- arviz/stats/stats.py | 13 ++++++++----- arviz/tests/base_tests/test_stats.py | 23 +++++++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 5afc823d4a..03e496873b 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -324,11 +324,11 @@ def hpd( Parameters ---------- ary : obj - onject containing posterior samples. + object containing posterior samples. Any object that can be converted to an az.InferenceData object. Refer to documentation of az.convert_to_dataset for details. group : str, optional - Specifies which InferenceData group should be used to calculate hpd. Defaults to 'posterior' + Specifies which InferenceData group should be used to calculate hpd. Defaults to 'posterior' var_names : list Names of variables to include in the hpd report credible_interval : float, optional @@ -342,10 +342,14 @@ def hpd( modes are well separated. skipna : bool If true ignores nan values when computing the hpd interval. Defaults to false. + kwargs : dict, optional + Additional keywords passed to ax.scatter + For example, to calculate hpd over "chain" dimension, Pass "input_core_dims" to "chain" in + the dictionary. Returns ------- - np.ndarray + np.ndarray or xarray.Dataset, depending upon input lower(s) and upper(s) values of the interval(s). Examples @@ -381,14 +385,13 @@ def hpd( return _hpd(ary, credible_interval, circular, skipna) ary = convert_to_dataset(ary) kwargs.setdefault("input_core_dims", [["chain"]]) - return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs).to_array().values[0] + return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs).x.values ary = convert_to_dataset(ary, group=group) var_names = _var_names(var_names, ary) ary = ary if var_names is None else ary[var_names] - kwargs.setdefault("input_core_dims", [["chain", "draw"]]) return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs) diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index 98ebdc8629..e20bf6df73 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -47,12 +47,35 @@ def test_hpd(): assert_array_almost_equal(interval, [-1.88, 1.88], 2) +def test_hpd_2darray(): + normal_sample = np.random.randn(12000, 5) + result = hpd(normal_sample) + assert result.shape == (5, 2,) + + def test_hpd_multidimension(): normal_sample = np.random.randn(12000, 10, 3) result = hpd(normal_sample) assert result.shape == (10, 3, 2,) +def test_hpd_idata(): + data = load_arviz_data("centered_eight") + normal_sample = data.posterior + result = hpd(normal_sample) + assert isinstance(result, Dataset) + assert result.dims == {"school": 8, "hpd": 2} + + result = hpd(normal_sample, **{"input_core_dims": [["chain"]]}) + assert isinstance(result, Dataset) + assert result.dims == {"draw": 500, "hpd": 2, "school": 8} + + result = hpd(normal_sample, var_names=["mu", "theta"]) + assert isinstance(result, Dataset) + assert result.dims == {"hpd": 2, "school": 8} + assert list(result.data_vars.keys()) == ["mu", "theta"] + + def test_hpd_multimodal(): normal_sample = np.concatenate( (np.random.normal(-4, 1, 2500000), np.random.normal(2, 0.5, 2500000)) From 1d75d932b5103115e9e7a06c94abfacb766b15b2 Mon Sep 17 00:00:00 2001 From: percygautam Date: Sat, 21 Mar 2020 01:36:57 +0530 Subject: [PATCH 08/21] test group and other minor nits for unimodal --- arviz/stats/stats.py | 29 ++++++++++++++++++++++++---- arviz/tests/base_tests/test_stats.py | 23 ++++++++++++++++------ 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 03e496873b..81cc250dae 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -343,9 +343,8 @@ def hpd( skipna : bool If true ignores nan values when computing the hpd interval. Defaults to false. kwargs : dict, optional - Additional keywords passed to ax.scatter - For example, to calculate hpd over "chain" dimension, Pass "input_core_dims" to "chain" in - the dictionary. + Additional keywords passed to `wrap_xarray_ufunc`. + See the docstring of :obj:`wrap_xarray_ufunc method `. Returns ------- @@ -362,6 +361,28 @@ def hpd( ...: import numpy as np ...: data = np.random.normal(size=2000) ...: az.hpd(data, credible_interval=.68) + + Calculate the hpd of a dataset: + + .. ipython:: + + In [1]: import arviz as az + ...: data = az.load_arviz_data('centered_eight') + ...: az.hpd(data) + + We can also calculate the hpd of some of the variables of dataset: + + .. ipython:: + + In [1]: az.hpd(data, var_names=["mu", "theta"]) + + If we want to calculate the hpd over specified dimension of dataset, + we can pass `input_core_dims` by kwargs: + + .. ipython:: + + In [1]: az.hpd(data, **{"input_core_dims": [["chain"]]}) + """ if credible_interval is None: credible_interval = rcParams["stats.credible_interval"] @@ -390,7 +411,7 @@ def hpd( ary = convert_to_dataset(ary, group=group) var_names = _var_names(var_names, ary) - ary = ary if var_names is None else ary[var_names] + ary = ary[var_names] if var_names else ary return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs) diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index e20bf6df73..9ebfeaaf0c 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -59,23 +59,34 @@ def test_hpd_multidimension(): assert result.shape == (10, 3, 2,) -def test_hpd_idata(): - data = load_arviz_data("centered_eight") - normal_sample = data.posterior - result = hpd(normal_sample) +def test_hpd_idata(centered_eight): + data = centered_eight.posterior + result = hpd(data) assert isinstance(result, Dataset) assert result.dims == {"school": 8, "hpd": 2} - result = hpd(normal_sample, **{"input_core_dims": [["chain"]]}) + result = hpd(data, **{"input_core_dims": [["chain"]]}) assert isinstance(result, Dataset) assert result.dims == {"draw": 500, "hpd": 2, "school": 8} - result = hpd(normal_sample, var_names=["mu", "theta"]) + +def test_hpd_idata_varnames(centered_eight): + data = centered_eight.posterior + result = hpd(data, var_names=["mu", "theta"]) assert isinstance(result, Dataset) assert result.dims == {"hpd": 2, "school": 8} assert list(result.data_vars.keys()) == ["mu", "theta"] +def test_hpd_idata_group(centered_eight): + result_posterior = hpd(centered_eight, group="posterior", var_names="mu") + result_prior = hpd(centered_eight, group="prior", var_names="mu") + assert result_prior.dims == {"hpd": 2} + print(result_posterior.mu.values, result_prior.mu.values) + assert result_posterior.mu.values[0] > result_prior.mu.values[0] + assert result_posterior.mu.values[1] > result_prior.mu.values[1] + + def test_hpd_multimodal(): normal_sample = np.concatenate( (np.random.normal(-4, 1, 2500000), np.random.normal(2, 0.5, 2500000)) From a24323ee5d07c4321e4c2f38d94e07263c196e6f Mon Sep 17 00:00:00 2001 From: percygautam Date: Sun, 22 Mar 2020 22:48:07 +0530 Subject: [PATCH 09/21] corrected failed tests --- arviz/plots/backends/bokeh/densityplot.py | 4 ++-- arviz/plots/backends/bokeh/forestplot.py | 2 +- arviz/plots/backends/bokeh/violinplot.py | 2 +- arviz/plots/backends/matplotlib/densityplot.py | 4 ++-- arviz/plots/backends/matplotlib/forestplot.py | 2 +- arviz/plots/backends/matplotlib/violinplot.py | 2 +- arviz/stats/stats.py | 3 ++- arviz/tests/base_tests/test_stats.py | 6 +++--- 8 files changed, 13 insertions(+), 12 deletions(-) diff --git a/arviz/plots/backends/bokeh/densityplot.py b/arviz/plots/backends/bokeh/densityplot.py index 3bc83a76ea..93ea1c5069 100644 --- a/arviz/plots/backends/bokeh/densityplot.py +++ b/arviz/plots/backends/bokeh/densityplot.py @@ -124,7 +124,7 @@ def _d_helper( if vec.dtype.kind == "f": if credible_interval != 1: - hpd_ = hpd(vec, credible_interval, multimodal=False) + hpd_ = hpd(vec, credible_interval=credible_interval, multimodal=False) new_vec = vec[(vec >= hpd_[0]) & (vec <= hpd_[1])] else: new_vec = vec @@ -174,7 +174,7 @@ def _d_helper( ) else: - xmin, xmax = hpd(vec, credible_interval, multimodal=False) + xmin, xmax = hpd(vec, credible_interval=credible_interval, multimodal=False) bins = get_bins(vec) _, hist, edges = histogram(vec, bins=bins) diff --git a/arviz/plots/backends/bokeh/forestplot.py b/arviz/plots/backends/bokeh/forestplot.py index 0d19ef8c67..6c24baf80d 100644 --- a/arviz/plots/backends/bokeh/forestplot.py +++ b/arviz/plots/backends/bokeh/forestplot.py @@ -548,7 +548,7 @@ def treeplot(self, qlist, credible_interval): """Get data for each treeplot for the variable.""" for y, _, label, values, color in self.iterator(): ntiles = np.percentile(values.flatten(), qlist) - ntiles[0], ntiles[-1] = hpd(values.flatten(), credible_interval, multimodal=False) + ntiles[0], ntiles[-1] = hpd(values.flatten(), credible_interval=credible_interval, multimodal=False) yield y, label, ntiles, color def ridgeplot(self, mult, ridgeplot_kind): diff --git a/arviz/plots/backends/bokeh/violinplot.py b/arviz/plots/backends/bokeh/violinplot.py index fc0264be7a..bebe3e4d42 100644 --- a/arviz/plots/backends/bokeh/violinplot.py +++ b/arviz/plots/backends/bokeh/violinplot.py @@ -66,7 +66,7 @@ def plot_violin( ax_.scatter(rug_x, val, **rug_kwargs) per = np.percentile(val, [25, 75, 50]) - hpd_intervals = hpd(val, credible_interval, multimodal=False) + hpd_intervals = hpd(val, credible_interval=credible_interval, multimodal=False) if quartiles: ax_.line( diff --git a/arviz/plots/backends/matplotlib/densityplot.py b/arviz/plots/backends/matplotlib/densityplot.py index 4d52b2b9c3..034cca4dac 100644 --- a/arviz/plots/backends/matplotlib/densityplot.py +++ b/arviz/plots/backends/matplotlib/densityplot.py @@ -134,7 +134,7 @@ def _d_helper( """ if vec.dtype.kind == "f": if credible_interval != 1: - hpd_ = hpd(vec, credible_interval, multimodal=False) + hpd_ = hpd(vec, credible_interval=credible_interval, multimodal=False) new_vec = vec[(vec >= hpd_[0]) & (vec <= hpd_[1])] else: new_vec = vec @@ -154,7 +154,7 @@ def _d_helper( ax.fill_between(x, density, color=color, alpha=shade) else: - xmin, xmax = hpd(vec, credible_interval, multimodal=False) + xmin, xmax = hpd(vec, credible_interval=credible_interval, multimodal=False) bins = get_bins(vec) if outline: ax.hist(vec, bins=bins, color=color, histtype="step", align="left") diff --git a/arviz/plots/backends/matplotlib/forestplot.py b/arviz/plots/backends/matplotlib/forestplot.py index d68671dcf8..f2a2118fb5 100644 --- a/arviz/plots/backends/matplotlib/forestplot.py +++ b/arviz/plots/backends/matplotlib/forestplot.py @@ -500,7 +500,7 @@ def treeplot(self, qlist, credible_interval): """Get data for each treeplot for the variable.""" for y, _, label, values, color in self.iterator(): ntiles = np.percentile(values.flatten(), qlist) - ntiles[0], ntiles[-1] = hpd(values.flatten(), credible_interval, multimodal=False) + ntiles[0], ntiles[-1] = hpd(values.flatten(), credible_interval=credible_interval, multimodal=False) yield y, label, ntiles, color def ridgeplot(self, mult, ridgeplot_kind): diff --git a/arviz/plots/backends/matplotlib/violinplot.py b/arviz/plots/backends/matplotlib/violinplot.py index 9c533103c6..e3c6066229 100644 --- a/arviz/plots/backends/matplotlib/violinplot.py +++ b/arviz/plots/backends/matplotlib/violinplot.py @@ -58,7 +58,7 @@ def plot_violin( ax_.plot(rug_x, val, **rug_kwargs) per = np.percentile(val, [25, 75, 50]) - hpd_intervals = hpd(val, credible_interval, multimodal=False) + hpd_intervals = hpd(val, credible_interval=credible_interval, multimodal=False) if quartiles: ax_.plot([0, 0], per[:2], lw=linewidth * 3, color="k", solid_capstyle="round") diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 81cc250dae..44237bb9c6 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -328,7 +328,8 @@ def hpd( Any object that can be converted to an az.InferenceData object. Refer to documentation of az.convert_to_dataset for details. group : str, optional - Specifies which InferenceData group should be used to calculate hpd. Defaults to 'posterior' + Specifies which InferenceData group should be used to calculate hpd. + Defaults to 'posterior' var_names : list Names of variables to include in the hpd report credible_interval : float, optional diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index 9ebfeaaf0c..76374d24b9 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -82,9 +82,9 @@ def test_hpd_idata_group(centered_eight): result_posterior = hpd(centered_eight, group="posterior", var_names="mu") result_prior = hpd(centered_eight, group="prior", var_names="mu") assert result_prior.dims == {"hpd": 2} - print(result_posterior.mu.values, result_prior.mu.values) - assert result_posterior.mu.values[0] > result_prior.mu.values[0] - assert result_posterior.mu.values[1] > result_prior.mu.values[1] + range_posterior = result_posterior.mu.values[1] - result_posterior.mu.values[0] + range_prior = result_prior.mu.values[1] - result_prior.mu.values[0] + assert range_posterior < range_prior def test_hpd_multimodal(): From ff176f4702fd6af1fff50a34285689fe5a5cc36c Mon Sep 17 00:00:00 2001 From: percygautam Date: Sun, 22 Mar 2020 23:50:28 +0530 Subject: [PATCH 10/21] added sel argument and its tests --- arviz/stats/stats.py | 95 ++++++++++++++++------------ arviz/tests/base_tests/test_stats.py | 6 ++ 2 files changed, 59 insertions(+), 42 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 44237bb9c6..3760dfb455 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -44,7 +44,7 @@ def compare( - dataset_dict, ic=None, method="BB-pseudo-BMA", b_samples=1000, alpha=1, seed=None, scale=None + dataset_dict, ic=None, method="BB-pseudo-BMA", b_samples=1000, alpha=1, seed=None, scale=None ): r"""Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation. @@ -306,15 +306,16 @@ def _ic_matrix(ics, ic_i): def hpd( - ary, - *, - group="posterior", - var_names=None, - credible_interval=None, - circular=False, - multimodal=False, - skipna=False, - **kwargs + ary, + *, + group="posterior", + var_names=None, + sel=None, + credible_interval=None, + circular=False, + multimodal=False, + skipna=False, + **kwargs ): """ Calculate highest posterior density (HPD) of array for given credible_interval. @@ -330,8 +331,10 @@ def hpd( group : str, optional Specifies which InferenceData group should be used to calculate hpd. Defaults to 'posterior' - var_names : list + var_names : list, optional Names of variables to include in the hpd report + sel: dict, optional + To calculate hpd over selection on all groups. credible_interval : float, optional Credible interval to compute. Defaults to 0.94. circular : bool, optional @@ -384,6 +387,12 @@ def hpd( In [1]: az.hpd(data, **{"input_core_dims": [["chain"]]}) + We can also calculate the hpd over a particular selection over all groups: + + .. ipython:: + + In [1]: az.hpd(data, sel={"chain":[0, 1, 3]}, **{"input_core_dims": [["draw"]]}) + """ if credible_interval is None: credible_interval = rcParams["stats.credible_interval"] @@ -410,6 +419,8 @@ def hpd( return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs).x.values ary = convert_to_dataset(ary, group=group) + if sel is not None: + ary = ary.sel(**sel) var_names = _var_names(var_names, ary) ary = ary[var_names] if var_names else ary @@ -579,7 +590,7 @@ def loo(data, pointwise=False, reff=None, scale=None): ess_p = ess(posterior, method="mean") # this mean is over all data variables reff = ( - np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples + np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples ) log_weights, pareto_shape = psislw(-log_likelihood, reff) @@ -865,20 +876,20 @@ def r2_score(y_true, y_pred): def summary( - data, - var_names: Optional[List[str]] = None, - fmt: str = "wide", - kind: str = "all", - round_to=None, - include_circ=None, - stat_funcs=None, - extend=True, - credible_interval=None, - order="C", - index_origin=None, - skipna=False, - coords: Optional[CoordSpec] = None, - dims: Optional[DimSpec] = None, + data, + var_names: Optional[List[str]] = None, + fmt: str = "wide", + kind: str = "all", + round_to=None, + include_circ=None, + stat_funcs=None, + extend=True, + credible_interval=None, + order="C", + index_origin=None, + skipna=False, + coords: Optional[CoordSpec] = None, + dims: Optional[DimSpec] = None, ) -> Union[pd.DataFrame, xr.Dataset]: """Create a data frame with summary statistics. @@ -1461,22 +1472,22 @@ def _loo_pit(y, y_hat, log_weights): def apply_test_function( - idata, - func, - group="both", - var_names=None, - pointwise=False, - out_data_shape=None, - out_pp_shape=None, - out_name_data="T", - out_name_pp=None, - func_args=None, - func_kwargs=None, - ufunc_kwargs=None, - wrap_data_kwargs=None, - wrap_pp_kwargs=None, - inplace=True, - overwrite=None, + idata, + func, + group="both", + var_names=None, + pointwise=False, + out_data_shape=None, + out_pp_shape=None, + out_name_data="T", + out_name_pp=None, + func_args=None, + func_kwargs=None, + ufunc_kwargs=None, + wrap_data_kwargs=None, + wrap_pp_kwargs=None, + inplace=True, + overwrite=None, ): """Apply a Bayesian test function to an InferenceData object. diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index 76374d24b9..85817d568b 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -87,6 +87,12 @@ def test_hpd_idata_group(centered_eight): assert range_posterior < range_prior +def test_hpd_sel(centered_eight): + data = centered_eight.posterior + result = hpd(data, sel={"chain":[0, 1, 3]}, **{"input_core_dims": [["draw"]]}) + assert_array_equal(result.coords["chain"], [0, 1, 3]) + + def test_hpd_multimodal(): normal_sample = np.concatenate( (np.random.normal(-4, 1, 2500000), np.random.normal(2, 0.5, 2500000)) From 82f15a4743b29ab9c2d3eca7137281e9075eb54f Mon Sep 17 00:00:00 2001 From: percygautam Date: Mon, 23 Mar 2020 02:34:45 +0530 Subject: [PATCH 11/21] minor nits --- arviz/plots/backends/bokeh/densityplot.py | 4 +- arviz/plots/backends/bokeh/forestplot.py | 2 +- arviz/plots/backends/bokeh/violinplot.py | 2 +- .../plots/backends/matplotlib/densityplot.py | 4 +- arviz/plots/backends/matplotlib/forestplot.py | 2 +- arviz/plots/backends/matplotlib/violinplot.py | 2 +- arviz/stats/stats.py | 97 +++++++++---------- 7 files changed, 56 insertions(+), 57 deletions(-) diff --git a/arviz/plots/backends/bokeh/densityplot.py b/arviz/plots/backends/bokeh/densityplot.py index 93ea1c5069..3bc83a76ea 100644 --- a/arviz/plots/backends/bokeh/densityplot.py +++ b/arviz/plots/backends/bokeh/densityplot.py @@ -124,7 +124,7 @@ def _d_helper( if vec.dtype.kind == "f": if credible_interval != 1: - hpd_ = hpd(vec, credible_interval=credible_interval, multimodal=False) + hpd_ = hpd(vec, credible_interval, multimodal=False) new_vec = vec[(vec >= hpd_[0]) & (vec <= hpd_[1])] else: new_vec = vec @@ -174,7 +174,7 @@ def _d_helper( ) else: - xmin, xmax = hpd(vec, credible_interval=credible_interval, multimodal=False) + xmin, xmax = hpd(vec, credible_interval, multimodal=False) bins = get_bins(vec) _, hist, edges = histogram(vec, bins=bins) diff --git a/arviz/plots/backends/bokeh/forestplot.py b/arviz/plots/backends/bokeh/forestplot.py index 6c24baf80d..0d19ef8c67 100644 --- a/arviz/plots/backends/bokeh/forestplot.py +++ b/arviz/plots/backends/bokeh/forestplot.py @@ -548,7 +548,7 @@ def treeplot(self, qlist, credible_interval): """Get data for each treeplot for the variable.""" for y, _, label, values, color in self.iterator(): ntiles = np.percentile(values.flatten(), qlist) - ntiles[0], ntiles[-1] = hpd(values.flatten(), credible_interval=credible_interval, multimodal=False) + ntiles[0], ntiles[-1] = hpd(values.flatten(), credible_interval, multimodal=False) yield y, label, ntiles, color def ridgeplot(self, mult, ridgeplot_kind): diff --git a/arviz/plots/backends/bokeh/violinplot.py b/arviz/plots/backends/bokeh/violinplot.py index bebe3e4d42..fc0264be7a 100644 --- a/arviz/plots/backends/bokeh/violinplot.py +++ b/arviz/plots/backends/bokeh/violinplot.py @@ -66,7 +66,7 @@ def plot_violin( ax_.scatter(rug_x, val, **rug_kwargs) per = np.percentile(val, [25, 75, 50]) - hpd_intervals = hpd(val, credible_interval=credible_interval, multimodal=False) + hpd_intervals = hpd(val, credible_interval, multimodal=False) if quartiles: ax_.line( diff --git a/arviz/plots/backends/matplotlib/densityplot.py b/arviz/plots/backends/matplotlib/densityplot.py index 034cca4dac..4d52b2b9c3 100644 --- a/arviz/plots/backends/matplotlib/densityplot.py +++ b/arviz/plots/backends/matplotlib/densityplot.py @@ -134,7 +134,7 @@ def _d_helper( """ if vec.dtype.kind == "f": if credible_interval != 1: - hpd_ = hpd(vec, credible_interval=credible_interval, multimodal=False) + hpd_ = hpd(vec, credible_interval, multimodal=False) new_vec = vec[(vec >= hpd_[0]) & (vec <= hpd_[1])] else: new_vec = vec @@ -154,7 +154,7 @@ def _d_helper( ax.fill_between(x, density, color=color, alpha=shade) else: - xmin, xmax = hpd(vec, credible_interval=credible_interval, multimodal=False) + xmin, xmax = hpd(vec, credible_interval, multimodal=False) bins = get_bins(vec) if outline: ax.hist(vec, bins=bins, color=color, histtype="step", align="left") diff --git a/arviz/plots/backends/matplotlib/forestplot.py b/arviz/plots/backends/matplotlib/forestplot.py index f2a2118fb5..d68671dcf8 100644 --- a/arviz/plots/backends/matplotlib/forestplot.py +++ b/arviz/plots/backends/matplotlib/forestplot.py @@ -500,7 +500,7 @@ def treeplot(self, qlist, credible_interval): """Get data for each treeplot for the variable.""" for y, _, label, values, color in self.iterator(): ntiles = np.percentile(values.flatten(), qlist) - ntiles[0], ntiles[-1] = hpd(values.flatten(), credible_interval=credible_interval, multimodal=False) + ntiles[0], ntiles[-1] = hpd(values.flatten(), credible_interval, multimodal=False) yield y, label, ntiles, color def ridgeplot(self, mult, ridgeplot_kind): diff --git a/arviz/plots/backends/matplotlib/violinplot.py b/arviz/plots/backends/matplotlib/violinplot.py index e3c6066229..9c533103c6 100644 --- a/arviz/plots/backends/matplotlib/violinplot.py +++ b/arviz/plots/backends/matplotlib/violinplot.py @@ -58,7 +58,7 @@ def plot_violin( ax_.plot(rug_x, val, **rug_kwargs) per = np.percentile(val, [25, 75, 50]) - hpd_intervals = hpd(val, credible_interval=credible_interval, multimodal=False) + hpd_intervals = hpd(val, credible_interval, multimodal=False) if quartiles: ax_.plot([0, 0], per[:2], lw=linewidth * 3, color="k", solid_capstyle="round") diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 3760dfb455..404f25f107 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -44,7 +44,7 @@ def compare( - dataset_dict, ic=None, method="BB-pseudo-BMA", b_samples=1000, alpha=1, seed=None, scale=None + dataset_dict, ic=None, method="BB-pseudo-BMA", b_samples=1000, alpha=1, seed=None, scale=None ): r"""Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation. @@ -306,16 +306,15 @@ def _ic_matrix(ics, ic_i): def hpd( - ary, - *, - group="posterior", - var_names=None, - sel=None, - credible_interval=None, - circular=False, - multimodal=False, - skipna=False, - **kwargs + ary, + credible_interval=None, + circular=False, + multimodal=False, + skipna=False, + group="posterior", + var_names=None, + sel=None, + **kwargs ): """ Calculate highest posterior density (HPD) of array for given credible_interval. @@ -328,13 +327,6 @@ def hpd( object containing posterior samples. Any object that can be converted to an az.InferenceData object. Refer to documentation of az.convert_to_dataset for details. - group : str, optional - Specifies which InferenceData group should be used to calculate hpd. - Defaults to 'posterior' - var_names : list, optional - Names of variables to include in the hpd report - sel: dict, optional - To calculate hpd over selection on all groups. credible_interval : float, optional Credible interval to compute. Defaults to 0.94. circular : bool, optional @@ -346,6 +338,13 @@ def hpd( modes are well separated. skipna : bool If true ignores nan values when computing the hpd interval. Defaults to false. + group : str, optional + Specifies which InferenceData group should be used to calculate hpd. + Defaults to 'posterior' + var_names : list, optional + Names of variables to include in the hpd report + sel: dict, optional + To calculate hpd over selection on all groups. kwargs : dict, optional Additional keywords passed to `wrap_xarray_ufunc`. See the docstring of :obj:`wrap_xarray_ufunc method `. @@ -590,7 +589,7 @@ def loo(data, pointwise=False, reff=None, scale=None): ess_p = ess(posterior, method="mean") # this mean is over all data variables reff = ( - np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples + np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples ) log_weights, pareto_shape = psislw(-log_likelihood, reff) @@ -876,20 +875,20 @@ def r2_score(y_true, y_pred): def summary( - data, - var_names: Optional[List[str]] = None, - fmt: str = "wide", - kind: str = "all", - round_to=None, - include_circ=None, - stat_funcs=None, - extend=True, - credible_interval=None, - order="C", - index_origin=None, - skipna=False, - coords: Optional[CoordSpec] = None, - dims: Optional[DimSpec] = None, + data, + var_names: Optional[List[str]] = None, + fmt: str = "wide", + kind: str = "all", + round_to=None, + include_circ=None, + stat_funcs=None, + extend=True, + credible_interval=None, + order="C", + index_origin=None, + skipna=False, + coords: Optional[CoordSpec] = None, + dims: Optional[DimSpec] = None, ) -> Union[pd.DataFrame, xr.Dataset]: """Create a data frame with summary statistics. @@ -1472,22 +1471,22 @@ def _loo_pit(y, y_hat, log_weights): def apply_test_function( - idata, - func, - group="both", - var_names=None, - pointwise=False, - out_data_shape=None, - out_pp_shape=None, - out_name_data="T", - out_name_pp=None, - func_args=None, - func_kwargs=None, - ufunc_kwargs=None, - wrap_data_kwargs=None, - wrap_pp_kwargs=None, - inplace=True, - overwrite=None, + idata, + func, + group="both", + var_names=None, + pointwise=False, + out_data_shape=None, + out_pp_shape=None, + out_name_data="T", + out_name_pp=None, + func_args=None, + func_kwargs=None, + ufunc_kwargs=None, + wrap_data_kwargs=None, + wrap_pp_kwargs=None, + inplace=True, + overwrite=None, ): """Apply a Bayesian test function to an InferenceData object. From 0ff7d180b7295253e5ab8d189a18fddef5b21fc6 Mon Sep 17 00:00:00 2001 From: percygautam Date: Mon, 23 Mar 2020 02:51:21 +0530 Subject: [PATCH 12/21] pydocstyle change --- arviz/stats/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 404f25f107..f67bfa04a7 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -465,7 +465,7 @@ def _hpd(ary, credible_interval, circular, skipna): def _hpd_multimodal(ary, credible_interval, skipna): - """Compute hpd if the distribution is multimodal""" + """Compute hpd if the distribution is multimodal.""" ary = ary.flatten() if skipna: ary = ary[~np.isnan(ary)] From deaf5e94b8a6f7a8fdc5c8fe9c9d77527aacb3c8 Mon Sep 17 00:00:00 2001 From: percygautam Date: Tue, 24 Mar 2020 17:29:46 +0530 Subject: [PATCH 13/21] minor changes --- arviz/stats/stats.py | 4 ++-- arviz/tests/base_tests/test_stats.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index f67bfa04a7..6ec85414be 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -384,13 +384,13 @@ def hpd( .. ipython:: - In [1]: az.hpd(data, **{"input_core_dims": [["chain"]]}) + In [1]: az.hpd(data, input_core_dims = [["chain"]]) We can also calculate the hpd over a particular selection over all groups: .. ipython:: - In [1]: az.hpd(data, sel={"chain":[0, 1, 3]}, **{"input_core_dims": [["draw"]]}) + In [1]: az.hpd(data, sel={"chain":[0, 1, 3]}, input_core_dims = [["draw"]]) """ if credible_interval is None: diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index 85817d568b..586cd0d74c 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -65,7 +65,7 @@ def test_hpd_idata(centered_eight): assert isinstance(result, Dataset) assert result.dims == {"school": 8, "hpd": 2} - result = hpd(data, **{"input_core_dims": [["chain"]]}) + result = hpd(data, input_core_dims=[["chain"]]) assert isinstance(result, Dataset) assert result.dims == {"draw": 500, "hpd": 2, "school": 8} @@ -89,7 +89,7 @@ def test_hpd_idata_group(centered_eight): def test_hpd_sel(centered_eight): data = centered_eight.posterior - result = hpd(data, sel={"chain":[0, 1, 3]}, **{"input_core_dims": [["draw"]]}) + result = hpd(data, sel={"chain": [0, 1, 3]}, input_core_dims=[["draw"]]) assert_array_equal(result.coords["chain"], [0, 1, 3]) From c069dd5d045dcafe4ff187b8a7b13b96d784f313 Mon Sep 17 00:00:00 2001 From: percygautam Date: Thu, 26 Mar 2020 00:30:06 +0530 Subject: [PATCH 14/21] hpd multimodal --- arviz/stats/stats.py | 19 ++++++++++--------- arviz/stats/stats_utils.py | 9 ++++++++- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 6ec85414be..7b0d0d131b 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -399,24 +399,25 @@ def hpd( if not 1 >= credible_interval > 0: raise ValueError("The value of credible_interval should be in the interval (0, 1]") - if multimodal: - return _hpd_multimodal(ary, credible_interval, skipna) - func_kwargs = { "credible_interval": credible_interval, - "circular": circular, "skipna": skipna, - "out_shape": (2,), } - kwargs.setdefault("output_core_dims", [["hpd"]]) + kwargs.setdefault("output_core_dims", [["hpd", "hpd2"] if multimodal else ["hpd"]]) + if not multimodal: + func_kwargs["circular"] = circular + + func = _hpd_multimodal if multimodal else _hpd if isinstance(ary, np.ndarray): if len(ary.shape) == 1: - return _hpd(ary, credible_interval, circular, skipna) + return func(ary, **func_kwargs) + func_kwargs["out_shape"] = (10, 2,) if multimodal else (2,) ary = convert_to_dataset(ary) kwargs.setdefault("input_core_dims", [["chain"]]) - return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs).x.values + return _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs).x.values + func_kwargs["out_shape"] = (10, 2,) if multimodal else (2,) ary = convert_to_dataset(ary, group=group) if sel is not None: ary = ary.sel(**sel) @@ -424,7 +425,7 @@ def hpd( ary = ary[var_names] if var_names else ary - return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs) + return _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs) def _hpd(ary, credible_interval, circular, skipna): diff --git a/arviz/stats/stats_utils.py b/arviz/stats/stats_utils.py index e26af7b3cb..92d9a2afe5 100644 --- a/arviz/stats/stats_utils.py +++ b/arviz/stats/stats_utils.py @@ -129,9 +129,16 @@ def _ufunc(*args, out=None, out_shape=None, **kwargs): msg = "Shape incorrect for `out`: {}.".format(out.shape) msg += " Correct shape is {}".format(arys[-1].shape[:-n_dims]) raise TypeError(msg) + func_out = np.empty(out_shape) for idx in np.ndindex(out.shape[:n_dims_out]): arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys] - out[idx] = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index] + ret = np.asarray(func(*arys_idx, *args[n_input:], **kwargs)) + if ret.shape[0] != out_shape[0]: + func_out[:] = np.NaN + func_out[:ret.shape[0], :] = ret + else: + func_out = ret + out[idx] = np.asarray(func_out)[index] return out def _multi_ufunc(*args, out=None, out_shape=None, **kwargs): From 000866c4687a2069a3c0e6efa6f128362c7a57fc Mon Sep 17 00:00:00 2001 From: percygautam Date: Fri, 27 Mar 2020 02:28:30 +0530 Subject: [PATCH 15/21] changes to hpd multimodal --- arviz/stats/stats.py | 51 +++++++++++++++++----------- arviz/stats/stats_utils.py | 9 +---- arviz/tests/base_tests/test_stats.py | 4 +-- 3 files changed, 35 insertions(+), 29 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 7b0d0d131b..94e0438aee 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -12,7 +12,7 @@ from scipy.optimize import minimize import xarray as xr -from ..plots.plot_utils import _fast_kde, get_bins +from ..plots.plot_utils import _fast_kde, get_bins, get_coords from ..data import convert_to_inference_data, convert_to_dataset, InferenceData, CoordSpec, DimSpec from .diagnostics import _multichain_statistics, _mc_error, ess from .stats_utils import ( @@ -313,7 +313,8 @@ def hpd( skipna=False, group="posterior", var_names=None, - sel=None, + coords=None, + max_modes=10, **kwargs ): """ @@ -343,8 +344,10 @@ def hpd( Defaults to 'posterior' var_names : list, optional Names of variables to include in the hpd report - sel: dict, optional - To calculate hpd over selection on all groups. + coords: mapping, optional + Specifies the subset over to calculate hpd. + max_modes: int, optional + Specifies the maximume number of modes for multimodal case. kwargs : dict, optional Additional keywords passed to `wrap_xarray_ufunc`. See the docstring of :obj:`wrap_xarray_ufunc method `. @@ -390,7 +393,7 @@ def hpd( .. ipython:: - In [1]: az.hpd(data, sel={"chain":[0, 1, 3]}, input_core_dims = [["draw"]]) + In [1]: az.hpd(data, coords={"chain":[0, 1, 3]}, input_core_dims = [["draw"]]) """ if credible_interval is None: @@ -402,30 +405,35 @@ def hpd( func_kwargs = { "credible_interval": credible_interval, "skipna": skipna, + "out_shape": (max_modes, 2,) if multimodal else (2,), } - kwargs.setdefault("output_core_dims", [["hpd", "hpd2"] if multimodal else ["hpd"]]) + kwargs.setdefault("output_core_dims", [["hpd", "mode"] if multimodal else ["hpd"]]) if not multimodal: func_kwargs["circular"] = circular + if multimodal: + func_kwargs["max_modes"] = max_modes func = _hpd_multimodal if multimodal else _hpd if isinstance(ary, np.ndarray): if len(ary.shape) == 1: - return func(ary, **func_kwargs) - func_kwargs["out_shape"] = (10, 2,) if multimodal else (2,) + func_kwargs.pop("out_shape") + out_func = func(ary, **func_kwargs) + out = out_func[~np.isnan(out_func)] + return out.reshape(out.shape[0] // 2, 2) if multimodal else out_func ary = convert_to_dataset(ary) kwargs.setdefault("input_core_dims", [["chain"]]) - return _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs).x.values + res = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs) + return res.dropna("mode", how="all").x.values if multimodal else res.x.values - func_kwargs["out_shape"] = (10, 2,) if multimodal else (2,) ary = convert_to_dataset(ary, group=group) - if sel is not None: - ary = ary.sel(**sel) + if coords is not None: + ary = get_coords(ary, coords) var_names = _var_names(var_names, ary) - ary = ary[var_names] if var_names else ary - return _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs) + res = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs) + return res.dropna("mode", how="all") if multimodal else res def _hpd(ary, credible_interval, circular, skipna): @@ -465,7 +473,7 @@ def _hpd(ary, credible_interval, circular, skipna): return hpd_intervals -def _hpd_multimodal(ary, credible_interval, skipna): +def _hpd_multimodal(ary, credible_interval, skipna, max_modes): """Compute hpd if the distribution is multimodal.""" ary = ary.flatten() if skipna: @@ -489,12 +497,17 @@ def _hpd_multimodal(ary, credible_interval, skipna): intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1) - hpd_intervals = [] - for interval in intervals_splitted: + hpd_intervals = np.full((max_modes, 2,), np.nan,) + for i, interval in enumerate(intervals_splitted): + if i == max_modes: + warnings.warn( + "found more modes than {0}, returning only the first {0} modes", max_modes + ) + break if interval.size == 0: - hpd_intervals.append((bins[0], bins[0])) + hpd_intervals[i] = np.asarray([bins[0], bins[0]]) else: - hpd_intervals.append((interval[0], interval[-1])) + hpd_intervals[i] = np.asarray([interval[0], interval[-1]]) return np.array(hpd_intervals) diff --git a/arviz/stats/stats_utils.py b/arviz/stats/stats_utils.py index 92d9a2afe5..e26af7b3cb 100644 --- a/arviz/stats/stats_utils.py +++ b/arviz/stats/stats_utils.py @@ -129,16 +129,9 @@ def _ufunc(*args, out=None, out_shape=None, **kwargs): msg = "Shape incorrect for `out`: {}.".format(out.shape) msg += " Correct shape is {}".format(arys[-1].shape[:-n_dims]) raise TypeError(msg) - func_out = np.empty(out_shape) for idx in np.ndindex(out.shape[:n_dims_out]): arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys] - ret = np.asarray(func(*arys_idx, *args[n_input:], **kwargs)) - if ret.shape[0] != out_shape[0]: - func_out[:] = np.NaN - func_out[:ret.shape[0], :] = ret - else: - func_out = ret - out[idx] = np.asarray(func_out)[index] + out[idx] = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index] return out def _multi_ufunc(*args, out=None, out_shape=None, **kwargs): diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index 586cd0d74c..7464d87aad 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -87,9 +87,9 @@ def test_hpd_idata_group(centered_eight): assert range_posterior < range_prior -def test_hpd_sel(centered_eight): +def test_hpd_coords(centered_eight): data = centered_eight.posterior - result = hpd(data, sel={"chain": [0, 1, 3]}, input_core_dims=[["draw"]]) + result = hpd(data, coords={"chain": [0, 1, 3]}, input_core_dims=[["draw"]]) assert_array_equal(result.coords["chain"], [0, 1, 3]) From e9d6ace9383c5d51ad7dfe9b4e16ab5b272c71ca Mon Sep 17 00:00:00 2001 From: percygautam Date: Fri, 27 Mar 2020 02:30:47 +0530 Subject: [PATCH 16/21] minor nits --- arviz/stats/stats.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 94e0438aee..4dc4f9f773 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -418,9 +418,9 @@ def hpd( if isinstance(ary, np.ndarray): if len(ary.shape) == 1: func_kwargs.pop("out_shape") - out_func = func(ary, **func_kwargs) - out = out_func[~np.isnan(out_func)] - return out.reshape(out.shape[0] // 2, 2) if multimodal else out_func + hpd_data = func(ary, **func_kwargs) + out = hpd_data[~np.isnan(hpd_data)] + return out.reshape(out.shape[0] // 2, 2) if multimodal else hpd_data ary = convert_to_dataset(ary) kwargs.setdefault("input_core_dims", [["chain"]]) res = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs) @@ -432,8 +432,8 @@ def hpd( var_names = _var_names(var_names, ary) ary = ary[var_names] if var_names else ary - res = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs) - return res.dropna("mode", how="all") if multimodal else res + hpd_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs) + return hpd_data.dropna("mode", how="all") if multimodal else hpd_data def _hpd(ary, credible_interval, circular, skipna): From 2aa68fe9a8f988df6e24585d7f3e94678ecabff9 Mon Sep 17 00:00:00 2001 From: percygautam Date: Fri, 27 Mar 2020 23:13:13 +0530 Subject: [PATCH 17/21] nits --- arviz/stats/stats.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 4dc4f9f773..9f5d767660 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -410,7 +410,7 @@ def hpd( kwargs.setdefault("output_core_dims", [["hpd", "mode"] if multimodal else ["hpd"]]) if not multimodal: func_kwargs["circular"] = circular - if multimodal: + else: func_kwargs["max_modes"] = max_modes func = _hpd_multimodal if multimodal else _hpd @@ -419,8 +419,8 @@ def hpd( if len(ary.shape) == 1: func_kwargs.pop("out_shape") hpd_data = func(ary, **func_kwargs) - out = hpd_data[~np.isnan(hpd_data)] - return out.reshape(out.shape[0] // 2, 2) if multimodal else hpd_data + return hpd_data[~np.isnan(hpd_data).all(axis=1), :] if multimodal else hpd_data + ary = convert_to_dataset(ary) kwargs.setdefault("input_core_dims", [["chain"]]) res = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs) From cdea066bf67e4181b5c3464cb9cf057f5b4efcd6 Mon Sep 17 00:00:00 2001 From: percygautam Date: Sat, 28 Mar 2020 02:26:25 +0530 Subject: [PATCH 18/21] final nits --- arviz/stats/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 9f5d767660..83c984b3ca 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -418,7 +418,7 @@ def hpd( if isinstance(ary, np.ndarray): if len(ary.shape) == 1: func_kwargs.pop("out_shape") - hpd_data = func(ary, **func_kwargs) + hpd_data = func(ary, **func_kwargs) # pylint: disable=unexpected-keyword-arg return hpd_data[~np.isnan(hpd_data).all(axis=1), :] if multimodal else hpd_data ary = convert_to_dataset(ary) From 6cfee4808b89493d1d20a3b5c5c5a0de3c1a0936 Mon Sep 17 00:00:00 2001 From: percygautam Date: Sun, 29 Mar 2020 22:38:07 +0530 Subject: [PATCH 19/21] add to changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 75cad75182..b32f47233b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,10 +8,11 @@ * Add `num_chains` and `pred_dims` arguments to io_pyro #1090 * Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079) * Allow xarray.Dataarray input for plots.(#1120) +* Revamped the `hpd` function to make it work with muti-dimensional arrays (#1117) ### Maintenance and fixes * Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115) * Fixed hist kind of `plot_dist` with multidimensional input (#1115) -* Fixed `TypeError` in `transform` argument of `plot_density` and `plot_forest` when `InferenceData is a list or tuple (#1121)` +* Fixed `TypeError` in `transform` argument of `plot_density` and `plot_forest` when `InferenceData is a list or tuple` (#1121) ### Deprecation ### Documentation From 19beaf20628319b596e9fed81bed3869ad6bbb4c Mon Sep 17 00:00:00 2001 From: percygautam Date: Mon, 30 Mar 2020 23:12:24 +0530 Subject: [PATCH 20/21] changes --- CHANGELOG.md | 4 ++-- arviz/stats/stats.py | 17 ++++++++--------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b32f47233b..e4d6ec1b1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,11 +8,11 @@ * Add `num_chains` and `pred_dims` arguments to io_pyro #1090 * Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079) * Allow xarray.Dataarray input for plots.(#1120) -* Revamped the `hpd` function to make it work with muti-dimensional arrays (#1117) +* Revamped the `hpd` function to make it work with mutidimensional arrays, InferenceData and xarray objects (#1117) ### Maintenance and fixes * Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115) * Fixed hist kind of `plot_dist` with multidimensional input (#1115) -* Fixed `TypeError` in `transform` argument of `plot_density` and `plot_forest` when `InferenceData is a list or tuple` (#1121) +* Fixed `TypeError` in `transform` argument of `plot_density` and `plot_forest` when `InferenceData` is a list or tuple (#1121) ### Deprecation ### Documentation diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 83c984b3ca..53395129b8 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -415,16 +415,14 @@ def hpd( func = _hpd_multimodal if multimodal else _hpd - if isinstance(ary, np.ndarray): - if len(ary.shape) == 1: - func_kwargs.pop("out_shape") - hpd_data = func(ary, **func_kwargs) # pylint: disable=unexpected-keyword-arg - return hpd_data[~np.isnan(hpd_data).all(axis=1), :] if multimodal else hpd_data + isarray = isinstance(ary, np.ndarray) + if isarray and ary.ndim <= 1: + func_kwargs.pop("out_shape") + hpd_data = func(ary, **func_kwargs) # pylint: disable=unexpected-keyword-arg + return hpd_data[~np.isnan(hpd_data).all(axis=1), :] if multimodal else hpd_data - ary = convert_to_dataset(ary) + if isarray: kwargs.setdefault("input_core_dims", [["chain"]]) - res = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs) - return res.dropna("mode", how="all").x.values if multimodal else res.x.values ary = convert_to_dataset(ary, group=group) if coords is not None: @@ -433,7 +431,8 @@ def hpd( ary = ary[var_names] if var_names else ary hpd_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs) - return hpd_data.dropna("mode", how="all") if multimodal else hpd_data + hpd_data = hpd_data.dropna("mode", how="all") if multimodal else hpd_data + return hpd_data.x.values if isarray else hpd_data def _hpd(ary, credible_interval, circular, skipna): From 758e9b5b5d54b64873de28a4deb8d721ab89069e Mon Sep 17 00:00:00 2001 From: percygautam Date: Tue, 31 Mar 2020 22:37:19 +0530 Subject: [PATCH 21/21] final changes --- arviz/stats/stats.py | 2 +- arviz/tests/base_tests/test_stats.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 53395129b8..cec2e3f836 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -421,7 +421,7 @@ def hpd( hpd_data = func(ary, **func_kwargs) # pylint: disable=unexpected-keyword-arg return hpd_data[~np.isnan(hpd_data).all(axis=1), :] if multimodal else hpd_data - if isarray: + if isarray and ary.ndim == 2: kwargs.setdefault("input_core_dims", [["chain"]]) ary = convert_to_dataset(ary, group=group) diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index 7464d87aad..10dd9d72c4 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -56,7 +56,7 @@ def test_hpd_2darray(): def test_hpd_multidimension(): normal_sample = np.random.randn(12000, 10, 3) result = hpd(normal_sample) - assert result.shape == (10, 3, 2,) + assert result.shape == (3, 2,) def test_hpd_idata(centered_eight):