Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blackjax's Sequential Monte Carlo over PyMC models. #6989

Closed

Conversation

ciguaran
Copy link
Contributor

@ciguaran ciguaran commented Nov 3, 2023

What is this PR about?

Allows to sample Pymc's models using Blackjax's Sequential Monte Carlo implementations. Apart from getting a Jax-based implementation for Blackjax, this PR allows for using HMC and NUTS as kernels, which aren't available in the existing PyMC implementation of SMC. Moreover, diagnosis are exposed and stored in the resulting arviz.InferenceData

Checklist

  • Explain important implementation details 👆
    In order to sample using BJ SMC, we need
  • code to compute jaxified logprior and loglikelihood functions, applyable over SMC Particles.
  • code to compute diagnosis over the SMC run.
  • code to build an arviz.InferenceData object from the sampler's output.

Major / Breaking Changes

  • ...
    None

New features

  • ...
  • Integration of Blackjax SMC samplers
  • Code to compute jaxified loglikelihood and logprior.
  • Diagnosis

Bugfixes

  • ...

Documentation

  • ...

Maintenance

  • ...

@junpenglao @aloctavodia

Copy link

codecov bot commented Nov 3, 2023

Codecov Report

Merging #6989 (123755a) into main (ec4407d) will decrease coverage by 0.91%.
The diff coverage is 12.21%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6989      +/-   ##
==========================================
- Coverage   87.78%   86.87%   -0.91%     
==========================================
  Files         100      102       +2     
  Lines       16896    17017     +121     
==========================================
- Hits        14832    14784      -48     
- Misses       2064     2233     +169     
Files Coverage Δ
pymc/smc/kernels.py 77.14% <100.00%> (-20.31%) ⬇️
pymc/smc/from_blackjax/kernels.py 0.00% <0.00%> (ø)
pymc/smc/from_blackjax/sampling.py 0.00% <0.00%> (ø)

... and 4 files with indirect coverage changes

model = fast_model()
population = {"x": np.array([2, 3, 4])}
blackjax_particles = blackjax_particles_from_pymc_population(model, population)
chex.assert_trees_all_close(blackjax_particles, [np.array([[2], [3], [4]])])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think PyMC use chex - use jax.tree_map(np.testing.assert..., ...) instead

Copy link
Contributor Author

@ciguaran ciguaran Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like the dependency is transitive (pymc -> blackjax -> optax -> chex)

@ricardoV94
Copy link
Member

Can we add this to pymc-experimental first?

@ciguaran
Copy link
Contributor Author

Closing this one in favour of pymc-devs/pymc-extras#267

@ciguaran ciguaran closed this Nov 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants