Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 100 additions & 66 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,16 +305,32 @@ def _ic_matrix(ics, ic_i):
return rows, cols, ic_i_val


def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=False):
def hpd(
ary,
*,
group="posterior",
var_names=None,
credible_interval=None,
circular=False,
multimodal=False,
skipna=False,
**kwargs
Comment thread
OriolAbril marked this conversation as resolved.
):
Comment thread
percygautam marked this conversation as resolved.
"""
Calculate highest posterior density (HPD) of array for given credible_interval.

The HPD is the minimum width Bayesian credible interval (BCI).

Parameters
----------
ary : Numpy array
An array containing posterior samples
ary : obj
onject containing posterior samples.
Any object that can be converted to an az.InferenceData object.
Refer to documentation of az.convert_to_dataset for details.
group : str, optional
Specifies which InferenceData group should be used to calculate hpd. Defaults to 'posterior'
var_names : list
Names of variables to include in the hpd report
credible_interval : float, optional
Credible interval to compute. Defaults to 0.94.
circular : bool, optional
Expand Down Expand Up @@ -349,86 +365,104 @@ def hpd(ary, credible_interval=None, circular=False, multimodal=False, skipna=Fa
if not 1 >= credible_interval > 0:
raise ValueError("The value of credible_interval should be in the interval (0, 1]")

if ary.ndim > 1:
hpd_array = np.array(
[
hpd(
row,
credible_interval=credible_interval,
circular=circular,
multimodal=multimodal,
)
for row in ary.T
]
)
return hpd_array

if multimodal:
if skipna:
ary = ary[~np.isnan(ary)]

if ary.dtype.kind == "f":
density, lower, upper = _fast_kde(ary)
range_x = upper - lower
dx = range_x / len(density)
bins = np.linspace(lower, upper, len(density))
else:
bins = get_bins(ary)
_, density, _ = histogram(ary, bins=bins)
dx = np.diff(bins)[0]
return _hpd_multimodal(ary, credible_interval, skipna)

density *= dx
func_kwargs = {
"credible_interval": credible_interval,
"circular": circular,
"skipna": skipna,
Comment thread
percygautam marked this conversation as resolved.
"out_shape": (2,),
}
kwargs.setdefault("output_core_dims", [["hpd"]])

idx = np.argsort(-density)
intervals = bins[idx][density[idx].cumsum() <= credible_interval]
intervals.sort()
if isinstance(ary, np.ndarray):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be only if the array is 1d or 2d:

isarray = isinstance(ary, np.ndarray)
if isarray and ary.ndim <= 2:

If the array has 3 or more dimensions, it should assume ArviZ dim order: (chain, draw, *shape). hpd should still return a numpy array though:

...
hpd_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs)
hpd_data = hpd_data.dropna("mode", how="all") if multimodal else hpd_data
return hpd_data.x.values if isarray else hpd_data

if len(ary.shape) == 1:
return _hpd(ary, credible_interval, circular, skipna)
ary = convert_to_dataset(ary)
kwargs.setdefault("input_core_dims", [["chain"]])
return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs).to_array().values[0]
Comment thread
percygautam marked this conversation as resolved.
Outdated

intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1)
ary = convert_to_dataset(ary, group=group)
var_names = _var_names(var_names, ary)

hpd_intervals = []
for interval in intervals_splitted:
if interval.size == 0:
hpd_intervals.append((bins[0], bins[0]))
else:
hpd_intervals.append((interval[0], interval[-1]))
ary = ary if var_names is None else ary[var_names]
Comment thread
percygautam marked this conversation as resolved.
Outdated

hpd_intervals = np.array(hpd_intervals)
kwargs.setdefault("input_core_dims", [["chain", "draw"]])
Comment thread
percygautam marked this conversation as resolved.
Outdated
return _wrap_xarray_ufunc(_hpd, ary, func_kwargs=func_kwargs, **kwargs)
Comment thread
percygautam marked this conversation as resolved.
Outdated

else:
if skipna:
nans = np.isnan(ary)
if not nans.all():
ary = ary[~nans]
n = len(ary)

if circular:
mean = st.circmean(ary, high=np.pi, low=-np.pi)
ary = ary - mean
ary = np.arctan2(np.sin(ary), np.cos(ary))
def _hpd(ary, credible_interval, circular, skipna):
"""Compute hpd over the flattened array."""
ary = ary.flatten()
if skipna:
nans = np.isnan(ary)
if not nans.all():
ary = ary[~nans]
n = len(ary)

if circular:
mean = st.circmean(ary, high=np.pi, low=-np.pi)
ary = ary - mean
ary = np.arctan2(np.sin(ary), np.cos(ary))

ary = np.sort(ary)
interval_idx_inc = int(np.floor(credible_interval * n))
n_intervals = n - interval_idx_inc
interval_width = ary[interval_idx_inc:] - ary[:n_intervals]
ary = np.sort(ary)
interval_idx_inc = int(np.floor(credible_interval * n))
n_intervals = n - interval_idx_inc
interval_width = ary[interval_idx_inc:] - ary[:n_intervals]

if len(interval_width) == 0:
raise ValueError("Too few elements for interval calculation. ")
if len(interval_width) == 0:
raise ValueError("Too few elements for interval calculation. ")

min_idx = np.argmin(interval_width)
hdi_min = ary[min_idx]
hdi_max = ary[min_idx + interval_idx_inc]
min_idx = np.argmin(interval_width)
hdi_min = ary[min_idx]
hdi_max = ary[min_idx + interval_idx_inc]

if circular:
hdi_min = hdi_min + mean
hdi_max = hdi_max + mean
hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))
if circular:
hdi_min = hdi_min + mean
hdi_max = hdi_max + mean
hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))

hpd_intervals = np.array([hdi_min, hdi_max])
hpd_intervals = np.array([hdi_min, hdi_max])

return hpd_intervals


def _hpd_multimodal(ary, credible_interval, skipna):
"""Compute hpd if the distribution is multimodal"""
Comment thread
percygautam marked this conversation as resolved.
Outdated
ary = ary.flatten()
if skipna:
ary = ary[~np.isnan(ary)]

if ary.dtype.kind == "f":
density, lower, upper = _fast_kde(ary)
range_x = upper - lower
dx = range_x / len(density)
bins = np.linspace(lower, upper, len(density))
else:
bins = get_bins(ary)
_, density, _ = histogram(ary, bins=bins)
dx = np.diff(bins)[0]

density *= dx

idx = np.argsort(-density)
intervals = bins[idx][density[idx].cumsum() <= credible_interval]
intervals.sort()

intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1)

hpd_intervals = []
Comment thread
percygautam marked this conversation as resolved.
Outdated
for interval in intervals_splitted:
if interval.size == 0:
hpd_intervals.append((bins[0], bins[0]))
else:
hpd_intervals.append((interval[0], interval[-1]))
Comment thread
percygautam marked this conversation as resolved.
Outdated

return np.array(hpd_intervals)


def loo(data, pointwise=False, reff=None, scale=None):
"""Pareto-smoothed importance sampling leave-one-out cross-validation.

Expand Down
6 changes: 6 additions & 0 deletions arviz/tests/base_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def test_hpd():
assert_array_almost_equal(interval, [-1.88, 1.88], 2)


def test_hpd_multidimension():
Comment thread
percygautam marked this conversation as resolved.
normal_sample = np.random.randn(12000, 10, 3)
result = hpd(normal_sample)
assert result.shape == (10, 3, 2,)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line will have to be updated to check that the result shape is the desired (3, 2)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Earlier, we were calculating hpd over one dimension only, for ndarrays. So, for backward compatibility I have set default to be calculated only over 'chain' for ndarrays. So, the result still would be (10, 3, 2,).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that calculating hpd only over chain is a very bad default, we'll keep the behaviour (for now) in 2d array case to keep backwards compatibility, but 3d arrays are not supported, so we do not have the backwards compatibility constraint.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I have done the changes.



def test_hpd_multimodal():
normal_sample = np.concatenate(
(np.random.normal(-4, 1, 2500000), np.random.normal(2, 0.5, 2500000))
Expand Down