-
Notifications
You must be signed in to change notification settings - Fork 346
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adapt all internal models to new setup (#1301)
* adapt LDA * adapt linearscvi * remove _get_var_names_from_setup_anndata * adapt peakvi * adapt autozi * adapt scanvi * fix scanvi test * fix totalvi test * fix dataloader tests * fix multiple cov tests * adapt condscvi * adapt destvi * adapt multivi * fix setup compat test * remove get_from_registry util * fix scanvi and peakvi scarches tests * fix backwards compat tests and default missing summary stat in models * address comment * Adapt all external models to new setup (#1302) * adapt cellassign * adapt gimvi * adapt solo model * adapt stereoscope
- Loading branch information
Showing
35 changed files
with
864 additions
and
654 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from typing import Optional, Union | ||
|
||
import numpy as np | ||
from anndata import AnnData | ||
from pandas.api.types import CategoricalDtype | ||
|
||
from scvi.data.anndata._utils import _make_obs_column_categorical | ||
|
||
from ._obs_field import CategoricalObsField | ||
|
||
|
||
class LabelsWithUnlabeledObsField(CategoricalObsField): | ||
""" | ||
An AnnDataField for labels which include explicitly unlabeled cells. | ||
Remaps the unlabeled category to the final index if present in labels. | ||
The unlabeled category is a specific category name specified by the user. | ||
Parameters | ||
---------- | ||
registry_key | ||
Key to register field under in data registry. | ||
obs_key | ||
Key to access the field in the AnnData obs mapping. If None, defaults to `registry_key`. | ||
unlabeled_category | ||
Value assigned to unlabeled cells. | ||
""" | ||
|
||
UNLABELED_CATEGORY = "unlabeled_category" | ||
WAS_REMAPPED = "was_remapped" | ||
|
||
def __init__( | ||
self, | ||
registry_key: str, | ||
obs_key: Optional[str], | ||
unlabeled_category: Union[str, int, float], | ||
) -> None: | ||
super().__init__(registry_key, obs_key) | ||
self._unlabeled_category = unlabeled_category | ||
|
||
def _remap_unlabeled_to_final_category( | ||
self, adata: AnnData, mapping: np.ndarray | ||
) -> dict: | ||
labels = self._get_original_column(adata) | ||
|
||
if self._unlabeled_category in labels: | ||
unlabeled_idx = np.where(mapping == self._unlabeled_category) | ||
unlabeled_idx = unlabeled_idx[0][0] | ||
# move unlabeled category to be the last position | ||
mapping[unlabeled_idx], mapping[-1] = mapping[-1], mapping[unlabeled_idx] | ||
cat_dtype = CategoricalDtype(categories=mapping, ordered=True) | ||
# rerun setup for the batch column | ||
mapping = _make_obs_column_categorical( | ||
adata, | ||
self._original_attr_key, | ||
self.attr_key, | ||
categorical_dtype=cat_dtype, | ||
return_mapping=True, | ||
) | ||
remapped = True | ||
else: | ||
remapped = False | ||
|
||
return { | ||
self.CATEGORICAL_MAPPING_KEY: mapping, | ||
self.ORIGINAL_ATTR_KEY: self._original_attr_key, | ||
self.UNLABELED_CATEGORY: self._unlabeled_category, | ||
self.WAS_REMAPPED: remapped, | ||
} | ||
|
||
def register_field(self, adata: AnnData) -> dict: | ||
if self.is_default: | ||
self._setup_default_attr(adata) | ||
|
||
state_registry = super().register_field(adata) | ||
mapping = state_registry[self.CATEGORICAL_MAPPING_KEY] | ||
return self._remap_unlabeled_to_final_category(adata, mapping) | ||
|
||
def transfer_field( | ||
self, | ||
state_registry: dict, | ||
adata_target: AnnData, | ||
extend_categories: bool = False, | ||
**kwargs, | ||
) -> dict: | ||
transfer_state_registry = super().transfer_field( | ||
state_registry, adata_target, extend_categories=extend_categories, **kwargs | ||
) | ||
mapping = transfer_state_registry[self.CATEGORICAL_MAPPING_KEY] | ||
return self._remap_unlabeled_to_final_category(adata_target, mapping) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.