Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom dataloader registry support #2932

Open
wants to merge 60 commits into
base: main
Choose a base branch
from

Conversation

ori-kron-wis
Copy link
Collaborator

No description provided.

@ori-kron-wis ori-kron-wis added this to the scvi-tools 1.2 milestone Aug 7, 2024
@ori-kron-wis ori-kron-wis self-assigned this Aug 7, 2024
@ori-kron-wis ori-kron-wis linked an issue Aug 7, 2024 that may be closed by this pull request
Copy link

codecov bot commented Aug 11, 2024

Codecov Report

Attention: Patch coverage is 53.70370% with 100 lines in your changes missing coverage. Please review.

Project coverage is 83.87%. Comparing base (6bb8d8c) to head (47376ca).

Files with missing lines Patch % Lines
src/scvi/model/base/_base_model.py 41.17% 70 Missing ⚠️
src/scvi/model/_scvi.py 52.94% 16 Missing ⚠️
src/scvi/model/base/_archesmixin.py 75.00% 8 Missing ⚠️
src/scvi/model/_scanvi.py 77.77% 4 Missing ⚠️
src/scvi/model/base/_save_load.py 75.00% 1 Missing ⚠️
src/scvi/model/base/_training_mixin.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2932      +/-   ##
==========================================
- Coverage   84.81%   83.87%   -0.95%     
==========================================
  Files         173      173              
  Lines       14793    14927     +134     
==========================================
- Hits        12547    12520      -27     
- Misses       2246     2407     +161     
Files with missing lines Coverage Δ
src/scvi/data/_utils.py 86.12% <100.00%> (+0.58%) ⬆️
src/scvi/external/stereoscope/_model.py 92.40% <ø> (ø)
src/scvi/external/stereoscope/_module.py 96.33% <ø> (ø)
src/scvi/model/_amortizedlda.py 94.11% <ø> (ø)
src/scvi/model/_autozi.py 95.40% <ø> (ø)
src/scvi/model/_condscvi.py 95.74% <ø> (ø)
src/scvi/model/_jaxscvi.py 92.30% <ø> (ø)
src/scvi/model/_linear_scvi.py 94.87% <ø> (ø)
src/scvi/model/_multivi.py 72.26% <ø> (ø)
src/scvi/model/_peakvi.py 87.09% <ø> (ø)
... and 7 more

... and 2 files with indirect coverage changes

@@ -232,6 +229,112 @@ def setup_anndata(
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

@staticmethod
def _get_summary_stats_from_registry(registry: dict) -> attrdict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in base or somewhere else. It doesn't make sense in scVI.


@classmethod
@setup_anndata_dsp.dedent
def setup_datamodule(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be defined for each model. Can we reduce the amount of duplicate code by adding something similar to setup_anndata?

@setup_anndata_dsp.dedent
def setup_datamodule(
cls,
datamodule, # TODO: what to put here?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be pytorch.DataLoader right? Martin has done typing for it in the current code.

"state_registry": {
"n_obs": datamodule.n_obs,
"n_vars": datamodule.n_vars,
"column_names": [str(i) for i in column_names], # TODO: from adata (czi)?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not following?

_validate_var_names(adata[modality], var_names[modality])
logger.debug("Subsetting query vars to reference vars.")
adata._inplace_subset_var(var_names)
_validate_var_names(adata, var_names)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to verify for dataloaders that the gene names are matching.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the order.

logger.debug("Subsetting query vars to reference vars.")
adata._inplace_subset_var(var_names)
_validate_var_names(adata, var_names)
registry = attr_dict.pop("registry_")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check and the ones below are independent of datamodule or AnnData, right? Remove the indent.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why every code here is displayed as modified?

@@ -202,7 +215,7 @@ def prepare_query_anndata(
Query adata ready to use in `load_query_data` unless `return_reference_var_names`
in which case a pd.Index of reference var names is returned.
"""
_, var_names, _ = _get_loaded_data(reference_model, device="cpu")
_, var_names, _ = _get_loaded_data(reference_model, device="cpu", adata=adata)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with a dataloader?

@@ -350,15 +363,15 @@ def requires_grad(key):
par.requires_grad = False


def _get_loaded_data(reference_model, device=None):
def _get_loaded_data(reference_model, device=None, adata=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need adata here?

self.registry_ = registry
self.summary_stats = _get_summary_stats_from_registry(registry)
elif self.__class__.__name__ == "GIMVI":
# note some models do accept empty registry/adata (e.g: gimvi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following this one. What is the exception with GIMVI?

else:
return self._adata_manager.get_from_registry(registry_key)

# def get_from_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this?

else:
# Case where correct AnnDataManager is found, replay registration as necessary.
adata_manager.validate()

return adata

def transfer_fields(self, adata: AnnOrMuData, **kwargs) -> AnnData:
"""Transfer fields from a model to an AnnData object."""
if self.adata:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we need transfer_fields? can we make it work with datamodule?

@@ -627,8 +711,7 @@ def save(

# save the model state dict and the trainer state dict only
model_state_dict = self.module.state_dict()

var_names = _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format)
var_names = self.get_var_names(legacy_mudata_format=legacy_mudata_format)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need two get_var_names function?

"Saved model does not contain original setup inputs. "
"Cannot load the original setup."
)
_validate_var_names(adata, var_names)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be validated also for a dataloader.


def get_state_registry(self, registry_key: str) -> attrdict:
"""Returns the state registry for the AnnDataField registered with this instance."""
return attrdict(self.registry_[_FIELD_REGISTRIES_KEY][registry_key][_STATE_REGISTRY_KEY])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with dataloader. Documentation should be updated then.

@@ -133,7 +133,10 @@ def _initialize_model(cls, adata, attr_dict):
if "pretrained_model" in non_kwargs.keys():
non_kwargs.pop("pretrained_model")

model = cls(adata, **non_kwargs, **kwargs)
if not adata:
adata = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is adata false here? Do we need a default value for registry?

if max_epochs is None:
if datamodule is None:
if self.adata is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should take here n_obs from summary stats to make it compatible with a dataloader.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below we don't need the if statement.

experiment_name = "mus_musculus"
obs_value_filter = 'is_primary_data == True and tissue_general in ["kidney"] and nnz >= 3000'

# This is under comments just to save time (selecting highly varkable genes):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this block, we don't need it.

dataloader_kwargs={"num_workers": 0, "persistent_workers": False},
)

# table of genes should be filtered by soma_joinid - but we should keep the encoded indexes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to.

model = scvi.model.SCVI(adata_orig, n_latent=10)
model.train(max_epochs=1)

# TODO: do we need to apply those functions to any census model as is?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not getting it.

_ = model.get_reconstruction_error(dataloader=dataloader)
_ = model.get_latent_representation(dataloader=dataloader)

scvi.model.SCVI.prepare_query_anndata(adata_orig, reference_model=model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test this with a second model trained using dataloader

n_layers = 1
n_latent = 50

scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this part. It's nice.


pprint(datamodule.registry)

batch_size = 1024
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does batch size have an effect. I thought it's defined by the datamodule?

# _ = model_census2.get_latent_representation()

# takes time
adata = cellxgene_census.get_anndata(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just download 10 cells - see below with obs_value_filter.

var_coords=hv_idx,
)

# TODO: do we need to put inside (or is it alrady pre-made) - perhaps need to tell CZI
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we need to make it.

adata.obs["batch"] = adata.obs[batch_keys].agg("".join, axis=1).astype("category")

scvi.model.SCVI.prepare_query_anndata(adata, save_path)
scvi.model.SCVI.load_query_data(registry=datamodule.registry, reference_model=save_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have more tests that actually fail - using different genes without prepare_query_anndata and different batch categories. Assert that it fails.


scvi.model.SCVI.prepare_query_anndata(adata, model_census2)

scvi.model.SCVI.setup_anndata(adata, batch_key="batch") # needed?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking that an AnnData model can be trained using datamodule. Do we really want it?


user_attributes_model_census3 = model_census3._get_user_attributes()
pprint(user_attributes_model_census3)
_ = model_census3.get_elbo()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uses AnnData for inference?

scvi.model.SCVI.prepare_query_anndata(adata, model_census3)
scvi.model.SCVI.load_query_data(adata, model_census3)

datamodule_inference = CensusSCVIDataModule(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check here that using different genes and different batches fails. You can take much fewer cells here, like 1000.

# Create a dataloder of a CZI module
datapipe = datamodule_inference.datapipe
dataloader = experiment_dataloader(datapipe, num_workers=0, persistent_workers=False)
mapped_dataloader = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this?


model = SCVI(adata, n_latent=n_latent)
model.train(max_epochs=1)
dataloader = model._make_data_loader(adata)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does model._make_data_loader exist for all models? We should then add the test to the other models as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the dataloader sufficient to also setup the model and does setup_datamodule work for it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix custom dataloader registry
2 participants