From 8f70c1ebe959ea69339c6d47f7ecf3f45a624147 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 9 Oct 2024 17:35:34 +0200 Subject: [PATCH] update _cell_transition_online --- src/moscot/base/problems/_mixins.py | 237 +++++++++++++--------------- 1 file changed, 112 insertions(+), 125 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index c87cb373..059a112b 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -1,9 +1,11 @@ from __future__ import annotations import types +from functools import partial from typing import ( TYPE_CHECKING, Any, + Callable, Generic, Iterable, Literal, @@ -175,6 +177,38 @@ def _cell_transition( ) return tm + def _annotation_aggregation_transition( + self: AnalysisMixinProtocol[K, B], + annotations_1: list[Any], + annotations_2: list[Any], + df: pd.DataFrame, + func: Callable[..., ArrayLike], + ) -> pd.DataFrame: + n1 = len(annotations_1) + n2 = len(annotations_2) + tm_arr = np.zeros((n1, n2)) + + # Factorize annotations in df_res_annotation + codes, uniques = pd.factorize(df.values) + # Map annotations in 'annotations_2' to indices in 'uniques' + annotations_in_df_to_idx = {annotation: idx for idx, annotation in enumerate(uniques)} + annotations_2_codes = [annotations_in_df_to_idx.get(annotation, -1) for annotation in annotations_2] + + for i, subset in enumerate(annotations_1): + result = func( + subset=subset, + ) + # Compute sums over 'codes' weighted by 'result' + sums = np.bincount(codes, weights=result.squeeze(), minlength=len(uniques)) + dist = [sums[code] if code != -1 else 0 for code in annotations_2_codes] + tm_arr[i, :] = dist + + return pd.DataFrame( + tm_arr, + index=annotations_1, + columns=annotations_2, + ) + def _cell_transition_online( self: AnalysisMixinProtocol[K, B], key: Optional[str], @@ -208,9 +242,9 @@ def _cell_transition_online( key if other_adata is None else other_key, target, ) - df_source = df_source.rename(columns={source_annotation_key: "res_annotation"}) - df_target = df_target.rename(columns={target_annotation_key: "res_annotation"}) res_annotation_key = "res_annotation" + df_source = df_source.rename(columns={source_annotation_key: res_annotation_key}) + df_target = df_target.rename(columns={target_annotation_key: res_annotation_key}) source_annotations_verified, target_annotations_verified = _validate_annotations( df_source=df_source, df_target=df_target, @@ -221,67 +255,46 @@ def _cell_transition_online( aggregation_mode=aggregation_mode, forward=forward, ) - + move_op = self.push if forward else self.pull + move_op_const_kwargs = { + "source": source, + "target": target, + "normalize": True, + "return_all": False, + "scale_by_marginals": False, + "key_added": None, + } if aggregation_mode == "annotation": - df_target["distribution"] = 0 - df_source["distribution"] = 0 - tm = pd.DataFrame( - np.zeros((len(source_annotations_verified), len(target_annotations_verified))), - index=source_annotations_verified, - columns=target_annotations_verified, + func = partial( + move_op, + data=source_annotation_key if forward else target_annotation_key, + split_mass=False, + **move_op_const_kwargs, ) - if forward: - tm = self._annotation_aggregation_transition( # type: ignore[attr-defined] - source=source, - target=target, - annotation_key=source_annotation_key, - annotations_1=source_annotations_verified, - annotations_2=target_annotations_verified, - df=df_target, - df_key=res_annotation_key, - tm=tm, - forward=True, - ) - else: - tm = self._annotation_aggregation_transition( # type: ignore[attr-defined] - source=source, - target=target, - annotation_key=target_annotation_key, - annotations_1=target_annotations_verified, - annotations_2=source_annotations_verified, - df=df_source, - df_key=res_annotation_key, - tm=tm, - forward=False, - ) + df = (df_target if forward else df_source)[res_annotation_key] + tm = self._annotation_aggregation_transition( # type: ignore[attr-defined] + annotations_1=source_annotations_verified if forward else target_annotations_verified, + annotations_2=target_annotations_verified if forward else source_annotations_verified, + df=df, + func=func, + ) + elif aggregation_mode == "cell": - tm = pd.DataFrame(columns=target_annotations_verified if forward else source_annotations_verified) - if forward: - tm = self._cell_aggregation_transition( # type: ignore[attr-defined] - source=source, - target=target, - annotation_key=res_annotation_key, - annotations_1=source_annotations_verified, - annotations_2=target_annotations_verified, - df_1=df_target, - df_2=df_source, - tm=tm, - batch_size=batch_size, - forward=True, - ) - else: - tm = self._cell_aggregation_transition( # type: ignore[attr-defined] - source=source, - target=target, - annotation_key=res_annotation_key, - annotations_1=target_annotations_verified, - annotations_2=source_annotations_verified, - df_1=df_source, - df_2=df_target, - tm=tm, - batch_size=batch_size, - forward=False, - ) + func = partial( + move_op, + data=None, + split_mass=True, + **move_op_const_kwargs, + ) + tm = self._cell_aggregation_transition( # type: ignore[attr-defined] + df_from=df_source if forward else df_target, + df_to=df_target if forward else df_source, + annotation_key=res_annotation_key, + annotations=target_annotations_verified if forward else source_annotations_verified, + batch_size=batch_size, + func=func, + ) + else: raise NotImplementedError(f"Aggregation mode `{aggregation_mode!r}` is not yet implemented.") @@ -480,77 +493,51 @@ def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key tmp[mask] = np.squeeze(v) return tmp - def _annotation_aggregation_transition( - self: AnalysisMixinProtocol[K, B], - source: K, - target: K, - annotation_key: str, - annotations_1: list[Any], - annotations_2: list[Any], - df_key: str, - df: pd.DataFrame, - tm: pd.DataFrame, - forward: bool, - ) -> pd.DataFrame: - if not forward: - tm = tm.T - func = self.push if forward else self.pull - for subset in annotations_1: - result = func( # TODO(@MUCDK) check how to make compatible with all policies - source=source, - target=target, - data=annotation_key, - subset=subset, - normalize=True, - return_all=False, - scale_by_marginals=False, - split_mass=False, - key_added=None, - ) - df["distribution"] = result - cell_dist = df[df[df_key].isin(annotations_2)].groupby(df_key, observed=False).sum(numeric_only=True) - cell_dist /= cell_dist.sum() - tm.loc[subset, :] = [ - cell_dist.loc[annotation, "distribution"] if annotation in cell_dist.distribution.index else 0 - for annotation in annotations_2 - ] - return tm - def _cell_aggregation_transition( self: AnalysisMixinProtocol[K, B], - source: str, - target: str, annotation_key: str, - # TODO(MUCDK): unused variables, del below - annotations_1: list[Any], - annotations_2: list[Any], - df_1: pd.DataFrame, - df_2: pd.DataFrame, - tm: pd.DataFrame, + df_from: pd.DataFrame, + df_to: pd.DataFrame, + annotations: list[Any], batch_size: Optional[int], - forward: bool, + func: Callable[..., ArrayLike], ) -> pd.DataFrame: - func = self.push if forward else self.pull + + # Factorize annotations in df_to + annotations_in_df_to = df_to[annotation_key].values + codes_to, uniques_to = pd.factorize(annotations_in_df_to) + # Map annotations in 'annotations' to codes + annotations_to_code = {annotation: idx for idx, annotation in enumerate(uniques_to)} + annotations_codes = [annotations_to_code.get(annotation, -1) for annotation in annotations] + n_annotations = len(annotations) + n_from_cells = len(df_from) + if batch_size is None: - batch_size = len(df_2) - for batch in range(0, len(df_2), batch_size): - result = func( # TODO(@MUCDK) check how to make compatible with all policies - source=source, - target=target, - data=None, - subset=(batch, batch_size), - normalize=True, - return_all=False, - scale_by_marginals=False, - split_mass=True, - key_added=None, - ) - current_cells = df_2.iloc[range(batch, min(batch + batch_size, len(df_2)))].index.tolist() - df_1.loc[:, current_cells] = result - to_app = df_1[df_1[annotation_key].isin(annotations_2)].groupby(annotation_key).sum().transpose() - tm = pd.concat([tm, to_app], verify_integrity=True, axis=0) - df_1 = df_1.drop(current_cells, axis=1) - return tm + batch_size = n_from_cells + + tm_arr = np.zeros((n_from_cells, n_annotations)) + index = df_from.index + + # Process in batches + for batch_start in range(0, n_from_cells, batch_size): + batch_end = min(batch_start + batch_size, n_from_cells) + subset = (batch_start, batch_end - batch_start) + result = func(subset=subset) + # Result shape: (n_to_cells, batch_size) + # For each cell in the batch, we compute the sum over annotations + for i in range(batch_end - batch_start): + cell_distribution = result[:, i] + # Aggregate over annotations using bincount + sums = np.bincount( + codes_to, + weights=cell_distribution, + minlength=len(uniques_to), + ) + # Map sums to annotations_verified_codes + dist = [sums[code] if code != -1 else 0 for code in annotations_codes] + tm_arr[batch_start + i, :] = dist + + return pd.DataFrame(tm_arr, index=index, columns=annotations) # adapted from: # https://github.com/theislab/cellrank/blob/master/cellrank/_utils/_utils.py#L392