From b3b2b6de89b9cadd2e8753967dc532b288f3cf56 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 6 Feb 2024 12:08:45 -0800 Subject: [PATCH 01/26] wip --- benchmark_sam.py | 65 ++-- microbenchmark.py | 334 +++++++++++++++++++ results.txt | 12 + torchao/quantization/dynamic_quant_sparse.py | 154 ++------- torchao/quantization/quant_api.py | 11 +- torchao/quantization/subclass.py | 102 ++++-- 6 files changed, 493 insertions(+), 185 deletions(-) create mode 100644 microbenchmark.py create mode 100644 results.txt diff --git a/benchmark_sam.py b/benchmark_sam.py index 5ab3f28013..973a650ff2 100644 --- a/benchmark_sam.py +++ b/benchmark_sam.py @@ -4,16 +4,17 @@ from segment_anything import sam_model_registry from torch.utils.benchmark import Timer from torchao.sparsity import apply_fake_sparsity, apply_sparse - -from torchao.quantization.dynamic_quant_sparse import apply_int4_dynamic_quant_sparse +from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured sam_checkpoint_base_path = "/home/jessecai/local/MODELS" model_type = 'vit_h' model_name = 'sam_vit_h_4b8939.pth' checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}" -batchsize = 16 +batchsize = 32 only_one_block = False +torch._dynamo.reset() + @torch.no_grad() def benchmark(f, *args, **kwargs): for _ in range(3): @@ -45,51 +46,49 @@ def get_sam_model(only_one_block=False, batchsize=1): torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.coordinate_descent_check_all_directions = True torch._inductor.config.force_fuse_int_mm_with_mul = True - -change_linear_weights_to_int8_dqtensors(model) +SparseSemiStructuredTensor._FORCE_CUTLASS = False +change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model) model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) -print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") +print(f"bf16 cusparselt compiled runtime of the final quant + sparsified block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") -del model_c, model, image +# del model, model_c model, image = get_sam_model(only_one_block, batchsize) model = model.to(torch.bfloat16) image = image.to(torch.bfloat16) -apply_sparse(model) torch._inductor.config.epilogue_fusion = False torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.coordinate_descent_check_all_directions = True torch._inductor.config.force_fuse_int_mm_with_mul = True +SparseSemiStructuredTensor._FORCE_CUTLASS = True +change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model) model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) -print(f"bf16 compiled runtime of the final sparsified block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") +print(f"bf16 cutlass compiled runtime of the final quant + sparsified block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") -del model_c, model, image -model, image = get_sam_model(only_one_block, batchsize) -model = model.to(torch.bfloat16) -image = image.to(torch.bfloat16) -change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model) -torch._inductor.config.epilogue_fusion = False -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.coordinate_descent_check_all_directions = True -torch._inductor.config.force_fuse_int_mm_with_mul = True -model_c = torch.compile(model, mode='max-autotune') -quant_res = benchmark(model_c, image) +# del model, model_c +# model, image = get_sam_model(only_one_block, batchsize) +# model = model.to(torch.bfloat16) +# image = image.to(torch.bfloat16) +# torch._inductor.config.epilogue_fusion = False +# torch._inductor.config.coordinate_descent_tuning = True +# torch._inductor.config.coordinate_descent_check_all_directions = True +# torch._inductor.config.force_fuse_int_mm_with_mul = True +# change_linear_weights_to_int8_dqtensors(model) +# model_c = torch.compile(model, mode='max-autotune') +# quant_res = benchmark(model_c, image) -print(f"bf16 compiled runtime of the final quant + sparsified block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") +# print(f"bf16 compiled runtime of the final quant block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") -del model_c, model, image -model, image = get_sam_model(only_one_block, batchsize) -model = model.to(torch.bfloat16) -image = image.to(torch.bfloat16) -apply_int4_dynamic_quant_sparse(model) -torch._inductor.config.epilogue_fusion = False -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.coordinate_descent_check_all_directions = True -torch._inductor.config.force_fuse_int_mm_with_mul = True -model_c = torch.compile(model, mode='max-autotune') -quant_res = benchmark(model_c, image) +# del model, model_c +# model, image = get_sam_model(only_one_block, batchsize) +# model = model.to(torch.bfloat16) +# image = image.to(torch.bfloat16) +# SparseSemiStructuredTensor._FORCE_CUTLASS = True +# apply_sparse(model) +# model_c = torch.compile(model, mode='max-autotune') +# quant_res = benchmark(model_c, image) -print(f"bf16 compiled runtime of the final quant + sparsified block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") +# print(f"bf16 compiled runtime of the final sparsified block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") diff --git a/microbenchmark.py b/microbenchmark.py new file mode 100644 index 0000000000..2198142831 --- /dev/null +++ b/microbenchmark.py @@ -0,0 +1,334 @@ +import argparse +import random + +import pandas as pd +import torch +import torch.utils.benchmark as benchmark +from torch import nn +from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured +from torch.ao.pruning import WeightNormSparsifier +from tqdm import tqdm + +import math + +import torch +import torch.nn.functional as F +import itertools +import torch.utils.benchmark as benchmark +import math + +dtype = torch.float16 +device = "cuda" +torch.manual_seed(42) + + +torch.set_printoptions( + precision=2, + threshold=None, + edgeitems=16, + linewidth=480, + profile=None, + sci_mode=False, +) + +def create_blocked_tensor(M, N, blocksize, sparsity): + assert sparsity <= 1.0 and sparsity >= 0.0, \ + "sparsity should be a value between 0 and 1" + A = torch.bernoulli(torch.full((M//blocksize, N//blocksize), + 1 - sparsity, dtype=torch.bfloat16, device=device)) + A = torch.repeat_interleave(A, blocksize, dim=0) + A = torch.repeat_interleave(A, blocksize, dim=1) + return A.contiguous() + + +def create_24_tensor(M, N): + A = torch.randn(weight_shape, device="cuda") + + choices = [[0, 1], [1, 0]] + mask_entries = [random.choice(choices) for i in range(M * N // 2)] + + mask = torch.tensor(mask_entries).cuda().bool().reshape(M, N) + + A.masked_fill_(~mask, 0) + + return A.contiguous() + + +def benchmark_in_us(f, *args, **kwargs): + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", + globals={"args": args, "kwargs": kwargs, "f": f} + ) + return int(t0.blocked_autorange().mean * 1e6) + + +def run_benchmark(input_shape, weight_shape, dtype, sparsity=None, backend=None, blocksize=None, sparsity_level=None): + + m, k = weight_shape + n, k = math.prod(input_shape[:-1]), input_shape[-1] + + if sparsity == "blocksparse": + A = create_blocked_tensor(m, k, blocksize=blocksize, sparsity=sparsity_level).to(dtype) + A_sparse = A.to_sparse_bsr(blocksize=blocksize) + + elif sparsity == "24": + # blocksize = 4 + # sparsity_level = 0.5 + if backend == "cutlass": + SparseSemiStructuredTensor._FORCE_CUTLASS = True + elif backend == "cusparselt": + SparseSemiStructuredTensor._FORCE_CUTLASS = False + else: + raise ValueError("Wrong value for backend") + + A = create_24_tensor(m, k).to(dtype) + A_sparse = to_sparse_semi_structured(A) + + # b = torch.randn(m, device="cuda").to(dtype) + x = torch.randn(n, k).to(dtype).cuda() + + + # get timing speedups + # handle int_mm custom + if dtype == torch.int8: + dense_time = benchmark_in_us(torch._int_mm, A, x.t()) + dense_output = torch._int_mm(A, x.t()).to(torch.float32).t() + else: + dense_time = benchmark_in_us(F.linear, x, A) + dense_output = F.linear(x, A).to(torch.float32) + + sparse_time = benchmark_in_us(F.linear, x, A_sparse) + sparse_output = F.linear(x, A_sparse).to(torch.float32) + + ratio = dense_time / sparse_time + + + if backend == "cusparselt": + # grab optimal alg id for cusparselt + padded = A_sparse._pad_tensor_for_matmul(x) + if dtype is torch.int8: + out_dtype = torch.bfloat16 + optimal_alg_id = torch._cslt_sparse_mm_search(A_sparse.compressed_tensor_cusparselt, padded.t()) + # print("optimal alg_id", optimal_alg_id) + else: + optimal_alg_id = None + + # sanity check correctness + correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) + + # # in depth checks + # dense_output = F.linear(x.to(torch.float32), A.to(torch.float32)) + + # diff = ~torch.isclose(dense_output, sparse_output) + + # dense_output_diff = dense_output[diff] + # sparse_output_diff = sparse_output[diff] + + # sparse_output_diff_nonzero = sparse_output_diff.nonzero() + # dense_output_diff = dense_output_diff[sparse_output_diff_nonzero] + # sparse_output_diff = sparse_output_diff[sparse_output_diff_nonzero] + + # outside_atol = ~((dense_output_diff - sparse_output_diff).abs() < 1e-3) + + # larger_dense_output_diff = dense_output_diff[outside_atol] + # larger_sparse_output_diff = sparse_output_diff[outside_atol] + + # pos = (1 - (larger_dense_output_diff / larger_sparse_output_diff)).abs().argmax().item() + + return { + "dtype": str(dtype), + "m": m, + "k": k, + "n": n, + "sparse_latency (us)": sparse_time, + "dense_latency (us)": dense_time, + "speedup (d/s)": f"{ratio:.3f}", + "correct": correct, + # "sparse v dense diff": f"{larger_dense_output_diff[pos]:+11.7f} vs. {larger_sparse_output_diff[pos]:+11.7f}", + "sparsity type": sparsity, + "backend": backend, + "blocksize": blocksize, + "sparsity level": sparsity_level, + "optimal_alg_id": optimal_alg_id, + } + +if __name__ == "__main__": + dtype_lookup = { + "int8": torch.int8, + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + + parser = argparse.ArgumentParser(description="GPU Sparsity Kernel Microbenchmarks") + parser.add_argument( + "--mode", + type=str, + choices=[ + "nvidia-bert", + "sam-shapes", + "nvidia-fixed-k", + "nvidia-fixed-mn", + "optimize-matmul-block-sparse", + ], + ) + parser.add_argument( + "--dtype", + type=str, + choices=dtype_lookup.keys(), + default="fp16", + ) + parser.add_argument("--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt") + parser.add_argument("--function", type=str, choices=["linear", "mm"], default="linear") + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("-contiguous", action="store_true") + parser.add_argument("-save", action="store_true") + args = parser.parse_args() + + eval_fn = run_benchmark + + print(f"Started benchmark: {args.mode} | dtype: {args.dtype}") + dtype = dtype_lookup[args.dtype] + + if args.mode == "nvidia-bert": + bert_shapes = [ + (3072, 1024, 16384), + (4096, 1024, 16384), + (1024, 1024, 16384), + (1024, 4096, 16384), + ] + results = [ + eval_fn(m, k, n, dtype, sparsity="blocksparse", blocksize=64, sparsity_level=0.8) + for (m, k, n) in tqdm(bert_shapes) + ] + + results += [ + eval_fn(m, k, n, dtype, sparsity="24", backend="cusparselt") + for (m, k, n) in tqdm(bert_shapes) + ] + + if args.mode == "optimize-matmul-block-sparse": + batch_size = args.batch_size + + sam_shapes = [ + (torch.Size([batch_size, 64, 64, 1280]), torch.Size([5120, 1280])), + ] + + from collections import defaultdict + results = [] + total_runtime = defaultdict(int) + + for (activation_shape, weight_shape) in tqdm(sam_shapes): + for blocksize in [64]: + for sparsity_level in range(0, 100): + sparsity_level = float(sparsity_level) / 100 + result = run_benchmark( + activation_shape, + weight_shape, + dtype, + sparsity="blocksparse", + blocksize=blocksize, + sparsity_level=sparsity_level) + total_runtime[f"{blocksize}_{sparsity_level}"] += 32 * result["sparse_latency (us)"] + results.append(result) + + if args.mode == "sam-shapes": + batch_size = args.batch_size + + sam_shapes = [ + (torch.Size([batch_size, 64, 64, 1280]), torch.Size([5120, 1280])), + (torch.Size([batch_size, 64, 64, 5120]), torch.Size([1280, 5120])), + (torch.Size([25 * batch_size, 14, 14, 1280]), torch.Size([3840, 1280])), + (torch.Size([25 * batch_size, 14, 14, 1280]), torch.Size([1280, 1280])), + ] + + from collections import defaultdict + results = [] + total_runtime = defaultdict(int) + + for (activation_shape, weight_shape) in tqdm(sam_shapes): + for backend in ["cutlass", "cusparselt"]: + result = run_benchmark( + activation_shape, + weight_shape, + dtype, + sparsity="24", + backend=backend) + + blocksize = None + sparsity_level = 0.5 + total_runtime[f"{backend}"] += 32 * result["sparse_latency (us)"] + results.append(result) + # for blocksize in [64]: + # for sparsity_level in [0.8, 0.9]: + # result = run_benchmark( + # activation_shape, + # weight_shape, + # dtype, + # sparsity="blocksparse", + # blocksize=blocksize, + # sparsity_level=sparsity_level) + # total_runtime[f"{blocksize}_{sparsity_level}"] += 32 * result["sparse_latency (us)"] + # results.append(result) + + total_runtime["dense"] += 32 * result["dense_latency (us)"] + + for line in total_runtime: + print(line, total_runtime[line], sep="\t") + + + elif args.mode == "nvidia-fixed-k": + mn_vals = [ + 3072, + 4096, + 5120, + 6144, + 7168, + 8192, + 9216, + 10240, + 11264, + 12288, + 13312, + 14336, + 15360, + 16384, + 17408, + 18432, + 19456, + 20480, + ] + results = ( + eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend) + for mn in tqdm(mn_vals) + ) + + elif args.mode == "nvidia-fixed-mn": + k_vals = [ + 2560, + 3840, + 5120, + 6400, + 7680, + 8960, + 10240, + 11520, + 12800, + 14080, + 15360, + 16640, + 17920, + 19200, + 20480, + ] + results = ( + eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend) + for k in tqdm(k_vals) + ) + + df = pd.DataFrame.from_records(results) + if args.save: + save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv" + df.to_csv(save_file) + print(f"Finished benchmark: {args.mode} saved results to {save_file}") + print(df) diff --git a/results.txt b/results.txt new file mode 100644 index 0000000000..bba973deb8 --- /dev/null +++ b/results.txt @@ -0,0 +1,12 @@ +res.txt + + +module swapping (ao_benchmarks) | latency (bs32) +----------------------------------------------------------------------------- +Baseline (bf16 compiled) | 1636.584 | +Semi-structured sparse (cutlass bf16 compiled) | | 1325.59 +Semi-structured sparse (cusparselt bf16 compiled) | 1389.318 | 1316.92 +Dynamic quant (int8 compiled) | 1404.085 | 1319.46 +Semi-structured sparse + dynamic quant (int8 compiled) | 1370.230 | 1293.68 +24 sparse + dynamic quant (int8 compiled) + fuse dequant| 1278.547 | +cutlass 24 sparse + dynamic quant (int8 compiled) | | 1293.11 diff --git a/torchao/quantization/dynamic_quant_sparse.py b/torchao/quantization/dynamic_quant_sparse.py index 1489c96d94..fa2de12200 100644 --- a/torchao/quantization/dynamic_quant_sparse.py +++ b/torchao/quantization/dynamic_quant_sparse.py @@ -12,20 +12,26 @@ from torchao.quantization import quant_api from torchao.sparsity import apply_fake_sparsity +from torch.sparse import SparseSemiStructuredTensor + +import math + +FUSE_DEQUANT = False + # Quant + Sparse helper functinos def sparse_quant_int8_dynamic_per_token_linear( x, w_vals_int8, + w_meta_int32, w_scales, bias, out_dtype=torch.float32, - fuse_dequant=True, ): # like F.linear, but with int8 dynamic quantization of activation, # and a quantized weight x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) mm_out = sparse_quant_int8_per_token_matmul( - x_vals_int8, x_scales, w_vals_int8, w_scales, out_dtype, fuse_dequant=fuse_dequant) + x_vals_int8, x_scales, w_vals_int8, w_meta_int32, w_scales, out_dtype) if bias is not None: mm_out += bias return mm_out @@ -34,133 +40,45 @@ def sparse_quant_int8_per_token_matmul( x_vals_int8, x_scales, w_vals_int8, + w_meta_int32, w_scales, out_dtype=torch.float32, - fuse_dequant=True, ): - # Quantized sparse matmul of int8 operands that accumulates to fp16 and returns - # out_dtype. This matmul uses cuSPARSELt as a backend. - - # Assumes that activation and weight quantization are symmetric, - # i.e. act_zp and w_zp is 0. - # Assumes that weight quantization is per-channel. - # NOTE: sparsity is only compatible with symmetric (zero-preserving) quantization techniques. - - # see - # https://github.com/google/gemmlowp/blob/master/doc/quantization.md - # for an overview of quantized matmul compute - - # in scalar form, assuming out_dtype is fp32 and zw == 0: - # - # Y_i_j_fp32 = sx * sw dot(X_i, W_j) - # assert x_vals_int8.dtype == torch.int8, \ f'x dtype {x_vals_int8.dtype} not yet supported' assert w_vals_int8.dtype == torch.int8, \ f'w dtype {w_vals_int8.dtype} not yet supported' - assert w_scales.dtype == out_dtype, \ - f'{w_scales.dtype} does not match {out_dtype}' - - # - # 1. do the matrix form of dot(X_i, W_j) - # + # assert w_scales.dtype == out_dtype, \ + # f'{w_scales.dtype} does not match {out_dtype}' + if w_meta_int32 is not None: + assert w_meta_int32.dtype == torch.int32, \ + f'{w_meta_int32.dtype} not yet supported' - # For sparse matmul, we need one of the input operands to be transposed. - # This is because cuSPARSELt only supports int8 matmul for specific formats: - # https://docs.nvidia.com/cuda/cusparselt/functions.html#matmul-descriptor-functions - # Because we currently only support the first input to the operand being sparse, - # we cannot transpose w_vals_int8, so instead we transpose x_vals_int8. tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() - # Since cuSPARSELt does not have support for int32 output, we instead use the fp16 kernel - # instead, by setting out_dtype. - # y_dot_fp16 = torch._sparse_semi_structured_linear(tmp, w_vals_int8, out_dtype=torch.float16) - y_dot_fp16 = torch._cslt_sparse_mm(w_vals_int8, tmp.t(), out_dtype=torch.float16).t() - y_dot_fp32 = y_dot_fp16.reshape(*x_vals_int8.shape[:-1], -1).to(out_dtype) - - # - # 2. rescale the output - # - # in cases with large matrices, y_dot_int32 can grow sufficiently - # large that y_dot_int32 * a float16 scale is greater than the maximum - # value of a float 16, (which results in a value of inf even if multiplying - # by the other scale would bring it within the expected range) - - # assert x_scales.dtype == torch.float, f"x_scales needs to be a torch.float32 but got {x_scales.dtype}" - - y = y_dot_fp32 * x_scales * w_scales + assert x_scales.dtype in [ + torch.float, + torch.bfloat16, + ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" + + if w_meta_int32 is None: + # if FUSE_DEQUANT: + y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16).t() + y_dot_bf16_w_scales_fused = y_dot_bf16_w_scales_fused.reshape(*x_vals_int8.shape[:-1], -1) + y = y_dot_bf16_w_scales_fused * x_scales + # else: + # y_dot_int32 = torch._cslt_sparse_mm(w_vals_int8, tmp.t(), out_dtype=torch.int32).t() + + # y = (y_dot_int32 * x_scales.reshape(-1, 1) * w_scales).reshape( + # *x_vals_int8.shape[:-1], y_dot_int32.shape[-1] + # ) + else: + y_dot_int32 = torch._sparse_semi_structured_linear(tmp, w_vals_int8, w_meta_int32.view(torch.int32), out_dtype=torch.int32) + + y = (y_dot_int32 * x_scales.reshape(-1, 1) * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_int32.shape[-1] + ) # can downcast only at the very end y = y.to(out_dtype) return y - -class SparseDynamicallyPerAxisQuantizedLinear(torch.nn.Linear): - """ - This class is a replacement for `torch.nn.Linear`, implementing sparse dynamic quantization on - the input across all axes except for the last axis. - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True - ): - super().__init__(in_features, out_features, bias) - - def forward(self, X: torch.Tensor) -> torch.Tensor: - """ - Performs the forward pass of the sparse quantized linear layer. - - This method applies dynamic quantization to the input tensor across all axes except - the last axis using the `quant_int8_dynamic_per_token_linear` function. - - We artifically limit the quantization value to int4 range to ensure we stay within the range of fp16. - This method will use cuSPASRELt to perform sparse matmul. - - Args: - X (torch.Tensor): The input tensor to the sparse quantized linear layer. - Returns: - torch.Tensor: The output tensor after the sparse quantized matmul and rescale. - """ - Y = sparse_quant_int8_dynamic_per_token_linear( - X, self.W_int_repr, self.W_scales, self.bias, X.dtype, fuse_dequant=self.fuse_dequant) - return Y - - @classmethod - def from_float(cls, mod: torch.nn.Linear, fuse_dequant=True) -> 'SparseDynamicallyPerAxisQuantizedLinear': - """ - Converts a `mod` of class `torch.nn.Linear` to the sparse dynamically quantized version of it. - Note: this class does not require calibration. - Args: - mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert. - Returns: - SparseDynamicallyPerAxisQuantizedLinear: The converted sparse quantized linear module. - """ - - # create the new module with a toy size to ensure initialization is fast - fake_in_features, fake_out_features = 8, 8 - new_mod = cls( - fake_in_features, fake_out_features, bias=mod.bias is not None) - new_mod.in_features = mod.in_features - new_mod.out_features = mod.out_features - # NOTE: We artifically clamp the values to int4 quantization to ensure we stay within the - # dynamic range of fp16 - W_int_repr, W_scales, _W_zps = dynamically_quantize_per_channel( - mod.weight, -8, 7, torch.int8) - new_mod.register_buffer('W_int_repr', torch._cslt_compress(W_int_repr.contiguous())) - new_mod.register_buffer('W_scales', W_scales) - new_mod.bias = mod.bias - new_mod.fuse_dequant = fuse_dequant - del new_mod.weight - - device_to_use = next(mod.parameters()).device - new_mod.to(device_to_use) - return new_mod - -def apply_int4_dynamic_quant_sparse(model, fuse_dequant=False): - apply_fake_sparsity(model) - quant_api._replace_with_custom_fn_if_matches_filter( - model, - partial(SparseDynamicallyPerAxisQuantizedLinear.from_float, fuse_dequant=fuse_dequant), - lambda mod, fqn: isinstance(mod, torch.nn.Linear)) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 1202ec399a..1585ee4f95 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -22,7 +22,8 @@ from .subclass import ( QuantizedLinearWeightBase, Int8DynamicallyQuantizedLinearWeight, - Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight, + Int8DynamicallyQuantized24CutlassLinearWeight, + Int8DynamicallyQuantized24CusparseltLinearWeight, Int8WeightOnlyQuantizedLinearWeight, Int4WeightOnlyQuantizedLinearWeight, ) @@ -159,9 +160,15 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs): def change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model, **kwargs): filter_fn = kwargs.pop("filter_fn", _is_linear) + from torch.sparse import SparseSemiStructuredTensor + if SparseSemiStructuredTensor._FORCE_CUTLASS: + subclass = Int8DynamicallyQuantized24CutlassLinearWeight + else: + subclass = Int8DynamicallyQuantized24CusparseltLinearWeight + _replace_with_custom_fn_if_matches_filter( model, - _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight, **kwargs), + _get_subclass_inserter(subclass, **kwargs), filter_fn, ) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index f7047d8cbd..4e9e3d1b39 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -18,10 +18,13 @@ from .utils import find_multiple import warnings +from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured + __all__ = [ "Int8DynamicallyQuantizedLinearWeight", - "Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight", + "Int8DynamicallyQuantized24CutlassLinearWeight", + "Int8DynamicallyQuantized24CusparseltLinearWeight", "Int8WeightOnlyQuantizedLinearWeight", "Int4WeightOnlyQuantizedLinearWeight", ] @@ -277,23 +280,45 @@ def from_float(cls, input_float, qmin=-128, qmax=127): int_data, w_scales, False, input_float.shape, dtype=input_float.dtype ) -class Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight(QuantizedLinearWeightBase): + +class Int8DynamicallyQuantized24CusparseltLinearWeight(Int8DynamicallyQuantizedLinearWeight): @staticmethod - def __new__(cls, int_data, q_scales, transposed, shape, **kwargs): + def _quantized_op(act_mat, w_qtensor, bias): + return sparse_quant_int8_dynamic_per_token_linear( + act_mat, w_qtensor.int_data, None, w_qtensor.q_scales, bias, act_mat.dtype + ) + + @classmethod + def from_float(cls, input_float, qmin=-128, qmax=127): + + assert input_float.is_cuda + + w_int_repr, w_scales, _ = dynamically_quantize_per_channel( + input_float, qmin, qmax, torch.int8 + ) + + int_data = w_int_repr.contiguous() + + + int_data = torch._cslt_compress(int_data) + + return cls( + int_data, w_scales, False, input_float.shape, dtype=input_float.dtype, + ) + +class Int8DynamicallyQuantized24CutlassLinearWeight(QuantizedLinearWeightBase): + + @staticmethod + def __new__(cls, int_data, mask_meta, q_scales, transposed, shape, **kwargs): kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype) return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] - def __init__(self, int_data, q_scales, transposed, shape, **kwargs): + def __init__(self, int_data, mask_meta, q_scales, transposed, shape, **kwargs): self.q_scales = q_scales + self.mask_meta = mask_meta super().__init__(int_data, transposed) - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_per_token_linear( - act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype - ) - def dequantize(self, dtype=None): """ Obtain the dequantized version of the quantized tensor subclass @@ -320,6 +345,7 @@ def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) return self.__class__( self.int_data.to(kwargs["device"]), + self.mask_meta.to(kwargs["device"]), self.q_scales.to(kwargs["device"]), self.transposed, self.shape, @@ -328,50 +354,62 @@ def to(self, *args, **kwargs): def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.int_data), fn(self.q_scales), self.transposed, self.shape, dtype=self.dtype + fn(self.int_data), + fn(self.mask_meta), + fn(self.q_scales), + self.transposed, + self.shape, + dtype=self.dtype ) def _change_shape(self, shape): return self.__class__( - self.int_data, self.q_scales, self.transposed, shape, dtype=self.dtype + self.int_data, + self.mask_meta, + self.q_scales, + self.transposed, + shape, + dtype=self.dtype ) def __tensor_flatten__(self): - return ["int_data", "q_scales"], [self.transposed, self.dtype, self.shape] + return ["int_data", "mask_meta", "q_scales"], [self.transposed, self.dtype, self.shape] @classmethod def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): int_data, q_scales = tensor_data_dict["int_data"], tensor_data_dict["q_scales"] + mask_meta = tensor_data_dict["mask_meta"] transposed, dtype, shape = tensor_attributes - return cls(int_data, q_scales, transposed, shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride) + return cls(int_data, mask_meta, q_scales, transposed, shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride) + + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + return sparse_quant_int8_dynamic_per_token_linear( + act_mat, w_qtensor.int_data, w_qtensor.mask_meta, w_qtensor.q_scales, bias, act_mat.dtype + ) @classmethod def from_float(cls, input_float, qmin=-128, qmax=127): - """ - Method used to convert a linear weight tensor to an instance of the - Int8DynamicallyQuantizedLinearWeight subclass. - Example usage:: + assert input_float.is_cuda - model.lin_mod.weight = ( - Int8DynamicallyQuantizedLinearWeight.from_float(model.lin_mod.weight) - ) - """ w_int_repr, w_scales, _ = dynamically_quantize_per_channel( input_float, qmin, qmax, torch.int8 ) - # the desired representation shape for fast quantized matmul is - # transposed compared to how it's stored as a linear weight, - # i.e. we want in_channels as dim=0 and out_channels (and quantized axis) as dim=1 - # however the external representation of our tensor will maintain the correct - # shape attribute which needs to be tracked directly. - int_data = w_int_repr.contiguous().t() - if cls is not Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight: - int_data = int_data.contiguous() - int_data = torch._cslt_compress(int_data) + int_data = w_int_repr.contiguous() + + sparse_tensor = to_sparse_semi_structured(int_data).t() + + if sparse_tensor.compressed_tensor_cusparselt is None: + int_data = sparse_tensor.sparse_tensor_cutlass + mask_meta = sparse_tensor.meta_tensor_cutlass + else: + int_data = sparse_tensor.compressed_tensor_cusparselt + mask_meta = None + return cls( - int_data, w_scales, False, input_float.shape, dtype=input_float.dtype + int_data, mask_meta, w_scales, False, input_float.shape, dtype=input_float.dtype ) From 5cc766ea678153a0e9531ae7b141d7fa33b02294 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 21 Mar 2024 13:28:01 -0700 Subject: [PATCH 02/26] test --- benchmark_sam.py | 94 ---------------- results.txt | 12 --- torchao/quantization/dynamic_quant_sparse.py | 100 ++++++++++-------- torchao/quantization/subclass.py | 8 +- .../sparsity/microbenchmarks.py | 61 +++++------ torchao/sparsity/{sparse.py => sparse_api.py} | 0 6 files changed, 90 insertions(+), 185 deletions(-) delete mode 100644 benchmark_sam.py delete mode 100644 results.txt rename microbenchmark.py => torchao/sparsity/microbenchmarks.py (86%) rename torchao/sparsity/{sparse.py => sparse_api.py} (100%) diff --git a/benchmark_sam.py b/benchmark_sam.py deleted file mode 100644 index 973a650ff2..0000000000 --- a/benchmark_sam.py +++ /dev/null @@ -1,94 +0,0 @@ -import torch -from torchao.quantization import change_linear_weights_to_int8_dqtensors -from torchao.quantization.quant_api import change_linear_weights_to_int8_dq_semi_structured_sparsetensors -from segment_anything import sam_model_registry -from torch.utils.benchmark import Timer -from torchao.sparsity import apply_fake_sparsity, apply_sparse -from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured - -sam_checkpoint_base_path = "/home/jessecai/local/MODELS" -model_type = 'vit_h' -model_name = 'sam_vit_h_4b8939.pth' -checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}" -batchsize = 32 -only_one_block = False - -torch._dynamo.reset() - -@torch.no_grad() -def benchmark(f, *args, **kwargs): - for _ in range(3): - f(*args, **kwargs) - torch.cuda.synchronize() - - torch.cuda.reset_peak_memory_stats() - t0 = Timer( - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} - ) - res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20) - return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9} - -def get_sam_model(only_one_block=False, batchsize=1): - sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda() - model = sam.image_encoder.eval() - image = torch.randn(batchsize, 3, 1024, 1024, device='cuda') - - # code to use just a single block of the model - if only_one_block: - model = model.blocks[0] - image = torch.randn(batchsize, 64, 64, 1280, device='cuda') - return model, image - -model, image = get_sam_model(only_one_block, batchsize) -model = model.to(torch.bfloat16) -image = image.to(torch.bfloat16) -torch._inductor.config.epilogue_fusion = False -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.coordinate_descent_check_all_directions = True -torch._inductor.config.force_fuse_int_mm_with_mul = True -SparseSemiStructuredTensor._FORCE_CUTLASS = False -change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model) -model_c = torch.compile(model, mode='max-autotune') -quant_res = benchmark(model_c, image) - -print(f"bf16 cusparselt compiled runtime of the final quant + sparsified block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") - -# del model, model_c -model, image = get_sam_model(only_one_block, batchsize) -model = model.to(torch.bfloat16) -image = image.to(torch.bfloat16) -torch._inductor.config.epilogue_fusion = False -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.coordinate_descent_check_all_directions = True -torch._inductor.config.force_fuse_int_mm_with_mul = True -SparseSemiStructuredTensor._FORCE_CUTLASS = True -change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model) -model_c = torch.compile(model, mode='max-autotune') -quant_res = benchmark(model_c, image) - -print(f"bf16 cutlass compiled runtime of the final quant + sparsified block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") - -# del model, model_c -# model, image = get_sam_model(only_one_block, batchsize) -# model = model.to(torch.bfloat16) -# image = image.to(torch.bfloat16) -# torch._inductor.config.epilogue_fusion = False -# torch._inductor.config.coordinate_descent_tuning = True -# torch._inductor.config.coordinate_descent_check_all_directions = True -# torch._inductor.config.force_fuse_int_mm_with_mul = True -# change_linear_weights_to_int8_dqtensors(model) -# model_c = torch.compile(model, mode='max-autotune') -# quant_res = benchmark(model_c, image) - -# print(f"bf16 compiled runtime of the final quant block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") - -# del model, model_c -# model, image = get_sam_model(only_one_block, batchsize) -# model = model.to(torch.bfloat16) -# image = image.to(torch.bfloat16) -# SparseSemiStructuredTensor._FORCE_CUTLASS = True -# apply_sparse(model) -# model_c = torch.compile(model, mode='max-autotune') -# quant_res = benchmark(model_c, image) - -# print(f"bf16 compiled runtime of the final sparsified block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") diff --git a/results.txt b/results.txt deleted file mode 100644 index bba973deb8..0000000000 --- a/results.txt +++ /dev/null @@ -1,12 +0,0 @@ -res.txt - - -module swapping (ao_benchmarks) | latency (bs32) ------------------------------------------------------------------------------ -Baseline (bf16 compiled) | 1636.584 | -Semi-structured sparse (cutlass bf16 compiled) | | 1325.59 -Semi-structured sparse (cusparselt bf16 compiled) | 1389.318 | 1316.92 -Dynamic quant (int8 compiled) | 1404.085 | 1319.46 -Semi-structured sparse + dynamic quant (int8 compiled) | 1370.230 | 1293.68 -24 sparse + dynamic quant (int8 compiled) + fuse dequant| 1278.547 | -cutlass 24 sparse + dynamic quant (int8 compiled) | | 1293.11 diff --git a/torchao/quantization/dynamic_quant_sparse.py b/torchao/quantization/dynamic_quant_sparse.py index fa2de12200..33538074b8 100644 --- a/torchao/quantization/dynamic_quant_sparse.py +++ b/torchao/quantization/dynamic_quant_sparse.py @@ -2,58 +2,86 @@ import torch.nn as nn from typing import Tuple, Optional -from functools import partial - from torchao.quantization.quant_primitives import ( dynamically_quantize_per_channel, quant_int8_dynamic_per_token_linear, quantize_activation_per_token_absmax ) -from torchao.quantization import quant_api -from torchao.sparsity import apply_fake_sparsity from torch.sparse import SparseSemiStructuredTensor -import math - -FUSE_DEQUANT = False - # Quant + Sparse helper functinos -def sparse_quant_int8_dynamic_per_token_linear( + +def sparse_quant_int8_dynamic_cutlass_linear( x, w_vals_int8, w_meta_int32, w_scales, bias, - out_dtype=torch.float32, + out_dtype, ): - # like F.linear, but with int8 dynamic quantization of activation, - # and a quantized weight x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) - mm_out = sparse_quant_int8_per_token_matmul( + mm_out = sparse_quant_int8_cutlass_matmul( x_vals_int8, x_scales, w_vals_int8, w_meta_int32, w_scales, out_dtype) + + if bias is not None: + mm_out += bias + return mm_out + +def sparse_quant_int8_dynamic_cslt_linear( + x, + w_vals_int8, + w_scales, + bias, + out_dtype, +): + x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) + mm_out = sparse_quant_int8_cslt_matmul( + x_vals_int8, x_scales, w_vals_int8, w_scales, out_dtype) + if bias is not None: mm_out += bias return mm_out -def sparse_quant_int8_per_token_matmul( + +def sparse_quant_int8_cslt_matmul( x_vals_int8, x_scales, w_vals_int8, - w_meta_int32, w_scales, - out_dtype=torch.float32, + out_dtype, ): - assert x_vals_int8.dtype == torch.int8, \ - f'x dtype {x_vals_int8.dtype} not yet supported' - assert w_vals_int8.dtype == torch.int8, \ - f'w dtype {w_vals_int8.dtype} not yet supported' - # assert w_scales.dtype == out_dtype, \ - # f'{w_scales.dtype} does not match {out_dtype}' - if w_meta_int32 is not None: - assert w_meta_int32.dtype == torch.int32, \ - f'{w_meta_int32.dtype} not yet supported' + assert x_vals_int8.dtype == torch.int8, f'x dtype {x_vals_int8.dtype} not yet supported' + assert w_vals_int8.dtype == torch.int8, f'w dtype {w_vals_int8.dtype} not yet supported' + assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' + + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() + + assert x_scales.dtype in [ + torch.float, + torch.bfloat16, + ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" + + y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(w_vals_int8, tmp.t(), alpha=w_scales, out_dtype=torch.bfloat16).t() + y = (y_dot_bf16_w_scales_fused* x_scales.reshape(-1, 1)).reshape( + *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] + ) + y = y.to(out_dtype) + return y + +def sparse_quant_int8_cutlass_matmul( + x_vals_int8, + x_scales, + w_vals_int8, + w_meta_int32, + w_scales, + out_dtype, +): + assert x_vals_int8.dtype == torch.int8, f'x dtype {x_vals_int8.dtype} not yet supported' + assert w_vals_int8.dtype == torch.int8, f'w dtype {w_vals_int8.dtype} not yet supported' + assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' + assert w_meta_int32.dtype == torch.int32, f'{w_meta_int32.dtype} not yet supported' tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() @@ -62,23 +90,9 @@ def sparse_quant_int8_per_token_matmul( torch.bfloat16, ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - if w_meta_int32 is None: - # if FUSE_DEQUANT: - y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16).t() - y_dot_bf16_w_scales_fused = y_dot_bf16_w_scales_fused.reshape(*x_vals_int8.shape[:-1], -1) - y = y_dot_bf16_w_scales_fused * x_scales - # else: - # y_dot_int32 = torch._cslt_sparse_mm(w_vals_int8, tmp.t(), out_dtype=torch.int32).t() - - # y = (y_dot_int32 * x_scales.reshape(-1, 1) * w_scales).reshape( - # *x_vals_int8.shape[:-1], y_dot_int32.shape[-1] - # ) - else: - y_dot_int32 = torch._sparse_semi_structured_linear(tmp, w_vals_int8, w_meta_int32.view(torch.int32), out_dtype=torch.int32) - - y = (y_dot_int32 * x_scales.reshape(-1, 1) * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_int32.shape[-1] - ) - # can downcast only at the very end + y_dot_int32 = torch._sparse_semi_structured_linear(tmp, w_vals_int8, w_meta_int32.view(torch.int32), out_dtype=torch.int32) + y = (y_dot_int32 * x_scales.reshape(-1, 1) * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_int32.shape[-1] + ) y = y.to(out_dtype) return y diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 4e9e3d1b39..12dd5753a5 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -14,7 +14,7 @@ quant_int8_dynamic_per_token_linear, unpack_tinygemm_scales_and_zeros, ) -from .dynamic_quant_sparse import sparse_quant_int8_dynamic_per_token_linear +from .dynamic_quant_sparse import sparse_quant_int8_dynamic_cutlass_linear, sparse_quant_int8_dynamic_cusparselt_linear from .utils import find_multiple import warnings @@ -285,8 +285,8 @@ class Int8DynamicallyQuantized24CusparseltLinearWeight(Int8DynamicallyQuantizedL @staticmethod def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_per_token_linear( - act_mat, w_qtensor.int_data, None, w_qtensor.q_scales, bias, act_mat.dtype + return sparse_quant_int8_dynamic_cusparselt_linear( + act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype ) @classmethod @@ -384,7 +384,7 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=No @staticmethod def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_per_token_linear( + return sparse_quant_int8_dynamic_cutlass_linear( act_mat, w_qtensor.int_data, w_qtensor.mask_meta, w_qtensor.q_scales, bias, act_mat.dtype ) diff --git a/microbenchmark.py b/torchao/sparsity/microbenchmarks.py similarity index 86% rename from microbenchmark.py rename to torchao/sparsity/microbenchmarks.py index 2198142831..8a909b9206 100644 --- a/microbenchmark.py +++ b/torchao/sparsity/microbenchmarks.py @@ -236,10 +236,8 @@ def run_benchmark(input_shape, weight_shape, dtype, sparsity=None, backend=None, batch_size = args.batch_size sam_shapes = [ - (torch.Size([batch_size, 64, 64, 1280]), torch.Size([5120, 1280])), - (torch.Size([batch_size, 64, 64, 5120]), torch.Size([1280, 5120])), - (torch.Size([25 * batch_size, 14, 14, 1280]), torch.Size([3840, 1280])), - (torch.Size([25 * batch_size, 14, 14, 1280]), torch.Size([1280, 1280])), + (torch.Size([batch_size, 256, 3072]), torch.Size([768, 3072])), + (torch.Size([batch_size, 256, 768]), torch.Size([3072, 768])), ] from collections import defaultdict @@ -247,35 +245,34 @@ def run_benchmark(input_shape, weight_shape, dtype, sparsity=None, backend=None, total_runtime = defaultdict(int) for (activation_shape, weight_shape) in tqdm(sam_shapes): - for backend in ["cutlass", "cusparselt"]: - result = run_benchmark( - activation_shape, - weight_shape, - dtype, - sparsity="24", - backend=backend) - - blocksize = None - sparsity_level = 0.5 - total_runtime[f"{backend}"] += 32 * result["sparse_latency (us)"] - results.append(result) - # for blocksize in [64]: - # for sparsity_level in [0.8, 0.9]: - # result = run_benchmark( - # activation_shape, - # weight_shape, - # dtype, - # sparsity="blocksparse", - # blocksize=blocksize, - # sparsity_level=sparsity_level) - # total_runtime[f"{blocksize}_{sparsity_level}"] += 32 * result["sparse_latency (us)"] - # results.append(result) - - total_runtime["dense"] += 32 * result["dense_latency (us)"] - - for line in total_runtime: - print(line, total_runtime[line], sep="\t") + # for backend in ["cutlass", "cusparselt"]: + # result = run_benchmark( + # activation_shape, + # weight_shape, + # dtype, + # sparsity="24", + # backend=backend) + + # blocksize = None + # sparsity_level = 0.5 + # total_runtime[f"{backend}"] += 32 * result["sparse_latency (us)"] + # results.append(result) + for blocksize in [8, 16, 32, 64]: + for sparsity_level in [0.8, 0.9]: + result = run_benchmark( + activation_shape, + weight_shape, + dtype, + sparsity="blocksparse", + blocksize=blocksize, + sparsity_level=sparsity_level) + # total_runtime[f"{blocksize}_{sparsity_level}"] += 32 * result["sparse_latency (us)"] + results.append(result) + + # total_runtime["dense"] += 32 * result["dense_latency (us)"] + # for line in total_runtime: + # print(line, total_runtime[line], sep="\t") elif args.mode == "nvidia-fixed-k": mn_vals = [ diff --git a/torchao/sparsity/sparse.py b/torchao/sparsity/sparse_api.py similarity index 100% rename from torchao/sparsity/sparse.py rename to torchao/sparsity/sparse_api.py From cd785b535330861002b176ba1464b65e68058481 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 27 Mar 2024 14:49:14 -0700 Subject: [PATCH 03/26] wip --- torchao/quantization/subclass.py | 23 ++++---- torchao/sparsity/benchmark_sam.py | 87 +++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 13 deletions(-) create mode 100644 torchao/sparsity/benchmark_sam.py diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 12dd5753a5..51b46a1dce 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -14,11 +14,11 @@ quant_int8_dynamic_per_token_linear, unpack_tinygemm_scales_and_zeros, ) -from .dynamic_quant_sparse import sparse_quant_int8_dynamic_cutlass_linear, sparse_quant_int8_dynamic_cusparselt_linear +from .dynamic_quant_sparse import sparse_quant_int8_dynamic_cutlass_linear, sparse_quant_int8_dynamic_cslt_linear from .utils import find_multiple import warnings -from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured +from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured, SparseSemiStructuredTensorCUTLASS __all__ = [ @@ -285,7 +285,7 @@ class Int8DynamicallyQuantized24CusparseltLinearWeight(Int8DynamicallyQuantizedL @staticmethod def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_cusparselt_linear( + return sparse_quant_int8_dynamic_cslt_linear( act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype ) @@ -398,18 +398,15 @@ def from_float(cls, input_float, qmin=-128, qmax=127): ) int_data = w_int_repr.contiguous() - - sparse_tensor = to_sparse_semi_structured(int_data).t() - - if sparse_tensor.compressed_tensor_cusparselt is None: - int_data = sparse_tensor.sparse_tensor_cutlass - mask_meta = sparse_tensor.meta_tensor_cutlass - else: - int_data = sparse_tensor.compressed_tensor_cusparselt - mask_meta = None + sparse_tensor = SparseSemiStructuredTensorCUTLASS.from_dense(int_data).t() return cls( - int_data, mask_meta, w_scales, False, input_float.shape, dtype=input_float.dtype + sparse_tensor.packed, + sparse_tensor.meta, + w_scales, + False, + input_float.shape, + dtype=input_float.dtype ) diff --git a/torchao/sparsity/benchmark_sam.py b/torchao/sparsity/benchmark_sam.py new file mode 100644 index 0000000000..9db6628c22 --- /dev/null +++ b/torchao/sparsity/benchmark_sam.py @@ -0,0 +1,87 @@ +import torch +from torchao.quantization import change_linear_weights_to_int8_dqtensors +from torchao.quantization import change_linear_weights_to_int8_dq_semi_structured_sparsetensors +from segment_anything import sam_model_registry +from torch.utils.benchmark import Timer + +sam_checkpoint_base_path = "/home/jessecai/local/MODELS" +model_type = 'vit_h' +model_name = 'sam_vit_h_4b8939.pth' +checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}" +batchsize = 16 +only_one_block = False + + +@torch.no_grad() +def benchmark(f, *args, **kwargs): + for _ in range(3): + f(*args, **kwargs) + torch.cuda.synchronize() + + torch.cuda.reset_peak_memory_stats() + t0 = Timer( + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} + ) + res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20) + return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9} + +def get_sam_model(only_one_block=False, batchsize=1): + sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda() + model = sam.image_encoder.eval() + image = torch.randn(batchsize, 3, 1024, 1024, device='cuda') + + # code to use just a single block of the model + if only_one_block: + model = model.blocks[0] + image = torch.randn(batchsize, 64, 64, 1280, device='cuda') + return model, image + +print("BENCHMARKING") + +try: + model, image = get_sam_model(False, batchsize) + model = model.to(torch.bfloat16) + image = image.to(torch.bfloat16) + model_c = torch.compile(model, mode='max-autotune') + quant_res = benchmark(model_c, image) + print(f"bf16 compiled runtime of the compiled full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") + # bf16 compiled runtime of the compiled full model is 729.65ms and peak memory 23.96GB + + del model_c, model, image + torch._inductor.config.epilogue_fusion = False + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.coordinate_descent_check_all_directions = True + torch._inductor.config.force_fuse_int_mm_with_mul = True + model, image = get_sam_model(False, batchsize) + model = model.to(torch.bfloat16) + image = image.to(torch.bfloat16) + change_linear_weights_to_int8_dqtensors(model) + model_c = torch.compile(model, mode='max-autotune') + quant_res = benchmark(model_c, image) + print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") + # bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB + + del model_c, model, image + model, image = get_sam_model(False, batchsize) + model = model.to(torch.bfloat16) + image = image.to(torch.bfloat16) + from torch.sparse import SparseSemiStructuredTensor + SparseSemiStructuredTensor._FORCE_CUTLASS = False + change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model) + model_c = torch.compile(model, mode='max-autotune') + quant_res = benchmark(model_c, image) + print(f"bf16 compiled runtime of the sparse + quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") + # bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB + + # del model_c, model, image + # model, image = get_sam_model(False, batchsize) + # model = model.to(torch.bfloat16) + # image = image.to(torch.bfloat16) + # change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model) + # model_c = torch.compile(model, mode='max-autotune') + # quant_res = benchmark(model_c, image) + # print(f"bf16 compiled runtime of the sparse + quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") + # # bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB + +except Exception as e: + print("unable to run full model: ", e) From d084b91d46a56aa217d0a7c1ea10f271346b837c Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 29 Mar 2024 10:50:18 -0700 Subject: [PATCH 04/26] wip --- torchao/sparsity/benchmark_sam.py | 86 ++++++++++----------- torchao/sparsity/dynamic_quant_sparse.py | 98 ++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 44 deletions(-) create mode 100644 torchao/sparsity/dynamic_quant_sparse.py diff --git a/torchao/sparsity/benchmark_sam.py b/torchao/sparsity/benchmark_sam.py index 9db6628c22..dec6300f4a 100644 --- a/torchao/sparsity/benchmark_sam.py +++ b/torchao/sparsity/benchmark_sam.py @@ -38,50 +38,48 @@ def get_sam_model(only_one_block=False, batchsize=1): print("BENCHMARKING") -try: - model, image = get_sam_model(False, batchsize) - model = model.to(torch.bfloat16) - image = image.to(torch.bfloat16) - model_c = torch.compile(model, mode='max-autotune') - quant_res = benchmark(model_c, image) - print(f"bf16 compiled runtime of the compiled full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") - # bf16 compiled runtime of the compiled full model is 729.65ms and peak memory 23.96GB +model, image = get_sam_model(False, batchsize) +model = model.to(torch.bfloat16) +image = image.to(torch.bfloat16) +model_c = torch.compile(model, mode='max-autotune') +quant_res = benchmark(model_c, image) +print(f"bf16 compiled runtime of the compiled full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") +# bf16 compiled runtime of the compiled full model is 729.65ms and peak memory 23.96GB - del model_c, model, image - torch._inductor.config.epilogue_fusion = False - torch._inductor.config.coordinate_descent_tuning = True - torch._inductor.config.coordinate_descent_check_all_directions = True - torch._inductor.config.force_fuse_int_mm_with_mul = True - model, image = get_sam_model(False, batchsize) - model = model.to(torch.bfloat16) - image = image.to(torch.bfloat16) - change_linear_weights_to_int8_dqtensors(model) - model_c = torch.compile(model, mode='max-autotune') - quant_res = benchmark(model_c, image) - print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") - # bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB +del model_c, model, image +torch._inductor.config.epilogue_fusion = False +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.coordinate_descent_check_all_directions = True +torch._inductor.config.force_fuse_int_mm_with_mul = True +model, image = get_sam_model(False, batchsize) +model = model.to(torch.bfloat16) +image = image.to(torch.bfloat16) +change_linear_weights_to_int8_dqtensors(model) +model_c = torch.compile(model, mode='max-autotune') +quant_res = benchmark(model_c, image) +print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") +# bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB - del model_c, model, image - model, image = get_sam_model(False, batchsize) - model = model.to(torch.bfloat16) - image = image.to(torch.bfloat16) - from torch.sparse import SparseSemiStructuredTensor - SparseSemiStructuredTensor._FORCE_CUTLASS = False - change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model) - model_c = torch.compile(model, mode='max-autotune') - quant_res = benchmark(model_c, image) - print(f"bf16 compiled runtime of the sparse + quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") - # bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB +del model_c, model, image +model, image = get_sam_model(False, batchsize) +model = model.to(torch.bfloat16) +image = image.to(torch.bfloat16) +from torch.sparse import SparseSemiStructuredTensor +SparseSemiStructuredTensor._FORCE_CUTLASS = True +change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model) +model_c = torch.compile(model, mode='max-autotune') +quant_res = benchmark(model_c, image) +print(f"bf16 compiled runtime of the 2:4 sparse CUTLASS + quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") +# bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB - # del model_c, model, image - # model, image = get_sam_model(False, batchsize) - # model = model.to(torch.bfloat16) - # image = image.to(torch.bfloat16) - # change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model) - # model_c = torch.compile(model, mode='max-autotune') - # quant_res = benchmark(model_c, image) - # print(f"bf16 compiled runtime of the sparse + quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") - # # bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB - -except Exception as e: - print("unable to run full model: ", e) +del model_c, model, image +model, image = get_sam_model(False, batchsize) +model = model.to(torch.bfloat16) +image = image.to(torch.bfloat16) +from torch.sparse import SparseSemiStructuredTensor +SparseSemiStructuredTensor._FORCE_CUTLASS = False +change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model) +model_c = torch.compile(model, mode='max-autotune') +quant_res = benchmark(model_c, image) +print(f"bf16 compiled runtime of the 2:4 sparse cuSPARSELt + quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") +# bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB diff --git a/torchao/sparsity/dynamic_quant_sparse.py b/torchao/sparsity/dynamic_quant_sparse.py new file mode 100644 index 0000000000..33538074b8 --- /dev/null +++ b/torchao/sparsity/dynamic_quant_sparse.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +from typing import Tuple, Optional + +from torchao.quantization.quant_primitives import ( + dynamically_quantize_per_channel, + quant_int8_dynamic_per_token_linear, + quantize_activation_per_token_absmax +) + +from torch.sparse import SparseSemiStructuredTensor + +# Quant + Sparse helper functinos + +def sparse_quant_int8_dynamic_cutlass_linear( + x, + w_vals_int8, + w_meta_int32, + w_scales, + bias, + out_dtype, +): + x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) + mm_out = sparse_quant_int8_cutlass_matmul( + x_vals_int8, x_scales, w_vals_int8, w_meta_int32, w_scales, out_dtype) + + if bias is not None: + mm_out += bias + return mm_out + +def sparse_quant_int8_dynamic_cslt_linear( + x, + w_vals_int8, + w_scales, + bias, + out_dtype, +): + x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) + mm_out = sparse_quant_int8_cslt_matmul( + x_vals_int8, x_scales, w_vals_int8, w_scales, out_dtype) + + if bias is not None: + mm_out += bias + return mm_out + + +def sparse_quant_int8_cslt_matmul( + x_vals_int8, + x_scales, + w_vals_int8, + w_scales, + out_dtype, +): + + assert x_vals_int8.dtype == torch.int8, f'x dtype {x_vals_int8.dtype} not yet supported' + assert w_vals_int8.dtype == torch.int8, f'w dtype {w_vals_int8.dtype} not yet supported' + assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' + + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() + + assert x_scales.dtype in [ + torch.float, + torch.bfloat16, + ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" + + y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(w_vals_int8, tmp.t(), alpha=w_scales, out_dtype=torch.bfloat16).t() + y = (y_dot_bf16_w_scales_fused* x_scales.reshape(-1, 1)).reshape( + *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] + ) + y = y.to(out_dtype) + return y + +def sparse_quant_int8_cutlass_matmul( + x_vals_int8, + x_scales, + w_vals_int8, + w_meta_int32, + w_scales, + out_dtype, +): + assert x_vals_int8.dtype == torch.int8, f'x dtype {x_vals_int8.dtype} not yet supported' + assert w_vals_int8.dtype == torch.int8, f'w dtype {w_vals_int8.dtype} not yet supported' + assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' + assert w_meta_int32.dtype == torch.int32, f'{w_meta_int32.dtype} not yet supported' + + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() + + assert x_scales.dtype in [ + torch.float, + torch.bfloat16, + ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" + + y_dot_int32 = torch._sparse_semi_structured_linear(tmp, w_vals_int8, w_meta_int32.view(torch.int32), out_dtype=torch.int32) + y = (y_dot_int32 * x_scales.reshape(-1, 1) * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_int32.shape[-1] + ) + y = y.to(out_dtype) + return y From 646a15791200b3f329512ff2c8b8f42fd5d80931 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 29 Mar 2024 10:50:38 -0700 Subject: [PATCH 05/26] wip --- torchao/quantization/dynamic_quant_sparse.py | 98 -------------------- torchao/quantization/subclass.py | 5 +- 2 files changed, 3 insertions(+), 100 deletions(-) delete mode 100644 torchao/quantization/dynamic_quant_sparse.py diff --git a/torchao/quantization/dynamic_quant_sparse.py b/torchao/quantization/dynamic_quant_sparse.py deleted file mode 100644 index 33538074b8..0000000000 --- a/torchao/quantization/dynamic_quant_sparse.py +++ /dev/null @@ -1,98 +0,0 @@ -import torch -import torch.nn as nn -from typing import Tuple, Optional - -from torchao.quantization.quant_primitives import ( - dynamically_quantize_per_channel, - quant_int8_dynamic_per_token_linear, - quantize_activation_per_token_absmax -) - -from torch.sparse import SparseSemiStructuredTensor - -# Quant + Sparse helper functinos - -def sparse_quant_int8_dynamic_cutlass_linear( - x, - w_vals_int8, - w_meta_int32, - w_scales, - bias, - out_dtype, -): - x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) - mm_out = sparse_quant_int8_cutlass_matmul( - x_vals_int8, x_scales, w_vals_int8, w_meta_int32, w_scales, out_dtype) - - if bias is not None: - mm_out += bias - return mm_out - -def sparse_quant_int8_dynamic_cslt_linear( - x, - w_vals_int8, - w_scales, - bias, - out_dtype, -): - x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) - mm_out = sparse_quant_int8_cslt_matmul( - x_vals_int8, x_scales, w_vals_int8, w_scales, out_dtype) - - if bias is not None: - mm_out += bias - return mm_out - - -def sparse_quant_int8_cslt_matmul( - x_vals_int8, - x_scales, - w_vals_int8, - w_scales, - out_dtype, -): - - assert x_vals_int8.dtype == torch.int8, f'x dtype {x_vals_int8.dtype} not yet supported' - assert w_vals_int8.dtype == torch.int8, f'w dtype {w_vals_int8.dtype} not yet supported' - assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' - - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() - - assert x_scales.dtype in [ - torch.float, - torch.bfloat16, - ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - - y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(w_vals_int8, tmp.t(), alpha=w_scales, out_dtype=torch.bfloat16).t() - y = (y_dot_bf16_w_scales_fused* x_scales.reshape(-1, 1)).reshape( - *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] - ) - y = y.to(out_dtype) - return y - -def sparse_quant_int8_cutlass_matmul( - x_vals_int8, - x_scales, - w_vals_int8, - w_meta_int32, - w_scales, - out_dtype, -): - assert x_vals_int8.dtype == torch.int8, f'x dtype {x_vals_int8.dtype} not yet supported' - assert w_vals_int8.dtype == torch.int8, f'w dtype {w_vals_int8.dtype} not yet supported' - assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' - assert w_meta_int32.dtype == torch.int32, f'{w_meta_int32.dtype} not yet supported' - - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() - - assert x_scales.dtype in [ - torch.float, - torch.bfloat16, - ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - - y_dot_int32 = torch._sparse_semi_structured_linear(tmp, w_vals_int8, w_meta_int32.view(torch.int32), out_dtype=torch.int32) - y = (y_dot_int32 * x_scales.reshape(-1, 1) * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_int32.shape[-1] - ) - y = y.to(out_dtype) - return y diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 51b46a1dce..ec7dc2edc9 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -18,7 +18,8 @@ from .utils import find_multiple import warnings -from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured, SparseSemiStructuredTensorCUTLASS +from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured +from torch.sparse.semi_structured import SparseSemiStructuredTensorCUTLASS __all__ = [ @@ -398,7 +399,7 @@ def from_float(cls, input_float, qmin=-128, qmax=127): ) int_data = w_int_repr.contiguous() - sparse_tensor = SparseSemiStructuredTensorCUTLASS.from_dense(int_data).t() + sparse_tensor = SparseSemiStructuredTensorCUTLASS.from_dense(int_data) return cls( sparse_tensor.packed, From f1a947f88f8110a1fe8ffd9ad3bbc6de98c07b26 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 29 Mar 2024 11:41:29 -0700 Subject: [PATCH 06/26] refactor --- torchao/quantization/quant_api.py | 26 ----- torchao/quantization/subclass.py | 135 ----------------------- torchao/sparsity/benchmark_sam.py | 2 +- torchao/sparsity/dynamic_quant_sparse.py | 133 +++++++++++++++++++++- torchao/sparsity/sparse_api.py | 19 ++++ 5 files changed, 152 insertions(+), 163 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f381b799b0..5074231b15 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -26,16 +26,9 @@ from .utils import TORCH_VERSION_AFTER_2_4 from .subclass import ( -<<<<<<< HEAD QuantizedLinearWeightBase, Int8DynamicallyQuantizedLinearWeight, - Int8DynamicallyQuantized24CutlassLinearWeight, - Int8DynamicallyQuantized24CusparseltLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, -======= ->>>>>>> main Int4WeightOnlyQuantizedLinearWeight, - Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, ) @@ -200,25 +193,6 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs): ) -<<<<<<< HEAD -def change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model, **kwargs): - filter_fn = kwargs.pop("filter_fn", _is_linear) - - from torch.sparse import SparseSemiStructuredTensor - if SparseSemiStructuredTensor._FORCE_CUTLASS: - subclass = Int8DynamicallyQuantized24CutlassLinearWeight - else: - subclass = Int8DynamicallyQuantized24CusparseltLinearWeight - - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter(subclass, **kwargs), - filter_fn, - ) - - -======= ->>>>>>> main def swap_conv2d_1x1_to_linear(model, filter_fn=None): """ Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized. diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 721153b00a..537099f67a 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -16,17 +16,11 @@ quant_int8_dynamic_per_token_linear, unpack_tinygemm_scales_and_zeros, ) -from .dynamic_quant_sparse import sparse_quant_int8_dynamic_cutlass_linear, sparse_quant_int8_dynamic_cslt_linear from .utils import find_multiple -from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured -from torch.sparse.semi_structured import SparseSemiStructuredTensorCUTLASS - __all__ = [ "Int8DynamicallyQuantizedLinearWeight", - "Int8DynamicallyQuantized24CutlassLinearWeight", - "Int8DynamicallyQuantized24CusparseltLinearWeight", "Int8WeightOnlyQuantizedLinearWeight", "Int4WeightOnlyQuantizedLinearWeight", ] @@ -314,135 +308,6 @@ def from_float(cls, input_float, qmin=-128, qmax=127): ) -class Int8DynamicallyQuantized24CusparseltLinearWeight(Int8DynamicallyQuantizedLinearWeight): - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_cslt_linear( - act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype - ) - - @classmethod - def from_float(cls, input_float, qmin=-128, qmax=127): - - assert input_float.is_cuda - - w_int_repr, w_scales, _ = dynamically_quantize_per_channel( - input_float, qmin, qmax, torch.int8 - ) - - int_data = w_int_repr.contiguous() - - - int_data = torch._cslt_compress(int_data) - - return cls( - int_data, w_scales, False, input_float.shape, dtype=input_float.dtype, - ) - -class Int8DynamicallyQuantized24CutlassLinearWeight(QuantizedLinearWeightBase): - - @staticmethod - def __new__(cls, int_data, mask_meta, q_scales, transposed, shape, **kwargs): - kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype) - return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] - - def __init__(self, int_data, mask_meta, q_scales, transposed, shape, **kwargs): - self.q_scales = q_scales - self.mask_meta = mask_meta - super().__init__(int_data, transposed) - - def dequantize(self, dtype=None): - """ - Obtain the dequantized version of the quantized tensor subclass - """ - dq_t = dequantize_per_channel( - self.int_data.t(), self.q_scales, 0, self.dtype if dtype is None else dtype - ).to(self.dtype) - # data was transposed to dequantize so make sure shape is correct - return dq_t if not self.transposed else dq_t.t() - - def int_repr(self): - """ - Get the internal integer representation of the quantized tensor - """ - return self.int_data if self.transposed else self.int_data.t() - - def q_params(self): - """ - Get the quantization scales for the quantized tensor - """ - return {"q_scales": self.q_scales} - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.int_data.to(kwargs["device"]), - self.mask_meta.to(kwargs["device"]), - self.q_scales.to(kwargs["device"]), - self.transposed, - self.shape, - **kwargs, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.int_data), - fn(self.mask_meta), - fn(self.q_scales), - self.transposed, - self.shape, - dtype=self.dtype - ) - - def _change_shape(self, shape): - return self.__class__( - self.int_data, - self.mask_meta, - self.q_scales, - self.transposed, - shape, - dtype=self.dtype - ) - - def __tensor_flatten__(self): - return ["int_data", "mask_meta", "q_scales"], [self.transposed, self.dtype, self.shape] - - @classmethod - def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): - int_data, q_scales = tensor_data_dict["int_data"], tensor_data_dict["q_scales"] - mask_meta = tensor_data_dict["mask_meta"] - transposed, dtype, shape = tensor_attributes - return cls(int_data, mask_meta, q_scales, transposed, shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride) - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_cutlass_linear( - act_mat, w_qtensor.int_data, w_qtensor.mask_meta, w_qtensor.q_scales, bias, act_mat.dtype - ) - - @classmethod - def from_float(cls, input_float, qmin=-128, qmax=127): - - assert input_float.is_cuda - - w_int_repr, w_scales, _ = dynamically_quantize_per_channel( - input_float, qmin, qmax, torch.int8 - ) - - int_data = w_int_repr.contiguous() - sparse_tensor = SparseSemiStructuredTensorCUTLASS.from_dense(int_data) - - return cls( - sparse_tensor.packed, - sparse_tensor.meta, - w_scales, - False, - input_float.shape, - dtype=input_float.dtype - ) - - class Int8WeightOnlyQuantizedLinearWeight(Int8DynamicallyQuantizedLinearWeight): """ A Tensor subclass that when applied to a weight used in a linear op/module, diff --git a/torchao/sparsity/benchmark_sam.py b/torchao/sparsity/benchmark_sam.py index dec6300f4a..d0d0d992ac 100644 --- a/torchao/sparsity/benchmark_sam.py +++ b/torchao/sparsity/benchmark_sam.py @@ -1,6 +1,6 @@ import torch from torchao.quantization import change_linear_weights_to_int8_dqtensors -from torchao.quantization import change_linear_weights_to_int8_dq_semi_structured_sparsetensors +from torchao.sparsity.sparse_api import change_linear_weights_to_int8_dq_semi_structured_sparsetensors from segment_anything import sam_model_registry from torch.utils.benchmark import Timer diff --git a/torchao/sparsity/dynamic_quant_sparse.py b/torchao/sparsity/dynamic_quant_sparse.py index 33538074b8..d7137b7974 100644 --- a/torchao/sparsity/dynamic_quant_sparse.py +++ b/torchao/sparsity/dynamic_quant_sparse.py @@ -8,7 +8,9 @@ quantize_activation_per_token_absmax ) -from torch.sparse import SparseSemiStructuredTensor +from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight, QuantizedLinearWeightBase + +from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS # Quant + Sparse helper functinos @@ -96,3 +98,132 @@ def sparse_quant_int8_cutlass_matmul( ) y = y.to(out_dtype) return y + + +class Int8DynamicallyQuantized24CusparseltLinearWeight(Int8DynamicallyQuantizedLinearWeight): + + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + return sparse_quant_int8_dynamic_cslt_linear( + act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype + ) + + @classmethod + def from_float(cls, input_float, qmin=-128, qmax=127): + + assert input_float.is_cuda + + w_int_repr, w_scales, _ = dynamically_quantize_per_channel( + input_float, qmin, qmax, torch.int8 + ) + + int_data = w_int_repr.contiguous() + + + int_data = torch._cslt_compress(int_data) + + return cls( + int_data, w_scales, False, input_float.shape, dtype=input_float.dtype, + ) + +class Int8DynamicallyQuantized24CutlassLinearWeight(QuantizedLinearWeightBase): + + @staticmethod + def __new__(cls, int_data, mask_meta, q_scales, transposed, shape, **kwargs): + kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype) + return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, int_data, mask_meta, q_scales, transposed, shape, **kwargs): + self.q_scales = q_scales + self.mask_meta = mask_meta + super().__init__(int_data, transposed) + + def dequantize(self, dtype=None): + """ + Obtain the dequantized version of the quantized tensor subclass + """ + dq_t = dequantize_per_channel( + self.int_data.t(), self.q_scales, 0, self.dtype if dtype is None else dtype + ).to(self.dtype) + # data was transposed to dequantize so make sure shape is correct + return dq_t if not self.transposed else dq_t.t() + + def int_repr(self): + """ + Get the internal integer representation of the quantized tensor + """ + return self.int_data if self.transposed else self.int_data.t() + + def q_params(self): + """ + Get the quantization scales for the quantized tensor + """ + return {"q_scales": self.q_scales} + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.int_data.to(kwargs["device"]), + self.mask_meta.to(kwargs["device"]), + self.q_scales.to(kwargs["device"]), + self.transposed, + self.shape, + **kwargs, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.int_data), + fn(self.mask_meta), + fn(self.q_scales), + self.transposed, + self.shape, + dtype=self.dtype + ) + + def _change_shape(self, shape): + return self.__class__( + self.int_data, + self.mask_meta, + self.q_scales, + self.transposed, + shape, + dtype=self.dtype + ) + + def __tensor_flatten__(self): + return ["int_data", "mask_meta", "q_scales"], [self.transposed, self.dtype, self.shape] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + int_data, q_scales = tensor_data_dict["int_data"], tensor_data_dict["q_scales"] + mask_meta = tensor_data_dict["mask_meta"] + transposed, dtype, shape = tensor_attributes + return cls(int_data, mask_meta, q_scales, transposed, shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride) + + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + return sparse_quant_int8_dynamic_cutlass_linear( + act_mat, w_qtensor.int_data, w_qtensor.mask_meta, w_qtensor.q_scales, bias, act_mat.dtype + ) + + @classmethod + def from_float(cls, input_float, qmin=-128, qmax=127): + + assert input_float.is_cuda + + w_int_repr, w_scales, _ = dynamically_quantize_per_channel( + input_float, qmin, qmax, torch.int8 + ) + + int_data = w_int_repr.contiguous() + sparse_tensor = SparseSemiStructuredTensorCUTLASS.from_dense(int_data) + + return cls( + sparse_tensor.packed, + sparse_tensor.meta, + w_scales, + False, + input_float.shape, + dtype=input_float.dtype + ) diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 8ecc622cfd..989095ef5e 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -1,6 +1,7 @@ import torch from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor +from torchao.sparsity.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearWeight, Int8DynamicallyQuantized24CutlassLinearWeight # Sparsity helper functions def apply_fake_sparsity(model): @@ -27,3 +28,21 @@ def apply_sparse(model): for name, mod in model.named_modules(): if isinstance(mod, torch.nn.Linear): mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight)) + + +def change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model, **kwargs): + filter_fn = kwargs.pop("filter_fn", _is_linear) + + from torch.sparse import SparseSemiStructuredTensor + if SparseSemiStructuredTensor._FORCE_CUTLASS: + subclass = Int8DynamicallyQuantized24CutlassLinearWeight + else: + subclass = Int8DynamicallyQuantized24CusparseltLinearWeight + + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter(subclass, **kwargs), + filter_fn, + ) + + From af1bbbe77086e61bcb4cc2dd1da156b1f2ce133d Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 29 Mar 2024 11:43:04 -0700 Subject: [PATCH 07/26] fix quant api --- torchao/quantization/quant_api.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 5074231b15..4194ceb9be 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -26,9 +26,8 @@ from .utils import TORCH_VERSION_AFTER_2_4 from .subclass import ( - QuantizedLinearWeightBase, - Int8DynamicallyQuantizedLinearWeight, Int4WeightOnlyQuantizedLinearWeight, + Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, ) From e193f2ffe05535d1130c71725b7b4b2ffc9b513b Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 29 Mar 2024 11:44:37 -0700 Subject: [PATCH 08/26] wip --- .../sparsity => benchmarks}/benchmark_sam.py | 0 torchao/sparsity/microbenchmarks.py | 331 ------------------ 2 files changed, 331 deletions(-) rename {torchao/sparsity => benchmarks}/benchmark_sam.py (100%) delete mode 100644 torchao/sparsity/microbenchmarks.py diff --git a/torchao/sparsity/benchmark_sam.py b/benchmarks/benchmark_sam.py similarity index 100% rename from torchao/sparsity/benchmark_sam.py rename to benchmarks/benchmark_sam.py diff --git a/torchao/sparsity/microbenchmarks.py b/torchao/sparsity/microbenchmarks.py deleted file mode 100644 index 8a909b9206..0000000000 --- a/torchao/sparsity/microbenchmarks.py +++ /dev/null @@ -1,331 +0,0 @@ -import argparse -import random - -import pandas as pd -import torch -import torch.utils.benchmark as benchmark -from torch import nn -from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured -from torch.ao.pruning import WeightNormSparsifier -from tqdm import tqdm - -import math - -import torch -import torch.nn.functional as F -import itertools -import torch.utils.benchmark as benchmark -import math - -dtype = torch.float16 -device = "cuda" -torch.manual_seed(42) - - -torch.set_printoptions( - precision=2, - threshold=None, - edgeitems=16, - linewidth=480, - profile=None, - sci_mode=False, -) - -def create_blocked_tensor(M, N, blocksize, sparsity): - assert sparsity <= 1.0 and sparsity >= 0.0, \ - "sparsity should be a value between 0 and 1" - A = torch.bernoulli(torch.full((M//blocksize, N//blocksize), - 1 - sparsity, dtype=torch.bfloat16, device=device)) - A = torch.repeat_interleave(A, blocksize, dim=0) - A = torch.repeat_interleave(A, blocksize, dim=1) - return A.contiguous() - - -def create_24_tensor(M, N): - A = torch.randn(weight_shape, device="cuda") - - choices = [[0, 1], [1, 0]] - mask_entries = [random.choice(choices) for i in range(M * N // 2)] - - mask = torch.tensor(mask_entries).cuda().bool().reshape(M, N) - - A.masked_fill_(~mask, 0) - - return A.contiguous() - - -def benchmark_in_us(f, *args, **kwargs): - t0 = benchmark.Timer( - stmt="f(*args, **kwargs)", - globals={"args": args, "kwargs": kwargs, "f": f} - ) - return int(t0.blocked_autorange().mean * 1e6) - - -def run_benchmark(input_shape, weight_shape, dtype, sparsity=None, backend=None, blocksize=None, sparsity_level=None): - - m, k = weight_shape - n, k = math.prod(input_shape[:-1]), input_shape[-1] - - if sparsity == "blocksparse": - A = create_blocked_tensor(m, k, blocksize=blocksize, sparsity=sparsity_level).to(dtype) - A_sparse = A.to_sparse_bsr(blocksize=blocksize) - - elif sparsity == "24": - # blocksize = 4 - # sparsity_level = 0.5 - if backend == "cutlass": - SparseSemiStructuredTensor._FORCE_CUTLASS = True - elif backend == "cusparselt": - SparseSemiStructuredTensor._FORCE_CUTLASS = False - else: - raise ValueError("Wrong value for backend") - - A = create_24_tensor(m, k).to(dtype) - A_sparse = to_sparse_semi_structured(A) - - # b = torch.randn(m, device="cuda").to(dtype) - x = torch.randn(n, k).to(dtype).cuda() - - - # get timing speedups - # handle int_mm custom - if dtype == torch.int8: - dense_time = benchmark_in_us(torch._int_mm, A, x.t()) - dense_output = torch._int_mm(A, x.t()).to(torch.float32).t() - else: - dense_time = benchmark_in_us(F.linear, x, A) - dense_output = F.linear(x, A).to(torch.float32) - - sparse_time = benchmark_in_us(F.linear, x, A_sparse) - sparse_output = F.linear(x, A_sparse).to(torch.float32) - - ratio = dense_time / sparse_time - - - if backend == "cusparselt": - # grab optimal alg id for cusparselt - padded = A_sparse._pad_tensor_for_matmul(x) - if dtype is torch.int8: - out_dtype = torch.bfloat16 - optimal_alg_id = torch._cslt_sparse_mm_search(A_sparse.compressed_tensor_cusparselt, padded.t()) - # print("optimal alg_id", optimal_alg_id) - else: - optimal_alg_id = None - - # sanity check correctness - correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) - - # # in depth checks - # dense_output = F.linear(x.to(torch.float32), A.to(torch.float32)) - - # diff = ~torch.isclose(dense_output, sparse_output) - - # dense_output_diff = dense_output[diff] - # sparse_output_diff = sparse_output[diff] - - # sparse_output_diff_nonzero = sparse_output_diff.nonzero() - # dense_output_diff = dense_output_diff[sparse_output_diff_nonzero] - # sparse_output_diff = sparse_output_diff[sparse_output_diff_nonzero] - - # outside_atol = ~((dense_output_diff - sparse_output_diff).abs() < 1e-3) - - # larger_dense_output_diff = dense_output_diff[outside_atol] - # larger_sparse_output_diff = sparse_output_diff[outside_atol] - - # pos = (1 - (larger_dense_output_diff / larger_sparse_output_diff)).abs().argmax().item() - - return { - "dtype": str(dtype), - "m": m, - "k": k, - "n": n, - "sparse_latency (us)": sparse_time, - "dense_latency (us)": dense_time, - "speedup (d/s)": f"{ratio:.3f}", - "correct": correct, - # "sparse v dense diff": f"{larger_dense_output_diff[pos]:+11.7f} vs. {larger_sparse_output_diff[pos]:+11.7f}", - "sparsity type": sparsity, - "backend": backend, - "blocksize": blocksize, - "sparsity level": sparsity_level, - "optimal_alg_id": optimal_alg_id, - } - -if __name__ == "__main__": - dtype_lookup = { - "int8": torch.int8, - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - - parser = argparse.ArgumentParser(description="GPU Sparsity Kernel Microbenchmarks") - parser.add_argument( - "--mode", - type=str, - choices=[ - "nvidia-bert", - "sam-shapes", - "nvidia-fixed-k", - "nvidia-fixed-mn", - "optimize-matmul-block-sparse", - ], - ) - parser.add_argument( - "--dtype", - type=str, - choices=dtype_lookup.keys(), - default="fp16", - ) - parser.add_argument("--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt") - parser.add_argument("--function", type=str, choices=["linear", "mm"], default="linear") - parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("-contiguous", action="store_true") - parser.add_argument("-save", action="store_true") - args = parser.parse_args() - - eval_fn = run_benchmark - - print(f"Started benchmark: {args.mode} | dtype: {args.dtype}") - dtype = dtype_lookup[args.dtype] - - if args.mode == "nvidia-bert": - bert_shapes = [ - (3072, 1024, 16384), - (4096, 1024, 16384), - (1024, 1024, 16384), - (1024, 4096, 16384), - ] - results = [ - eval_fn(m, k, n, dtype, sparsity="blocksparse", blocksize=64, sparsity_level=0.8) - for (m, k, n) in tqdm(bert_shapes) - ] - - results += [ - eval_fn(m, k, n, dtype, sparsity="24", backend="cusparselt") - for (m, k, n) in tqdm(bert_shapes) - ] - - if args.mode == "optimize-matmul-block-sparse": - batch_size = args.batch_size - - sam_shapes = [ - (torch.Size([batch_size, 64, 64, 1280]), torch.Size([5120, 1280])), - ] - - from collections import defaultdict - results = [] - total_runtime = defaultdict(int) - - for (activation_shape, weight_shape) in tqdm(sam_shapes): - for blocksize in [64]: - for sparsity_level in range(0, 100): - sparsity_level = float(sparsity_level) / 100 - result = run_benchmark( - activation_shape, - weight_shape, - dtype, - sparsity="blocksparse", - blocksize=blocksize, - sparsity_level=sparsity_level) - total_runtime[f"{blocksize}_{sparsity_level}"] += 32 * result["sparse_latency (us)"] - results.append(result) - - if args.mode == "sam-shapes": - batch_size = args.batch_size - - sam_shapes = [ - (torch.Size([batch_size, 256, 3072]), torch.Size([768, 3072])), - (torch.Size([batch_size, 256, 768]), torch.Size([3072, 768])), - ] - - from collections import defaultdict - results = [] - total_runtime = defaultdict(int) - - for (activation_shape, weight_shape) in tqdm(sam_shapes): - # for backend in ["cutlass", "cusparselt"]: - # result = run_benchmark( - # activation_shape, - # weight_shape, - # dtype, - # sparsity="24", - # backend=backend) - - # blocksize = None - # sparsity_level = 0.5 - # total_runtime[f"{backend}"] += 32 * result["sparse_latency (us)"] - # results.append(result) - for blocksize in [8, 16, 32, 64]: - for sparsity_level in [0.8, 0.9]: - result = run_benchmark( - activation_shape, - weight_shape, - dtype, - sparsity="blocksparse", - blocksize=blocksize, - sparsity_level=sparsity_level) - # total_runtime[f"{blocksize}_{sparsity_level}"] += 32 * result["sparse_latency (us)"] - results.append(result) - - # total_runtime["dense"] += 32 * result["dense_latency (us)"] - - # for line in total_runtime: - # print(line, total_runtime[line], sep="\t") - - elif args.mode == "nvidia-fixed-k": - mn_vals = [ - 3072, - 4096, - 5120, - 6144, - 7168, - 8192, - 9216, - 10240, - 11264, - 12288, - 13312, - 14336, - 15360, - 16384, - 17408, - 18432, - 19456, - 20480, - ] - results = ( - eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend) - for mn in tqdm(mn_vals) - ) - - elif args.mode == "nvidia-fixed-mn": - k_vals = [ - 2560, - 3840, - 5120, - 6400, - 7680, - 8960, - 10240, - 11520, - 12800, - 14080, - 15360, - 16640, - 17920, - 19200, - 20480, - ] - results = ( - eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend) - for k in tqdm(k_vals) - ) - - df = pd.DataFrame.from_records(results) - if args.save: - save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv" - df.to_csv(save_file) - print(f"Finished benchmark: {args.mode} saved results to {save_file}") - print(df) From db7d98db20c210a2a3f07d56cf4ecbb93aaa7ab3 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 29 Mar 2024 15:20:06 -0700 Subject: [PATCH 09/26] updated script and cleaned up api --- benchmarks/benchmark_sam.py | 98 +++++++++++++++++----------------- torchao/sparsity/sparse_api.py | 15 +++--- 2 files changed, 59 insertions(+), 54 deletions(-) diff --git a/benchmarks/benchmark_sam.py b/benchmarks/benchmark_sam.py index d0d0d992ac..e1dbd07834 100644 --- a/benchmarks/benchmark_sam.py +++ b/benchmarks/benchmark_sam.py @@ -1,16 +1,22 @@ +from pprint import pprint + import torch from torchao.quantization import change_linear_weights_to_int8_dqtensors -from torchao.sparsity.sparse_api import change_linear_weights_to_int8_dq_semi_structured_sparsetensors +from torchao.sparsity.sparse_api import change_linear_weights_to_int8_dq_24_sparsetensors +from torchao.sparsity.sparse_api import apply_sparse from segment_anything import sam_model_registry from torch.utils.benchmark import Timer +from torch.sparse import SparseSemiStructuredTensor sam_checkpoint_base_path = "/home/jessecai/local/MODELS" model_type = 'vit_h' model_name = 'sam_vit_h_4b8939.pth' checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}" -batchsize = 16 -only_one_block = False +torch._inductor.config.epilogue_fusion = False +torch._inductor.config.coordinate_descent_tuning = False +torch._inductor.config.coordinate_descent_check_all_directions = False +torch._inductor.config.force_fuse_int_mm_with_mul = False @torch.no_grad() def benchmark(f, *args, **kwargs): @@ -36,50 +42,46 @@ def get_sam_model(only_one_block=False, batchsize=1): image = torch.randn(batchsize, 64, 64, 1280, device='cuda') return model, image -print("BENCHMARKING") +def mlp_only(mod, name): + return isinstance(mod, torch.nn.Linear) and "mlp" in name -model, image = get_sam_model(False, batchsize) -model = model.to(torch.bfloat16) -image = image.to(torch.bfloat16) -model_c = torch.compile(model, mode='max-autotune') -quant_res = benchmark(model_c, image) -print(f"bf16 compiled runtime of the compiled full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") -# bf16 compiled runtime of the compiled full model is 729.65ms and peak memory 23.96GB +def attention_only(mod, name): + return isinstance(mod, torch.nn.Linear) and "mlp" not in name -del model_c, model, image -torch._inductor.config.epilogue_fusion = False -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.coordinate_descent_check_all_directions = True -torch._inductor.config.force_fuse_int_mm_with_mul = True -model, image = get_sam_model(False, batchsize) -model = model.to(torch.bfloat16) -image = image.to(torch.bfloat16) -change_linear_weights_to_int8_dqtensors(model) -model_c = torch.compile(model, mode='max-autotune') -quant_res = benchmark(model_c, image) -print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") -# bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB - -del model_c, model, image -model, image = get_sam_model(False, batchsize) -model = model.to(torch.bfloat16) -image = image.to(torch.bfloat16) -from torch.sparse import SparseSemiStructuredTensor -SparseSemiStructuredTensor._FORCE_CUTLASS = True -change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model) -model_c = torch.compile(model, mode='max-autotune') -quant_res = benchmark(model_c, image) -print(f"bf16 compiled runtime of the 2:4 sparse CUTLASS + quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") -# bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB - -del model_c, model, image -model, image = get_sam_model(False, batchsize) -model = model.to(torch.bfloat16) -image = image.to(torch.bfloat16) -from torch.sparse import SparseSemiStructuredTensor -SparseSemiStructuredTensor._FORCE_CUTLASS = False -change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model) -model_c = torch.compile(model, mode='max-autotune') -quant_res = benchmark(model_c, image) -print(f"bf16 compiled runtime of the 2:4 sparse cuSPARSELt + quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") -# bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB +def run_once(label, dtype=torch.bfloat16, batchsize=16, compile=True, quantize=False, sparse=False): + res = { + "label": label, + "batchsize": batchsize, + "dtype": dtype, + "compile": compile, + "quantize": quantize, + "sparse": sparse, + } + + model, image = get_sam_model(False, batchsize) + model = model.to(dtype) + image = image.to(dtype) + + if sparse and quantize: + SparseSemiStructuredTensor._FORCE_CUTLASS = (sparse == "cutlass") + change_linear_weights_to_int8_dq_24_sparsetensors(model, filter_fn=mlp_only) + change_linear_weights_to_int8_dqtensors(model, filter_fn=attention_only) + elif quantize: + change_linear_weights_to_int8_dqtensors(model) + elif sparse: + SparseSemiStructuredTensor._FORCE_CUTLASS = (sparse == "cutlass") + apply_sparse(model) + + if compile: + model = torch.compile(model, mode='max-autotune') + + res.update(benchmark(model, image)) + pprint(res) + + return res + +print("BENCHMARKING") +run_once("baseline") +run_once("quant", quantize=True) +run_once("sparse", sparse="cusparselt") +run_once("quant+sparse(mlp)", quantize=True, sparse="cusparselt") diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 989095ef5e..887514381c 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -1,8 +1,11 @@ import torch -from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor +from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT from torchao.sparsity.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearWeight, Int8DynamicallyQuantized24CutlassLinearWeight +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter, _get_subclass_inserter, _is_linear +from torch.ao.pruning import WeightNormSparsifier + # Sparsity helper functions def apply_fake_sparsity(model): """ @@ -10,7 +13,6 @@ def apply_fake_sparsity(model): It uses the torch.ao.pruning flow. """ # torch.ao.pruning flow - from torch.ao.pruning import WeightNormSparsifier sparse_config = [] for name, mod in model.named_modules(): if isinstance(mod, torch.nn.Linear): @@ -23,17 +25,18 @@ def apply_fake_sparsity(model): sparsifier.step() sparsifier.squash_mask() -def apply_sparse(model): +def apply_sparse(model, **kwargs): + filter_fn = kwargs.pop("filter_fn", _is_linear) + apply_fake_sparsity(model) for name, mod in model.named_modules(): - if isinstance(mod, torch.nn.Linear): + if filter_fn(name, mod): mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight)) -def change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model, **kwargs): +def change_linear_weights_to_int8_dq_24_sparsetensors(model, **kwargs): filter_fn = kwargs.pop("filter_fn", _is_linear) - from torch.sparse import SparseSemiStructuredTensor if SparseSemiStructuredTensor._FORCE_CUTLASS: subclass = Int8DynamicallyQuantized24CutlassLinearWeight else: From f9e344901f668165ad7dd15693cd888ebdb43eed Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 29 Mar 2024 15:24:00 -0700 Subject: [PATCH 10/26] update --- benchmarks/benchmark_sam.py | 9 ++++----- torchao/sparsity/__init__.py | 5 ++++- torchao/sparsity/dynamic_quant_sparse.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/benchmarks/benchmark_sam.py b/benchmarks/benchmark_sam.py index e1dbd07834..25166ea418 100644 --- a/benchmarks/benchmark_sam.py +++ b/benchmarks/benchmark_sam.py @@ -2,8 +2,7 @@ import torch from torchao.quantization import change_linear_weights_to_int8_dqtensors -from torchao.sparsity.sparse_api import change_linear_weights_to_int8_dq_24_sparsetensors -from torchao.sparsity.sparse_api import apply_sparse +from torchao.sparsity import change_linear_weights_to_int8_dq_24_sparsetensors, apply_sparse from segment_anything import sam_model_registry from torch.utils.benchmark import Timer from torch.sparse import SparseSemiStructuredTensor @@ -81,7 +80,7 @@ def run_once(label, dtype=torch.bfloat16, batchsize=16, compile=True, quantize=F return res print("BENCHMARKING") -run_once("baseline") -run_once("quant", quantize=True) -run_once("sparse", sparse="cusparselt") +# run_once("baseline") +# run_once("quant", quantize=True) +# run_once("sparse", sparse="cusparselt") run_once("quant+sparse(mlp)", quantize=True, sparse="cusparselt") diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index f4e2838f4a..bf47b1afea 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -6,8 +6,11 @@ from .wanda import WandaSparsifier # noqa: F403 from .utils import PerChannelNormObserver # noqa: F403 +from .sparse_api import change_linear_weights_to_int8_dq_24_sparsetensors, apply_sparse __all__ = [ "WandaSparsifier", - "PerChannelNormObserver" + "PerChannelNormObserver", + "apply_sparse", + "change_linear_weights_to_int8_dq_24_sparsetensors", ] diff --git a/torchao/sparsity/dynamic_quant_sparse.py b/torchao/sparsity/dynamic_quant_sparse.py index d7137b7974..3638065659 100644 --- a/torchao/sparsity/dynamic_quant_sparse.py +++ b/torchao/sparsity/dynamic_quant_sparse.py @@ -12,7 +12,7 @@ from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS -# Quant + Sparse helper functinos +# Qunt + Sparse helper functinos def sparse_quant_int8_dynamic_cutlass_linear( x, From 65e290ca0467b4c165161628223a700f2ced1cc7 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 1 Apr 2024 08:16:39 -0700 Subject: [PATCH 11/26] wip --- benchmarks/benchmark_sam.py | 52 ++++++++++++------------ torchao/sparsity/dynamic_quant_sparse.py | 46 ++++++++------------- torchao/sparsity/sparse_api.py | 5 +-- 3 files changed, 45 insertions(+), 58 deletions(-) diff --git a/benchmarks/benchmark_sam.py b/benchmarks/benchmark_sam.py index 25166ea418..9b4d46b99e 100644 --- a/benchmarks/benchmark_sam.py +++ b/benchmarks/benchmark_sam.py @@ -1,5 +1,5 @@ from pprint import pprint - +import pandas as pd import torch from torchao.quantization import change_linear_weights_to_int8_dqtensors from torchao.sparsity import change_linear_weights_to_int8_dq_24_sparsetensors, apply_sparse @@ -12,10 +12,10 @@ model_name = 'sam_vit_h_4b8939.pth' checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}" -torch._inductor.config.epilogue_fusion = False -torch._inductor.config.coordinate_descent_tuning = False -torch._inductor.config.coordinate_descent_check_all_directions = False -torch._inductor.config.force_fuse_int_mm_with_mul = False +torch._inductor.config.epilogue_fusion = True +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.coordinate_descent_check_all_directions = True +torch._inductor.config.force_fuse_int_mm_with_mul = True @torch.no_grad() def benchmark(f, *args, **kwargs): @@ -41,46 +41,46 @@ def get_sam_model(only_one_block=False, batchsize=1): image = torch.randn(batchsize, 64, 64, 1280, device='cuda') return model, image -def mlp_only(mod, name): - return isinstance(mod, torch.nn.Linear) and "mlp" in name - -def attention_only(mod, name): - return isinstance(mod, torch.nn.Linear) and "mlp" not in name - -def run_once(label, dtype=torch.bfloat16, batchsize=16, compile=True, quantize=False, sparse=False): +def run_once(label, dtype=torch.bfloat16, batchsize=16, compile=True, quantize=False, sparsify=False): res = { "label": label, "batchsize": batchsize, "dtype": dtype, "compile": compile, "quantize": quantize, - "sparse": sparse, + "sparsify": sparsify, } model, image = get_sam_model(False, batchsize) model = model.to(dtype) image = image.to(dtype) - if sparse and quantize: - SparseSemiStructuredTensor._FORCE_CUTLASS = (sparse == "cutlass") - change_linear_weights_to_int8_dq_24_sparsetensors(model, filter_fn=mlp_only) - change_linear_weights_to_int8_dqtensors(model, filter_fn=attention_only) + if sparsify and quantize: + SparseSemiStructuredTensor._FORCE_CUTLASS = (sparsify == "cutlass") + change_linear_weights_to_int8_dq_24_sparsetensors(model) elif quantize: change_linear_weights_to_int8_dqtensors(model) - elif sparse: - SparseSemiStructuredTensor._FORCE_CUTLASS = (sparse == "cutlass") + elif sparsify: + SparseSemiStructuredTensor._FORCE_CUTLASS = (sparsify == "cutlass") apply_sparse(model) if compile: model = torch.compile(model, mode='max-autotune') res.update(benchmark(model, image)) - pprint(res) - + print(f"{label} finished in {res['time']} and {res['memory']} run with {res['batchsize']} batchsize, {res['dtype']} dtype, {res['compile']} compile, {res['quantize']} quantize, {res['sparsify']} sparsify") return res -print("BENCHMARKING") -# run_once("baseline") -# run_once("quant", quantize=True) -# run_once("sparse", sparse="cusparselt") -run_once("quant+sparse(mlp)", quantize=True, sparse="cusparselt") + +if __name__ == "__main__": + ALL_RUNS = [] + print("BENCHMARKING") + ALL_RUNS.append(run_once("baseline")) + ALL_RUNS.append(run_once("quant", quantize=True)) + ALL_RUNS.append(run_once("sparse", sparse="cusparselt")) + ALL_RUNS.append(run_once("sparse", sparse="cutlass")) + ALL_RUNS.append(run_once("quant+sparse (fuse one mul)", quantize=True, sparse="cusparselt")) + ALL_RUNS.append(run_once("quant+sparse", quantize=True, sparse="cutlass")) + df = pd.DataFrame(ALL_RUNS) + df.to_csv("sam_benchmark_results.csv") + print(df) diff --git a/torchao/sparsity/dynamic_quant_sparse.py b/torchao/sparsity/dynamic_quant_sparse.py index 3638065659..2f3d41ce33 100644 --- a/torchao/sparsity/dynamic_quant_sparse.py +++ b/torchao/sparsity/dynamic_quant_sparse.py @@ -10,11 +10,11 @@ from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight, QuantizedLinearWeightBase -from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS +from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, to_sparse_semi_structured # Qunt + Sparse helper functinos -def sparse_quant_int8_dynamic_cutlass_linear( +def sparse_quant_int8_dynamic_linear( x, w_vals_int8, w_meta_int32, @@ -23,29 +23,17 @@ def sparse_quant_int8_dynamic_cutlass_linear( out_dtype, ): x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) - mm_out = sparse_quant_int8_cutlass_matmul( - x_vals_int8, x_scales, w_vals_int8, w_meta_int32, w_scales, out_dtype) + if w_meta_int32 is None: + mm_out = sparse_quant_int8_cslt_matmul( + x_vals_int8, x_scales, w_vals_int8, w_scales, out_dtype) + else: + mm_out = sparse_quant_int8_cutlass_matmul( + x_vals_int8, x_scales, w_vals_int8, w_meta_int32, w_scales, out_dtype) if bias is not None: mm_out += bias return mm_out -def sparse_quant_int8_dynamic_cslt_linear( - x, - w_vals_int8, - w_scales, - bias, - out_dtype, -): - x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) - mm_out = sparse_quant_int8_cslt_matmul( - x_vals_int8, x_scales, w_vals_int8, w_scales, out_dtype) - - if bias is not None: - mm_out += bias - return mm_out - - def sparse_quant_int8_cslt_matmul( x_vals_int8, x_scales, @@ -56,7 +44,7 @@ def sparse_quant_int8_cslt_matmul( assert x_vals_int8.dtype == torch.int8, f'x dtype {x_vals_int8.dtype} not yet supported' assert w_vals_int8.dtype == torch.int8, f'w dtype {w_vals_int8.dtype} not yet supported' - assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' + # assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() @@ -65,11 +53,12 @@ def sparse_quant_int8_cslt_matmul( torch.bfloat16, ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(w_vals_int8, tmp.t(), alpha=w_scales, out_dtype=torch.bfloat16).t() - y = (y_dot_bf16_w_scales_fused* x_scales.reshape(-1, 1)).reshape( + y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float), out_dtype=torch.bfloat16).t() + y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] ) y = y.to(out_dtype) + return y def sparse_quant_int8_cutlass_matmul( @@ -104,12 +93,12 @@ class Int8DynamicallyQuantized24CusparseltLinearWeight(Int8DynamicallyQuantizedL @staticmethod def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_cslt_linear( - act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype + return sparse_quant_int8_dynamic_linear( + act_mat, w_qtensor.int_data, None, w_qtensor.q_scales, bias, act_mat.dtype ) @classmethod - def from_float(cls, input_float, qmin=-128, qmax=127): + def from_float(cls, input_float, qmin=-8, qmax=7): assert input_float.is_cuda @@ -118,14 +107,13 @@ def from_float(cls, input_float, qmin=-128, qmax=127): ) int_data = w_int_repr.contiguous() - - int_data = torch._cslt_compress(int_data) return cls( int_data, w_scales, False, input_float.shape, dtype=input_float.dtype, ) + class Int8DynamicallyQuantized24CutlassLinearWeight(QuantizedLinearWeightBase): @staticmethod @@ -203,7 +191,7 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=No @staticmethod def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_cutlass_linear( + return sparse_quant_int8_dynamic_linear( act_mat, w_qtensor.int_data, w_qtensor.mask_meta, w_qtensor.q_scales, bias, act_mat.dtype ) diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 887514381c..fdf8e35ee8 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -30,12 +30,13 @@ def apply_sparse(model, **kwargs): apply_fake_sparsity(model) for name, mod in model.named_modules(): - if filter_fn(name, mod): + if filter_fn(mod, name): mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight)) def change_linear_weights_to_int8_dq_24_sparsetensors(model, **kwargs): filter_fn = kwargs.pop("filter_fn", _is_linear) + use_experimental = kwargs.pop("use_experimental", False) if SparseSemiStructuredTensor._FORCE_CUTLASS: subclass = Int8DynamicallyQuantized24CutlassLinearWeight @@ -47,5 +48,3 @@ def change_linear_weights_to_int8_dq_24_sparsetensors(model, **kwargs): _get_subclass_inserter(subclass, **kwargs), filter_fn, ) - - From 31706d5ed619ff447b88be957affbe49d0e113cb Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 1 Apr 2024 12:13:26 -0700 Subject: [PATCH 12/26] formatted api and added per-linear tuning to script --- benchmarks/benchmark_sam.py | 111 +++++++++++++++++------ benchmarks/sam_benchmark_results.csv | 3 + torchao/sparsity/__init__.py | 6 +- torchao/sparsity/dynamic_quant_sparse.py | 106 ++++++++++++++++------ torchao/sparsity/sparse_api.py | 34 ++++--- 5 files changed, 189 insertions(+), 71 deletions(-) create mode 100644 benchmarks/sam_benchmark_results.csv diff --git a/benchmarks/benchmark_sam.py b/benchmarks/benchmark_sam.py index 9b4d46b99e..aec5028b55 100644 --- a/benchmarks/benchmark_sam.py +++ b/benchmarks/benchmark_sam.py @@ -2,10 +2,21 @@ import pandas as pd import torch from torchao.quantization import change_linear_weights_to_int8_dqtensors -from torchao.sparsity import change_linear_weights_to_int8_dq_24_sparsetensors, apply_sparse +from torchao.sparsity import change_linear_weights_to_int8_dq_semi_structured_sparsetensors +from torchao.sparsity.sparse_api import apply_sparse_semi_structured, apply_fake_sparsity +from torchao.sparsity.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearWeight, Int8DynamicallyQuantized24CutlassLinearWeight from segment_anything import sam_model_registry from torch.utils.benchmark import Timer -from torch.sparse import SparseSemiStructuredTensor +from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT +from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, + _get_subclass_inserter, + _is_linear, + QuantizedLinearWeightBase, + Int8DynamicallyQuantizedLinearWeight, +) +from itertools import product +from tqdm import tqdm sam_checkpoint_base_path = "/home/jessecai/local/MODELS" model_type = 'vit_h' @@ -41,46 +52,88 @@ def get_sam_model(only_one_block=False, batchsize=1): image = torch.randn(batchsize, 64, 64, 1280, device='cuda') return model, image -def run_once(label, dtype=torch.bfloat16, batchsize=16, compile=True, quantize=False, sparsify=False): +def qkv_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'qkv' in name + +def proj_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'proj' in name + +def lin1_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'lin1' in name + +def lin2_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'lin2' in name + +SUBCLASSES = { + # "baseline": None, + "quant": Int8DynamicallyQuantizedLinearWeight, + "sparse (cutlass)": SparseSemiStructuredTensorCUTLASS, + "sparse (cusparselt)": SparseSemiStructuredTensorCUSPARSELT, + "quant+sparse (cusparselt fuse one mul)": Int8DynamicallyQuantized24CusparseltLinearWeight, + "quant+sparse (cutlass)": Int8DynamicallyQuantized24CutlassLinearWeight, +} + +def run_once(dtype=torch.bfloat16, batchsize=32, compile=True, qkv=None, proj=None, lin1=None, lin2=None): res = { - "label": label, "batchsize": batchsize, "dtype": dtype, "compile": compile, - "quantize": quantize, - "sparsify": sparsify, + "qkv" : qkv, + "proj": proj, + "lin1": lin1, + "lin2": lin2, } + with torch.no_grad(): + model, image = get_sam_model(False, batchsize) + model = model.to(dtype) + image = image.to(dtype) - model, image = get_sam_model(False, batchsize) - model = model.to(dtype) - image = image.to(dtype) + # 2:4 prune model + apply_fake_sparsity(model) + option_and_filter_fn = zip([qkv, proj, lin1, lin2], + [qkv_only, proj_only, lin1_only, lin2_only]) - if sparsify and quantize: - SparseSemiStructuredTensor._FORCE_CUTLASS = (sparsify == "cutlass") - change_linear_weights_to_int8_dq_24_sparsetensors(model) - elif quantize: - change_linear_weights_to_int8_dqtensors(model) - elif sparsify: - SparseSemiStructuredTensor._FORCE_CUTLASS = (sparsify == "cutlass") - apply_sparse(model) + for option, filter_fn in option_and_filter_fn: + if option: + subclass = SUBCLASSES[option] + if issubclass(subclass, SparseSemiStructuredTensor): + for name, mod in model.named_modules(): + if filter_fn(mod, name): + mod.weight = torch.nn.Parameter(subclass.from_dense(mod.weight)) + # replace with to_sparse_semi_structured + elif issubclass(subclass, QuantizedLinearWeightBase): + _replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(subclass), filter_fn) - if compile: - model = torch.compile(model, mode='max-autotune') + # for name, mod in model.named_modules(): + # if isinstance(mod, torch.nn.Linear): + # print(name, mod.weight.data.__class__.__name__) - res.update(benchmark(model, image)) - print(f"{label} finished in {res['time']} and {res['memory']} run with {res['batchsize']} batchsize, {res['dtype']} dtype, {res['compile']} compile, {res['quantize']} quantize, {res['sparsify']} sparsify") - return res + if compile: + model = torch.compile(model, mode='max-autotune') + res.update(benchmark(model, image)) + pprint(res) + return res if __name__ == "__main__": - ALL_RUNS = [] print("BENCHMARKING") - ALL_RUNS.append(run_once("baseline")) - ALL_RUNS.append(run_once("quant", quantize=True)) - ALL_RUNS.append(run_once("sparse", sparse="cusparselt")) - ALL_RUNS.append(run_once("sparse", sparse="cutlass")) - ALL_RUNS.append(run_once("quant+sparse (fuse one mul)", quantize=True, sparse="cusparselt")) - ALL_RUNS.append(run_once("quant+sparse", quantize=True, sparse="cutlass")) + # ALL_RUNS = [run_once(qkv=qkv, proj=proj, lin1=lin1, lin2=lin2) + # for (qkv, proj, lin1, lin2) + # in tqdm(list(product(SUBCLASSES, SUBCLASSES, SUBCLASSES, SUBCLASSES)))] + ALL_RUNS = [ + run_once(), + run_once(qkv="quant", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), + # run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="quant+sparse (cusparselt fuse one mul)", lin2="quant+sparse (cutlass)"), + # run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="quant+sparse (cusparselt fuse one mul)", lin2="quant+sparse (cusparselt fuse one mul)"), + run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), + ] df = pd.DataFrame(ALL_RUNS) df.to_csv("sam_benchmark_results.csv") print(df) + + +# optimal order goes +# qkv - cusparselt +# proj - cusparselt +# lin1 - int8 cusparselt / cutlass +# lin2 - cutlass diff --git a/benchmarks/sam_benchmark_results.csv b/benchmarks/sam_benchmark_results.csv new file mode 100644 index 0000000000..5f64ed9a1d --- /dev/null +++ b/benchmarks/sam_benchmark_results.csv @@ -0,0 +1,3 @@ +,batchsize,dtype,compile,qkv,proj,lin1,lin2,time,memory +0,32,torch.bfloat16,True,quant,quant,quant+sparse (cutlass),quant+sparse (cutlass),32.57098700851202,5.451383808 +1,32,torch.bfloat16,True,sparse (cusparselt),sparse (cusparselt),quant+sparse (cutlass),quant+sparse (cutlass),32.90631342679262,5.452028928 diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index bf47b1afea..f8ceee7b9b 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -6,11 +6,11 @@ from .wanda import WandaSparsifier # noqa: F403 from .utils import PerChannelNormObserver # noqa: F403 -from .sparse_api import change_linear_weights_to_int8_dq_24_sparsetensors, apply_sparse +from .sparse_api import change_linear_weights_to_int8_dq_semi_structured_sparsetensors, apply_sparse_semi_structured __all__ = [ "WandaSparsifier", "PerChannelNormObserver", - "apply_sparse", - "change_linear_weights_to_int8_dq_24_sparsetensors", + "apply_sparse_semi_structured", + "change_linear_weights_to_int8_dq_semi_structured_sparsetensors", ] diff --git a/torchao/sparsity/dynamic_quant_sparse.py b/torchao/sparsity/dynamic_quant_sparse.py index 2f3d41ce33..5b280fb32e 100644 --- a/torchao/sparsity/dynamic_quant_sparse.py +++ b/torchao/sparsity/dynamic_quant_sparse.py @@ -5,35 +5,47 @@ from torchao.quantization.quant_primitives import ( dynamically_quantize_per_channel, quant_int8_dynamic_per_token_linear, - quantize_activation_per_token_absmax + quantize_activation_per_token_absmax, ) -from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight, QuantizedLinearWeightBase +from torchao.quantization.subclass import ( + Int8DynamicallyQuantizedLinearWeight, + QuantizedLinearWeightBase, +) -from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, to_sparse_semi_structured +from torch.sparse import ( + SparseSemiStructuredTensor, + SparseSemiStructuredTensorCUTLASS, + to_sparse_semi_structured, +) # Qunt + Sparse helper functinos + def sparse_quant_int8_dynamic_linear( - x, - w_vals_int8, - w_meta_int32, - w_scales, - bias, - out_dtype, + x : torch.Tensor, + w_vals_int8_packed : torch.Tensor, + w_meta_int32 : Optional[torch.Tensor], + w_scales : torch.Tensor, + bias : Optional[torch.Tensor], + out_dtype : torch.dtype, ): x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) + # w_meta_int32 is either None or meta tensor if w_meta_int32 is None: mm_out = sparse_quant_int8_cslt_matmul( - x_vals_int8, x_scales, w_vals_int8, w_scales, out_dtype) + x_vals_int8, x_scales, w_vals_int8_packed, w_scales, out_dtype + ) else: mm_out = sparse_quant_int8_cutlass_matmul( - x_vals_int8, x_scales, w_vals_int8, w_meta_int32, w_scales, out_dtype) + x_vals_int8, x_scales, w_vals_int8_packed, w_meta_int32, w_scales, out_dtype + ) if bias is not None: mm_out += bias return mm_out + def sparse_quant_int8_cslt_matmul( x_vals_int8, x_scales, @@ -42,8 +54,12 @@ def sparse_quant_int8_cslt_matmul( out_dtype, ): - assert x_vals_int8.dtype == torch.int8, f'x dtype {x_vals_int8.dtype} not yet supported' - assert w_vals_int8.dtype == torch.int8, f'w dtype {w_vals_int8.dtype} not yet supported' + assert ( + x_vals_int8.dtype == torch.int8 + ), f"x dtype {x_vals_int8.dtype} not yet supported" + assert ( + w_vals_int8.dtype == torch.int8 + ), f"w dtype {w_vals_int8.dtype} not yet supported" # assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() @@ -53,7 +69,9 @@ def sparse_quant_int8_cslt_matmul( torch.bfloat16, ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float), out_dtype=torch.bfloat16).t() + y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( + w_vals_int8, tmp.t(), alpha=w_scales, out_dtype=torch.bfloat16 + ).t() y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] ) @@ -61,6 +79,7 @@ def sparse_quant_int8_cslt_matmul( return y + def sparse_quant_int8_cutlass_matmul( x_vals_int8, x_scales, @@ -69,10 +88,14 @@ def sparse_quant_int8_cutlass_matmul( w_scales, out_dtype, ): - assert x_vals_int8.dtype == torch.int8, f'x dtype {x_vals_int8.dtype} not yet supported' - assert w_vals_int8.dtype == torch.int8, f'w dtype {w_vals_int8.dtype} not yet supported' - assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' - assert w_meta_int32.dtype == torch.int32, f'{w_meta_int32.dtype} not yet supported' + assert ( + x_vals_int8.dtype == torch.int8 + ), f"x dtype {x_vals_int8.dtype} not yet supported" + assert ( + w_vals_int8.dtype == torch.int8 + ), f"w dtype {w_vals_int8.dtype} not yet supported" + assert w_scales.dtype == out_dtype, f"{w_scales.dtype} does not match {out_dtype}" + assert w_meta_int32.dtype == torch.int32, f"{w_meta_int32.dtype} not yet supported" tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() @@ -81,7 +104,9 @@ def sparse_quant_int8_cutlass_matmul( torch.bfloat16, ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - y_dot_int32 = torch._sparse_semi_structured_linear(tmp, w_vals_int8, w_meta_int32.view(torch.int32), out_dtype=torch.int32) + y_dot_int32 = torch._sparse_semi_structured_linear( + tmp, w_vals_int8, w_meta_int32.view(torch.int32), out_dtype=torch.int32 + ) y = (y_dot_int32 * x_scales.reshape(-1, 1) * w_scales).reshape( *x_vals_int8.shape[:-1], y_dot_int32.shape[-1] ) @@ -89,7 +114,9 @@ def sparse_quant_int8_cutlass_matmul( return y -class Int8DynamicallyQuantized24CusparseltLinearWeight(Int8DynamicallyQuantizedLinearWeight): +class Int8DynamicallyQuantized24CusparseltLinearWeight( + Int8DynamicallyQuantizedLinearWeight +): @staticmethod def _quantized_op(act_mat, w_qtensor, bias): @@ -110,7 +137,11 @@ def from_float(cls, input_float, qmin=-8, qmax=7): int_data = torch._cslt_compress(int_data) return cls( - int_data, w_scales, False, input_float.shape, dtype=input_float.dtype, + int_data, + w_scales, + False, + input_float.shape, + dtype=input_float.dtype, ) @@ -166,7 +197,7 @@ def _apply_fn_to_data(self, fn): fn(self.q_scales), self.transposed, self.shape, - dtype=self.dtype + dtype=self.dtype, ) def _change_shape(self, shape): @@ -176,23 +207,42 @@ def _change_shape(self, shape): self.q_scales, self.transposed, shape, - dtype=self.dtype + dtype=self.dtype, ) def __tensor_flatten__(self): - return ["int_data", "mask_meta", "q_scales"], [self.transposed, self.dtype, self.shape] + return ["int_data", "mask_meta", "q_scales"], [ + self.transposed, + self.dtype, + self.shape, + ] @classmethod - def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None + ): int_data, q_scales = tensor_data_dict["int_data"], tensor_data_dict["q_scales"] mask_meta = tensor_data_dict["mask_meta"] transposed, dtype, shape = tensor_attributes - return cls(int_data, mask_meta, q_scales, transposed, shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride) + return cls( + int_data, + mask_meta, + q_scales, + transposed, + shape if outer_size is None else outer_size, + dtype=dtype, + strides=outer_stride, + ) @staticmethod def _quantized_op(act_mat, w_qtensor, bias): return sparse_quant_int8_dynamic_linear( - act_mat, w_qtensor.int_data, w_qtensor.mask_meta, w_qtensor.q_scales, bias, act_mat.dtype + act_mat, + w_qtensor.int_data, + w_qtensor.mask_meta, + w_qtensor.q_scales, + bias, + act_mat.dtype, ) @classmethod @@ -213,5 +263,5 @@ def from_float(cls, input_float, qmin=-128, qmax=127): w_scales, False, input_float.shape, - dtype=input_float.dtype + dtype=input_float.dtype, ) diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index fdf8e35ee8..a918f6e096 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -1,11 +1,23 @@ - import torch -from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT -from torchao.sparsity.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearWeight, Int8DynamicallyQuantized24CutlassLinearWeight - -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter, _get_subclass_inserter, _is_linear +from torch.sparse import ( + to_sparse_semi_structured, + SparseSemiStructuredTensor, + SparseSemiStructuredTensorCUTLASS, + SparseSemiStructuredTensorCUSPARSELT, +) +from torchao.sparsity.dynamic_quant_sparse import ( + Int8DynamicallyQuantized24CusparseltLinearWeight, + Int8DynamicallyQuantized24CutlassLinearWeight, +) + +from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, + _get_subclass_inserter, + _is_linear, +) from torch.ao.pruning import WeightNormSparsifier + # Sparsity helper functions def apply_fake_sparsity(model): """ @@ -18,14 +30,15 @@ def apply_fake_sparsity(model): if isinstance(mod, torch.nn.Linear): sparse_config.append({"tensor_fqn": f"{name}.weight"}) - sparsifier = WeightNormSparsifier(sparsity_level=1.0, - sparse_block_shape=(1,4), - zeros_per_block=2) + sparsifier = WeightNormSparsifier( + sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2 + ) sparsifier.prepare(model, sparse_config) sparsifier.step() sparsifier.squash_mask() -def apply_sparse(model, **kwargs): + +def apply_sparse_semi_structured(model, **kwargs): filter_fn = kwargs.pop("filter_fn", _is_linear) apply_fake_sparsity(model) @@ -34,9 +47,8 @@ def apply_sparse(model, **kwargs): mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight)) -def change_linear_weights_to_int8_dq_24_sparsetensors(model, **kwargs): +def change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model, **kwargs): filter_fn = kwargs.pop("filter_fn", _is_linear) - use_experimental = kwargs.pop("use_experimental", False) if SparseSemiStructuredTensor._FORCE_CUTLASS: subclass = Int8DynamicallyQuantized24CutlassLinearWeight From 0ce93c851d164f26aa309d5feacfe5ef4c8db06a Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 1 Apr 2024 16:19:19 -0700 Subject: [PATCH 13/26] update --- benchmarks/benchmark_sam.py | 68 +++++++++----------- benchmarks/sam_benchmark_results.csv | 8 ++- torchao/sparsity/dynamic_quant_sparse.py | 82 ++++++++++++++++++++++-- 3 files changed, 109 insertions(+), 49 deletions(-) diff --git a/benchmarks/benchmark_sam.py b/benchmarks/benchmark_sam.py index aec5028b55..0174f95770 100644 --- a/benchmarks/benchmark_sam.py +++ b/benchmarks/benchmark_sam.py @@ -4,7 +4,7 @@ from torchao.quantization import change_linear_weights_to_int8_dqtensors from torchao.sparsity import change_linear_weights_to_int8_dq_semi_structured_sparsetensors from torchao.sparsity.sparse_api import apply_sparse_semi_structured, apply_fake_sparsity -from torchao.sparsity.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearWeight, Int8DynamicallyQuantized24CutlassLinearWeight +from torchao.sparsity.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearWeight, Int8DynamicallyQuantized24CutlassLinearWeight, Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight from segment_anything import sam_model_registry from torch.utils.benchmark import Timer from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT @@ -65,16 +65,17 @@ def lin2_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'lin2' in name SUBCLASSES = { - # "baseline": None, - "quant": Int8DynamicallyQuantizedLinearWeight, - "sparse (cutlass)": SparseSemiStructuredTensorCUTLASS, - "sparse (cusparselt)": SparseSemiStructuredTensorCUSPARSELT, - "quant+sparse (cusparselt fuse one mul)": Int8DynamicallyQuantized24CusparseltLinearWeight, - "quant+sparse (cutlass)": Int8DynamicallyQuantized24CutlassLinearWeight, + "quant" : Int8DynamicallyQuantizedLinearWeight, + "quant+sparse (cutlass)" : Int8DynamicallyQuantized24CutlassLinearWeight, + "quant+sparse (cusparselt)" : Int8DynamicallyQuantized24CusparseltLinearWeight, + "quant+sparse (cusparselt fuse mul)" : Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, + # "sparse (cutlass)" : SparseSemiStructuredTensorCUTLASS, + # "sparse (cusparselt)" : SparseSemiStructuredTensorCUSPARSELT, } -def run_once(dtype=torch.bfloat16, batchsize=32, compile=True, qkv=None, proj=None, lin1=None, lin2=None): +def run_once(block_only=False, dtype=torch.bfloat16, batchsize=32, compile=True, qkv=None, proj=None, lin1=None, lin2=None): res = { + "block_only": block_only, "batchsize": batchsize, "dtype": dtype, "compile": compile, @@ -84,56 +85,45 @@ def run_once(dtype=torch.bfloat16, batchsize=32, compile=True, qkv=None, proj=No "lin2": lin2, } with torch.no_grad(): - model, image = get_sam_model(False, batchsize) + model, image = get_sam_model(block_only, batchsize) model = model.to(dtype) image = image.to(dtype) # 2:4 prune model apply_fake_sparsity(model) - option_and_filter_fn = zip([qkv, proj, lin1, lin2], - [qkv_only, proj_only, lin1_only, lin2_only]) + option_and_filter_fn = zip([qkv, proj, lin1, lin2], [qkv_only, proj_only, lin1_only, lin2_only]) for option, filter_fn in option_and_filter_fn: - if option: - subclass = SUBCLASSES[option] - if issubclass(subclass, SparseSemiStructuredTensor): - for name, mod in model.named_modules(): - if filter_fn(mod, name): - mod.weight = torch.nn.Parameter(subclass.from_dense(mod.weight)) - # replace with to_sparse_semi_structured - elif issubclass(subclass, QuantizedLinearWeightBase): - _replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(subclass), filter_fn) - - # for name, mod in model.named_modules(): - # if isinstance(mod, torch.nn.Linear): - # print(name, mod.weight.data.__class__.__name__) + subclass = SUBCLASSES.get(option, None) + if subclass and issubclass(subclass, SparseSemiStructuredTensor): + # replace with to_sparse_semi_structured + for name, mod in model.named_modules(): + if filter_fn(mod, name): + mod.weight = torch.nn.Parameter(subclass.from_dense(mod.weight)) + elif subclass and issubclass(subclass, QuantizedLinearWeightBase): + _replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(subclass), filter_fn) if compile: model = torch.compile(model, mode='max-autotune') res.update(benchmark(model, image)) - pprint(res) + res["img/s"] = 1 / (res['time'] / 1000 / res['batchsize']) return res if __name__ == "__main__": print("BENCHMARKING") - # ALL_RUNS = [run_once(qkv=qkv, proj=proj, lin1=lin1, lin2=lin2) - # for (qkv, proj, lin1, lin2) - # in tqdm(list(product(SUBCLASSES, SUBCLASSES, SUBCLASSES, SUBCLASSES)))] + # ALL_RUNS = [run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)")] + # for option in tqdm(SUBCLASSES)] ALL_RUNS = [ run_once(), - run_once(qkv="quant", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), - # run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="quant+sparse (cusparselt fuse one mul)", lin2="quant+sparse (cutlass)"), - # run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="quant+sparse (cusparselt fuse one mul)", lin2="quant+sparse (cusparselt fuse one mul)"), - run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), + run_once(qkv="quant", proj="quant", lin1="quant", lin2="quant"), + run_once(qkv="quant+sparse (cusparselt)", proj="quant+sparse (cusparselt)", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cutlass)"), + run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), + run_once(qkv="quant", proj="quant", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cusparselt)"), + run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"), + run_once(qkv="sparse (cutlass)", proj="sparse (cutlass)", lin1="sparse (cutlass)", lin2="sparse (cutlass)"), + run_once(qkv="quant+sparse (cutlass)", proj="quant+sparse (cutlass)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), ] df = pd.DataFrame(ALL_RUNS) df.to_csv("sam_benchmark_results.csv") print(df) - - -# optimal order goes -# qkv - cusparselt -# proj - cusparselt -# lin1 - int8 cusparselt / cutlass -# lin2 - cutlass diff --git a/benchmarks/sam_benchmark_results.csv b/benchmarks/sam_benchmark_results.csv index 5f64ed9a1d..7cfb27fafc 100644 --- a/benchmarks/sam_benchmark_results.csv +++ b/benchmarks/sam_benchmark_results.csv @@ -1,3 +1,5 @@ -,batchsize,dtype,compile,qkv,proj,lin1,lin2,time,memory -0,32,torch.bfloat16,True,quant,quant,quant+sparse (cutlass),quant+sparse (cutlass),32.57098700851202,5.451383808 -1,32,torch.bfloat16,True,sparse (cusparselt),sparse (cusparselt),quant+sparse (cutlass),quant+sparse (cutlass),32.90631342679262,5.452028928 +,block_only,batchsize,dtype,compile,qkv,proj,lin1,lin2,time,memory,img/s +0,False,32,torch.bfloat16,True,,,,,1457.0417301729321,28.280423936,21.96230851686177 +1,False,32,torch.bfloat16,True,quant,quant,quant,quant,1318.5919532552361,28.261341696,24.268311300551254 +2,False,32,torch.bfloat16,True,quant+sparse (cusparselt),quant,quant+sparse (cutlass),quant+sparse (cutlass),1253.1237555667758,28.18694656,25.536184960061433 +3,False,32,torch.bfloat16,True,quant+sparse (cutlass),quant+sparse (cutlass),quant+sparse (cutlass),quant+sparse (cutlass),1290.4946617782116,27.837008896,24.796693041648258 diff --git a/torchao/sparsity/dynamic_quant_sparse.py b/torchao/sparsity/dynamic_quant_sparse.py index 5b280fb32e..a8c0e06fa9 100644 --- a/torchao/sparsity/dynamic_quant_sparse.py +++ b/torchao/sparsity/dynamic_quant_sparse.py @@ -16,7 +16,6 @@ from torch.sparse import ( SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, - to_sparse_semi_structured, ) # Qunt + Sparse helper functinos @@ -29,22 +28,60 @@ def sparse_quant_int8_dynamic_linear( w_scales : torch.Tensor, bias : Optional[torch.Tensor], out_dtype : torch.dtype, + fuse_mul=False, ): x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) # w_meta_int32 is either None or meta tensor if w_meta_int32 is None: - mm_out = sparse_quant_int8_cslt_matmul( - x_vals_int8, x_scales, w_vals_int8_packed, w_scales, out_dtype - ) + if fuse_mul: + mm_out = sparse_quant_int8_cslt_matmul_fuse_mul( + x_vals_int8, x_scales, w_vals_int8_packed, w_scales, out_dtype, + ) + else: + mm_out = sparse_quant_int8_cslt_matmul( + x_vals_int8, x_scales, w_vals_int8_packed, w_scales, out_dtype, + ) else: mm_out = sparse_quant_int8_cutlass_matmul( - x_vals_int8, x_scales, w_vals_int8_packed, w_meta_int32, w_scales, out_dtype + x_vals_int8, x_scales, w_vals_int8_packed, w_meta_int32, w_scales, out_dtype, ) if bias is not None: mm_out += bias return mm_out +def sparse_quant_int8_cslt_matmul_fuse_mul( + x_vals_int8, + x_scales, + w_vals_int8, + w_scales, + out_dtype, +): + + assert ( + x_vals_int8.dtype == torch.int8 + ), f"x dtype {x_vals_int8.dtype} not yet supported" + assert ( + w_vals_int8.dtype == torch.int8 + ), f"w dtype {w_vals_int8.dtype} not yet supported" + # assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' + + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() + + assert x_scales.dtype in [ + torch.float, + torch.bfloat16, + ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" + + y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( + w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16 + ).t() + y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( + *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] + ) + y = y.to(out_dtype) + + return y def sparse_quant_int8_cslt_matmul( x_vals_int8, @@ -70,9 +107,9 @@ def sparse_quant_int8_cslt_matmul( ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, tmp.t(), alpha=w_scales, out_dtype=torch.bfloat16 + w_vals_int8, tmp.t(), out_dtype=torch.bfloat16 ).t() - y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( + y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1) * w_scales).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] ) y = y.to(out_dtype) @@ -144,6 +181,37 @@ def from_float(cls, input_float, qmin=-8, qmax=7): dtype=input_float.dtype, ) +class Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight( + Int8DynamicallyQuantizedLinearWeight +): + + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + return sparse_quant_int8_dynamic_linear( + act_mat, w_qtensor.int_data, None, w_qtensor.q_scales, bias, act_mat.dtype, + fuse_mul=True + ) + + @classmethod + def from_float(cls, input_float, qmin=-8, qmax=7): + + assert input_float.is_cuda + + w_int_repr, w_scales, _ = dynamically_quantize_per_channel( + input_float, qmin, qmax, torch.int8 + ) + + int_data = w_int_repr.contiguous() + int_data = torch._cslt_compress(int_data) + + return cls( + int_data, + w_scales, + False, + input_float.shape, + dtype=input_float.dtype, + ) + class Int8DynamicallyQuantized24CutlassLinearWeight(QuantizedLinearWeightBase): From a7d535933bed2fb8b46aaf38663960646981b51a Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 1 Apr 2024 16:31:40 -0700 Subject: [PATCH 14/26] clean up imports --- benchmarks/benchmark_sam.py | 17 +++++++----- torchao/sparsity/__init__.py | 12 +++++++-- torchao/sparsity/dynamic_quant_sparse.py | 4 +-- torchao/sparsity/sparse_api.py | 34 ++---------------------- 4 files changed, 23 insertions(+), 44 deletions(-) diff --git a/benchmarks/benchmark_sam.py b/benchmarks/benchmark_sam.py index 0174f95770..ac8ae02a3e 100644 --- a/benchmarks/benchmark_sam.py +++ b/benchmarks/benchmark_sam.py @@ -1,10 +1,5 @@ -from pprint import pprint import pandas as pd import torch -from torchao.quantization import change_linear_weights_to_int8_dqtensors -from torchao.sparsity import change_linear_weights_to_int8_dq_semi_structured_sparsetensors -from torchao.sparsity.sparse_api import apply_sparse_semi_structured, apply_fake_sparsity -from torchao.sparsity.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearWeight, Int8DynamicallyQuantized24CutlassLinearWeight, Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight from segment_anything import sam_model_registry from torch.utils.benchmark import Timer from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT @@ -15,6 +10,14 @@ QuantizedLinearWeightBase, Int8DynamicallyQuantizedLinearWeight, ) +from torchao.quantization import change_linear_weights_to_int8_dqtensors +from torchao.sparsity import ( + apply_sparse_semi_structured, + apply_fake_sparsity, + Int8DynamicallyQuantized24CusparseltLinearWeight, + Int8DynamicallyQuantized24CutlassLinearWeight, + Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, +) from itertools import product from tqdm import tqdm @@ -69,8 +72,8 @@ def lin2_only(mod, name): "quant+sparse (cutlass)" : Int8DynamicallyQuantized24CutlassLinearWeight, "quant+sparse (cusparselt)" : Int8DynamicallyQuantized24CusparseltLinearWeight, "quant+sparse (cusparselt fuse mul)" : Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, - # "sparse (cutlass)" : SparseSemiStructuredTensorCUTLASS, - # "sparse (cusparselt)" : SparseSemiStructuredTensorCUSPARSELT, + "sparse (cutlass)" : SparseSemiStructuredTensorCUTLASS, + "sparse (cusparselt)" : SparseSemiStructuredTensorCUSPARSELT, } def run_once(block_only=False, dtype=torch.bfloat16, batchsize=32, compile=True, qkv=None, proj=None, lin1=None, lin2=None): diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index f8ceee7b9b..3540f0b226 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -6,11 +6,19 @@ from .wanda import WandaSparsifier # noqa: F403 from .utils import PerChannelNormObserver # noqa: F403 -from .sparse_api import change_linear_weights_to_int8_dq_semi_structured_sparsetensors, apply_sparse_semi_structured +from .sparse_api import apply_sparse_semi_structured, apply_fake_sparsity +from .dynamic_quant_sparse import ( + Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, + Int8DynamicallyQuantized24CusparseltLinearWeight, + Int8DynamicallyQuantized24CutlassLinearWeight +) __all__ = [ "WandaSparsifier", "PerChannelNormObserver", "apply_sparse_semi_structured", - "change_linear_weights_to_int8_dq_semi_structured_sparsetensors", + "apply_fake_sparsity", + "Int8DynamicallyQuantizedCusparseltLinearFuseMulWeight", + "Int8DynamicallyQuantizedCusparseltLinearWeight", + "Int8DynamicallyQuantizedCutlassLinearWeight", ] diff --git a/torchao/sparsity/dynamic_quant_sparse.py b/torchao/sparsity/dynamic_quant_sparse.py index a8c0e06fa9..df2457c2ab 100644 --- a/torchao/sparsity/dynamic_quant_sparse.py +++ b/torchao/sparsity/dynamic_quant_sparse.py @@ -18,9 +18,7 @@ SparseSemiStructuredTensorCUTLASS, ) -# Qunt + Sparse helper functinos - - +# Quant + Sparse helper functinos def sparse_quant_int8_dynamic_linear( x : torch.Tensor, w_vals_int8_packed : torch.Tensor, diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index a918f6e096..90e35a4121 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -1,22 +1,7 @@ import torch -from torch.sparse import ( - to_sparse_semi_structured, - SparseSemiStructuredTensor, - SparseSemiStructuredTensorCUTLASS, - SparseSemiStructuredTensorCUSPARSELT, -) -from torchao.sparsity.dynamic_quant_sparse import ( - Int8DynamicallyQuantized24CusparseltLinearWeight, - Int8DynamicallyQuantized24CutlassLinearWeight, -) - -from torchao.quantization.quant_api import ( - _replace_with_custom_fn_if_matches_filter, - _get_subclass_inserter, - _is_linear, -) from torch.ao.pruning import WeightNormSparsifier - +from torch.sparse import to_sparse_semi_structured +from torchao.quantization.quant_api import _is_linear # Sparsity helper functions def apply_fake_sparsity(model): @@ -45,18 +30,3 @@ def apply_sparse_semi_structured(model, **kwargs): for name, mod in model.named_modules(): if filter_fn(mod, name): mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight)) - - -def change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model, **kwargs): - filter_fn = kwargs.pop("filter_fn", _is_linear) - - if SparseSemiStructuredTensor._FORCE_CUTLASS: - subclass = Int8DynamicallyQuantized24CutlassLinearWeight - else: - subclass = Int8DynamicallyQuantized24CusparseltLinearWeight - - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter(subclass, **kwargs), - filter_fn, - ) From 3b8969638a92044b63530f2c454033d51de3b1a4 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 22 Apr 2024 16:48:07 -0700 Subject: [PATCH 15/26] updated files --- benchmarks/benchmark_sam.py | 30 +++++++++--------- sam_benchmark_results.csv | 2 ++ torchao/sparsity/__init__.py | 6 ++-- torchao/sparsity/dynamic_quant_sparse.py | 40 ++---------------------- 4 files changed, 21 insertions(+), 57 deletions(-) create mode 100644 sam_benchmark_results.csv diff --git a/benchmarks/benchmark_sam.py b/benchmarks/benchmark_sam.py index ac8ae02a3e..33dab865fa 100644 --- a/benchmarks/benchmark_sam.py +++ b/benchmarks/benchmark_sam.py @@ -14,8 +14,7 @@ from torchao.sparsity import ( apply_sparse_semi_structured, apply_fake_sparsity, - Int8DynamicallyQuantized24CusparseltLinearWeight, - Int8DynamicallyQuantized24CutlassLinearWeight, + Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight, Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, ) from itertools import product @@ -69,9 +68,8 @@ def lin2_only(mod, name): SUBCLASSES = { "quant" : Int8DynamicallyQuantizedLinearWeight, - "quant+sparse (cutlass)" : Int8DynamicallyQuantized24CutlassLinearWeight, - "quant+sparse (cusparselt)" : Int8DynamicallyQuantized24CusparseltLinearWeight, - "quant+sparse (cusparselt fuse mul)" : Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, + "quant+sparse (cutlass)" : Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight, + "quant+sparse (cusparselt)" : Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, "sparse (cutlass)" : SparseSemiStructuredTensorCUTLASS, "sparse (cusparselt)" : SparseSemiStructuredTensorCUSPARSELT, } @@ -115,18 +113,18 @@ def run_once(block_only=False, dtype=torch.bfloat16, batchsize=32, compile=True, if __name__ == "__main__": print("BENCHMARKING") - # ALL_RUNS = [run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)")] + ALL_RUNS = [run_once(qkv="quant+sparse (cutlass)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)")] # for option in tqdm(SUBCLASSES)] - ALL_RUNS = [ - run_once(), - run_once(qkv="quant", proj="quant", lin1="quant", lin2="quant"), - run_once(qkv="quant+sparse (cusparselt)", proj="quant+sparse (cusparselt)", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cutlass)"), - run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), - run_once(qkv="quant", proj="quant", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cusparselt)"), - run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"), - run_once(qkv="sparse (cutlass)", proj="sparse (cutlass)", lin1="sparse (cutlass)", lin2="sparse (cutlass)"), - run_once(qkv="quant+sparse (cutlass)", proj="quant+sparse (cutlass)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), - ] + # ALL_RUNS = [ + # run_once(), + # run_once(qkv="quant", proj="quant", lin1="quant", lin2="quant"), + # run_once(qkv="quant+sparse (cusparselt)", proj="quant+sparse (cusparselt)", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cutlass)"), + # run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), + # run_once(qkv="quant", proj="quant", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cusparselt)"), + # run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"), + # run_once(qkv="sparse (cutlass)", proj="sparse (cutlass)", lin1="sparse (cutlass)", lin2="sparse (cutlass)"), + # run_once(qkv="quant+sparse (cutlass)", proj="quant+sparse (cutlass)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), + # ] df = pd.DataFrame(ALL_RUNS) df.to_csv("sam_benchmark_results.csv") print(df) diff --git a/sam_benchmark_results.csv b/sam_benchmark_results.csv new file mode 100644 index 0000000000..af34848157 --- /dev/null +++ b/sam_benchmark_results.csv @@ -0,0 +1,2 @@ +,block_only,batchsize,dtype,compile,qkv,proj,lin1,lin2,time,memory,img/s +0,False,32,torch.bfloat16,True,quant+sparse (cutlass),quant,quant+sparse (cutlass),quant+sparse (cutlass),1269.7276677936316,28.195974144,25.202254634338605 diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index 3540f0b226..0576224ea8 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -8,9 +8,8 @@ from .utils import PerChannelNormObserver # noqa: F403 from .sparse_api import apply_sparse_semi_structured, apply_fake_sparsity from .dynamic_quant_sparse import ( + Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight, Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, - Int8DynamicallyQuantized24CusparseltLinearWeight, - Int8DynamicallyQuantized24CutlassLinearWeight ) __all__ = [ @@ -19,6 +18,5 @@ "apply_sparse_semi_structured", "apply_fake_sparsity", "Int8DynamicallyQuantizedCusparseltLinearFuseMulWeight", - "Int8DynamicallyQuantizedCusparseltLinearWeight", - "Int8DynamicallyQuantizedCutlassLinearWeight", + "Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight", ] diff --git a/torchao/sparsity/dynamic_quant_sparse.py b/torchao/sparsity/dynamic_quant_sparse.py index df2457c2ab..f6a79bf052 100644 --- a/torchao/sparsity/dynamic_quant_sparse.py +++ b/torchao/sparsity/dynamic_quant_sparse.py @@ -13,10 +13,7 @@ QuantizedLinearWeightBase, ) -from torch.sparse import ( - SparseSemiStructuredTensor, - SparseSemiStructuredTensorCUTLASS, -) +from torch.sparse import to_sparse_semi_structured # Quant + Sparse helper functinos def sparse_quant_int8_dynamic_linear( @@ -148,37 +145,6 @@ def sparse_quant_int8_cutlass_matmul( y = y.to(out_dtype) return y - -class Int8DynamicallyQuantized24CusparseltLinearWeight( - Int8DynamicallyQuantizedLinearWeight -): - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_linear( - act_mat, w_qtensor.int_data, None, w_qtensor.q_scales, bias, act_mat.dtype - ) - - @classmethod - def from_float(cls, input_float, qmin=-8, qmax=7): - - assert input_float.is_cuda - - w_int_repr, w_scales, _ = dynamically_quantize_per_channel( - input_float, qmin, qmax, torch.int8 - ) - - int_data = w_int_repr.contiguous() - int_data = torch._cslt_compress(int_data) - - return cls( - int_data, - w_scales, - False, - input_float.shape, - dtype=input_float.dtype, - ) - class Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight( Int8DynamicallyQuantizedLinearWeight ): @@ -211,7 +177,7 @@ def from_float(cls, input_float, qmin=-8, qmax=7): ) -class Int8DynamicallyQuantized24CutlassLinearWeight(QuantizedLinearWeightBase): +class Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight(QuantizedLinearWeightBase): @staticmethod def __new__(cls, int_data, mask_meta, q_scales, transposed, shape, **kwargs): @@ -321,7 +287,7 @@ def from_float(cls, input_float, qmin=-128, qmax=127): ) int_data = w_int_repr.contiguous() - sparse_tensor = SparseSemiStructuredTensorCUTLASS.from_dense(int_data) + sparse_tensor = to_sparse_semi_structured(int_data) return cls( sparse_tensor.packed, From a7b4f8bd2fe4df1f98269a8edf17bda8b686cede Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 22 Apr 2024 16:57:55 -0700 Subject: [PATCH 16/26] remove file --- sam_benchmark_results.csv | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 sam_benchmark_results.csv diff --git a/sam_benchmark_results.csv b/sam_benchmark_results.csv deleted file mode 100644 index af34848157..0000000000 --- a/sam_benchmark_results.csv +++ /dev/null @@ -1,2 +0,0 @@ -,block_only,batchsize,dtype,compile,qkv,proj,lin1,lin2,time,memory,img/s -0,False,32,torch.bfloat16,True,quant+sparse (cutlass),quant,quant+sparse (cutlass),quant+sparse (cutlass),1269.7276677936316,28.195974144,25.202254634338605 From dafcd6354f3444bbc58fc5bc3fd4ad622cdbba0f Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 22 Apr 2024 17:06:54 -0700 Subject: [PATCH 17/26] remove fuse mul API --- benchmarks/benchmark_sam.py | 2 +- torchao/sparsity/__init__.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmark_sam.py b/benchmarks/benchmark_sam.py index 33dab865fa..e7e6220320 100644 --- a/benchmarks/benchmark_sam.py +++ b/benchmarks/benchmark_sam.py @@ -15,8 +15,8 @@ apply_sparse_semi_structured, apply_fake_sparsity, Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight, - Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, ) +from torchao.sparsity.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight from itertools import product from tqdm import tqdm diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index 0576224ea8..fe4a2008a3 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -7,16 +7,12 @@ from .wanda import WandaSparsifier # noqa: F403 from .utils import PerChannelNormObserver # noqa: F403 from .sparse_api import apply_sparse_semi_structured, apply_fake_sparsity -from .dynamic_quant_sparse import ( - Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight, - Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, -) +from .dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight __all__ = [ "WandaSparsifier", "PerChannelNormObserver", "apply_sparse_semi_structured", "apply_fake_sparsity", - "Int8DynamicallyQuantizedCusparseltLinearFuseMulWeight", "Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight", ] From 070d7735a83cd3c61bd530816090395fd941ba23 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 22 Apr 2024 17:55:33 -0700 Subject: [PATCH 18/26] add test --- test/sparsity/test_sparse_api.py | 65 ++++++++++++++++++++++++ torchao/sparsity/dynamic_quant_sparse.py | 3 +- 2 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 test/sparsity/test_sparse_api.py diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py new file mode 100644 index 0000000000..083cd41ab9 --- /dev/null +++ b/test/sparsity/test_sparse_api.py @@ -0,0 +1,65 @@ +import logging +import unittest + +import torch +from torch import nn + +from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured, Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight +from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, + _get_subclass_inserter, + _is_linear, +) +from torch.testing._internal.common_utils import TestCase + + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) + +class TestSemiStructuredSparse(TestCase): + + def test_sparse(self): + input = torch.rand((128, 128), device="cuda").half() + model = ( + nn.Sequential( + nn.Linear(128, 256), + nn.Linear(256, 128), + ) + .half() + .cuda() + ) + + apply_fake_sparsity(model) + dense_result = model(input) + + apply_sparse_semi_structured(model) + sparse_result = model(input) + + assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + + +class TestQuantSemiSparse(TestCase): + + def test_quant_semi_sparse(self): + input = torch.rand((128, 128), device="cuda").half() + model = ( + nn.Sequential( + nn.Linear(128, 256), + nn.Linear(256, 128), + ) + .half() + .cuda() + ) + + apply_fake_sparsity(model) + dense_result = model(input) + + _replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear) + sparse_result = model(input) + + assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/sparsity/dynamic_quant_sparse.py b/torchao/sparsity/dynamic_quant_sparse.py index f6a79bf052..aced0b82c2 100644 --- a/torchao/sparsity/dynamic_quant_sparse.py +++ b/torchao/sparsity/dynamic_quant_sparse.py @@ -6,6 +6,7 @@ dynamically_quantize_per_channel, quant_int8_dynamic_per_token_linear, quantize_activation_per_token_absmax, + dequantize_per_channel, ) from torchao.quantization.subclass import ( @@ -194,7 +195,7 @@ def dequantize(self, dtype=None): Obtain the dequantized version of the quantized tensor subclass """ dq_t = dequantize_per_channel( - self.int_data.t(), self.q_scales, 0, self.dtype if dtype is None else dtype + self.int_data, self.q_scales, 0, self.dtype if dtype is None else dtype ).to(self.dtype) # data was transposed to dequantize so make sure shape is correct return dq_t if not self.transposed else dq_t.t() From 0a8e226c3d67d23038b1c2b2d16e971241da1ff9 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 23 Apr 2024 18:49:26 -0700 Subject: [PATCH 19/26] move to prototype --- benchmarks/benchmark_sam.py | 3 +-- torchao/sparsity/__init__.py | 2 -- torchao/sparsity/{ => prototype}/dynamic_quant_sparse.py | 0 3 files changed, 1 insertion(+), 4 deletions(-) rename torchao/sparsity/{ => prototype}/dynamic_quant_sparse.py (100%) diff --git a/benchmarks/benchmark_sam.py b/benchmarks/benchmark_sam.py index e7e6220320..14c2d8bc5e 100644 --- a/benchmarks/benchmark_sam.py +++ b/benchmarks/benchmark_sam.py @@ -14,9 +14,8 @@ from torchao.sparsity import ( apply_sparse_semi_structured, apply_fake_sparsity, - Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight, ) -from torchao.sparsity.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight +from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight from itertools import product from tqdm import tqdm diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index fe4a2008a3..6621d086d0 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -7,12 +7,10 @@ from .wanda import WandaSparsifier # noqa: F403 from .utils import PerChannelNormObserver # noqa: F403 from .sparse_api import apply_sparse_semi_structured, apply_fake_sparsity -from .dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight __all__ = [ "WandaSparsifier", "PerChannelNormObserver", "apply_sparse_semi_structured", "apply_fake_sparsity", - "Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight", ] diff --git a/torchao/sparsity/dynamic_quant_sparse.py b/torchao/sparsity/prototype/dynamic_quant_sparse.py similarity index 100% rename from torchao/sparsity/dynamic_quant_sparse.py rename to torchao/sparsity/prototype/dynamic_quant_sparse.py From 4e0c8b33daf8ed45a880186a6fc7e4939e2a7606 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 23 Apr 2024 20:34:19 -0700 Subject: [PATCH 20/26] fix tests --- test/sparsity/test_sparse_api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 083cd41ab9..9d1007004e 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -4,7 +4,8 @@ import torch from torch import nn -from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured, Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight +from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured +from torchao.sparsity.prototype import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, _get_subclass_inserter, From db9ca9cf64b0a412e1c88c4c484eb01d25b50037 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 24 Apr 2024 02:59:32 -0700 Subject: [PATCH 21/26] fix test --- test/sparsity/test_sparse_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 9d1007004e..2f2965753a 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -5,7 +5,7 @@ from torch import nn from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured -from torchao.sparsity.prototype import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight +from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, _get_subclass_inserter, From 7a9b6f9117c7ad29765e2281ae58e6eeb369cbfb Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 25 Apr 2024 11:40:48 -0700 Subject: [PATCH 22/26] added init --- torchao/sparsity/prototype/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 torchao/sparsity/prototype/__init__.py diff --git a/torchao/sparsity/prototype/__init__.py b/torchao/sparsity/prototype/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From fb3beb6b058dc23c34a74bd4967ee1220697f4d2 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 25 Apr 2024 11:59:21 -0700 Subject: [PATCH 23/26] updated test --- test/sparsity/test_sparse_api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 2f2965753a..6c27de00b5 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -11,6 +11,7 @@ _get_subclass_inserter, _is_linear, ) +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 from torch.testing._internal.common_utils import TestCase @@ -20,6 +21,7 @@ class TestSemiStructuredSparse(TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_sparse(self): input = torch.rand((128, 128), device="cuda").half() model = ( @@ -42,6 +44,7 @@ def test_sparse(self): class TestQuantSemiSparse(TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quant_semi_sparse(self): input = torch.rand((128, 128), device="cuda").half() model = ( From 1a1edcc53b71130f16d680eda39d69f432f8e5d6 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 25 Apr 2024 12:56:08 -0700 Subject: [PATCH 24/26] fix test --- test/sparsity/test_sparse_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 6c27de00b5..1774ff8364 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -23,7 +23,7 @@ class TestSemiStructuredSparse(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_sparse(self): - input = torch.rand((128, 128), device="cuda").half() + input = torch.rand((128, 128)).half().cuda() model = ( nn.Sequential( nn.Linear(128, 256), @@ -46,7 +46,7 @@ class TestQuantSemiSparse(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quant_semi_sparse(self): - input = torch.rand((128, 128), device="cuda").half() + input = torch.rand((128, 128)).half().cuda() model = ( nn.Sequential( nn.Linear(128, 256), From 461a30699faed599250e60a7d50a338cf14bee1c Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 25 Apr 2024 19:38:51 -0700 Subject: [PATCH 25/26] skip on pt 2.2 --- test/sparsity/test_sparse_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 1774ff8364..afc9119c91 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -21,6 +21,7 @@ class TestSemiStructuredSparse(TestCase): + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "pytorch 2.3+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_sparse(self): input = torch.rand((128, 128)).half().cuda() @@ -44,6 +45,7 @@ def test_sparse(self): class TestQuantSemiSparse(TestCase): + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "pytorch 2.3+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quant_semi_sparse(self): input = torch.rand((128, 128)).half().cuda() From ddc5dead26204334943f9127a30cd36eb84975d6 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 26 Apr 2024 08:59:51 -0700 Subject: [PATCH 26/26] typo --- test/sparsity/test_sparse_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index afc9119c91..83c0544f6e 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -21,7 +21,7 @@ class TestSemiStructuredSparse(TestCase): - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "pytorch 2.3+ feature") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_sparse(self): input = torch.rand((128, 128)).half().cuda() @@ -45,7 +45,7 @@ def test_sparse(self): class TestQuantSemiSparse(TestCase): - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "pytorch 2.3+ feature") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quant_semi_sparse(self): input = torch.rand((128, 128)).half().cuda()