Skip to content

Commit

Permalink
added some fixes based on custom data loader test
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-kron-wis committed Aug 1, 2024
1 parent 17282cd commit a4143f5
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 139 deletions.
14 changes: 10 additions & 4 deletions src/scvi/model/base/_archesmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def load_query_data(
validate_single_device=True,
)

attr_dict, var_names, load_state_dict = _get_loaded_data(reference_model, device=device)
attr_dict, var_names, load_state_dict = _get_loaded_data(
reference_model, device=device, adata=adata
)

if adata is not None:
if isinstance(adata, MuData):
Expand Down Expand Up @@ -216,7 +218,7 @@ def prepare_query_anndata(
Query adata ready to use in `load_query_data` unless `return_reference_var_names`
in which case a pd.Index of reference var names is returned.
"""
_, var_names, _ = _get_loaded_data(reference_model, device="cpu")
_, var_names, _ = _get_loaded_data(reference_model, device="cpu", adata=adata)
var_names = pd.Index(var_names)

if return_reference_var_names:
Expand Down Expand Up @@ -364,15 +366,19 @@ def requires_grad(key):
par.requires_grad = False


def _get_loaded_data(reference_model, device=None):
def _get_loaded_data(reference_model, device=None, adata=None):
if isinstance(reference_model, str):
attr_dict, var_names, load_state_dict, _ = _load_saved_files(
reference_model, load_adata=False, map_location=device
)
else:
attr_dict = reference_model._get_user_attributes()
attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"}
var_names = _get_var_names(reference_model.adata)
var_names = (
_get_var_names(reference_model.adata)
if attr_dict["registry_"]["setup_method_name"] != "setup_datamodule"
else _get_var_names(adata)

This comment has been minimized.

Copy link
@canergen

canergen Aug 5, 2024

Member

We should take var_names out of registry.

)
load_state_dict = deepcopy(reference_model.module.state_dict())

return attr_dict, var_names, load_state_dict
Expand Down
8 changes: 8 additions & 0 deletions tests/dataloaders/test_custom_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
from pprint import pprint

import numpy as np
import scanpy as sc
Expand Down Expand Up @@ -41,6 +42,11 @@
# Loading the model (just as a compariosn)
model_orig_loaded = scvi.model.SCVI.load(model_dir, adata=adata)

# when loading from disk
scvi.model.SCVI.prepare_query_anndata(adata, model_dir)
# O
scvi.model.SCVI.prepare_query_anndata(adata, model_orig_loaded)

# Obtaining model outputs
SCVI_LATENT_KEY = "X_scVI"
latent = model_orig.get_latent_representation()
Expand All @@ -53,6 +59,8 @@
# adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict()
adata_manager.registry[_constants._FIELD_REGISTRIES_KEY]

pprint(adata_manager.registry)

# Plot UMAP and save the figure for later check
sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi")
sc.tl.umap(adata, neighbors_key="scvi")
Expand Down
Loading

0 comments on commit a4143f5

Please sign in to comment.