-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Implement a denormalize
custom Jaxpr operator simplifying MCX logpdfs
#71
base: master
Are you sure you want to change the base?
Conversation
…f by removing constants.
denormalize
custom Jaxpr operator simplifying MCX logpdfsdenormalize
custom Jaxpr operator simplifying MCX logpdfs
@rlouf As we discussed on Slack, there is quite a bit of additional complexity to add to this PR to handle properly the support I'll start with a fairly dummy implementation, getting it working, and I think then we can iterate on it to make it less naive and using more properly symbolic programming concept (I started looking at Oryx codebase on that). |
That sounds like a very good plan to me! I'll have a closer look too when my big PR is merged. |
f8f3e6b
to
965f6dd
Compare
Overview
This PR is implementing a generic
denormalize
decorator which removes normalizing constants in a logpdf. Per call to contributions in #65.Implementation
The current implementation is a two passes algorithm:
add
andsub
operations where one of the input is a constant;Once the latter simplifying mapping is found, the rest of decorator code is just a simple execution pass on the Jaxpr, skipping the operations where a simplifying mapping exists.
Limitations
Even though we try to have a fairly generic implementation, some simplifications are not supported at the moment. For instance, we do not propagate constant simplification in
concat
ormul
operations. These cases could be supported in the future, if it happens to be a performance bottleneck in MCX.