-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add semi-structured sparse + dynamic int8 subclasses #36
Conversation
can you move the benchmark_sam and other .py files to one of the other directories? Maybe make a torch/benchmarks dir? |
) | ||
|
||
int_data = w_int_repr.contiguous() | ||
int_data = torch._cslt_compress(int_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it currently possible to replace this with
from torch.sparse import to_sparse_semi_structured
int_data = to_sparse_semi_structured(int_data)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving this one here b/c it's the cuSPARSELt fuse mul special one, but I have changed the subclass to be backend agnostic (Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
)
side note - do you care much about naming convention? This name is so long I kind of want to change it to something simpler like QuantizedSemiSparseLinearWeight
) | ||
|
||
|
||
class Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight(QuantizedLinearWeightBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is sparsity also implemented with tensor subclass? I thought we should be able to compose them in some way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have nested subclassing support currently for tracing, so we can't compose them currently :( hence why we're landing in prototype.
I can tag you in the issue i'll make to raise this for core.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving for prototype. Thanks for sending this :D
afc9d3d
to
4e0c8b3
Compare
…bs/ao into jcaip/quant+sparse_subclasses
…bs/ao into jcaip/quant+sparse_subclasses
This PR adds in int8 dynamic quantization + semi-structured sparsity support into torchao.
This PR adds in int8 dynamic quantization + semi-structured sparsity support into torchao.
This is implemented by extending the existing quantization subclasses to use sparse ops.
Ideally we would be able to compose subclasses, and call
to_sparse_semi_structured
from inside the quantization subclass, but ATM nested subclass tracing does not work with torch.compile and for stuff like fusing scales into the sparse multiply you would probably want to implement it like this anyways.In particular, this PR adds in two new subclasses:
For the cuSPARSELt subclasse, I can extend
Int8DynamicallyQuantizedWeightBase
by storing the compressed representation in W_int_repr.FuseMulWeight
will fuse one of the multiplies for the dequant into the cuSPARSELt matmul op. However cuSPARSELt expects this in a float32 format, so this eats into our previous speedups since we're now passing this as a bfloat16 tensor.However for the general subclass, I need to extend QuantizeWeightBase, because I need to pass two tensors (packed and meta) for the CUTLASS sparse mm op. This relies on
to_sparse_semi_structured
to decide between CUTLASS and cuSPARSELt, which is the right choice for UI but makes benchmarking between them kind of difficult, since it's a class var that decides which backend gets used. Maybe we should add a flag toto_sparse_semi_structured
because you might mix between cutlass and cusparselt.I've also added a benchmarking script for SAM. I don't know how we plan on handling dependencies in torchao, but let me know if there's a better place for that.
On batch size 32, I see a 1.16x speedup over bfloat16 torch.compile baseline, from 21.96 -> 25.54 img/s.
Other benchmarks (BS=16)