Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyMC Implementation of Pathfinder VI #386

Closed
wants to merge 10 commits into from
Closed

Conversation

aphc14
Copy link

@aphc14 aphc14 commented Oct 31, 2024

Summary:

  • Adds a PyMC implementation of Pathfinder Variational Inference using PyTensor operations. The new implementation allows users to choose between PyMC and BlackJAX backends while maintaining the same API.
  • Added lbfgs.py module implementing L-BFGS optimisation with history tracking
  • Extended pathfinder.py with PyMC implementation using PyTensor operations.

Note: Another draft PR will be sent that focuses on a PyTensor symbolic implementation using pytensor.function. I've sent two PR drafts to get feedback on which version would be better.

import pymc as pm
import pymc_experimental as pmx

with pm.Model() as model:
    # ... model definition ...
    idata = pmx.fit(
        method="pathfinder",
        inference_backend="pymc",  # or "blackjax"
        random_seed=42
    )
with model:
    # eight_schools_model
    idata = pmx.fit(model=model, method="pathfinder", random_seed=41, inference_backend="pymc")
    
    # New implementation now passes this assertion! :)
    # FIXME: pathfinder doesn't find a reasonable mean! Fix bug or choose model pathfinder can handle
    np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=1.0)

    # But, it also fails this :(
    # FIXME: now the tau is being underestimated. getting tau around 1.5.
    # np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=0.5)

`fit_pathfinder`
- Edited `fit_pathfinder` to produce `pathfinder_state`, `pathfinder_info`, `pathfinder_samples` and `pathfinder_idata` for closer examination of the outputs.
- Changed the `num_samples` argument name to `num_draws` to avoid `TypeError` got multiple values for keyword argument 'num_samples'.
- Initial points are automatically set to jitter as jitter is required for pathfinder.

Extras
- New function 'get_jaxified_logp_ravel_inputs' to simplify previous code structure in fit_pathfinder.

Tests
- Added extra test for pathfinder to test pathfinder_info variables and pathfinder_idata  are consistent for a given random seed.
Add a new PyMC-based implementation of Pathfinder VI that uses PyTensor operations which provides support for both PyMC and BlackJAX backends in fit_pathfinder.
@twiecki
Copy link
Member

twiecki commented Oct 31, 2024

This looks great @aphc14, can you find me on linkedin?

- Implemented  in  to support running multiple Pathfinder instances in parallel.
- Implemented  function in  for Pareto Smoothed Importance Resampling (PSIR).
- Moved relevant pathfinder files into the  directory.
- Updated tests to reflect changes in the Pathfinder implementation and added tests for new functionalities.
@aphc14
Copy link
Author

aphc14 commented Nov 4, 2024

@twiecki yup, added you :)
@fonnesbeck, FYI, Multipath Pathfinder has just been implemented!

aphc14 added a commit to aphc14/pymc-experimental that referenced this pull request Nov 7, 2024
@aphc14 aphc14 closed this Nov 11, 2024
@aphc14
Copy link
Author

aphc14 commented Nov 11, 2024

closed this in favour of #387

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants