From c573af15a36f0c879da982654554795823db6bda Mon Sep 17 00:00:00 2001 From: LiaoShiqi97 <70330114+LiaoShiqi97@users.noreply.github.com> Date: Fri, 21 Oct 2022 03:19:02 +0200 Subject: [PATCH] fix_reverse_out_bound_quadratic_spline (#3140) * fix_reverse_out_bound_quadratic_spline * fix_reverse_out_bound_quadratic_spline * add two spaces before comment * add sapce * space * lint Co-authored-by: Fritz Obermeyer --- pyro/distributions/transforms/spline.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyro/distributions/transforms/spline.py b/pyro/distributions/transforms/spline.py index ba7d240a1b..2f32f61aac 100644 --- a/pyro/distributions/transforms/spline.py +++ b/pyro/distributions/transforms/spline.py @@ -254,6 +254,8 @@ def _monotonic_rational_spline( c = -input_delta * (inputs - input_cumheights) discriminant = b.pow(2) - 4 * a * c + # Make sure outside_interval input can be reversed as identity. + discriminant = discriminant.masked_fill(outside_interval_mask, 0) assert (discriminant >= 0).all() root = (2 * c) / (-b - torch.sqrt(discriminant))