-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring to support Structured VI #2416
Conversation
pymc3/variational/opvi.py
Outdated
class GroupApproximation(object): | ||
""" | ||
Grouped Approximation that is used for modelling mutual dependencies | ||
fro a specified group of variables. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: fro->for
API change proposal №1: local variables are held in separate group. So that you do not pass a On low level(high level API can be less verbose) it will roughly look like approx = Approximation([
MeanField([local_rv], # group of local variables
shapes=dict([(local_rv, (s, 12))]), # where `s` is symbolic
params=dict(mu=mu, rho=rho)), # where mu, rho have valid shape: (flexible, latent_ndim)
# in this case flexible shape is `s`, and latent ndim is 12
MeanField(-1) # delayed init with `-1` instead of group. Will capture all the rest variables
]) As a bonus the following construction will be possible approx = Approximation([
FullRank([local_rv], # group of local variables
shapes=dict([(local_rv, (s, 12))]), # where `s` is symbolic
params=dict(mu=mu, L_tril=L_tril)), # where mu, L_tril have valid shape: (flexible, param_standard_shape)
# in this case flexible shape is `s`, and latent ndim is 12, so
# mu.shape = (s, 12)
# L_tril.shape = (s, 12 * (12 + 1) / 2)
MeanField(-1) # delayed init with `-1` instead of group. Will capture all the rest variables
]) Holding backward compatibility for the first iterations will make future work super hard. API change is still under discussion, feel free to suggest ideas and criticize solutions. The main Idea behind is handling local_rv as a separate group with it's flexible shape that depends on dynamic shape. Now I try to do it with 3d tensors where 2d dim holds for flexible shape. I hope to implement broadcastable Approximations so that one can parametrize any distribution for local variables. Even in this setup it's all hard enough. |
Personally I feel myself free to change API for 3.2 as we have released 3.1 but still a bit worried about regular users. They might not like so rapid changes in API. For top level API like |
I also invite @fonnesbeck for discussion |
How is it compare to the current API? The global and local vars are pass together to the Approximation? |
I think reasonable way to go is to restrict so called
So we'll have no way to capture interactions between local and global variables as well as local with local. This is exactly what we have now but much more flexible. Here one can parametrize local variables with flows of FullRank or separate groups with more complex distributions rather than MeanField. BTW I see it too hard to implement convenient API that allows more than one variable in local group. |
pymc3/variational/opvi.py
Outdated
self._mid_size = size | ||
self.vmap = dict() | ||
self.ndim = 0 | ||
self._global_ = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What’s the trailing underscore for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use this to mention that it should not be changed even internally after init
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Another option would be to use all caps (e.g. self.GLOBAL
). This is common for fixed variables, such as constants in Python. Not sure what PEP8 has to say about that, but the Google style guide uses this convention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all caps is more about class attributes, not instance specific
I still need to properly read these papers. Could there not be default groupings based on, say, hierarchical structure in the model? Whether a drastic API change is worth it depends on what structured VI gives us beyond what can be achieved with normalizing flows. The copula VI paper shows a 4x slowdown compared to mean field; is this comparable to what we are getting with NFVI? I haven’t seen these implemented anywhere, which seems odd. Would have thought Edward would have them. Any idea why not? |
NFVI is not scalable in terms of model size. Ofc there is a slowdown but there is a tradeoff. We can vary flow length and thus computational cost. I want to make this tradeoff more flexible and require flows or FullRank when needed. |
Remark: NFVI is the most scalable approach we currently have for complex models but we can make it better |
We can ask @dustinvtran for some opinions. |
I find implementation not easy. I need careful architecture choice first. |
Interested to hear from @dustinvtran too |
I agree. If two local variables should interact with each other, these variables can be merged when inputted to the network for posterior parameter estimation. In my opinion, you can first implement for "OK" case. Then, if required, other two cases would be considered. |
In practice, the tradeoff for any algorithm is its perceived benefit vs its complexity of implementation. I often use mean-field and structured approaches (i.e., low rank Gaussians) as simple baselines. They're fast and work well on many problems. Copula VI is difficult to implement if the library doesn't already have extensive support for copula distributions like R does, but it's also robust. I haven't seen normalizing flows, operator VI, and many others used for problems besides training deep generative models. This is because they're quite difficult to train in practice, which means they're generally restricted to experts who are knowledgeable to get them to work.
Edward supports structured approximations. E.g., you just specify a Cholesky-parameterized normal distribution as your approximating family. If TensorFlow had support for copula distributions, that would also just work.™ |
What is the connectivity for low rank (structured) approximations and householder flows? |
I'm still a bit stuck with architecture choice. I want to make more assumptions about local vars. They are the following
How do feel about it? |
I am still going through the paper... |
@ferrine Does the second choice means that the first dimension is the minibatch size? I don't understand the first choice. |
That's right. First one is about user should have only one subsampling dimension there. |
I don't come up with any ideas when "one flexible dim" is required. I think the 1st dim should always be mini-batch size for latent variables. |
I feel like we speak about the same things. 1st dim is for minibatches and no other. |
It is flexible because it depends on batch size |
@ferrine I dont think it is a good idea to rename it as |
@aseyboldt Could you please review my last commit? I've decided to use dynamic class dispatching to use single implementation for |
@junpenglao I see your point and had the same concerns. I'm thinking about better name for Group. Maybe {Group|Local|Global}Approx? |
I feel like it is a good direction but I still have bad design. After calling new I'll then have Local or Global group but not approximation instance. Thus, calling init will give invalid results |
I have 2 solutions:
# metaclass' call
def __call__(cls, *args, **kwargs):
is_local = kwargs.get('local', False)
if is_local:
_cls = Local
else:
_cls = Global
instance = object.__new__(_cls)
cls.__init__(instance, *args, **kwargs)
return instance Which is better? |
I think that first one will be better |
pymc3/math.py
Outdated
|
||
def make_node(self, *matrices): | ||
if not matrices: | ||
raise ValueError('Got no matrices to allocate') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just “No matrices to allocate” is more appropriate.
def get_transformed(z): | ||
if hasattr(z, 'transformed'): | ||
z = z.transformed | ||
return z |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need newline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Do we need an example or a notebook on how approximation groups are used? It won’t be entirely clear to users. |
I yet do not have a good example for that. It can be done after special features like gumbel softmax |
@taku-y do you have any further comment? Otherwise, this is ready to merge! |
@ferrine @junpenglao I'm sorry for not responding. The code and documentation looks good to me. Though I didn't run the notebooks, I believe it's fine to merge. |
Thanks @ferrine, this is a massive piece of work and adds amazing new functionality. Thanks also to all the reviewers! |
try: | ||
approx = Approximation([cls([three_var_model.one], batched=True, **kw), Group(None, vfam='mf')]) | ||
inference = pm.KLqp(approx) | ||
approx = inference.fit(3, obj_n_mc=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see typo here. batched
->rowwise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR?
Some readings:
Discussions: