-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
training acceleration via runtime semi-structured sparsity #184
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/184
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 02446fa with merge base 8a4e693 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
# now you can run your normal training loop | ||
|
||
# if you need to swap back from semi_sparse linear to normal linear, we provide a utility function | ||
swap_semi_sparse_linear_with_linear(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's up with having _
at the end of the name though?
swap_semi_sparse_linear_with_linear(model) | |
swap_semi_sparse_linear_with_linear_(model) |
@@ -0,0 +1,53 @@ | |||
# Accelerated Sparse Training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why prototype? The API seems quite nice already if it module swaps on linear should be able to support most interesting models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will deprecating these APIs in the future be an issue then? I'm not sure that this swap_linear
API is something that I want to commit to long term.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not an issue
swap_semi_sparse_linear_with_linear_, | ||
) | ||
|
||
model = torch.nn.Sequential(torch.nn.Linear(64, 64)).cuda().to(torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does the example actually run faster? If not try to to have the minimal example that'll run faster and a way of printing that speedup to console. Also specify the supported SM for those speedups since if it's ampere+ then 3090 and 4090 should also benefit from your work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated, the example will run faster, and added compute capability limitations earlier in the readme, but for printing speedups - I don't think it's necessary as we have the benchmark script for that.
modifying the benchmarks for the example I see the following speedup (fw pass only):
[------------------------------------------------ mlpfw -------------------------------------------------]
| act24 | dense | w24 | s24_inp_sparsify24 | s24_inp_clone
1 threads: -----------------------------------------------------------------------------------------------
f16 (44160,1024,4096,1024) | 4813.2 | 4031.0 | 3440.1 | 255.4 | 121.4
""" | ||
|
||
def forward(self, x): | ||
sparse_weight = semi_sparse_sparsify(self.weight, backend="cusparselt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit on name but how about semi_sparsify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, this is a typo on my end, should be semi_structured_sparsify
class _SparsifyFunc(torch.autograd.Function): | ||
|
||
@staticmethod | ||
def forward(ctx, x: torch.Tensor, algo: str, backend: str): # type: ignore[override] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you put the supported backends in an enum?
) | ||
else: | ||
if ( | ||
tensor.compressed_swizzled_bitmask is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b q what's going on here?
reference_sparse_tensor.compressed_swizzled_bitmask, | ||
) | ||
|
||
# Add pointwise ops to the dispatch table |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how do you decide what to do add here? It seems like you looked at what's most used in actual modeling code but my question would be why not all pointwise ops in pytorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just added what was necessary for our experiments, for some pointwise ops, like mul you need to define the sparsification_like_args_list, so you cannot apply naively for all pointwise ops.
But it would make sense to add support for all the naive ones, I can add in a subsequent PR.
swap_semi_sparse_linear_with_linear_(model_c) | ||
for name, mod in model_c.named_modules(): | ||
assert not isinstance(mod, SemiSparseLinear) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
an interesting omission is no compile support but then you do have allow in graphs calls in your code so should we test compile support explciitly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should work with compile - I'll add a test
@msaroufim @vkuzo I guess we missed it but this is the first technique we have in ao for training performance improvement :) |
This PR adds in support for training acceleration, using runtime semi-structured sparsity kernels, which landed in core earlier: pytorch/pytorch#122350 This collects the necessary autograd functions, to support training and packages it up in a replacement `nn.Linear` modules, `SemiSparseLinear`, as well as a user API to swap out modules, `swap_linear_with_semi_sparse_linear_`. It also adds in some benchmarking code from xformers in order to measure the speedup of this module when applied to DINO shapes. We have a blog post coming out with more details about how this works. Testing: ``` python test/sparsity/test_fast_sparse_training.py ``` Benchmarking: ``` python benchmarks/benchmark_semi_sparse.py ``` For VIT-L MLP shapes we see the following results: ``` [------------------------------------------------ mlpfwbw -------------------------------------------------] | act24 | dense | w24 | s24_inp_sparsify24 | s24_inp_clone 1 threads: ------------------------------------------------------------------------------------------------- f16 (44160,1024,4096,1024) | 11881.0 | 11534.3 | 9204.7 | 255.1 | 125.8 Times are in microseconds (us). ```
* install gguf * moved model build to builder.py
This PR adds in support for training acceleration, using runtime semi-structured sparsity kernels, which landed in core earlier: pytorch/pytorch#122350
This collects the necessary autograd functions, to support training and packages it up in a replacement
nn.Linear
modules,SemiSparseLinear
, as well as a user API to swap out modules,swap_linear_with_semi_sparse_linear_
.It also adds in some benchmarking code from xformers in order to measure the speedup of this module when applied to DINO shapes.
We have a blog post coming out with more details about how this works.
Testing:
Benchmarking:
For VIT-L MLP shapes we see the following results: