Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wrapper for running blackjax pathfinder #72

Merged
merged 31 commits into from
Sep 9, 2022
Merged

Conversation

twiecki
Copy link
Member

@twiecki twiecki commented Sep 7, 2022

Lot's of help from @ricardoV94

CC @rlouf

pymc_experimental/inference/pathfinder.py Outdated Show resolved Hide resolved
pymc_experimental/inference/pathfinder.py Outdated Show resolved Hide resolved
pymc_experimental/tests/test_pathfinder.py Outdated Show resolved Hide resolved
@twiecki
Copy link
Member Author

twiecki commented Sep 8, 2022

Adressed comments.

@twiecki
Copy link
Member Author

twiecki commented Sep 8, 2022

Ready for next round.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just two hard-coded values still in there.

pymc_experimental/inference/pathfinder.py Outdated Show resolved Hide resolved
pymc_experimental/inference/pathfinder.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

Just checking if the tests pass, they were failing weirdly before.

@ricardoV94
Copy link
Member

Maybe you just need to rebase from main, because @ferrine did some changes? (if they fail again)

@twiecki twiecki force-pushed the blackjax_pathfinder branch from afc09ea to 03406cc Compare September 8, 2022 15:47
@twiecki
Copy link
Member Author

twiecki commented Sep 8, 2022

Oh, of course it doesn't work on windows...

@ColCarroll
Copy link
Member

Not sure who this comment is for, but:

  1. it'd be cool to be able to use the covariance estimates from the lbfgs iterations to initialize a mass matrix for NUTS/HMC
  2. it'd be even cooler to use this as an automatic initialization for those samplers: you could use the condition number of the covariance, and the desired acceptance probability to also estimate an initial step size. i'm specifically suggesting having a path towards providing pm.sample(init='pathfinder'). while in the experimental directory, it perhaps makes most sense to have a sample_pathfinder function?

neither of these are blockers, just hopeful suggestions!

@ColCarroll
Copy link
Member

Sorry, looking more at the blackjax API, it seems like you'd get both the covariance and the initial values from just init, so I'm describing a different function from this one:

  • this one does a variational fit + generates (approximate) samples
  • i'd like to do a more efficient job initializing MCMC and generating (asymptotically, but not actually) exact samples

@rlouf
Copy link

rlouf commented Sep 8, 2022

  1. it'd be cool to be able to use the covariance estimates from the lbfgs iterations to initialize a mass matrix for NUTS/HMC

We have an adaptation routine that does just that. We really need documentation 🙃

@ricardoV94
Copy link
Member

Oh, of course it doesn't work on windows...

pytest.skip based on sys.platform.

@ricardoV94
Copy link
Member

Goddamn, isn't there a way to maybe turn the import into a noop?

Good old try/except

@twiecki
Copy link
Member Author

twiecki commented Sep 9, 2022

It's green :party:!

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about this instead?

pymc_experimental/inference/fit.py Show resolved Hide resolved
pymc_experimental/inference/fit.py Outdated Show resolved Hide resolved
try:
from pymc_experimental.inference import fit_pathfinder
except ImportError as exc:
raise RuntimeError("Need JAX/ Blackjax / wahever to use `pathfinder`") from exc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not acceptable

@ricardoV94 ricardoV94 changed the title Add wrapper for running blackjax pathfinder. Add wrapper for running blackjax pathfinder Sep 9, 2022
@ricardoV94 ricardoV94 merged commit 75aa37b into main Sep 9, 2022
@twiecki twiecki deleted the blackjax_pathfinder branch September 9, 2022 15:32
@twiecki twiecki mentioned this pull request Sep 10, 2022
@ricardoV94
Copy link
Member

ricardoV94 commented Sep 10, 2022

It's green :party:!

Haha it was green because no tests were run 🤦

@twiecki
Copy link
Member Author

twiecki commented Sep 10, 2022

FML.

@twiecki twiecki restored the blackjax_pathfinder branch September 10, 2022 09:27
twiecki added a commit that referenced this pull request Sep 10, 2022
twiecki added a commit that referenced this pull request Sep 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants