-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pyro integration #895
Pyro integration #895
Conversation
@vitkl, if you have a chance, it would be great if you could let me know if this code is general enough for your use case. I'm not super familiar with Pyro, but by having everything in a torch nn module (see test case), our save/load, etc of the high-level API will work automatically. The thing to note about Pytorch lightning is that you need to write device agnostic code. The trainer handles moving things to the GPU for you. This code here works on my GPU. |
Thanks @adamgayoso ! We are currently quite busy with paper revision. We (I and @yozhikoff) will try using cell2location model with you code next week (hopefully) or week thereafter. Given your comments - getting this to work will probably need some tweaking. Maybe I am missing something but where is As far as I understand pyro plates, they need to be given index from the data loader and other args ( We currently do this by providing indices as argument to the model / guide functions: def create_plates(self, x_data, idx, cell2sample, cell2covar):
return [pyro.plate("obs_axis", self.n_obs, dim=-2,
subsample_size=self.minibatch_size,
subsample=idx),
pyro.plate("var_axis", self.n_var, dim=-1),
pyro.plate("factor_axis", self.n_fact, dim=-2),
pyro.plate("experim_axis", self.n_experim, dim=-2)]
def model(self, x_data, idx, cell2sample, cell2covar):
obs_axis, var_axis, factor_axis, experim_axis = self.create_plates(x_data, idx, cell2sample, cell2covar)
with var_axis, factor_axis:
gene_loadings_fg = pyro.sample('gene_loadings_fg',
dist.Gamma(torch.ones([self.n_fact, self.n_var]),
torch.ones([self.n_fact, self.n_var])))
with var_axis:
gene_alpha_g = pyro.sample('gene_alpha_g', dist.Exponential(torch.ones([1, self.n_var])))
with experim_axis:
detection_mean = pyro.sample('detection_mean', dist.Gamma(torch.ones([self.n_experim, 1]), torch.ones([self.n_experim, 1])))
with obs_axis as ind:
cell_norm = pyro.sample('cell_norm', dist.Gamma(200, 200 / torch.mm(cell2sample[ind], self.detection_mean)))
with var_axis:
with obs_axis as ind:
self.mu_biol = torch.mm(cell2covar, self.gene_loadings_fg)
self.total_count, self.logits = _convert_mean_disp_to_counts_logits(self.mu_biol, self.gene_alpha_g,
eps=1e-8) # from scVI
self.data_target = pyro.sample('data_target',
dist.NegativeBinomial(total_count=self.total_count, logits=self.logits),
obs=x_data) # I did not manage to make pyro work with scVI NegativeBinomial distribution class However, we are still figuring out how to use pyro plates correctly and need to do more testing. Do failed checks mean that this code does not work? |
Yes, but the issue is just that I need to tell it not to use the GPU, as I tested it locally with a GPU. I'll fix that.
This is definitely opaque. What happens is that setup_anndata() adds a data registry to the anndata itself, that is then used by our data loader. So our AnnDataLoader returns a dictionary with keys (accessible via And going back to your original comment about mixing autoguides with NN encoders -- that's not something that's really in our control as much as it is a question of whether you can do that in Pyro easily. |
Codecov Report
@@ Coverage Diff @@
## master #895 +/- ##
==========================================
- Coverage 89.27% 89.23% -0.05%
==========================================
Files 73 73
Lines 5604 5647 +43
==========================================
+ Hits 5003 5039 +36
- Misses 601 608 +7
Continue to review full report at Codecov.
|
I added a simple version of scVI with Pyro. I think this is enough to show that our API makes it easy to implement models with Pyro as well! Once this is merged, we can add a skeleton in pyro to make this integration more clear! |
what's with the codacy? |
This looks very exciting. Thanks for putting this together! I will try
adding cell2location tomorrow or on Monday.
Is "ind_x" in this line an index in minibatch or is it a column in
`st_adata.obs`?
https://github.com/YosefLab/scvi-tools/blob/cda7bf82edd56d6580fb2e0f525e396f6350b124/scvi/external/stereoscope/_model.py#L165
…On Sun, 31 Jan 2021, 01:34 Romain Lopez, ***@***.***> wrote:
what's with the codacy?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#895 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFMFTV2E4JDB6FS2PIMPWATS4SXR3ANCNFSM4WYHYWNA>
.
|
With AnnotationDataLoader or otherwise, is it possible to load and use the whole dataset in each training iteration rather than mini-batches? |
it should be easy to change the pytorch data loader behavior to send the full dataset at each iteration (maximal batch size, and un-randomized sampling of indices) |
the only caveat here is it would put on cuda each time -- this is a new functionality we need to add -- full batch data loading on gpu.
@vitkl it's actually both, we make an obs column before training of this 1d array and it will get loaded by the data loader, so each minibatch we'd know which index each cell was. |
This is exactly what I am asking about - is it possible to load full data
once rather than in every batch. Will try adding cell2location tomorrow -
full data would have been faster to test (we didn't fully sort out
minibatch training in pyro yet)
…On Mon, 1 Feb 2021, 02:10 Adam Gayoso, ***@***.***> wrote:
it should be easy to change the pytorch data loader behavior to send the
full dataset at each iteration (maximal batch size, and un-randomized
sampling of indices)
the only caveat here is it would put on cuda each time -- this is a new
functionality we need to add -- full batch data loading on gpu.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#895 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFMFTVYJ72WMYEMHRXJJZTTS4YEPZANCNFSM4WYHYWNA>
.
|
I'm not really sure how these autoguides work, but it's not happy on load, as when it's initialized it has no parameters. I think it's parameters get initialized after some data passes through?
Yes, but it will be slower if using GPU, just set batch size to right size and shuffle to False. |
@galenxing @romain-lopez I think I figured it out -- if using AutoGuides, which from my understanding is like ADVI? You have to run some data through the model so the guides get their torch params and THEN you can load from the state dict. Otherwise the guides have no torch params (I believe pyro needs to trace through the model to infer the guide, and this requires data) In other words, I think we need to do the following
There are probably a few good ways we can detect the Pyro/autoguide prereq. |
That's correct. AutoGuides provide an easy way to do ADVI by using
appropriately transformed normal distributions for VI approximation.
Yes, the AutoGuides need to be passed the data and all other arguments as
to the model function. I think think is used to guess dimensions. Also a
good idea to clear param store.
…On Mon, 1 Feb 2021, 05:10 Adam Gayoso, ***@***.***> wrote:
@galenxing <https://github.com/galenxing> @romain-lopez
<https://github.com/romain-lopez> I think I figured it out -- if using
AutoGuides, which from my understanding is like ADVI? You have to run some
data through the model so the guides get their torch params and THEN you
can load from the state dict. Otherwise the guides have no torch params (I
believe pyro needs to trace through the model to infer the guide, and this
requires data)
In other words, I think we need to do the following
1. If the underlying module is Pyro and there are Autoguides, run a
models train method for 1 training step just before loading the state dict
(might need to clear param store before too?)
There are probably a few good ways we can detect the Pyro/autoguide prereq.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#895 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFMFTV5X6OZ6WKZVTZMK3QTS4YZVTANCNFSM4WYHYWNA>
.
|
@fritzo @jamestwebber |
@romain-lopez exciting! in case it helps i implemented scanvi here. also see here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, glad to see you're trying out Pyro 😄
I have only a couple warnings about sharp edges like PyroModule
requirements. Let me know if you have any questions or want to chat (Pyro slack, PyTorch slack, zoom, ...).
scvi/lightning/_trainingplans.py
Outdated
self, | ||
pyro_module: PyroBaseModuleClass, | ||
lr: float = 1e-3, | ||
loss_fn: Callable = pyro.infer.Trace_ELBO(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Warning: it's dangerous to set stateful default arguments. In this case the same default Trace_ELBO
instance will be shared by all models and (details below) bad things may happen. I'd recommend defaulting to None
and doing a standard if loss_fn is None: loss_fn = pyro.infer.Trace_ELBO()
.
Each Trace_ELBO
instance guesses and stores the number of plates in its model, and assumes it will be associated with only a single model. If you use it with a different model with a different number of plates, you might see tensor shape errors.
tests/models/test_pyro.py
Outdated
class BayesianRegression(PyroModule, PyroBaseModuleClass): | ||
def __init__(self, in_features, out_features): | ||
super().__init__() | ||
|
||
self._auto_guide = AutoDiagonalNormal(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technical detail: One of the rules of PyroModule
s is that the model and guide must be separate PyroModule
s, and cannot be contained in a single PyroModule
; however they can be contained in a single nn.Module
. You might consider refactoring to separate PyroBaseModuleClass
from some sort of PyroModelClass
like this (where s <: t
means issubclass(s, t)
):
PyroBaseModuleClass <: nn.Module
.model : PyroBaseModelClass <: PyroModule
.guide : AutoNormal <: PyroModule
Basically "a model can't have its guide as an attribute".
The reason is due to PyroModule
naming schemes and caching of pyro.sample
calls. If model
and guide
are contained in a single PyroModule
there may be weird conflicts in both names and the pyro.sample
cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is super helpful!
tests/models/test_pyro.py
Outdated
# score against actual counts | ||
pyro.sample("obs", x_dist.to_event(1), obs=x) | ||
|
||
@pyro_method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See note above: a PyroModule
model cannot own its guide; you'll need an outer nn.Module
to contain both.
Wonderful! Thanks @martinjankowiak and @fritzo! We will be in touch if we have more questions! |
also @vitkl the API changed a bit, please take a look at the bayesian regression example. The wrapper NN Module class needs to have a |
try: | ||
model.module.load_state_dict(model_state_dict) | ||
except RuntimeError as err: | ||
if isinstance(model.module, PyroBaseModuleClass): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
tests/models/test_pyro.py
Outdated
# sets a prior over the weight vector | ||
# self.linear.weight = PyroSample( | ||
# dist.Normal(self.zero, self.one) | ||
# .expand([out_features, in_features]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we keep those, uncomment, or delete?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds great!
Initial implementation of a Pyro base module class, and a Pyro training plan (pytorch lightning class). See the example in the tests file I added.