diff --git a/scvi/model/base/_archesmixin.py b/scvi/model/base/_archesmixin.py index 8ba780ed98..92446f45af 100644 --- a/scvi/model/base/_archesmixin.py +++ b/scvi/model/base/_archesmixin.py @@ -195,7 +195,7 @@ def prepare_query_anndata( Query adata ready to use in `load_query_data` unless `return_reference_var_names` in which case a pd.Index of reference var names is returned. """ - _, var_names, _ = _get_loaded_data(reference_model) + _, var_names, _ = _get_loaded_data(reference_model, device="cpu") var_names = pd.Index(var_names) if return_reference_var_names: