Skip to content

Commit

Permalink
update _cell_transition_online
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Oct 9, 2024
1 parent eb431e2 commit 8f70c1e
Showing 1 changed file with 112 additions and 125 deletions.
237 changes: 112 additions & 125 deletions src/moscot/base/problems/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import types
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Iterable,
Literal,
Expand Down Expand Up @@ -175,6 +177,38 @@ def _cell_transition(
)
return tm

def _annotation_aggregation_transition(
self: AnalysisMixinProtocol[K, B],
annotations_1: list[Any],
annotations_2: list[Any],
df: pd.DataFrame,
func: Callable[..., ArrayLike],
) -> pd.DataFrame:
n1 = len(annotations_1)
n2 = len(annotations_2)
tm_arr = np.zeros((n1, n2))

# Factorize annotations in df_res_annotation
codes, uniques = pd.factorize(df.values)
# Map annotations in 'annotations_2' to indices in 'uniques'
annotations_in_df_to_idx = {annotation: idx for idx, annotation in enumerate(uniques)}
annotations_2_codes = [annotations_in_df_to_idx.get(annotation, -1) for annotation in annotations_2]

for i, subset in enumerate(annotations_1):
result = func(
subset=subset,
)
# Compute sums over 'codes' weighted by 'result'
sums = np.bincount(codes, weights=result.squeeze(), minlength=len(uniques))
dist = [sums[code] if code != -1 else 0 for code in annotations_2_codes]
tm_arr[i, :] = dist

return pd.DataFrame(
tm_arr,
index=annotations_1,
columns=annotations_2,
)

def _cell_transition_online(
self: AnalysisMixinProtocol[K, B],
key: Optional[str],
Expand Down Expand Up @@ -208,9 +242,9 @@ def _cell_transition_online(
key if other_adata is None else other_key,
target,
)
df_source = df_source.rename(columns={source_annotation_key: "res_annotation"})
df_target = df_target.rename(columns={target_annotation_key: "res_annotation"})
res_annotation_key = "res_annotation"
df_source = df_source.rename(columns={source_annotation_key: res_annotation_key})
df_target = df_target.rename(columns={target_annotation_key: res_annotation_key})
source_annotations_verified, target_annotations_verified = _validate_annotations(
df_source=df_source,
df_target=df_target,
Expand All @@ -221,67 +255,46 @@ def _cell_transition_online(
aggregation_mode=aggregation_mode,
forward=forward,
)

move_op = self.push if forward else self.pull
move_op_const_kwargs = {
"source": source,
"target": target,
"normalize": True,
"return_all": False,
"scale_by_marginals": False,
"key_added": None,
}
if aggregation_mode == "annotation":
df_target["distribution"] = 0
df_source["distribution"] = 0
tm = pd.DataFrame(
np.zeros((len(source_annotations_verified), len(target_annotations_verified))),
index=source_annotations_verified,
columns=target_annotations_verified,
func = partial(
move_op,
data=source_annotation_key if forward else target_annotation_key,
split_mass=False,
**move_op_const_kwargs,
)
if forward:
tm = self._annotation_aggregation_transition( # type: ignore[attr-defined]
source=source,
target=target,
annotation_key=source_annotation_key,
annotations_1=source_annotations_verified,
annotations_2=target_annotations_verified,
df=df_target,
df_key=res_annotation_key,
tm=tm,
forward=True,
)
else:
tm = self._annotation_aggregation_transition( # type: ignore[attr-defined]
source=source,
target=target,
annotation_key=target_annotation_key,
annotations_1=target_annotations_verified,
annotations_2=source_annotations_verified,
df=df_source,
df_key=res_annotation_key,
tm=tm,
forward=False,
)
df = (df_target if forward else df_source)[res_annotation_key]
tm = self._annotation_aggregation_transition( # type: ignore[attr-defined]
annotations_1=source_annotations_verified if forward else target_annotations_verified,
annotations_2=target_annotations_verified if forward else source_annotations_verified,
df=df,
func=func,
)

elif aggregation_mode == "cell":
tm = pd.DataFrame(columns=target_annotations_verified if forward else source_annotations_verified)
if forward:
tm = self._cell_aggregation_transition( # type: ignore[attr-defined]
source=source,
target=target,
annotation_key=res_annotation_key,
annotations_1=source_annotations_verified,
annotations_2=target_annotations_verified,
df_1=df_target,
df_2=df_source,
tm=tm,
batch_size=batch_size,
forward=True,
)
else:
tm = self._cell_aggregation_transition( # type: ignore[attr-defined]
source=source,
target=target,
annotation_key=res_annotation_key,
annotations_1=target_annotations_verified,
annotations_2=source_annotations_verified,
df_1=df_source,
df_2=df_target,
tm=tm,
batch_size=batch_size,
forward=False,
)
func = partial(
move_op,
data=None,
split_mass=True,
**move_op_const_kwargs,
)
tm = self._cell_aggregation_transition( # type: ignore[attr-defined]
df_from=df_source if forward else df_target,
df_to=df_target if forward else df_source,
annotation_key=res_annotation_key,
annotations=target_annotations_verified if forward else source_annotations_verified,
batch_size=batch_size,
func=func,
)

else:
raise NotImplementedError(f"Aggregation mode `{aggregation_mode!r}` is not yet implemented.")

Expand Down Expand Up @@ -480,77 +493,51 @@ def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key
tmp[mask] = np.squeeze(v)
return tmp

def _annotation_aggregation_transition(
self: AnalysisMixinProtocol[K, B],
source: K,
target: K,
annotation_key: str,
annotations_1: list[Any],
annotations_2: list[Any],
df_key: str,
df: pd.DataFrame,
tm: pd.DataFrame,
forward: bool,
) -> pd.DataFrame:
if not forward:
tm = tm.T
func = self.push if forward else self.pull
for subset in annotations_1:
result = func( # TODO(@MUCDK) check how to make compatible with all policies
source=source,
target=target,
data=annotation_key,
subset=subset,
normalize=True,
return_all=False,
scale_by_marginals=False,
split_mass=False,
key_added=None,
)
df["distribution"] = result
cell_dist = df[df[df_key].isin(annotations_2)].groupby(df_key, observed=False).sum(numeric_only=True)
cell_dist /= cell_dist.sum()
tm.loc[subset, :] = [
cell_dist.loc[annotation, "distribution"] if annotation in cell_dist.distribution.index else 0
for annotation in annotations_2
]
return tm

def _cell_aggregation_transition(
self: AnalysisMixinProtocol[K, B],
source: str,
target: str,
annotation_key: str,
# TODO(MUCDK): unused variables, del below
annotations_1: list[Any],
annotations_2: list[Any],
df_1: pd.DataFrame,
df_2: pd.DataFrame,
tm: pd.DataFrame,
df_from: pd.DataFrame,
df_to: pd.DataFrame,
annotations: list[Any],
batch_size: Optional[int],
forward: bool,
func: Callable[..., ArrayLike],
) -> pd.DataFrame:
func = self.push if forward else self.pull

# Factorize annotations in df_to
annotations_in_df_to = df_to[annotation_key].values
codes_to, uniques_to = pd.factorize(annotations_in_df_to)
# Map annotations in 'annotations' to codes
annotations_to_code = {annotation: idx for idx, annotation in enumerate(uniques_to)}
annotations_codes = [annotations_to_code.get(annotation, -1) for annotation in annotations]
n_annotations = len(annotations)
n_from_cells = len(df_from)

if batch_size is None:
batch_size = len(df_2)
for batch in range(0, len(df_2), batch_size):
result = func( # TODO(@MUCDK) check how to make compatible with all policies
source=source,
target=target,
data=None,
subset=(batch, batch_size),
normalize=True,
return_all=False,
scale_by_marginals=False,
split_mass=True,
key_added=None,
)
current_cells = df_2.iloc[range(batch, min(batch + batch_size, len(df_2)))].index.tolist()
df_1.loc[:, current_cells] = result
to_app = df_1[df_1[annotation_key].isin(annotations_2)].groupby(annotation_key).sum().transpose()
tm = pd.concat([tm, to_app], verify_integrity=True, axis=0)
df_1 = df_1.drop(current_cells, axis=1)
return tm
batch_size = n_from_cells

tm_arr = np.zeros((n_from_cells, n_annotations))
index = df_from.index

# Process in batches
for batch_start in range(0, n_from_cells, batch_size):
batch_end = min(batch_start + batch_size, n_from_cells)
subset = (batch_start, batch_end - batch_start)
result = func(subset=subset)
# Result shape: (n_to_cells, batch_size)
# For each cell in the batch, we compute the sum over annotations
for i in range(batch_end - batch_start):
cell_distribution = result[:, i]
# Aggregate over annotations using bincount
sums = np.bincount(
codes_to,
weights=cell_distribution,
minlength=len(uniques_to),
)
# Map sums to annotations_verified_codes
dist = [sums[code] if code != -1 else 0 for code in annotations_codes]
tm_arr[batch_start + i, :] = dist

return pd.DataFrame(tm_arr, index=index, columns=annotations)

# adapted from:
# https://github.com/theislab/cellrank/blob/master/cellrank/_utils/_utils.py#L392
Expand Down

0 comments on commit 8f70c1e

Please sign in to comment.