Skip to content

Commit

Permalink
fixed ruff B028
Browse files Browse the repository at this point in the history
  • Loading branch information
alitinet committed Jul 19, 2024
1 parent d8b13c5 commit c25355b
Show file tree
Hide file tree
Showing 16 changed files with 760 additions and 652 deletions.
12 changes: 5 additions & 7 deletions src/multimil/data/_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import warnings
from typing import List, Optional, Union

import anndata as ad
import numpy as np
import pandas as pd


def organize_multiome_anndatas(
adatas: List[List[Union[ad.AnnData, None]]],
layers: Optional[List[List[Union[str, None]]]] = None,
adatas: list[list[ad.AnnData | None]],
layers: list[list[str | None]] | None = None,
):
"""Concatenate all the input anndata objects.
Expand All @@ -26,7 +25,6 @@ def organize_multiome_anndatas(
# TODO: add checks for layers
# TODO: add check that len of modalities is the same as len of losses, etc


# needed for scArches operation setup
datasets_lengths = {}
datasets_obs_names = {}
Expand All @@ -45,8 +43,8 @@ def organize_multiome_anndatas(
stacklevel=2,
)
# check that all adatas in the same modality have the same number of features
if (mod_length := modality_lengths.get(f'{mod}', None)) is None:
modality_lengths[f'{mod}'] = adata.shape[1]
if (mod_length := modality_lengths.get(f"{mod}", None)) is None:
modality_lengths[f"{mod}"] = adata.shape[1]
else:
if adata.shape[1] != mod_length:
raise ValueError(
Expand Down Expand Up @@ -78,7 +76,7 @@ def organize_multiome_anndatas(
for mod, modality_adatas in enumerate(adatas):
for i, adata in enumerate(modality_adatas):
if not isinstance(adata, ad.AnnData) and adata is None:
X_zeros = np.zeros((datasets_lengths[i], modality_lengths[f'{mod}']))
X_zeros = np.zeros((datasets_lengths[i], modality_lengths[f"{mod}"]))
adatas[mod][i] = ad.AnnData(X_zeros, dtype=X_zeros.dtype)
adatas[mod][i].obs_names = datasets_obs_names[i]
adatas[mod][i].var_names = modality_var_names[mod]
Expand Down
9 changes: 4 additions & 5 deletions src/multimil/dataloaders/_ann_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import itertools
from typing import Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -34,7 +33,7 @@ def __init__(
batch_size: int,
min_size_per_class: int,
shuffle: bool = True,
drop_last: Union[bool, int] = True,
drop_last: bool | int = True,
shuffle_classes: bool = True,
):
if drop_last > batch_size:
Expand Down Expand Up @@ -168,9 +167,9 @@ def __init__(
indices=None,
batch_size=128,
min_size_per_class=None,
data_and_attributes: Optional[dict] = None,
drop_last: Union[bool, int] = True,
sampler: Optional[torch.utils.data.sampler.Sampler] = StratifiedSampler,
data_and_attributes: dict | None = None,
drop_last: bool | int = True,
sampler: torch.utils.data.sampler.Sampler | None = StratifiedSampler,
**data_loader_kwargs,
):
if adata_manager.adata is None:
Expand Down
6 changes: 2 additions & 4 deletions src/multimil/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Optional

from scvi.data import AnnDataManager
from scvi.dataloaders import DataSplitter

from ..dataloaders._ann_dataloader import GroupAnnDataLoader
from multimil.dataloaders._ann_dataloader import GroupAnnDataLoader


# adjusted from scvi-tools
Expand Down Expand Up @@ -33,7 +31,7 @@ def __init__(
adata_manager: AnnDataManager,
group_column: str,
train_size: float = 0.9,
validation_size: Optional[float] = None,
validation_size: float | None = None,
use_gpu: bool = False,
**kwargs,
):
Expand Down
4 changes: 1 addition & 3 deletions src/multimil/distributions/_mmd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List, Optional

import torch


Expand All @@ -21,7 +19,7 @@ def gaussian_kernel(
self,
x: torch.Tensor,
y: torch.Tensor,
gamma: Optional[List[float]] = None,
gamma: list[float] | None = None,
) -> torch.Tensor:
"""Apply Guassian kernel.
Expand Down
3 changes: 1 addition & 2 deletions src/multimil/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from ._multivae import MultiVAE
from ._mil import MILClassifier
from ._multivae import MultiVAE
from ._multivae_mil import MultiVAE_MIL


__all__ = ["MultiVAE", "MILClassifier", "MultiVAE_MIL"]
Loading

0 comments on commit c25355b

Please sign in to comment.