Skip to content

Commit

Permalink
Fixed attr_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
canergen committed Jul 31, 2024
1 parent 14f343d commit 17282cd
Showing 1 changed file with 37 additions and 29 deletions.
66 changes: 37 additions & 29 deletions src/scvi/model/base/_archesmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
import torch
from anndata import AnnData
from lightning import LightningDataModule
from mudata import MuData
from scipy.sparse import csr_matrix

Expand Down Expand Up @@ -39,8 +40,9 @@ class ArchesMixin:
@devices_dsp.dedent
def load_query_data(
cls,
adata: AnnOrMuData,
reference_model: Union[str, BaseModelClass],
adata: None | AnnOrMuData = None,
reference_model: Union[str, BaseModelClass] = None,
datamodule: None | LightningDataModule = None,
inplace_subset_query_vars: bool = False,
accelerator: str = "auto",
device: Union[int, str] = "auto",
Expand Down Expand Up @@ -83,6 +85,11 @@ def load_query_data(
freeze_classifier
Whether to freeze classifier completely. Only applies to `SCANVI`.
"""
if reference_model is None:
raise ValueError("Please provide a reference model as string or loaded model.")
if adata is None and datamodule is None:
raise ValueError("Please provide either an AnnData or a datamodule.")

_, _, device = parse_device_args(
accelerator=accelerator,
devices=device,
Expand All @@ -92,44 +99,45 @@ def load_query_data(

attr_dict, var_names, load_state_dict = _get_loaded_data(reference_model, device=device)

if isinstance(adata, MuData):
for modality in adata.mod:
if adata is not None:
if isinstance(adata, MuData):
for modality in adata.mod:
if inplace_subset_query_vars:
logger.debug(f"Subsetting {modality} query vars to reference vars.")
adata[modality]._inplace_subset_var(var_names[modality])
_validate_var_names(adata[modality], var_names[modality])

else:
if inplace_subset_query_vars:
logger.debug(f"Subsetting {modality} query vars to reference vars.")
adata[modality]._inplace_subset_var(var_names[modality])
_validate_var_names(adata[modality], var_names[modality])
logger.debug("Subsetting query vars to reference vars.")
adata._inplace_subset_var(var_names)
_validate_var_names(adata, var_names)

else:
if inplace_subset_query_vars:
logger.debug("Subsetting query vars to reference vars.")
adata._inplace_subset_var(var_names)
_validate_var_names(adata, var_names)

if inplace_subset_query_vars:
logger.debug("Subsetting query vars to reference vars.")
adata._inplace_subset_var(var_names)
_validate_var_names(adata, var_names)
registry = attr_dict.pop("registry_")
if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__:
raise ValueError("It appears you are loading a model from a different class.")

registry = attr_dict.pop("registry_")
if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__:
raise ValueError("It appears you are loading a model from a different class.")
if _SETUP_ARGS_KEY not in registry:
raise ValueError(
"Saved model does not contain original setup inputs. "
"Cannot load the original setup."
)

if _SETUP_ARGS_KEY not in registry:
raise ValueError(
"Saved model does not contain original setup inputs. "
"Cannot load the original setup."
setup_method = getattr(cls, registry[_SETUP_METHOD_NAME])
setup_method(
adata,
source_registry=registry,
extend_categories=True,
allow_missing_labels=True,
**registry[_SETUP_ARGS_KEY],
)

setup_method = getattr(cls, registry[_SETUP_METHOD_NAME])
setup_method(
adata,
source_registry=registry,
extend_categories=True,
allow_missing_labels=True,
**registry[_SETUP_ARGS_KEY],
)

model = _initialize_model(cls, adata, attr_dict)
model = _initialize_model(cls, adata, datamodule, attr_dict)
adata_manager = model.get_anndata_manager(adata, required=True)

if REGISTRY_KEYS.CAT_COVS_KEY in adata_manager.data_registry:
Expand Down

0 comments on commit 17282cd

Please sign in to comment.