Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Implement a denormalize custom Jaxpr operator simplifying MCX logpdfs #71

Open
wants to merge 46 commits into
base: master
Choose a base branch
from

Conversation

balancap
Copy link

@balancap balancap commented Jan 31, 2021

Overview

This PR is implementing a generic denormalize decorator which removes normalizing constants in a logpdf. Per call to contributions in #65.

Implementation

The current implementation is a two passes algorithm:

  • Forward pass to find all constant variables in the Jaxpr graph;
  • Backward pass to find all simplifying assignment, but skipping add and sub operations where one of the input is a constant;

Once the latter simplifying mapping is found, the rest of decorator code is just a simple execution pass on the Jaxpr, skipping the operations where a simplifying mapping exists.

Limitations

Even though we try to have a fairly generic implementation, some simplifications are not supported at the moment. For instance, we do not propagate constant simplification in concat or mul operations. These cases could be supported in the future, if it happens to be a performance bottleneck in MCX.

@balancap balancap changed the title Implement a denormalize custom Jaxpr operator simplifying MCX logpdfs WIP: Implement a denormalize custom Jaxpr operator simplifying MCX logpdfs Feb 3, 2021
@balancap
Copy link
Author

balancap commented Feb 3, 2021

@rlouf As we discussed on Slack, there is quite a bit of additional complexity to add to this PR to handle properly the support select condition appearing in lot of distributions logpdf.

I'll start with a fairly dummy implementation, getting it working, and I think then we can iterate on it to make it less naive and using more properly symbolic programming concept (I started looking at Oryx codebase on that).

@rlouf
Copy link
Owner

rlouf commented Feb 3, 2021

That sounds like a very good plan to me! I'll have a closer look too when my big PR is merged.

@rlouf rlouf force-pushed the master branch 3 times, most recently from f8f3e6b to 965f6dd Compare February 23, 2021 11:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants