-
-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add wrapper for running blackjax pathfinder #72
Merged
Merged
Changes from 23 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
26d4db0
Add wrapper for running blackjax pathfinder.
twiecki 9cc95ec
Run black.
twiecki 1e8ad4a
Run precommit.
twiecki a4cf339
Add blackjax to requirements.
twiecki 43b6f8e
Do not make import optional.
twiecki 3a3b2d7
Add more kwargs. Add license. Improve tests. Add doc string. Add fit …
twiecki 74de0f9
Add fit function to base namespace.
twiecki d4e9ab4
Update copyright year.
twiecki 1bf3473
Add type to random_seed and better init. Test for correct shapes.
twiecki 79e89df
Update pymc_experimental/inference/pathfinder.py
twiecki 03406cc
Update pymc_experimental/inference/pathfinder.py
twiecki 4f0dc4e
Fix import of fit function.
twiecki 2cfeed9
Add fit.py.
twiecki 746d1d1
Skip on windows.
twiecki 75e1358
Skip on windows. Move imports inside so that we do not error on windows.
twiecki ca2001d
skipif instead of xfail.
twiecki b6d2843
try/except blackjax import.
twiecki 73f9e5c
try/except blackjax import.
twiecki cf3d0de
Update pymc_experimental/inference/fit.py
twiecki 6deca7b
Update pymc_experimental/inference/fit.py
twiecki 0bbfb56
Move blackjax to dev reqs.
twiecki 7f7fcb3
Make import non-optional.
twiecki e2dad01
Precommit.
twiecki 4fc5e89
Change imports.
twiecki 7bdb10a
Call fit() from test.
twiecki 8f07c25
Fix fit import.
twiecki 350b77d
Fix kwargs.
twiecki 7734b03
Add blackjax to test env.
twiecki 625caf0
Only look for tests in test subdir.
twiecki 2d0a435
Only look for tests in test subdir.
twiecki 0dca76c
Only look for tests in test subdir.
twiecki File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from pymc_experimental.inference.pathfinder import fit_pathfinder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright 2022 The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
def fit(method, *kwargs): | ||
""" | ||
Fit a model with an inference algorithm | ||
|
||
Parameters | ||
---------- | ||
method : str | ||
Which inference method to run. | ||
Supported: pathfinder | ||
|
||
kwargs are passed on. | ||
|
||
Returns | ||
------- | ||
arviz.InferenceData | ||
""" | ||
if method == "pathfinder": | ||
try: | ||
from pymc_experimental.inference import fit_pathfinder | ||
except ImportError as exc: | ||
raise RuntimeError("Need JAX/ Blackjax / wahever to use `pathfinder`") from exc | ||
return fit_pathfinder(**kwargs) | ||
twiecki marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# Copyright 2022 The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import collections | ||
import sys | ||
from typing import Optional | ||
|
||
import arviz as az | ||
import blackjax | ||
import jax | ||
import jax.numpy as jnp | ||
import jax.random as random | ||
import numpy as np | ||
import pymc as pm | ||
from pymc import modelcontext | ||
from pymc.sampling import RandomSeed, _get_seeds_per_chain | ||
from pymc.sampling_jax import get_jaxified_graph | ||
from pymc.util import get_default_varnames | ||
|
||
|
||
def convert_flat_trace_to_idata( | ||
samples, | ||
dims=None, | ||
coords=None, | ||
include_transformed=False, | ||
postprocessing_backend="cpu", | ||
model=None, | ||
): | ||
|
||
model = modelcontext(model) | ||
init_position_dict = model.initial_point() | ||
trace = collections.defaultdict(list) | ||
astart = pm.blocking.DictToArrayBijection.map(init_position_dict) | ||
for sample in samples: | ||
raveld_vars = pm.blocking.RaveledVars(sample, astart.point_map_info) | ||
point = pm.blocking.DictToArrayBijection.rmap(raveld_vars, init_position_dict) | ||
for p, v in point.items(): | ||
trace[p].append(v.tolist()) | ||
|
||
trace = {k: np.asarray(v)[None, ...] for k, v in trace.items()} | ||
|
||
var_names = model.unobserved_value_vars | ||
vars_to_sample = list(get_default_varnames(var_names, include_transformed=include_transformed)) | ||
print("Transforming variables...", file=sys.stdout) | ||
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) | ||
result = jax.vmap(jax.vmap(jax_fn))( | ||
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0]) | ||
) | ||
|
||
trace = {v.name: r for v, r in zip(vars_to_sample, result)} | ||
idata = az.from_dict(trace, dims=dims, coords=coords) | ||
|
||
return idata | ||
|
||
|
||
def fit_pathfinder( | ||
iterations=5_000, | ||
random_seed: Optional[RandomSeed] = None, | ||
postprocessing_backend="cpu", | ||
ftol=1e-4, | ||
model=None, | ||
): | ||
""" | ||
Fit the pathfinder algorithm as implemented in blackjax | ||
|
||
Requires the JAX backend | ||
|
||
Parameters | ||
---------- | ||
iterations : int | ||
Number of iterations to run. | ||
random_seed : int | ||
Random seed to set. | ||
postprocessing_backend : str | ||
Where to compute transformations of the trace. | ||
"cpu" or "gpu". | ||
ftol : float | ||
Floating point tolerance | ||
|
||
Returns | ||
------- | ||
arviz.InferenceData | ||
|
||
Reference | ||
--------- | ||
https://arxiv.org/abs/2108.03782 | ||
""" | ||
|
||
(random_seed,) = _get_seeds_per_chain(random_seed, 1) | ||
|
||
model = modelcontext(model) | ||
|
||
rvs = [rv.name for rv in model.value_vars] | ||
init_position_dict = model.initial_point() | ||
init_position = [init_position_dict[rv] for rv in rvs] | ||
|
||
new_logprob, new_input = pm.aesaraf.join_nonshared_inputs( | ||
init_position_dict, (model.logp(),), model.value_vars, () | ||
) | ||
|
||
logprob_fn_list = get_jaxified_graph([new_input], new_logprob) | ||
|
||
def logprob_fn(x): | ||
return logprob_fn_list(x)[0] | ||
|
||
dim = sum(v.size for v in init_position_dict.values()) | ||
|
||
rng_key = random.PRNGKey(random_seed) | ||
w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(dim), jnp.eye(dim)) | ||
path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=ftol) | ||
|
||
pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=ftol) | ||
state = pathfinder.init(w0) | ||
|
||
def inference_loop(rng_key, kernel, initial_state, num_samples): | ||
@jax.jit | ||
def one_step(state, rng_key): | ||
state, info = kernel(rng_key, state) | ||
return state, (state, info) | ||
|
||
keys = jax.random.split(rng_key, num_samples) | ||
return jax.lax.scan(one_step, initial_state, keys) | ||
|
||
_, rng_key = random.split(rng_key) | ||
print("Running pathfinder...", file=sys.stdout) | ||
_, (_, samples) = inference_loop(rng_key, pathfinder.step, state, iterations) | ||
|
||
dims = { | ||
var_name: [dim for dim in dims if dim is not None] | ||
for var_name, dims in model.RV_dims.items() | ||
} | ||
|
||
idata = convert_flat_trace_to_idata( | ||
samples, postprocessing_backend=postprocessing_backend, coords=model.coords, dims=dims | ||
) | ||
|
||
return idata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright 2022 The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import sys | ||
|
||
import numpy as np | ||
import pymc as pm | ||
import pytest | ||
|
||
import pymc_experimental as pmx | ||
|
||
|
||
@pytest.mark.skipif(sys.platform == "win32", reason="JAX not supported on windows.") | ||
def test_pathfinder(): | ||
# Data of the Eight Schools Model | ||
J = 8 | ||
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) | ||
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) | ||
|
||
with pm.Model() as model: | ||
|
||
mu = pm.Normal("mu", mu=0.0, sigma=10.0) | ||
tau = pm.HalfCauchy("tau", 5.0) | ||
|
||
theta = pm.Normal("theta", mu=0, sigma=1, shape=J) | ||
theta_1 = mu + tau * theta | ||
obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y) | ||
|
||
idata = pmx.inference.fit_pathfinder(iterations=100) | ||
|
||
assert idata is not None | ||
twiecki marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert "theta" in idata.posterior._variables.keys() | ||
assert "tau" in idata.posterior._variables.keys() | ||
assert "mu" in idata.posterior._variables.keys() | ||
assert idata.posterior["mu"].shape == (1, 100) | ||
assert idata.posterior["tau"].shape == (1, 100) | ||
assert idata.posterior["theta"].shape == (1, 100, 8) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
dask[all] | ||
blackjax |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not acceptable