-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix_reverse_out_bound_quadratic_spline #3140
fix_reverse_out_bound_quadratic_spline #3140
Conversation
@@ -254,6 +254,7 @@ def _monotonic_rational_spline( | |||
c = -input_delta * (inputs - input_cumheights) | |||
|
|||
discriminant = b.pow(2) - 4 * a * c | |||
discriminant[outside_interval_mask] = 0 # added to make sure outside_interval input can be reversed as identity. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you'll need two spaces before the comment, and possibly to move the line comment to a separate line to satisfy max line length requirements. You can locally run make lint
to see lint errors.
Also you might need to use a non-inplace version like
discriminant = discriminant.masked_fill(outside_interval_mask, 0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comment. I am new to pulling a request. So in this case, it is because of my non-standard comment that my request fails, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe so, we use the black automatic style checker to ensure all code follows the same style. This helps everyone collaborate by making the code more uniform.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get it. Thanks. I will edit it soon
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have re-pulled it. Please check it.
Issue Description
quadratic spline in pyro.distributions.transforms.Spline does not consider outside_interval input at reverse.
According to the Neural spline flow(https://arxiv.org/abs/1906.04032), monotonic rational-quadratic transforms sets the boundary derivatives as 1 to match the linear tails. However, when it takes the y(sample from the target distribution ) to reverse back to the x(sample from the base distribution) in the function _monotonic_rational_spline, it does not consider the scenario where y is out of bound, which would render a negative discriminant and return Error.
Environment
For any bugs, please provide the following:
python version: 3.8.7
PyTorch version: 1.12.1+cpu
Pyro version: 1.8.2
Solution
I replace a fake discriminant for the out-of-bound input.
Origin:
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
Changed:
discriminant = b.pow(2) - 4 * a * c
discriminant[outside_interval_mask] = 0 # added to make sure outputs[outside_interval_mask] = inputs[outside_interval_mask]
assert (discriminant >= 0).all()