Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Field specific view_anndata_setup #1315

Merged
merged 8 commits into from
Feb 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 15 additions & 19 deletions scvi/data/anndata/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def get_state_registry(self, registry_key: str) -> attrdict:
]
)

def _view_summary_stats(self, console: rich.console.Console) -> None:
def _view_summary_stats(self) -> rich.table.Table:
"""Prints summary stats."""
t = rich.table.Table(title="Summary Statistics")
t.add_column(
Expand All @@ -277,9 +277,9 @@ def _view_summary_stats(self, console: rich.console.Console) -> None:
)
for stat_key, count in self.summary_stats.items():
t.add_row(stat_key, str(count))
console.print(t)
return t

def _view_data_registry(self, console: rich.console.Console) -> None:
def _view_data_registry(self) -> rich.table.Table:
"""Prints data registry."""
t = rich.table.Table(title="Data Registry")
t.add_column(
Expand All @@ -296,13 +296,6 @@ def _view_data_registry(self, console: rich.console.Console) -> None:
no_wrap=True,
overflow="fold",
)
t.add_column(
"State Registry Keys",
justify="center",
style="green",
no_wrap=True,
overflow="fold",
)

for registry_key, data_loc in self.data_registry.items():
attr_name = data_loc.attr_name
Expand All @@ -311,15 +304,11 @@ def _view_data_registry(self, console: rich.console.Console) -> None:
scvi_data_str = f"adata.{attr_name}"
else:
scvi_data_str = f"adata.{attr_name}['{attr_key}']"
t.add_row(registry_key, scvi_data_str)

state_registry = self.get_state_registry(registry_key)
state_registry_keys = ", ".join(state_registry.keys())

t.add_row(registry_key, scvi_data_str, state_registry_keys)

console.print(t)
return t

def view_registry(self) -> None:
def view_registry(self, hide_state_registries: bool = False) -> None:
"""Prints summary of the registry."""

version = self._registry[_constants._SCVI_VERSION_KEY]
Expand All @@ -328,5 +317,12 @@ def view_registry(self) -> None:
in_colab = "google.colab" in sys.modules
force_jupyter = None if not in_colab else True
console = rich.console.Console(force_jupyter=force_jupyter)
self._view_summary_stats(console)
self._view_data_registry(console)
console.print(self._view_summary_stats())
console.print(self._view_data_registry())

if not hide_state_registries:
for field in self.fields:
state_registry = self.get_state_registry(field.registry_key)
t = field.view_state_registry(state_registry)
if t is not None:
console.print(t)
23 changes: 20 additions & 3 deletions scvi/data/anndata/fields/_base_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pandas as pd
import rich
from anndata import AnnData

from scvi.data.anndata import _constants
Expand Down Expand Up @@ -94,8 +95,8 @@ def get_summary_stats(self, state_registry: dict) -> dict:
Parameters
----------
state_registry
Dictionary returned by `register_field`. Summary stats should always be a function
of information stored in this dictionary.
Dictionary returned by :meth:`~scvi.data.anndata.fields.BaseAnnDataField.register_field`.
Summary stats should always be a function of information stored in this dictionary.

Returns
-------
Expand All @@ -104,7 +105,23 @@ def get_summary_stats(self, state_registry: dict) -> dict:
This mapping is then combined with the mappings of other fields to make up
the summary stats mapping.
"""
return dict()

@abstractmethod
def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]:
"""
Returns a :class:`rich.table.Table` summarizing a state registry produced by this field.

Parameters
----------
state_registry
Dictionary returned by :meth:`~scvi.data.anndata.fields.BaseAnnDataField.register_field`.
Printed summary should always be a function of information stored in this dictionary.

Returns
-------
state_registry_summary
Optional :class:`rich.table.Table` summarizing the ``state_registry``.
"""

def get_field_data(self, adata: AnnData) -> Union[np.ndarray, pd.DataFrame]:
"""Returns the requested data as determined by the field for a given AnnData object."""
Expand Down
4 changes: 4 additions & 0 deletions scvi/data/anndata/fields/_layer_field.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from typing import Optional

import rich
from anndata import AnnData

from scvi.data._utils import _check_nonnegative_integers
Expand Down Expand Up @@ -88,3 +89,6 @@ def transfer_field(

def get_summary_stats(self, state_registry: dict) -> dict:
return state_registry.copy()

def view_state_registry(self, _state_registry: dict) -> Optional[rich.table.Table]:
return None
32 changes: 32 additions & 0 deletions scvi/data/anndata/fields/_obs_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional

import numpy as np
import rich
from anndata import AnnData
from pandas.api.types import CategoricalDtype

Expand Down Expand Up @@ -72,6 +73,9 @@ def transfer_field(
def get_summary_stats(self, _state_registry: dict) -> dict:
return {}

def view_state_registry(self, _state_registry: dict) -> Optional[rich.table.Table]:
return None


class CategoricalObsField(BaseObsField):
"""
Expand Down Expand Up @@ -163,3 +167,31 @@ def get_summary_stats(self, state_registry: dict) -> dict:
categorical_mapping = state_registry[self.CATEGORICAL_MAPPING_KEY]
n_categories = len(np.unique(categorical_mapping))
return {self.count_stat_key: n_categories}

