Skip to content

Commit

Permalink
Merge pull request #1097 from YosefLab/scarches_tweaks
Browse files Browse the repository at this point in the history
scarches: deepcopy input model in memory, proper version check
  • Loading branch information
adamgayoso authored Jul 15, 2021
2 parents c1d14d0 + e55a905 commit ba22825
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions scvi/model/base/_archesmixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import warnings
from copy import deepcopy
from typing import Optional, Union

import torch
Expand Down Expand Up @@ -82,14 +83,15 @@ def load_query_data(
attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"}
scvi_setup_dict = attr_dict.pop("scvi_setup_dict_")
var_names = reference_model.adata.var_names
load_state_dict = reference_model.module.state_dict().copy()
load_state_dict = deepcopy(reference_model.module.state_dict())

if inplace_subset_query_vars:
logger.debug("Subsetting query vars to reference vars.")
adata._inplace_subset_var(var_names)
_validate_var_names(adata, var_names)

if scvi_setup_dict["scvi_version"] < "0.8":
version_split = scvi_setup_dict["scvi_version"].split(".")
if version_split[1] < "8" and version_split[0] == "0":
warnings.warn(
"Query integration should be performed using models trained with version >= 0.8"
)
Expand Down

0 comments on commit ba22825

Please sign in to comment.