Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2:4 Sparse GEMM API #1728

Open
tsengalb99 opened this issue Feb 18, 2025 · 3 comments
Open

2:4 Sparse GEMM API #1728

tsengalb99 opened this issue Feb 18, 2025 · 3 comments
Assignees

Comments

@tsengalb99
Copy link

I have a setup where I have a dense matrices $A, B$ and a 2:4 sparsity mask $M$. Is there an API in torchao where I can perform $A (B \odot M)^T$ and get the speedups from 2:4 GEMMs? That is, instead of having a pre-sparsified matrix $B$, I want to apply $M$ to $B$ online and then do the sparse GEMM.

@jcaip
Copy link
Contributor

jcaip commented Feb 18, 2025

Hi @tsengalb99

We don't have a public API for this but you may be able to hack something together with

def semi_structured_sparsify_like(
.

See https://github.com/pytorch/pytorch/blob/c9a15d980f249ad3697822476f658946d7907b44/test/test_sparse_semi_structured.py#L755 for an example of the private API torch._sparse_semi_structured_apply to use.

@tsengalb99
Copy link
Author

tsengalb99 commented Feb 19, 2025

I tried doing the following:

sm = torch.sparse.to_sparse_structured(B*M)
y = torch.mm(sm, A.T).T

This sometimes works but I sometimes get
NotImplementedError: `SparseSemiStructuredTensorCUSPARSELT` matmul: operation is not supported
in reference to

[rank7]:   File "/home/alberttseng/miniconda3/lib/python3.12/site-packages/torch/sparse/_semi_structured_ops.py", line 122, in semi_sparse_mm
[rank7]:     res = A._mm(B_padded)

Do you know how to fix this?

@jcaip
Copy link
Contributor

jcaip commented Feb 20, 2025

Do you have a script to repro @tsengalb99? I wouldn't expect a transient error here, I wonder what it could be.
In any case, I doubt your approach would be faster unless on very large matrices. I think to be faster you'd have to do something like outlined here. Note that this doesn't use to_sparse_semi_structured and uses torch._sparse_semi_structured_apply instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants