diff --git a/pyro/distributions/transforms/spline.py b/pyro/distributions/transforms/spline.py index ba7d240a1b..2f32f61aac 100644 --- a/pyro/distributions/transforms/spline.py +++ b/pyro/distributions/transforms/spline.py @@ -254,6 +254,8 @@ def _monotonic_rational_spline( c = -input_delta * (inputs - input_cumheights) discriminant = b.pow(2) - 4 * a * c + # Make sure outside_interval input can be reversed as identity. + discriminant = discriminant.masked_fill(outside_interval_mask, 0) assert (discriminant >= 0).all() root = (2 * c) / (-b - torch.sqrt(discriminant))