-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor old Distribution base class #5308
Comments
Originally posted by @brandonwillard in #5169 (comment) |
Do we have a spec in mind for a replacement class? Maybe an explicit list of what needs to be removed and added with respect to the current class would be a good place to start. |
I will update the issue tomorrow with those |
|
Updated the top post to mention every function of |
|
The question of RNGs with default updates exists anyway, but could arguably be offloaded to Line 968 in 75ea2a8
|
Yeah, I like a functional approach there, and I agree |
What is the selling point of it? First of all, if we create no instances of any PyMC distribution. Why do they exist at all? We distinguish between variables that are
Pure Aesara RV Then PyMC distributions have a much richer API & behavior compared to Aesara So we need to wrap the Neither
So we could move away from Then every distribution needs to get a Then we'd have to dispatch
This once bit me really hard when I tried to demo Aesara to someone. If Aesara made the non-deterministic |
Can't we just check if the first variable is a string and react accordingly? |
We already do that anyways, just via the cryptic MetaDistribution. Structurally, what difference does it make if the logp and logcdf are fake methods inside a fake classes or real functions one after the other? Does this seem so bad? More importantly users should not be calling these pseudo methods themselves because they expect as inputs the already parsed and symbolic canonical parameters of the distribution. For some (many) distributions these have nothing to do with what they would pass into |
Posting this incase it helps... I came across an issue like this one on discourse while trying to use v3 code in v4. Trying to do some of the plots in BCB the authors use FreeRV.distribution.all_trees[..].predict_output() in creating insightful plots. In v4 the error is 'TensorVariable' object has no attribute 'distribution'. I tried to solve but it seems in v3, model.py used def var() and returned a <>RV pymc object where distribution=dist. However in v4 we have def register_rv() which returns an aesara tensor variable? Are we breaking this functionality in v4? Sorry if it is something obvious I am missing. |
@mitch-at-orika yes, there is a breaking change w.r.t. what's returned. In most cases that's not a problem, but it looks like @aloctavodia I believe you can link to relevant |
Thanks for the explanation Michael, it is reassuring it was a BART only option, I originally thought I had just missed this functionality of pymc vars until now. |
Here is some pseudo-code that might suffice? from functools import partial, wraps
import pytensor.tensor.random.basic as ptr
from pymc.distributions.continuous import get_tau_sigma
from pymc.pytensorf import convert_observed_data
from pymc.distributions.shape_utils import convert_dims, shape_from_dims
from pymc.model import modelcontext
from pymc.util import UNSET
def handle_shape(ndim_supp=None):
"""Convert the shape argument to size used by PyTensor."""
def inner_decorator(dist):
@wraps(dist)
def inner_func(*args, size=None, shape=None, **kwargs):
if shape is not None and size is not None:
raise ValueError("Cannot pass both size and shape")
if shape is not None:
# If needed, call dist without size to find out ndim_supp
local_ndim_supp = dist(*args).owner.op.ndim_supp if ndim_supp is None else ndim_supp
size = shape if local_ndim_supp == 0 else shape[:-local_ndim_supp]
return dist(*args, size=size, **kwargs)
return inner_func
return inner_decorator
def register_model_rv(dist, rv_type=None):
"""Register a random variable in a model context."""
@wraps(dist)
def inner_func(name, *args, dims=None, transform=UNSET, observed=None, model=None, **kwargs):
if dims is not None:
dims = convert_dims(dims)
if observed is not None:
observed = convert_observed_data(observed)
# The shape of the variable is determined from the following sources:
# size or shape, otherwise dims, otherwise observed.
if kwargs.get("size") is None and kwargs.get("shape") is None:
if dims is not None:
kwargs["shape"] = shape_from_dims(dims, model)
elif observed is not None:
kwargs["shape"] = tuple(observed.shape)
rv = dist(*args, **kwargs)
model = modelcontext(model)
return model.register_rv(rv, name=name, dims=dims, transform=transform, observed=observed)
# Monkey-patch useful attributes
if rv_type is not None:
inner_func.rv_type = rv_type
inner_func.dist = dist
return inner_func
@handle_shape(ndim_supp=0)
def normal_dist(mu=0, sigma=None, tau=None, **kwargs):
_, sigma = get_tau_sigma(sigma=sigma, tau=tau)
return ptr.normal(mu, sigma, **kwargs)
Normal = register_model_rv(normal_dist, rv_type=ptr.NormalRV) This also make writing distribution helpers simpler. Right now we have to define a redundant pymc/pymc/distributions/mixture.py Lines 551 to 560 in 7bb2ccd
Instead this could be done like this: from pymc.distributions.mixture import Mixture
def normal_mixture_dist(w, mu, sigma=None, tau=None, **kwargs):
return Mixture.dist(w, Normal.dist(mu, sigma=sigma, tau=tau), **kwargs)
NormalMixture = register_model_rv(normal_mixture_dist) |
@ricardoV94 I think I understand this, but just to be certain: You have written a function that constructs the correct pytensor TensorVariable, and then you have a wrapper class that associates that variable with whatever model context manager this is created within. Yes, I believe this should work. Here is an example of how you would type |
I think this introduces another potential problem, however. This new |
They are being monkey-patched here: # Monkey-patch useful attributes
if rv_type is not None:
inner_func.rv_type = rv_type
inner_func.dist = dist
return inner_func I am not sure that's the best approach, but the current fake classes also seem odd. Maybe what's done by There is no other method that should be attached to the |
Here is a non-fake class that does the same: import pytensor.tensor.random.basic as ptr
from pymc.distributions.continuous import get_tau_sigma
from pymc.pytensorf import convert_observed_data
from pymc.distributions.shape_utils import convert_dims, shape_from_dims
from pymc.model import modelcontext
from pymc.util import UNSET
class Distribution:
rv_type = None
rv_op = None
@classmethod
def dist(cls, *args, size=None, shape=None, **kwargs):
if shape is not None and size is not None:
raise ValueError("Cannot pass both size and shape")
if shape is not None:
ndim_supp = getattr(cls.rv_type, "ndim_supp", None)
if ndim_supp is None:
# If needed, call dist without size to find out ndim_supp
ndim_supp = dist(*args).owner.op.ndim_supp
size = shape if ndim_supp == 0 else shape[:-ndim_supp]
return cls.rv_op(*args, size=size, **kwargs)
def __call__(self, name, *args, dims = None, transform = UNSET, observed = None, model = None, ** kwargs):
if dims is not None:
dims = convert_dims(dims)
if observed is not None:
observed = convert_observed_data(observed)
# The shape of the variable is determined from the sources:
# size or shape, otherwise dims, otherwise observed.
if kwargs.get("size") is None and kwargs.get("shape") is None:
if dims is not None:
kwargs["shape"] = shape_from_dims(dims, model)
elif observed is not None:
kwargs["shape"] = tuple(observed.shape)
rv = self.dist(*args, **kwargs)
model = modelcontext(model)
return model.register_rv(rv, name=name, dims=dims, transform=transform, observed=observed)
class NormalDist(Distribution):
rv_type = ptr.NormalRV
@staticmethod
def rv_op(mu=0, sigma=None, tau=None, **kwargs):
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
return ptr.normal(mu, sigma, **kwargs)
class NormalMixtureDist(Distribution):
# If we subclass from a refactord `Mixture`, this `rv_type` would be obtained automatically
rv_type = Mixture.rv_type
@staticmethod
def rv_op(w, mu, sigma=None, tau=None, **kwargs):
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
return Mixture.dist(w, Normal.dist(mu=mu, sigma=sigma), **kwargs)
Normal = NormalDist()
NormalMixture = NormalMixtureDist() The sole point of it is that it provides a |
I believe we desire two interfaces per distribution that require similar (but not identical) signatures: For the case of the normal distribution, and only using a minimal number of variables to represent the difference: def normal_dist(mu, sigma, **kwargs):
...
def Normal(name, mu, sigma, dims, observed, **kwargs):
... I do not believe there is a way to programmatically produce a static function signature for In this case The only solution that I think makes sense is to type the signature out in both cases, using mechanisms like |
Yes that's correct |
Ah yes, that makes sense! Thank you for clarifying these things for me! |
I took a stab at rewriting the Normal class as a function ( |
Looks good. However, the whole code inside Also how bad would be to monkey_patch |
These are good points! I haven't quite worked out a decorator approach, but I figured out a way to keep the current API while refactoring things to a "function-based" approach. The key part is |
You need to consider observed size before you create the dist so that Anyway the downside of your last approach is you need to manually duplicate the signature of the dist and call. Also it seems not much better to initialize |
Assuming that our goal is to have a function signature for the I've refactored the gist into a small repo, just to be able to spread the functions and classes into separate modules. |
The idea of the decorator was to wrap dist, so you only have to define one of the objects (dist in this case). We don't need classes if the only reason for them is so that Normal.dist exists. I rather just monkey patch tbh. There's nothing "class"ey about them otherwise. They don't hold state or do anything. |
Right, but then what changed in your previous attempt? We moved Not saying it's worse, just double checking what we're trading off |
Right, I'm interpreting your question as something like "We wanted to move from a class-based approach to a function-based approach, but now we have moved from one class approach to another one, how is this better?". I think that's a very valid question. Just to summarize, the goals I see with this github issue are:
So, in pymc today, Using that metaclass approach, we have no way to statically (by that I mean "in a manner in which the type checker can infer without running the code") give our Distribution named arguments (e.g. Using only functions ( My solution solves the problem with monkey-patching the function approach by statically defining the The code flow is also completely linear: The Distribution-class approach is nonlinear because it defines |
PyMC distribution classes are weird objects that hold RandomVariables, logp, logcdf and moment methods together (basically doing runtime dispatching) and manage most of the non-RandomVariable kwargs that users are familiar with (observed, transformed, size/dims) and behind the scenes actions like registration in the model.
This exists mostly for backwards compatibility with V3 and ease of developer refactoring, but the current result is far from pretty.
We need to figure out a more elegant/permanent architecture now that many things that existed to accommodate V3 limitations no longer hold.
Distribution
Distribution
is currently performing the following tasks:pymc/pymc/distributions/distribution.py
Line 135 in 75ea2a8
FutureWarnings
fortestval
kwargtau
->sigma
). This is done by the.dist
methods.logp
,logcdf
,random
methods.dist()
API to create an unnamed RV that is not registered in the model. This type of variables is necessary for use in Potentials and other distribution factories that use RVs as building blocks such as Bound and Censored distributions, as well as Mixtures and Timeseries once they get refactored for V4DistributionMeta
In addition we have a
DistributionMeta
that does the following:pymc/pymc/distributions/distribution.py
Line 70 in 75ea2a8
logp
,logcdf
,moment
,default_transform
methods defined in the old PyMC distributions to apply to the respectiverv_op
rv_op
type as subclass of the old style PyMC distribution, so that V3 Discrete/Continuous subclass checks still work?If we want to get rid of
Distribution
we probably need to statically dispatch our methods to the respectiverv_op
. That is nothing special, and is how we do it for aeppl from the get go: https://github.com/aesara-devs/aeppl/blob/38d0c2ea4ecf8505f85317047089ab9999d2f78e/aeppl/logprob.py#L104-L130The text was updated successfully, but these errors were encountered: