Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add blocksparse_int_addmm. Eliminate unnecessary contiguous calls which leads to performance increase. #891

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

pearu
Copy link

@pearu pearu commented Sep 16, 2024

As in the title.

This PR is created on top of #821 and requires pytorch/pytorch#136104 .

The diff of jcaip/int8-bsr and pearu/int8-bsr branches is

diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py
index 402948ad..d9e5e56f 100644
--- a/torchao/dtypes/affine_quantized_tensor.py
+++ b/torchao/dtypes/affine_quantized_tensor.py
@@ -1179,7 +1179,7 @@ def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor,
     w_vals = weight_tensor.layout_tensor
     w_scales = weight_tensor.layout_tensor.scale
     tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
-    tmp_t = tmp.t().contiguous()
+    tmp_t = tmp.t()
 
     # # Need to put this into custom op
     # weight_bsr = torch.sparse_bsr_tensor(w_vals.crow_indices(),
@@ -1197,21 +1197,15 @@ def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor,
     # input = torch.zeros(M, N, dtype=torch.int32, device=dense.device)
     # y = _int_bsr_dense_addmm(input, weight_bsr, tmp_t).t().contiguous()
 
-    y = torch.ops.blocksparse.int_mm(w_vals.crow_indices(),
-                                          w_vals.col_indices(),
-                                          w_vals.values(),
-                                          w_vals.shape[0],
-                                          w_vals.shape[1],
-                                          tmp_t)
 
-    # breakpoint()
-
-
-    y = x_scales.reshape(-1, 1) * y
-
-    y = (y * w_scales).reshape(
-        *x_vals_int8.shape[:-1], y.shape[-1]
-    )
+    y = torch.ops.blocksparse.int_addmm(w_vals.crow_indices(),
+                                        w_vals.col_indices(),
+                                        w_vals.values(),
+                                        tmp_t,
+                                        w_scales,
+                                        x_scales.reshape(-1))
+    y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1])
+    y = y.reshape(*y_shape)
 
     # can downcast only at the very end
     output_dtype = input_tensor.dtype
diff --git a/torchao/sparsity/prototype/superblock/benchmark.py b/torchao/sparsity/prototype/superblock/benchmark.py
index c960eb7b..20df85f7 100644
--- a/torchao/sparsity/prototype/superblock/benchmark.py
+++ b/torchao/sparsity/prototype/superblock/benchmark.py
@@ -13,6 +13,7 @@ import torch.utils.data
 import utils
 from torch import nn
 from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm
+from torch.sparse._triton_ops_meta import dump as store_tuned_kernel_params
 from torchao.sparsity.prototype.superblock.utils import accelerate_with_sparsity, simulate_sparsity
 from torchao.utils import benchmark_model, profiler_runner
 
@@ -34,15 +35,30 @@ def main(args):
     # BSR kernel tuning
     if args.bsr and args.tune_kernel_params:
         print("Tuning kernel params")
+        kwargs = dict(
+            dtype=torch.int8 if args.quantization else dtype,
+            sparsity=args.sparsity_linear, verbose=True,
+            # per blocksparse_int_addmm:
+            alpha=1, beta=0, use_left_alpha=True, use_right_alpha=True,
+            # force tuning because existing tuning parameters are
+            # computed for use_left/right_alpha=False, however, it
+            # turns out that re-tuning for use_left/right_alpha=False
+            # leads to the same set of tuning parametes:
+            # force=True
+        )
         if args.model == "vit_b_16":
-            optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
-            optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
+            optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, **kwargs)
+            optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, **kwargs)
         elif args.model == "vit_h_14":
-            optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
-            optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
+            optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, **kwargs)
+            optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, **kwargs)
         else:
             raise NotImplementedError("Tuning kernel params for this model is not supported yet.")
-
+        # Warning: the following call will overwrite the source code
+        # of torch.sparse._triton_ops_meta (hence it is commented out
+        # by default) but when used, it'll enables reusing the tuned
+        # parameters in subsequent runs of this script:
+        # store_tuned_kernel_params()
     print("Creating model")
     model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)

diff --git a/torchao/sparsity/prototype/superblock/blocksparse.py b/torchao/sparsity/prototype/superblock/blocksparse.py
index 5bf458ea..8e520ac9 100644
--- a/torchao/sparsity/prototype/superblock/blocksparse.py
+++ b/torchao/sparsity/prototype/superblock/blocksparse.py
@@ -5,7 +5,7 @@ from typing import Optional, Tuple, List, Dict, Any, Callable
 from torch.utils._python_dispatch import return_and_correct_aliasing
 from torchao.utils import TorchAOBaseTensor
 from torchao.quantization.quant_api import _get_linear_subclass_inserter
-from torch.sparse._triton_ops import bsr_dense_mm, _int_bsr_dense_addmm, broadcast_batch_dims
+from torch.sparse._triton_ops import bsr_dense_mm, _int_bsr_dense_addmm, broadcast_batch_dims, bsr_dense_addmm
 
 aten = torch.ops.aten
 
@@ -41,6 +41,31 @@ def blocksparse_int_mm_abstract(crow_indices: torch.Tensor, col_indices: torch.T
     new_shape = (A.shape[-1], M)
     return torch.empty(new_shape, dtype=torch.int8, device=A.device)
 
+
[email protected]_op("blocksparse::int_addmm", mutates_args=())
+def blocksparse_int_addmm(crow_indices: torch.Tensor,
+                          col_indices: torch.Tensor,
+                          values: torch.Tensor,
+                          A: torch.Tensor,
+                          left_alpha: torch.Tensor,
+                          right_alpha: torch.Tensor) -> torch.Tensor:
+    assert values.dtype == torch.int8
+    M = left_alpha.shape[-1]
+    K = A.shape[-2]
+    N = A.shape[-1]
+    weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K))
+    original_batch_dims_broadcasted = broadcast_batch_dims(blocksparse_int_addmm, weight_bsr, A)
+    out = A.new_empty(original_batch_dims_broadcasted + (M, N))
+    return bsr_dense_addmm(out, weight_bsr, A, alpha=1, beta=0, out=out, left_alpha=left_alpha, right_alpha=right_alpha).t()
+
+
[email protected]_fake("blocksparse::int_addmm")
+def blocksparse_int_addmm_abstract(crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, A: torch.Tensor, left_alpha: torch.Tensor, right_alpha: torch.Tensor) -> torch.Tensor:
+    N = A.shape[-1]
+    M = left_alpha.shape[-1]
+    return torch.empty((N, M), dtype=torch.int8, device=A.device)
+
+
 # Subclass definition
 class BlockSparseTensor(TorchAOBaseTensor):
     bsr_crow_indices: Optional[torch.Tensor]

As a result, the following performance test

python torchao/sparsity/prototype/superblock/benchmark.py --model vit_h_14   --batch-size 256   --sparsity-linear 0.8   --sp-linear-tile-size 64 --bsr 64 --sparsity bsr --quantization  --tune-kernel-params

leads to

340.710 ms
2.935 img/s
Memory: 2592

which should be compared to the previous state:

581.503 ms
1.720 img/s
Memory: 5004

Copy link

pytorch-bot bot commented Sep 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/891

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 16, 2024
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* Fixing the help mode of the download subcommand

* Initial Addition of subparsers for generation

* Move compile out of generation exclusive

* typo

* Fix test by removing temperature, which is a field eval doesn't use or expect

* Typo Generater => Generator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants