From 1bacd029c6e01b5812e7aa0654a7cc4c3f168b64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Mon, 9 Sep 2024 21:32:06 +0200 Subject: [PATCH] W4A8 based on CUTLASS --- setup.py | 7 + test/test_s8s4_linear_cutlass.py | 46 ++ .../s8s4_linear_cutlass.cu | 420 ++++++++++++++++++ torchao/csrc/s8s4_linear_cutlass.cpp | 8 + torchao/dtypes/affine_quantized_tensor.py | 52 ++- torchao/ops.py | 33 ++ torchao/quantization/quant_api.py | 8 +- 7 files changed, 572 insertions(+), 2 deletions(-) create mode 100644 test/test_s8s4_linear_cutlass.py create mode 100644 torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu create mode 100644 torchao/csrc/s8s4_linear_cutlass.cpp diff --git a/setup.py b/setup.py index 229e18eec6..400655fdf5 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,12 @@ def get_extensions(): extension = CUDAExtension if use_cuda else CppExtension if not IS_WINDOWS: + import cutlass_library + cutlass_library_dir = os.path.dirname(cutlass_library.__file__) + cutlass_include_dir = os.path.join(cutlass_library_dir, "source", "include") + # FIXME: remove this once CUTLASS package updated to include int4/int8 MM + cutlass_include_dir = "/data/quansight/scratch/cutlass/include" + extra_link_args = [] extra_compile_args = { "cxx": [ @@ -74,6 +80,7 @@ def get_extensions(): "nvcc": [ "-O3" if not debug_mode else "-O0", "-t=0", + "-I" + cutlass_include_dir, ] } diff --git a/test/test_s8s4_linear_cutlass.py b/test/test_s8s4_linear_cutlass.py new file mode 100644 index 0000000000..78209ba598 --- /dev/null +++ b/test/test_s8s4_linear_cutlass.py @@ -0,0 +1,46 @@ +# FIXME: move this test to the appropriate test file!!! + +import copy + +from torchao.quantization import quantize_ +from torchao.quantization.quant_api import int8_dynamic_activation_int4_weight + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +import pytest + + +class ToyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 32) + + def forward(self, x): + x = self.linear(x) + return x + + +class TestS8S4LinearCUTLASS(TestCase): + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_s8s4_linear_cutlass_(self): + # FIXME: remove this! + torch.manual_seed(0) + + input = torch.rand((64, 128)).half().cuda() + model = ToyModel().half().cuda() + + output_ref = model(input) + + modelq = copy.deepcopy(model) + quantize_(modelq, int8_dynamic_activation_int4_weight(group_size=128)) + output = modelq(input) + + assert torch.allclose(output, output_ref, rtol=1e-1, atol=0) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu new file mode 100644 index 0000000000..426b7034b3 --- /dev/null +++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu @@ -0,0 +1,420 @@ +#include + +#include +#include +#include + +#if defined(_MSC_VER) || (CUDA_VERSION < 11080) +#else +#include +#include +#include +#include +#include + +#define CUTLASS_STATUS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + __func__, " : Got CUTLASS error: ", \ + cutlassGetStatusString(status)); \ + } +#endif + +namespace torchao { + +#if defined(_MSC_VER) || (CUDA_VERSION < 11080) +#else +template< + typename ElementA, + typename ElementAScale, + typename ElementB, + typename ElementBScale, + typename ElementC, + typename ElementAccumulator, + typename ElementEpilogue, + typename ElementOutput, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape> +void s8s4_linear_kernel_cutlass( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + const int m = tensor_a.size(0); + const int n = tensor_b.size(0); + const int k = tensor_a.size(1); + + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentAScale = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentBScale = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + // FIXME: re-check this!!! + // Check for current CUTLASS limitations w.r.t. alignments. + TORCH_CHECK(k % AlignmentA == 0, + __func__, " : Number of columns of tensor A must be divisible ", + "by ", AlignmentA); + TORCH_CHECK(k % AlignmentB == 0, + __func__, " : Number of columns of tensor B must be divisible ", + "by ", AlignmentB); + TORCH_CHECK(n % AlignmentC == 0, + __func__, " : Number of columns of tensor C must be divisible ", + "by ", AlignmentC); + + using SmArch = cutlass::arch::Sm80; + using ThreadblockSwizzle = cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; + constexpr auto NumStages = 4; + + constexpr auto NumEVTEpilogueStages = 1; + + using TensorAScaleTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementAScale, + AlignmentAScale, + NumEVTEpilogueStages + >; + using TensorBScaleTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementBScale, + AlignmentBScale, + NumEVTEpilogueStages + >; + using TensorCTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementC, + AlignmentC, + NumEVTEpilogueStages + >; + using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementOutput, + AlignmentOutput, + NumEVTEpilogueStages + >; + + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + using TensorAScale = + cutlass::epilogue::threadblock::VisitorColBroadcast< + TensorAScaleTileThreadMap, + ElementAScale, + cute::Stride>; + using TensorAScaleArguments = typename TensorAScale::Arguments; + + using TensorBScale = + cutlass::epilogue::threadblock::VisitorRowBroadcast< + TensorBScaleTileThreadMap, + ElementBScale, + cute::Stride>; + using TensorBScaleArguments = typename TensorBScale::Arguments; + + using TensorC = + cutlass::epilogue::threadblock::VisitorRowBroadcast< + TensorCTileThreadMap, + ElementC, + cute::Stride>; + using TensorCArguments = typename TensorC::Arguments; + + using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT< + ApplyAScale, + Accum, + TensorAScale>; + + using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT< + ApplyBScale, + EVTApplyAScale, + TensorBScale>; + + using ApplySum = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT< + ApplySum, + EVTApplyBScale, + TensorC>; + + using Output = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest, + cute::Stride // StrideMNL + >; + + using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT< + Output, + EVTApplySum>; + + using EVTKernel = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, + ElementC, LayoutC, AlignmentC, + ElementAccumulator, + ElementEpilogue, + cutlass::arch::OpClassTensorOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EVTOutput, + ThreadblockSwizzle, + NumStages, + cutlass::arch::OpMultiplyAddMixedInputUpcast, + NumEVTEpilogueStages + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalBase; + + cutlass::gemm::GemmCoord problem_size(m, n, k); + constexpr auto SplitKFactor = 1; + + TensorAScaleArguments tensor_a_scale_arguments{ + (ElementAScale*)tensor_a_scale.data_ptr(), + ElementAScale(1), + {cute::_1{}, cute::_0{}, problem_size.m()} + }; + TensorBScaleArguments tensor_b_scale_arguments{ + (ElementBScale*)tensor_b_scale.data_ptr(), + ElementBScale(1), + {cute::_0{}, cute::_1{}, problem_size.n()} + }; + TensorCArguments tensor_c_arguments{ + (ElementC*)tensor_c.data_ptr(), + ElementC(0), + {cute::_0{}, cute::_1{}, problem_size.n()} + }; + typename Output::Arguments output_arguments{ + (ElementOutput*)tensor_d.data_ptr(), + {problem_size.n(), cute::_1{}, problem_size.mn().product()} + }; + typename EVTOutput::Arguments callback_arguments{ + { + { + { + {}, // Accum + tensor_a_scale_arguments, // TensorAScale + {} // ApplyAScale + }, // EVTApplyAScale + tensor_b_scale_arguments, // TensorBScale + {}, // ApplyBScale + }, // EVTApplyBScale + tensor_c_arguments, // TensorC + {} // ApplySum + }, // EVTApplySum + output_arguments // Output + }; // EVTOutput + constexpr auto AvailSms = -1; + + typename Gemm::Arguments arguments( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + SplitKFactor, + callback_arguments, // arguments of EVT callbacks + (ElementA*)tensor_a.data_ptr(), + (ElementB*)tensor_b.data_ptr(), + nullptr, // ptr C (unused) + nullptr, // ptr D (unused) + problem_size.mk().product(), // batch stride A + problem_size.nk().product(), // batch stride B + 0, // batch stride C (unused) + 0, // batch stride D (unused) + problem_size.k(), // stride A + problem_size.k(), // stride B + 0, // stride C (unused) + 0, // stride D (unused) + AvailSms); + + Gemm gemm_op; + + cutlass::Status status; + + // Verify that GEMM operation with given arguments can be performed + // by CUTLASS. + status = gemm_op.can_implement(arguments); + CUTLASS_STATUS_CHECK(status); + + // Allocate workspace for CUTLASS mixed datatypes GEMM kernel. + const auto workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = tensor_a.new_empty({(int64_t)workspace_size}, + at::TensorOptions().dtype(at::kByte)); + + // Initialize CUTLASS mixed datatypes GEMM object. + status = gemm_op.initialize(arguments, workspace.data_ptr(), + at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status); + + // Perform mixed datatypes GEMM operation. + status = gemm_op.run(at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status); +} +#endif + +// Perform linear operation, using corresponding CUTLASS mixed +// data-types GEMM kernel, to given arguments: +// result = (input * input_scale) @ (weight * weight_scale).T + bias +// Notes: The "input_scale" tensor is expected to be a vector, of size +// equal to number of rows of "input" tensor. The "weight_scale" +// tensor is expected to be a vector, of size equal to number of rows +// of "weight" tensor. The "bias" tensor is expected to be a vector, +// of size equal to number of rows of "weight" tensor. +at::Tensor +s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, + const at::Tensor& weight, const at::Tensor& weight_scale, + const at::Tensor& bias) { +#if defined(_MSC_VER) || (CUDA_VERSION < 11080) + AT_ERROR(__func__, " : CUTLASS not supported"); + return at::Tensor{}; +#else + + // For now, only CC 8.x devices are supported. + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm8x = dprops->major == 8; + TORCH_CHECK(is_sm8x, + __func__, " : Supported only on GPUs with compute capability " + "8.x"); + + // Validate datatypes of input tensors. + TORCH_CHECK(input.dtype() == at::kChar, + __func__, " : The input datatype ", input.dtype(), + " not supported"); + TORCH_CHECK(input_scale.dtype() == at::kHalf, + __func__, " : The input scale datatype ", input_scale.dtype(), + " not supported"); + TORCH_CHECK(weight.dtype() == at::kChar, " : The weight datatype ", + weight.dtype(), " not supported"); + TORCH_CHECK(weight_scale.dtype() == at::kHalf, + __func__, " : The weight scale datatype ", weight_scale.dtype(), + " not supported"); + TORCH_CHECK(bias.dtype() == at::kHalf, + __func__, " : Expected bias datatype ", bias.dtype(), ", got", + bias.dtype()); + + // Validate layouts of input tensors. + TORCH_CHECK(input.layout() == at::Layout::Strided, + __func__, " : Expected input argument to be strided, got layout ", + input.layout()); + TORCH_CHECK(input.dim() == 2, + __func__, " : Expected input argument to be 2D tensor, got ", + input.dim(), " dims"); + const auto input_strides = input.strides(); + TORCH_CHECK(input_strides[0] >= 1 && input_strides[1] == 1, + __func__, " : Invalid strides for input argument: row stride = ", + input_strides[0], ", column stride = ", input_strides[1]); + TORCH_CHECK(input_scale.layout() == at::Layout::Strided, + __func__, " : Expected input scale argument to be strided, got " + "layout ", input_scale.layout()); + TORCH_CHECK(input_scale.dim() == 1, + __func__, " : Expected input scale argument to be 1D tensor, ", + "got ", input_scale.dim(), " dims"); + const auto input_scale_strides = input_scale.strides(); + TORCH_CHECK(input_scale_strides[0] == 1, + __func__, " : Invalid strides for input scale argument: element ", + "stride = ", input_scale_strides[0]); + TORCH_CHECK(weight.layout() == at::Layout::Strided, + __func__, + " : Expected weight argument to be strided, got layout ", + weight.layout()); + TORCH_CHECK(weight.dim() == 2, + __func__, " : Expected weight argument to be 2D tensor, got ", + weight.dim(), " dims"); + const auto weight_strides = weight.strides(); + TORCH_CHECK(weight_strides[0] >= 1 && weight_strides[1] == 1, + __func__, " : Invalid strides for weight argument: row stride = ", + weight_strides[0], ", column stride = ", weight_strides[1]); + TORCH_CHECK(weight_scale.layout() == at::Layout::Strided, + __func__, " : Expected weight scale argument to be strided, got " + "layout ", weight_scale.layout()); + TORCH_CHECK(weight_scale.dim() == 1 || weight_scale.dim() == 2, + __func__, " : Expected weight scale argument to be 1D tensor, ", + "got ", weight_scale.dim(), " dims"); + const auto weight_scale_strides = weight_scale.strides(); + TORCH_CHECK(weight_scale_strides[0] == 1, + __func__, " : Invalid strides for weight scale argument: ", + "element stride = ", weight_scale_strides[0]); + TORCH_CHECK(bias.layout() == at::Layout::Strided, + __func__, " : Expected bias argument to be strided, got layout ", + bias.layout()); + TORCH_CHECK(bias.dim() == 1, + __func__, " : Expected bias argument to be 1D tensor, got ", + bias.dim(), " dims"); + const auto bias_strides = bias.strides(); + TORCH_CHECK(bias_strides[0] == 1, + __func__, " : Invalid strides for bias argument: element stride ", + "= ", bias_strides[0]); + + // Validate sizes of input tensors. + TORCH_CHECK(input.size(1) == 2 * weight.size(1), + __func__, " : Expected input argument to have ", + 2 * weight.size(1), " columns, but got ", input.size(1)); + TORCH_CHECK(input_scale.numel() == input.size(0), + __func__, " : Expected input scale argument to have ", + input.size(0), " elements, got ", input_scale.numel(), + " elements"); + TORCH_CHECK(weight_scale.numel() == weight.size(0), + __func__, " : Expected weight scale argument to have ", + weight.size(0), " elements, got ", weight_scale.numel(), + " elements"); + TORCH_CHECK(bias.numel() == weight.size(0), + __func__, " : Expected bias argument to have ", weight.size(0), + " elements, got ", bias.numel(), " elements"); + + // Introduce alias names for arguments, according to the CUTLASS + // naming conventions. + const auto& tensor_a = input; + const auto& tensor_a_scale = input_scale; + const auto& tensor_b = weight; + const auto& tensor_b_scale = weight_scale; + const auto& tensor_c = bias; + + // Create output tensor. + at::Tensor tensor_d = + tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)}); + + using ElementA = int8_t; + using ElementAScale = cutlass::half_t; + using ElementB = cutlass::int4b_t; + using ElementBScale = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = int32_t; + using ElementEpilogue = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + s8s4_linear_kernel_cutlass< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, ThreadblockShape, + WarpShape, InstructionShape>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + + return tensor_d; +#endif +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::s8s4_linear_cutlass", &s8s4_linear_cutlass); +} + +} // namespace torchao diff --git a/torchao/csrc/s8s4_linear_cutlass.cpp b/torchao/csrc/s8s4_linear_cutlass.cpp new file mode 100644 index 0000000000..cc82ff5bfe --- /dev/null +++ b/torchao/csrc/s8s4_linear_cutlass.cpp @@ -0,0 +1,8 @@ +#include +#include +#include + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + m.impl_abstract_pystub("torchao.ops"); + m.def("s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"); +} diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 142a49a368..2966e873c6 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1130,6 +1130,15 @@ def _aqt_is_int8_reduced_range(aqt): aqt.quant_max is None or aqt.quant_max == 127 ) +def _aqt_is_int4(aqt): + """Check if an AffineQuantizedTensor is int4 quantized Tensor""" + # TODO: use torch.int4 + return ( + aqt.layout_tensor.dtype == torch.int32 and + aqt.quant_min is None or aqt.quant_min == -8 and + aqt.quant_max is None or aqt.quant_max == 7 + ) + def _aqt_is_uint4(aqt): """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" # TODO: use torch.uint4 @@ -1151,7 +1160,9 @@ def _aqt_is_uint4(aqt): def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): return ( isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and + # FIXME: revert this when zero-point properly handled by kernel!!! + #_aqt_is_int8_reduced_range(input_tensor) and + not _aqt_is_int8_reduced_range(input_tensor) and isinstance(weight_tensor, AffineQuantizedTensor) and weight_tensor.is_cuda and input_tensor.dtype == weight_tensor.dtype and @@ -1469,6 +1480,44 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b return out +def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): + # FIXME: refine these checks!!! + return ( + isinstance(input_tensor, AffineQuantizedTensor) and + _aqt_is_int8(input_tensor) and + input_tensor.dtype == torch.float16 and + len(input_tensor.shape) == 2 and + isinstance(weight_tensor, AffineQuantizedTensor) and + _aqt_is_int4(weight_tensor) and + input_tensor.dtype == torch.float16 and + len(weight_tensor.shape) == 2 + ) + +def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): + from torchao.ops import s8s4_linear_cutlass + + assert isinstance(input_tensor, AffineQuantizedTensor) + assert isinstance(weight_tensor, AffineQuantizedTensor) + + input = input_tensor.layout_tensor.int_data + input_scale = input_tensor.layout_tensor.scale + + weight = weight_tensor.layout_tensor.int_data + weight_scale = weight_tensor.layout_tensor.scale + + out = s8s4_linear_cutlass( + input, input_scale, weight, weight_scale, bias + ) + + # FIXME: remove this! + m, k = input.shape + n, _ = weight.shape + # This is the calculation that s8s4_linear_cutlass() is performing. + #out = (input.to(torch.half) @ weight.to(torch.half).T) * input_scale.view(m, 1).expand(m, n) * weight_scale.expand(m, n) + bias + + return out + + def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), @@ -1478,6 +1527,7 @@ def _register_aqt_quantized_linear_dispatches(): (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_floatx_weight_check, _linear_f16_act_floatx_weight_impl), (_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl), + (_linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/ops.py b/torchao/ops.py index 7f7adab864..154d59caea 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -261,3 +261,36 @@ def _( torch._check(workspace.numel() >= min_workspace_size, lambda: f"workspace.numel = {workspace.numel()} is below min_workspace_size = {min_workspace_size}") return torch.empty((x.size(0), s.size(1)), dtype=x.dtype, device=x.device) + + +def s8s4_linear_cutlass( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_scale: Tensor, + bias: Tensor, +) -> Tensor: + # FIXME: write docs!!! + """ + """ + + # FIXME: this should be done by quantizer!!! + weight = ((weight[:, 1::2] & 0xF) << 4) | (weight[:, 0::2] & 0xF) + + return torch.ops.torchao.s8s4_linear_cutlass.default( + input, input_scale, weight, weight_scale, bias + ) + + +@register_custom_op(f"torchao::s8s4_linear_cutlass") +def _( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_scale: Tensor, + bias: Tensor, +) -> Tensor: + # FIXME: implement all checks from s8s4_linear_cutlass() here!!! + + m, n = input.size(0), weight.size(0) + return torch.empty((m, n), dtype=input_scale.dtype, device=input.device) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index cf5aab2800..8bf20e9a8a 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -490,6 +490,9 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32, mappi # input settings input_quant_func = _int8_asymm_per_token_quant + # FIXME: remove this when zero-point properly handled by kernel!!! + input_quant_func = _int8_symm_per_token_reduced_range_quant + weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps) weight = to_linear_activation_quantized(weight, input_quant_func) return weight @@ -575,7 +578,10 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: eps = 1e-5 quant_min = -127 quant_max = 127 - return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) + + # FIXME: revert this when zero-point properly handled by kernel!!! + #return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) + return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float16 if x.dtype == torch.float16 else None) def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()):