Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

spatial AlignmentProblem for spatial omics slides with different number of cells shuffles cells randomly #801

Open
hspitzer opened this issue Feb 19, 2025 · 1 comment

Comments

@hspitzer
Copy link

Hi,

I've been experimenting with this really nice package for the last few days, for aligning spatial transcriptomics and spatial proteomics slides. First within modality, but I'm also keen to experiment with multimodal alignment. I have come across a potential bug, that might be related to #798 where the cells seem to be completely shuffled after alignment.

This only occurs when the shape of the different slides that I am aligning are not the same. In addition, I get an Error about incompatible shapes when trying to align slides with the "star" policy. I believe these two things might be related, so I am putting both of them in the same issue, but happy to split them up if more convenient.

I would be thankful for any pointers or help on this!

Here a reproducible example (adapted from the spatial alignment example:

from moscot import datasets
from moscot.problems.space import AlignmentProblem
import scanpy as sc
import squidpy as sq

adata = datasets.sim_align()
# add x coordinate to obs to visualise it later on
adata.obs['x'] = adata.obsm['spatial'][:, 0]
# subsample adata to get slightly different number of cells for every batch
# if we don't do this, the errors don't ocurr and everything behaves normally
sc.pp.subsample(adata, fraction=0.99)
display(adata.obs.groupby('batch').count())
sq.pl.spatial_scatter(adata, shape=None, library_id="batch", color=["batch", 'x'])

This results in the following image:

Image

ap = AlignmentProblem(adata=adata)
ap = ap.prepare(batch_key="batch", policy="sequential")
ap = ap.solve()
ap.align(key_added="spatial_warp",  reference='0')
sq.pl.spatial_scatter(
    adata, shape=None, spatial_key="spatial_warp", library_key='batch', color=["batch", 'x']
)

Now, after the alignment all x coordinates are completely shuffled, even though the points look like they are aligned nicely

Image

Secondly, here the error that ocurrs when trying the "star" policy

ap = AlignmentProblem(adata=adata)
ap = ap.prepare(batch_key="batch", policy="star", reference='0')
ap = ap.solve()
ap.align(key_added="spatial_warp",  reference='0')
error message
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/Users/hannah.spitzer/projects/collab_lipidomics/bugreport_moscot.ipynb Cell 11 line 1
----> 1 ap.align(key_added="spatial_warp",  reference='0')
      2 sq.pl.spatial_scatter(
      3     adata, shape=None, spatial_key="spatial_warp", library_key='batch', color=["batch", 'x']
      4 )

File ~/miniconda3/envs/analysis310/lib/python3.10/site-packages/moscot/problems/space/_mixins.py:138, in SpatialAlignmentMixin.align(self, reference, mode, spatial_key, key_added)
    135 if spatial_key is None:
    136     spatial_key = self.spatial_key
--> 138 aligned_maps, aligned_metadata = self._interpolate_scheme(
    139     reference=reference, mode=mode, spatial_key=spatial_key  # type: ignore[arg-type]
    140 )
    141 aligned_basis = np.vstack([aligned_maps[k] for k in self._policy._cat])
    143 if key_added is None:

File ~/miniconda3/envs/analysis310/lib/python3.10/site-packages/moscot/problems/space/_mixins.py:82, in SpatialAlignmentMixin._interpolate_scheme(self, reference, mode, spatial_key)
     79         steps[reference, start, False] = self._policy.plan(start=reference, end=start)
     81 for (start, end, forward), path in steps.items():
---> 82     tmap = self._interpolate_transport(path=path, scale_by_marginals=True)
     83     # make `tmap` to have shape `(m, n_ref)` and apply it to `src` of shape `(n_ref, d)`
     84     key, tmap = (start, tmap) if forward else (end, tmap.T)

File ~/miniconda3/envs/analysis310/lib/python3.10/site-packages/moscot/base/problems/_mixins.py:407, in AnalysisMixin._interpolate_transport(self, path, scale_by_marginals, **_)
    405 # TODO(@MUCDK, @giovp, discuss what exactly this function should do, seems like it could be more generic)
    406 fst, *rest = path
--> 407 return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals)

File ~/miniconda3/envs/analysis310/lib/python3.10/site-packages/moscot/base/output.py:196, in BaseDiscreteSolverOutput.chain(self, outputs, scale_by_marginals)
    194 op = self.as_linear_operator(scale_by_marginals)
    195 for out in outputs:
--> 196     op *= out.as_linear_operator(scale_by_marginals)
    198 return op

File ~/miniconda3/envs/analysis310/lib/python3.10/site-packages/scipy/sparse/linalg/_interface.py:433, in LinearOperator.__mul__(self, x)
    432 def __mul__(self, x):
--> 433     return self.dot(x)

File ~/miniconda3/envs/analysis310/lib/python3.10/site-packages/scipy/sparse/linalg/_interface.py:457, in LinearOperator.dot(self, x)
    442 """Matrix-matrix or matrix-vector multiplication.
    443 
    444 Parameters
   (...)
    454 
    455 """
    456 if isinstance(x, LinearOperator):
--> 457     return _ProductLinearOperator(self, x)
    458 elif np.isscalar(x):
    459     return _ScaledLinearOperator(self, x)

File ~/miniconda3/envs/analysis310/lib/python3.10/site-packages/scipy/sparse/linalg/_interface.py:722, in _ProductLinearOperator.__init__(self, A, B)
    720     raise ValueError('both operands have to be a LinearOperator')
    721 if A.shape[1] != B.shape[0]:
--> 722     raise ValueError(f'cannot multiply {A} and {B}: shape mismatch')
    723 super().__init__(_get_dtype([A, B]),
    724                                              (A.shape[0], B.shape[1]))
    725 self.args = (A, B)

ValueError: cannot multiply <393x398 _CustomLinearOperator with dtype=float32> and <397x398 _CustomLinearOperator with dtype=float32>: shape mismatch
@giovp
Copy link
Member

giovp commented Feb 24, 2025

thanks @hspitzer for reporting! Indeed it seems to be related to #798 . Either something changed in the way the indices are handled, or possibly it has to do with the scipy LinearOperator that might have undergone some changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants