Skip to content

Commit

Permalink
Pass manager to data loader (#1280)
Browse files Browse the repository at this point in the history
* pass manager to data loader

* address comment
  • Loading branch information
justjhong authored Nov 30, 2021
1 parent 5b79915 commit a8e4720
Show file tree
Hide file tree
Showing 11 changed files with 118 additions and 78 deletions.
2 changes: 1 addition & 1 deletion scvi/data/anndata/_compat.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from anndata import AnnData

from . import _constants
from ._manager import AnnDataManager
from .fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
)
from .manager import AnnDataManager


def manager_from_setup_dict(
Expand Down
4 changes: 2 additions & 2 deletions scvi/data/anndata/fields/_layer_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
else _constants._ADATA_ATTRS.LAYERS
)
self._attr_key = layer
self._is_count_data = is_count_data
self.is_count_data = is_count_data

@property
def registry_key(self):
Expand All @@ -56,7 +56,7 @@ def validate_field(self, adata: AnnData) -> None:
super().validate_field(adata)
x = self.get_field(adata)

if self._is_count_data and not _check_nonnegative_integers(x):
if self.is_count_data and not _check_nonnegative_integers(x):
logger_data_loc = (
"adata.X" if self.attr_key is None else f"adata.layers[{self.attr_key}]"
)
Expand Down
10 changes: 3 additions & 7 deletions scvi/data/anndata/_manager.py → scvi/data/anndata/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def add_field(self, field: Type[BaseAnnDataField]) -> None:
), "Fields have been frozen. Create a new AnnDataManager object for additional fields."
self.fields.add(field)

def _register_fields(
def register_fields(
self,
adata: AnnData,
source_setup_dict: Optional[dict] = None,
**transfer_kwargs
):
"""
Helper function with registers each field associated with this instance.
Registers each field associated with this instance with the AnnData object.
Either registers or transfers the setup from `source_setup_dict` if passed in.
Expand Down Expand Up @@ -114,10 +114,6 @@ def _register_fields(

self._assign_uuid()

def register_fields(self, adata: AnnData):
"""Registers each field associated with this instance with the AnnData object."""
return self._register_fields(adata)

def transfer_setup(
self, adata_target: AnnData, source_setup_dict: Optional[dict] = None, **kwargs
) -> AnnDataManager:
Expand All @@ -143,7 +139,7 @@ def transfer_setup(
)
fields = self.fields
new_adata_manager = self.__class__(fields)
new_adata_manager._register_fields(adata_target, setup_dict, **kwargs)
new_adata_manager.register_fields(adata_target, setup_dict, **kwargs)
return new_adata_manager

def get_adata_uuid(self) -> UUID:
Expand Down
17 changes: 10 additions & 7 deletions scvi/dataloaders/_ann_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import logging
from typing import Optional, Union

import anndata
import numpy as np
import torch
from torch.utils.data import DataLoader

from scvi.data.anndata.manager import AnnDataManager

from ._anntorchdataset import AnnTorchDataset

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -91,8 +92,8 @@ class AnnDataLoader(DataLoader):
Parameters
----------
adata
An anndata objects
adata_manager
AnnDataManager object that has been created via setup_anndata.
shuffle
Whether the data should be shuffled
indices
Expand All @@ -109,7 +110,7 @@ class AnnDataLoader(DataLoader):

def __init__(
self,
adata: anndata.AnnData,
adata_manager: AnnDataManager,
shuffle=False,
indices=None,
batch_size=128,
Expand All @@ -118,11 +119,11 @@ def __init__(
**data_loader_kwargs,
):

if "_scvi" not in adata.uns.keys():
if adata_manager.adata is None:
raise ValueError("Please run setup_anndata() on your anndata object first.")

if data_and_attributes is not None:
data_registry = adata.uns["_scvi"]["data_registry"]
data_registry = adata_manager.get_data_registry()
for key in data_and_attributes.keys():
if key not in data_registry.keys():
raise ValueError(
Expand All @@ -131,7 +132,9 @@ def __init__(
)
)

self.dataset = AnnTorchDataset(adata, getitem_tensors=data_and_attributes)
self.dataset = AnnTorchDataset(
adata_manager.adata, getitem_tensors=data_and_attributes
)

sampler_kwargs = {
"batch_size": batch_size,
Expand Down
1 change: 0 additions & 1 deletion scvi/dataloaders/_anntorchdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def setup_getitem(self):
----------
getitem_tensors:
Either a list of keys in the scvi data registry to return when getitem is called
or
Examples
--------
Expand Down
11 changes: 6 additions & 5 deletions scvi/dataloaders/_concat_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import List, Optional, Union

import numpy as np
from anndata import AnnData
from torch.utils.data import DataLoader

from scvi.data.anndata.manager import AnnDataManager

from ._ann_dataloader import AnnDataLoader


Expand All @@ -14,8 +15,8 @@ class ConcatDataLoader(DataLoader):
Parameters
----------
adata
AnnData object that has been registered via setup_anndata.
adata_manager
AnnDataManager object that has been created via setup_anndata.
indices_list
List where each element is a list of indices in the adata to load
shuffle
Expand All @@ -32,7 +33,7 @@ class ConcatDataLoader(DataLoader):

def __init__(
self,
adata: AnnData,
adata_manager: AnnDataManager,
indices_list: List[List[int]],
shuffle: bool = False,
batch_size: int = 128,
Expand All @@ -44,7 +45,7 @@ def __init__(
for indices in indices_list:
self.dataloaders.append(
AnnDataLoader(
adata,
adata_manager,
indices=indices,
shuffle=shuffle,
batch_size=batch_size,
Expand Down
Loading

0 comments on commit a8e4720

Please sign in to comment.