Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wrapper for running blackjax pathfinder #72

Merged
merged 31 commits into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
26d4db0
Add wrapper for running blackjax pathfinder.
twiecki Sep 7, 2022
9cc95ec
Run black.
twiecki Sep 7, 2022
1e8ad4a
Run precommit.
twiecki Sep 7, 2022
a4cf339
Add blackjax to requirements.
twiecki Sep 7, 2022
43b6f8e
Do not make import optional.
twiecki Sep 7, 2022
3a3b2d7
Add more kwargs. Add license. Improve tests. Add doc string. Add fit …
twiecki Sep 8, 2022
74de0f9
Add fit function to base namespace.
twiecki Sep 8, 2022
d4e9ab4
Update copyright year.
twiecki Sep 8, 2022
1bf3473
Add type to random_seed and better init. Test for correct shapes.
twiecki Sep 8, 2022
79e89df
Update pymc_experimental/inference/pathfinder.py
twiecki Sep 8, 2022
03406cc
Update pymc_experimental/inference/pathfinder.py
twiecki Sep 8, 2022
4f0dc4e
Fix import of fit function.
twiecki Sep 8, 2022
2cfeed9
Add fit.py.
twiecki Sep 8, 2022
746d1d1
Skip on windows.
twiecki Sep 8, 2022
75e1358
Skip on windows. Move imports inside so that we do not error on windows.
twiecki Sep 8, 2022
ca2001d
skipif instead of xfail.
twiecki Sep 9, 2022
b6d2843
try/except blackjax import.
twiecki Sep 9, 2022
73f9e5c
try/except blackjax import.
twiecki Sep 9, 2022
cf3d0de
Update pymc_experimental/inference/fit.py
twiecki Sep 9, 2022
6deca7b
Update pymc_experimental/inference/fit.py
twiecki Sep 9, 2022
0bbfb56
Move blackjax to dev reqs.
twiecki Sep 9, 2022
7f7fcb3
Make import non-optional.
twiecki Sep 9, 2022
e2dad01
Precommit.
twiecki Sep 9, 2022
4fc5e89
Change imports.
twiecki Sep 9, 2022
7bdb10a
Call fit() from test.
twiecki Sep 9, 2022
8f07c25
Fix fit import.
twiecki Sep 9, 2022
350b77d
Fix kwargs.
twiecki Sep 9, 2022
7734b03
Add blackjax to test env.
twiecki Sep 9, 2022
625caf0
Only look for tests in test subdir.
twiecki Sep 9, 2022
2d0a435
Only look for tests in test subdir.
twiecki Sep 9, 2022
0dca76c
Only look for tests in test subdir.
twiecki Sep 9, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pymc_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
_log.addHandler(handler)


from pymc_experimental import distributions, gp, utils
from pymc_experimental import distributions, gp, inference, utils
from pymc_experimental.bart import *
from pymc_experimental.inference.fit import fit
2 changes: 1 addition & 1 deletion pymc_experimental/bart/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020 The PyMC Developers
# Copyright 2022 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
1 change: 1 addition & 0 deletions pymc_experimental/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pymc_experimental.inference.pathfinder import fit_pathfinder
35 changes: 35 additions & 0 deletions pymc_experimental/inference/fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2022 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pymc_experimental.inference import fit_pathfinder
twiecki marked this conversation as resolved.
Show resolved Hide resolved


def fit(method, *kwargs):
"""
Fit a model with an inference algorithm

Parameters
----------
method : str
Which inference method to run.
Supported: pathfinder

kwargs are passed on.

Returns
-------
arviz.InferenceData
"""
if method == "pathfinder":
return fit_pathfinder(**kwargs)
twiecki marked this conversation as resolved.
Show resolved Hide resolved
154 changes: 154 additions & 0 deletions pymc_experimental/inference/pathfinder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2022 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

try:
import blackjax
import jax
import jax.numpy as jnp
import jax.random as random
from pymc.sampling_jax import get_jaxified_graph
except ImportError:
warnings.warn("Can't import blackjax. Pathfinder will not be available.")

import collections
import sys
from typing import Optional

import arviz as az
import numpy as np
import pymc as pm
from pymc import modelcontext
from pymc.sampling import RandomSeed, _get_seeds_per_chain
from pymc.util import get_default_varnames


def convert_flat_trace_to_idata(
samples,
dims=None,
coords=None,
include_transformed=False,
postprocessing_backend="cpu",
model=None,
):

model = modelcontext(model)
init_position_dict = model.initial_point()
trace = collections.defaultdict(list)
astart = pm.blocking.DictToArrayBijection.map(init_position_dict)
for sample in samples:
raveld_vars = pm.blocking.RaveledVars(sample, astart.point_map_info)
point = pm.blocking.DictToArrayBijection.rmap(raveld_vars, init_position_dict)
for p, v in point.items():
trace[p].append(v.tolist())

trace = {k: np.asarray(v)[None, ...] for k, v in trace.items()}

var_names = model.unobserved_value_vars
vars_to_sample = list(get_default_varnames(var_names, include_transformed=include_transformed))
print("Transforming variables...", file=sys.stdout)
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = jax.vmap(jax.vmap(jax_fn))(
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
)

trace = {v.name: r for v, r in zip(vars_to_sample, result)}
idata = az.from_dict(trace, dims=dims, coords=coords)

return idata


def fit_pathfinder(
iterations=5_000,
random_seed: Optional[RandomSeed] = None,
postprocessing_backend="cpu",
ftol=1e-4,
model=None,
):
"""
Fit the pathfinder algorithm as implemented in blackjax

Requires the JAX backend

Parameters
----------
iterations : int
Number of iterations to run.
random_seed : int
Random seed to set.
postprocessing_backend : str
Where to compute transformations of the trace.
"cpu" or "gpu".
ftol : float
Floating point tolerance

Returns
-------
arviz.InferenceData

Reference
---------
https://arxiv.org/abs/2108.03782
"""

(random_seed,) = _get_seeds_per_chain(random_seed, 1)

model = modelcontext(model)

rvs = [rv.name for rv in model.value_vars]
init_position_dict = model.initial_point()
init_position = [init_position_dict[rv] for rv in rvs]

new_logprob, new_input = pm.aesaraf.join_nonshared_inputs(
init_position_dict, (model.logp(),), model.value_vars, ()
)

logprob_fn_list = get_jaxified_graph([new_input], new_logprob)

def logprob_fn(x):
return logprob_fn_list(x)[0]

dim = sum(v.size for v in init_position_dict.values())

rng_key = random.PRNGKey(random_seed)
w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(dim), jnp.eye(dim))
path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=ftol)

pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=ftol)
state = pathfinder.init(w0)

def inference_loop(rng_key, kernel, initial_state, num_samples):
@jax.jit
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)

keys = jax.random.split(rng_key, num_samples)
return jax.lax.scan(one_step, initial_state, keys)

_, rng_key = random.split(rng_key)
print("Running pathfinder...", file=sys.stdout)
_, (_, samples) = inference_loop(rng_key, pathfinder.step, state, iterations)

dims = {
var_name: [dim for dim in dims if dim is not None]
for var_name, dims in model.RV_dims.items()
}

idata = convert_flat_trace_to_idata(
samples, postprocessing_backend=postprocessing_backend, coords=model.coords, dims=dims
)

return idata
48 changes: 48 additions & 0 deletions pymc_experimental/tests/test_pathfinder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2022 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

import numpy as np
import pymc as pm
import pytest

import pymc_experimental as pmx


@pytest.mark.skipif(sys.platform == "win32", reason="JAX not supported on windows.")
def test_pathfinder():
# Data of the Eight Schools Model
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

with pm.Model() as model:

mu = pm.Normal("mu", mu=0.0, sigma=10.0)
tau = pm.HalfCauchy("tau", 5.0)

theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
theta_1 = mu + tau * theta
obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y)

idata = pmx.inference.fit_pathfinder(iterations=100)

assert idata is not None
twiecki marked this conversation as resolved.
Show resolved Hide resolved
assert "theta" in idata.posterior._variables.keys()
assert "tau" in idata.posterior._variables.keys()
assert "mu" in idata.posterior._variables.keys()
assert idata.posterior["mu"].shape == (1, 100)
assert idata.posterior["tau"].shape == (1, 100)
assert idata.posterior["theta"].shape == (1, 100, 8)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pymc>=4.0.1
xhistogram
blackjax