def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]:
source_key = state_registry[self.ORIGINAL_ATTR_KEY]
mapping = state_registry[self.CATEGORICAL_MAPPING_KEY]
t = rich.table.Table(title=f"{self.registry_key} State Registry")
t.add_column(
"Source Location",
justify="center",
style="dodger_blue1",
no_wrap=True,
overflow="fold",
)
t.add_column(
"Categories", justify="center", style="green", no_wrap=True, overflow="fold"
)
t.add_column(
"scvi-tools Encoding",
justify="center",
style="dark_violet",
no_wrap=True,
overflow="fold",
)
for i, cat in enumerate(mapping):
if i == 0:
t.add_row("adata.obs['{}']".format(source_key), str(cat), str(i))
else:
t.add_row("", str(cat), str(i))
return t
51 changes: 51 additions & 0 deletions scvi/data/anndata/fields/_obsm_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pandas as pd
import rich
from anndata import AnnData
from pandas.api.types import CategoricalDtype

Expand Down Expand Up @@ -151,6 +152,9 @@ def get_summary_stats(self, state_registry: dict) -> dict:
n_obsm_cols = len(state_registry[self.COLUMN_NAMES_KEY])
return {self.count_stat_key: n_obsm_cols}

def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]:
return None


class JointObsField(BaseObsmField):
"""
Expand Down Expand Up @@ -234,6 +238,22 @@ def get_summary_stats(self, _state_registry: dict) -> dict:
n_obs_keys = len(self.obs_keys)
return {self.count_stat_key: n_obs_keys}

def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]:
if self.is_empty:
return None

t = rich.table.Table(title=f"{self.registry_key} State Registry")
t.add_column(
"Source Location",
justify="center",
style="dodger_blue1",
no_wrap=True,
overflow="fold",
)
for key in state_registry[self.COLUMNS_KEY]:
t.add_row("adata.obs['{}']".format(key))
return t


class CategoricalJointObsField(JointObsField):
"""
Expand Down Expand Up @@ -330,3 +350,34 @@ def get_summary_stats(self, _state_registry: dict) -> dict:
return {
self.count_stat_key: n_obs_keys,
}

def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]:
if self.is_empty:
return None

t = rich.table.Table(title=f"{self.registry_key} State Registry")
t.add_column(
"Source Location",
justify="center",
style="dodger_blue1",
no_wrap=True,
overflow="fold",
)
t.add_column(
"Categories", justify="center", style="green", no_wrap=True, overflow="fold"
)
t.add_column(
"scvi-tools Encoding",
justify="center",
style="dark_violet",
no_wrap=True,
overflow="fold",
)
for key, mappings in state_registry[self.MAPPINGS_KEY].items():
for i, mapping in enumerate(mappings):
if i == 0:
t.add_row("adata.obs['{}']".format(key), str(mapping), str(i))
else:
t.add_row("", str(mapping), str(i))
t.add_row("", "")
return t
6 changes: 4 additions & 2 deletions scvi/model/base/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,9 @@ def setup_anndata(
on a model-specific instance of :class:`~scvi.data.anndata.AnnDataManager`.
"""

def view_anndata_setup(self, adata: Optional[AnnData] = None) -> None:
def view_anndata_setup(
self, adata: Optional[AnnData] = None, hide_state_registries: bool = False
) -> None:
"""
Print summary of the setup for the initial AnnData or a given AnnData object.

Expand All @@ -558,4 +560,4 @@ def view_anndata_setup(self, adata: Optional[AnnData] = None) -> None:
f"Given AnnData not setup with {self.__class__.__name__}. "
"Cannot view setup summary."
)
adata_manager.view_registry()
adata_manager.view_registry(hide_state_registries=hide_state_registries)
2 changes: 2 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ def test_new_setup_compat():
continuous_covariate_keys=["cont1", "cont2"],
)
adata_manager = SCVI.manager_store[adata.uns[_constants._SCVI_UUID_KEY]]
model = SCVI(adata)
model.view_anndata_setup(hide_state_registries=True)

# Backwards compatibility test.
adata2_manager = manager_from_setup_dict(SCVI, adata2, LEGACY_SETUP_DICT)
Expand Down