-
-
Notifications
You must be signed in to change notification settings - Fork 500
Revamping HPD #1117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revamping HPD #1117
Changes from 6 commits
d29cefb
2bb5646
7de2518
e4d8318
ae2e59d
f2e67b4
6dd3ee8
1d75d93
9f7ad1f
a24323e
ff176f4
82f15a4
0ff7d18
deaf5e9
c069dd5
000866c
e9d6ace
2aa68fe
cdea066
9e6ac55
6cfee48
19beaf2
758e9b5
b746b9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -305,16 +305,32 @@ def _ic_matrix(ics, ic_i): | |
| return rows, cols, ic_i_val | ||
|
|
||
|
|
||
| def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=False): | ||
| def hpd( | ||
| ary, | ||
| *, | ||
| group="posterior", | ||
| var_names=None, | ||
| credible_interval=None, | ||
| circular=False, | ||
| multimodal=False, | ||
| skipna=False, | ||
| **kwargs | ||
| ): | ||
|
percygautam marked this conversation as resolved.
|
||
| """ | ||
| Calculate highest posterior density (HPD) of array for given credible_interval. | ||
|
|
||
| The HPD is the minimum width Bayesian credible interval (BCI). | ||
|
|
||
| Parameters | ||
| ---------- | ||
| ary : Numpy array | ||
| An array containing posterior samples | ||
| ary : obj | ||
| onject containing posterior samples. | ||
| Any object that can be converted to an az.InferenceData object. | ||
| Refer to documentation of az.convert_to_dataset for details. | ||
| group : str, optional | ||
| Specifies which InferenceData group should be used to calculate hpd. Defaults to 'posterior' | ||
| var_names : list | ||
| Names of variables to include in the hpd report | ||
| credible_interval : float, optional | ||
| Credible interval to compute. Defaults to 0.94. | ||
| circular : bool, optional | ||
|
|
@@ -349,86 +365,104 @@ def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=Fa | |
| if not 1 >= credible_interval > 0: | ||
| raise ValueError("The value of credible_interval should be in the interval (0, 1]") | ||
|
|
||
| if ary.ndim > 1: | ||
| hpd_array = np.array( | ||
| [ | ||
| hpd( | ||
| row, | ||
| credible_interval=credible_interval, | ||
| circular=circular, | ||
| multimodal=multimodal, | ||
| ) | ||
| for row in ary.T | ||
| ] | ||
| ) | ||
| return hpd_array | ||
|
|
||
| if multimodal: | ||
| if skipna: | ||
| ary = ary[~np.isnan(ary)] | ||
|
|
||
| if ary.dtype.kind == "f": | ||
| density, lower, upper = _fast_kde(ary) | ||
| range_x = upper - lower | ||
| dx = range_x / len(density) | ||
| bins = np.linspace(lower, upper, len(density)) | ||
| else: | ||
| bins = get_bins(ary) | ||
| _, density, _ = histogram(ary, bins=bins) | ||
| dx = np.diff(bins)[0] | ||
| return _hpd_multimodal(ary, credible_interval, skipna) | ||
|
|
||
| density *= dx | ||
| func_kwargs = { | ||
| "credible_interval": credible_interval, | ||
| "circular": circular, | ||
| "skipna": skipna, | ||
|
percygautam marked this conversation as resolved.
|
||
| "out_shape": (2,), | ||
| } | ||
| kwargs.setdefault("output_core_dims", [["hpd"]]) | ||
|
|
||
| idx = np.argsort(-density) | ||
| intervals = bins[idx][density[idx].cumsum() <= credible_interval] | ||
| intervals.sort() | ||
| if isinstance(ary, np.ndarray): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be only if the array is 1d or 2d: If the array has 3 or more dimensions, it should assume ArviZ dim order: |
||
| if len(ary.shape) == 1: | ||
| return _hpd(ary, credible_interval, circular, skipna) | ||
| ary = convert_to_dataset(ary) | ||
| kwargs.setdefault("input_core_dims", [["chain"]]) | ||
| return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs).to_array().values[0] | ||
|
percygautam marked this conversation as resolved.
Outdated
|
||
|
|
||
| intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1) | ||
| ary = convert_to_dataset(ary, group=group) | ||
| var_names = _var_names(var_names, ary) | ||
|
|
||
| hpd_intervals = [] | ||
| for interval in intervals_splitted: | ||
| if interval.size == 0: | ||
| hpd_intervals.append((bins[0], bins[0])) | ||
| else: | ||
| hpd_intervals.append((interval[0], interval[-1])) | ||
| ary = ary if var_names is None else ary[var_names] | ||
|
percygautam marked this conversation as resolved.
Outdated
|
||
|
|
||
| hpd_intervals = np.array(hpd_intervals) | ||
| kwargs.setdefault("input_core_dims", [["chain", "draw"]]) | ||
|
percygautam marked this conversation as resolved.
Outdated
|
||
| return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs) | ||
|
percygautam marked this conversation as resolved.
Outdated
|
||
|
|
||
| else: | ||
| if skipna: | ||
| nans = np.isnan(ary) | ||
| if not nans.all(): | ||
| ary = ary[~nans] | ||
| n = len(ary) | ||
|
|
||
| if circular: | ||
| mean = st.circmean(ary, high=np.pi, low=-np.pi) | ||
| ary = ary - mean | ||
| ary = np.arctan2(np.sin(ary), np.cos(ary)) | ||
| def _hpd(ary, credible_interval, circular, skipna): | ||
| """Compute hpd over the flattened array.""" | ||
| ary = ary.flatten() | ||
| if skipna: | ||
| nans = np.isnan(ary) | ||
| if not nans.all(): | ||
| ary = ary[~nans] | ||
| n = len(ary) | ||
|
|
||
| if circular: | ||
| mean = st.circmean(ary, high=np.pi, low=-np.pi) | ||
| ary = ary - mean | ||
| ary = np.arctan2(np.sin(ary), np.cos(ary)) | ||
|
|
||
| ary = np.sort(ary) | ||
| interval_idx_inc = int(np.floor(credible_interval * n)) | ||
| n_intervals = n - interval_idx_inc | ||
| interval_width = ary[interval_idx_inc:] - ary[:n_intervals] | ||
| ary = np.sort(ary) | ||
| interval_idx_inc = int(np.floor(credible_interval * n)) | ||
| n_intervals = n - interval_idx_inc | ||
| interval_width = ary[interval_idx_inc:] - ary[:n_intervals] | ||
|
|
||
| if len(interval_width) == 0: | ||
| raise ValueError("Too few elements for interval calculation. ") | ||
| if len(interval_width) == 0: | ||
| raise ValueError("Too few elements for interval calculation. ") | ||
|
|
||
| min_idx = np.argmin(interval_width) | ||
| hdi_min = ary[min_idx] | ||
| hdi_max = ary[min_idx + interval_idx_inc] | ||
| min_idx = np.argmin(interval_width) | ||
| hdi_min = ary[min_idx] | ||
| hdi_max = ary[min_idx + interval_idx_inc] | ||
|
|
||
| if circular: | ||
| hdi_min = hdi_min + mean | ||
| hdi_max = hdi_max + mean | ||
| hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min)) | ||
| hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max)) | ||
| if circular: | ||
| hdi_min = hdi_min + mean | ||
| hdi_max = hdi_max + mean | ||
| hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min)) | ||
| hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max)) | ||
|
|
||
| hpd_intervals = np.array([hdi_min, hdi_max]) | ||
| hpd_intervals = np.array([hdi_min, hdi_max]) | ||
|
|
||
| return hpd_intervals | ||
|
|
||
|
|
||
| def _hpd_multimodal(ary, credible_interval, skipna): | ||
| """Compute hpd if the distribution is multimodal""" | ||
|
percygautam marked this conversation as resolved.
Outdated
|
||
| ary = ary.flatten() | ||
| if skipna: | ||
| ary = ary[~np.isnan(ary)] | ||
|
|
||
| if ary.dtype.kind == "f": | ||
| density, lower, upper = _fast_kde(ary) | ||
| range_x = upper - lower | ||
| dx = range_x / len(density) | ||
| bins = np.linspace(lower, upper, len(density)) | ||
| else: | ||
| bins = get_bins(ary) | ||
| _, density, _ = histogram(ary, bins=bins) | ||
| dx = np.diff(bins)[0] | ||
|
|
||
| density *= dx | ||
|
|
||
| idx = np.argsort(-density) | ||
| intervals = bins[idx][density[idx].cumsum() <= credible_interval] | ||
| intervals.sort() | ||
|
|
||
| intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1) | ||
|
|
||
| hpd_intervals = [] | ||
|
percygautam marked this conversation as resolved.
Outdated
|
||
| for interval in intervals_splitted: | ||
| if interval.size == 0: | ||
| hpd_intervals.append((bins[0], bins[0])) | ||
| else: | ||
| hpd_intervals.append((interval[0], interval[-1])) | ||
|
percygautam marked this conversation as resolved.
Outdated
|
||
|
|
||
| return np.array(hpd_intervals) | ||
|
|
||
|
|
||
| def loo(data, pointwise=False, reff=None, scale=None): | ||
| """Pareto-smoothed importance sampling leave-one-out cross-validation. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,6 +47,12 @@ def test_hpd(): | |
| assert_array_almost_equal(interval, [-1.88, 1.88], 2) | ||
|
|
||
|
|
||
| def test_hpd_multidimension(): | ||
|
percygautam marked this conversation as resolved.
|
||
| normal_sample = np.random.randn(12000, 10, 3) | ||
| result = hpd(normal_sample) | ||
| assert result.shape == (10, 3, 2,) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line will have to be updated to check that the result shape is the desired
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Earlier, we were calculating
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue is that calculating hpd only over chain is a very bad default, we'll keep the behaviour (for now) in 2d array case to keep backwards compatibility, but 3d arrays are not supported, so we do not have the backwards compatibility constraint.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I have done the changes. |
||
|
|
||
|
|
||
| def test_hpd_multimodal(): | ||
| normal_sample = np.concatenate( | ||
| (np.random.normal(-4, 1, 2500000), np.random.normal(2, 0.5, 2500000)) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.