diff --git a/pyro/distributions/projected_normal.py b/pyro/distributions/projected_normal.py index 3dc14c2187..42e5c3b330 100644 --- a/pyro/distributions/projected_normal.py +++ b/pyro/distributions/projected_normal.py @@ -170,3 +170,24 @@ def _log_prob_3(concentration, value): ).log() return para_part + perp_part + + +@ProjectedNormal._register_log_prob(dim=4) +def _log_prob_4(concentration, value): + # We integrate along a ray, factorizing the integrand as a product of: + # a truncated normal distribution over coordinate t parallel to the ray, and + # a bivariate normal distribution over coordinate r perpendicular to the ray. + t = _dot(concentration, value) + t2 = t.square() + r2 = _dot(concentration, concentration) - t2 + perp_part = r2.mul(-0.5) - 1.5 * math.log(2 * math.pi) + + # This is the log of a definite integral, computed by mathematica: + # Integrate[x^3/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}] + # = (2 + t^2)/(E^(t^2/2) Sqrt[2 Pi]) + (t (3 + t^2) (1 + Erf[t/Sqrt[2]]))/2 + para_part = ( + (2 + t2) * t2.mul(-0.5).exp() / (2 * math.pi) ** 0.5 + + t * (3 + t2) * (1 + (t * 0.5 ** 0.5).erf()) / 2 + ).log() + + return para_part + perp_part diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index 28680a881e..2067638623 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -529,6 +529,11 @@ def __init__(self, von_loc, von_conc, skewness): {"concentration": [2.0, 3.0], "test_data": [0.0, 1.0]}, {"concentration": [0.0, 0.0, 0.0], "test_data": [1.0, 0.0, 0.0]}, {"concentration": [-1.0, 2.0, 3.0], "test_data": [0.0, 0.0, 1.0]}, + {"concentration": [0.0, 0.0, 0.0, 0.0], "test_data": [1.0, 0.0, 0.0, 0.0]}, + { + "concentration": [-1.0, 2.0, 0.5, -0.5], + "test_data": [0.0, 1.0, 0.0, 0.0], + }, ], ), Fixture(