Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform jax samples #4427

Merged
merged 5 commits into from
Feb 12, 2021
Merged

Conversation

martiningram
Copy link
Contributor

Hi there,

This pull request addresses the following issue: #4415

It uses the model's fastfn function to compute the values of all the variables listed in model.unobserved_RVs. This should include the transformed variables as well as any deterministic variables of interest. It's also what I believe is done in the pymc3 samplers currently -- it's a bit hidden but I believe it is done here.

A note is that this implementation loops over chains and samples and is thus not particularly efficient. I have added a timing print statement to the code to easily see this. I ran it on the LKJ example and sampling took 20s (500 warmup, 500 sampling, 4 chains), transforming took 7s. A cool improvement would be to somehow turn the theano fastfn into a JAX function, which could then be evaluated much more efficiently using jax.vmap across the samples, but I didn't see an easy way to jax_funcify this function (just calling jax_funcify doesn't work). If someone knows how, I am happy to update the code.

Interested to hear your thoughts! I'm also planning to add an example notebook soon to show what this does.

@ricardoV94
Copy link
Member

Probably dumb question, but can't we convert all transformed variables at the end of sampling with vectorized operations?

@martiningram
Copy link
Contributor Author

martiningram commented Jan 21, 2021

Probably dumb question, but can't we convert all transformed variables at the end of sampling with vectorized operations?

Hey Ricardo,

in principle I think that's a good idea, but at least with this approach, calling the function returned by eval_fun = model.fastfn(var_names) seems to complain if the dimensions don't match what it expects -- e.g. calling it directly on the dictionary of samples:

eval_fun(samples)
TypeError: ('Wrong number of dimensions: expected 1, got 3 with shape (4, 500, 3).', 'Container name "chol_cholesky-cov-packed__"')

As mentioned in the pull request, if we could get a JAX version of the function, we could just use vmap to automatically vectorise everything using JAX. Failing that, maybe there's some theano trick that I'm missing here? (I basically don't know any theano) I thought maybe theano.map might work, but it doesn't seem to be able to map over dicts.

@junpenglao
Copy link
Member

Probably dumb question, but can't we convert all transformed variables at the end of sampling with vectorized operations?

+1. We should jaxify the theano function, and then vmap it.

@junpenglao
Copy link
Member

To be more specific, we should have only 1 for loop to loop over all the RVs, and if they are transformed, grab the forward function and compile it into jax function, then call jax.vmap(jax.vmap(forward(...)))

@martiningram
Copy link
Contributor Author

Probably dumb question, but can't we convert all transformed variables at the end of sampling with vectorized operations?

+1. We should jaxify the theano function, and then vmap it.

I agree that this would be the best solution. Here's what happens if I naively try:

jax_funcify(model.fastfn(var_names))

gives:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-34-8156f409d522> in <module>
----> 1 jax_funcify(model.fastfn(var_names))

~/miniconda3/envs/pymc3/lib/python3.7/functools.py in wrapper(*args, **kw)
    838                             '1 positional argument')
    839 
--> 840         return dispatch(args[0].__class__)(*args, **kw)
    841 
    842     funcname = getattr(func, '__name__', 'singledispatch function')

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in jax_funcify(op)
    195 def jax_funcify(op):
    196     """Create a JAX "perform" function for a Theano `Variable` and its `Op`."""
--> 197     raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
    198 
    199 

NotImplementedError: No JAX conversion for the given `Op`: <pymc3.model.FastPointFunc object at 0x7f88ad1f4b90>

Also:

jax_funcify(model.fastfn)

gives

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-35-d10111a9c6ba> in <module>
----> 1 jax_funcify(model.fastfn)

~/miniconda3/envs/pymc3/lib/python3.7/functools.py in wrapper(*args, **kw)
    838                             '1 positional argument')
    839 
--> 840         return dispatch(args[0].__class__)(*args, **kw)
    841 
    842     funcname = getattr(func, '__name__', 'singledispatch function')

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in jax_funcify(op)
    195 def jax_funcify(op):
    196     """Create a JAX "perform" function for a Theano `Variable` and its `Op`."""
--> 197     raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
    198 
    199 

NotImplementedError: No JAX conversion for the given `Op`: <bound method Model.fastfn of <pymc3.model.Model object at 0x7f88afa06ed0>>

@martiningram
Copy link
Contributor Author

martiningram commented Jan 21, 2021

To be more specific, we should have only 1 for loop to loop over all the RVs, and if they are transformed, grab the forward function and compile it into jax function, then call jax.vmap(jax.vmap(forward(...)))

Yep, sounds good. Only issue: would this cover deterministic variables too? If not it might be better to try to use the model.fastfn. If we did have that as a JAX function, I believe JAX's pytree features would mean we could actually just do:

jax.vmap(jax.vmap(fun))(samples)

over the dict of samples and it should all work just fine (but maybe I'm wrong about that).

@junpenglao
Copy link
Member

over the dict of samples and it should all work just fine (but maybe I'm wrong about that).

Oh you are right.
As for the error, that's because pymc3.model.FastPointFunc is already jitted (theano.function jit the function in https://github.com/pymc-devs/pymc3/blob/acb326149adffe03fd06aac6515ed58b682f646b/pymc3/model.py#L1239). Instead, try to do something similar to how the jaxify log_prob function in the sample_jax function

@martiningram
Copy link
Contributor Author

martiningram commented Jan 21, 2021

over the dict of samples and it should all work just fine (but maybe I'm wrong about that).

Oh you are right.
As for the error, that's because pymc3.model.FastPointFunc is already jitted (theano.function jit the function in

https://github.com/pymc-devs/pymc3/blob/acb326149adffe03fd06aac6515ed58b682f646b/pymc3/model.py#L1239
). Instead, try to do something similar to how the jaxify log_prob function in the sample_jax function

Sounds great, thanks, I'll look into that!
EDIT: thanks, I think I know what to do. It's late in Aus but should hopefully have an update tomorrow!

@martiningram
Copy link
Contributor Author

martiningram commented Jan 21, 2021

Hi all,

please take a look at the latest version, which uses jax_funcify and vmap. It's not quite as neat as I'd hoped it to be, but I think this is much better. This version requires only one loop over the rvs that have to be computed. It's much faster, now taking only something like a fifth of a second.

As a side note (it's not an issue with the code here): I did run into slightly odd behaviour which I thought I'd mention. When exploring how jax_funcify worked on the LKJ example:

graphs = {x.name: theano.graph.fg.FunctionGraph(model.free_RVs, [x]) for x in model.unobserved_RVs}
jax_fns = {x: jax_funcify(y) for x, y in graphs.items()}

I got:

{'chol_cholesky-cov-packed__': [],
 'μ': [],
 'chol': [<function theano.link.jax.jax_dispatch.jax_funcify_ViewOp.<locals>.viewop(x)>],
 'chol_stds': [<function theano.link.jax.jax_dispatch.jax_funcify_Identity.<locals>.identity(x)>],
 'chol_corr': [<function theano.link.jax.jax_dispatch.jax_funcify_Identity.<locals>.identity(x)>],
 'cov': [<function theano.link.jax.jax_dispatch.jax_funcify_Identity.<locals>.identity(x)>]}

What I thought was a little odd here is the first two: these two are the inputs, and the jaxified function is just empty. I'm guessing this is related to the fact that no computation is required for them, since they are already in the inputs, but I would have expected some kind of identity function or something.

For the code in the pull request, I avoid this issue by first working out which RVs actually have to be computed:

free_rv_names = {x.name for x in model.free_RVs}
unobserved_names = {x.name for x in model.unobserved_RVs}

names_to_compute = unobserved_names - free_rv_names
ops_to_compute = [x for x in model.unobserved_RVs if x.name in names_to_compute]

and then only computing those, using jax.vmap(jax.vmap(fun)) for each one, as discussed. We could maybe avoid the for-loop entirely by putting the list of JAX functions together into a single JAX function and then vmapping that, but I'm not sure that's worth it, let me know what you think.

Looking forward to hearing your thoughts and whether anything could be improved.

@michaelosthege michaelosthege added this to the vNext milestone Jan 22, 2021
@martiningram
Copy link
Contributor Author

Hi all,

I've made one further change, and I think I may have cluttered everything by rebasing with upstream, I'm really sorry about that, I'm very new to all of this. I hope I haven't made things too horrible; please let me know if there's anything I should do to neaten up this pull request again.

The one additional change I have made is this commit here:
0b3ba7a

What this change does is use:

pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed)

, where keep_untransformed is currently set to False, to discard the untransformed variables. This gets rid of the variables ending with __log, and so on, which I think is desirable since most users will only want to see the transformed parameters. It also matches what arviz is doing: https://github.com/arviz-devs/arviz/blob/db41827a700b8bfb29f28566bea24728cc424a3f/arviz/data/io_pymc3.py#L246

Once again, sorry about the clutter, let me know if there's something I should and of course if there is anything else to improve with this pull request.

@michaelosthege
Copy link
Member

Hi @martiningram ,
yes the diff is mixed with changes that already happened on master. Looks like something went wrong with the rebase.
If I understood correctly, your PR makes only few changes.
If fixing through rebase prooves too difficult, it might be worthwhile to use the xkcd strategy:

  1. make a backup copy of your PyMC3 folder
  2. checkout the latest master
  3. delete the transform_jax_samples branch
  4. create a new transform_jax_samples branch
  5. re-apply your changes
  6. commit & force-push to override the old transform_jax_samples branch with the new one

When you consider it done, make sure to add a line to the release notes.

* Add `pymc3.sampling_jax._transform_samples` function which transforms draws

* Modify `pymc3.sampling_jax.sample_numpyro_nuts` function to use this function to return transformed samples

* Add release note
@martiningram
Copy link
Contributor Author

Thanks a lot for your help @michaelosthege ! I think I've managed to follow your xkcd strategy, and I've added a line to the release notes, too. Let me know what you think!

@junpenglao
Copy link
Member

Could you add a small test? A Normal likelihood with HalfNormal prior as sigma will do.

@martiningram
Copy link
Contributor Author

Could you add a small test? A Normal likelihood with HalfNormal prior as sigma will do.

Sure, will do! I'm on holiday right now but should have something by early next week at the latest.

pymc3/sampling_jax.py Outdated Show resolved Hide resolved
@martiningram
Copy link
Contributor Author

martiningram commented Feb 8, 2021

Hi @junpenglao , I've tried my hand at adding a small test. The test checks that the transformation from the log scale works correctly. Please let me know if I should change anything, or whether this is roughly what you had in mind! Note that to make this work, I had to add an argument keep_untransformed to sample_numpyro_nuts which, when True, keeps variables like sigma_log__ rather than the default behaviour of discarding them.

@michaelosthege
Copy link
Member

It looks like the test in the CI failes due to the lack of jax.
@twiecki @junpenglao how should we proceed? Should we install jax into the CI environment, oder leave the test out of the CI ?

@twiecki
Copy link
Member

twiecki commented Feb 8, 2021 via email

@twiecki
Copy link
Member

twiecki commented Feb 9, 2021

CC @MarcoGorelli on how to install JAX only for CI.

@MarcoGorelli
Copy link
Contributor

I think in the conda-envs files

@michaelosthege
Copy link
Member

@twiecki @MarcoGorelli as you can see above I tried to add the jax depency, but it didn't work.
I suspect dependency incompatibilities, but the job log doesn't really make sense to me, because is just reports "warnings" but then fails anyway.

I would really like to get 3.11.1 going. If we can't get this PR merged soon, I'd rather take it out of the 3.11.1 milestone.

@MarcoGorelli
Copy link
Contributor

in blackjax jax is pinned very narrowly (I think Remi said it changes very frequently), does that need doing here too?

@MarcoGorelli
Copy link
Contributor

the job log doesn't really make sense to me, because is just reports "warnings" but then fails anyway.

which job are you referring to? Just had a look at the top one and it shows

>       from numpyro.infer import MCMC, NUTS
E       ModuleNotFoundError: No module named 'numpyro'

pymc3/sampling_jax.py:128: ModuleNotFoundError

@michaelosthege
Copy link
Member

michaelosthege commented Feb 11, 2021

I looked at the Windows job, because that one was ❌ and I did not notice the other the others were still ⏳.

Adding jax to requirements-dev okay. But adding numpyro too? I'm really sceptical about that. Call me conservative, but I think sampling_jax would be better suited as an add-on package.

@twiecki
Copy link
Member

twiecki commented Feb 11, 2021

@michaelosthege I get your point and I do agree. However, just by the nature of how important the JAX stuff is to our future I'd like us to not put any boundaries here. I heard from some users that could really use this as a replacement for pm.sample() and get 2x speed-up on a really complicated model.

Until the blackjax samplers are more mature or we find a better solution I would recommend we bite the bullet and add it as a dependency here.

@michaelosthege
Copy link
Member

Until the blackjax samplers are more mature or we find a better solution I would recommend we bite the bullet and add it as a dependency here.

Alternative idea: We split the Jax tests into their own CI job and install NumPyro just in that CI job. Then potential failures are at least separate from the rest.

@twiecki
Copy link
Member

twiecki commented Feb 12, 2021

Alternative idea: We split the Jax tests into their own CI job and install NumPyro just in that CI job. Then potential failures are at least separate from the rest.

This is the way.

@twiecki twiecki merged commit e46f490 into pymc-devs:master Feb 12, 2021
@twiecki
Copy link
Member

twiecki commented Feb 12, 2021

Thanks @martiningram and everyone else!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants