Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conjugate relationship marginalisation #358

Open
3 tasks
theorashid opened this issue Jul 5, 2024 · 4 comments
Open
3 tasks

Conjugate relationship marginalisation #358

theorashid opened this issue Jul 5, 2024 · 4 comments

Comments

@theorashid
Copy link
Contributor

cc: @larryshamalama

Implement conjugate relationships in PyMC via rewrites. Saves us working them out by hand.

For starters:

  • Normal-Normal
  • Beta-Binomial
  • Gamma-Poisson

nimble has a list of possibilities that we can add.

@ricardoV94
Copy link
Member

What is the idea of conjugacy? I'm familiar with conjugate priors, but those are not the same as marginalization? Instead they provide closed form solutions for posteriors?

Triple ? means I may be missing the point :)

@jessegrabowski
Copy link
Member

This issue made me think of this paper: https://arxiv.org/abs/2302.00564

But maybe this is thinking about fully conjugate models, not intermediate relationships?

@theorashid
Copy link
Contributor Author

theorashid commented Jul 8, 2024

I was thinking we would follow nimble's conjugate (Gibbs) samplers (probably a better name for the issue). So any prior with a sampling (dependent) node (their language) which is conjugate can be rewritten.

But also we could have rewrites to take advantage of the properties of normals or exponential distributions. e.g. (borrowed from an old chat with numpyro devs)

  • If X ~ Normal(0, np.eye(10)) and y ~ Normal(0, 1) then X + y ~ MVN(0, np.eye(10) + 1)
  • If X ~ Normal(0, np.eye(10)) then cumsum(X) ~ MVN(0, scale_tril=np.tril(np.ones(10, 10)))

PS. Just as an example, because I appreciate not many people around here have ever used nimble. Here's an example of the default MCMC config for a simple BUGS model.

pumpCode <- nimbleCode({ 
  # Define relationships between nodes
  for (i in 1:N){
      theta[i] ~ dgamma(alpha,beta)
      lambda[i] <- theta[i]*t[i]
      x[i] ~ dpois(lambda[i])
  }
  # Set priors
  alpha ~ dexp(1.0)
  beta ~ dgamma(0.1,1.0)
})
...
pumpMCMC <- buildMCMC(pumpModel)
## ===== Monitors =====
## thin = 1: alpha, beta
## ===== Samplers =====
## RW sampler (1)
##   - alpha
## conjugate sampler (11)
##   - beta
##   - theta[]  (10 elements)

So it uses conjugate relationships where possible (theta: gamma -> poisson, beta: gamma -> gamma), and everything else with the default non-conjugate RW sampler.

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 8, 2024

Okay that clarifies it. I just hadn't heard of conjugacy as marginalization. I am not sure how we should handle this, my best guess when thinking about this in the past was to define a conjugate step sampler that can take draws from (an arbitrary) posterior distribution that has closed form solution.

Perhaps the easiest is a helper find_conjugate_steps that would return the specialized step samplers and could be passed to pm.sample. API would look something like:

with pm.Model() as m:
  ...
  conjugate_steps = pmx.find_conjugate_steps()
  idata = pm.sample(step=conjugate_steps)

That would have a natural fallback when conjugate steps can't be found, and users can also exclude some if they don't like it.

Otherwise we would need to re-sketch the step sampler assignment logic that exists in PyMC, as that eagerly defines a variable to belong to a sampler if that variable is of a certain type (or if the model logp can be differentiated wrt to it), and doesn't really have a nice place for reasoning about the whole model (I could be wrong here)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants