Skip to content

Commit

Permalink
W4A8 based on CUTLASS
Browse files Browse the repository at this point in the history
CUTLASS-based s8s4_linear_cutlass() operator is introduced, performing
linear transformation over quantized 8-bit input and quantized 4-bit
weight tensors, with corresponding floating point scale tensors
attached.

A benchmark script, for comparing performance of MM based on this
linear operator with MM over 16-bit floating point tensors is supplied
in benchmarks/benchmarks/benchmark_s8s4_cutlass.py.

The Llama generator script torchao/_models/llama/generate.py is
changed, to add "int8adq-int4w-symm" quantization as an option, that
will in turn activate s8s4_linear_cutlass() operator.  With this type
of quantization activated, i.e. if generate.py script run as follows:

python generate.py --compile --precision=torch.float16 -q int8adq-int4w-symm

the generator achieves around 133 tok/sec on A100, vs. around 93
tok/sec without quantization, i.e. when generate.py script run as
follows:

python generate.py --compile --precision=torch.float16
  • Loading branch information
alexsamardzic committed Jan 3, 2025
1 parent d9fe2c2 commit 9f824fe
Show file tree
Hide file tree
Showing 20 changed files with 1,023 additions and 7 deletions.
1 change: 1 addition & 0 deletions .github/workflows/float8_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
submodules: recursive
script: |
conda create -n venv python=3.9 -y
conda activate venv
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nightly_smoke_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
submodules: recursive
script: |
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
submodules: recursive
script: |
conda create -n venv python=3.9 -y
conda activate venv
Expand Down Expand Up @@ -93,6 +94,7 @@ jobs:
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
submodules: recursive
script: |
conda create -n venv python=3.9 -y
conda activate venv
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://github.com/NVIDIA/cutlass
53 changes: 53 additions & 0 deletions benchmarks/benchmark_s8s4_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import pandas as pd
from torchao.utils import benchmark_torch_function_in_microseconds
from torchao.ops import s8s4_linear_cutlass
from tqdm import tqdm


def get_problem(m, n, k):
groupsize = k

dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((k, n), dtype=torch.half, device=dev)

A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev)
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
B = torch.randint(-128, 127, size=(n, k // 2), dtype=torch.int8, device=dev)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
C = None

return A_ref, B_ref, A, A_scale, B, B_scale, C


def benchmark(m: int, k: int, n: int):
A_ref, B_ref, A, A_scale, B, B_scale, C = get_problem(m, n, k)

fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref)
s8s4_linear_cutlass_time = benchmark_torch_function_in_microseconds(
s8s4_linear_cutlass, A, A_scale, B, B_scale, C
)

return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"s8s4_linear_cutlass latency (ms)": s8s4_linear_cutlass_time,
"speedup (d/s)": fp16_time / s8s4_linear_cutlass_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)

results = []
for m in tqdm([1 << i for i in range(10)]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n))

df = pd.DataFrame(results)
df.to_csv("s8s4_linear_cutlass_time_results.csv", index=False)
print(df.to_markdown(index=False))
12 changes: 12 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ def get_extensions():
extra_compile_args["nvcc"].append("-g")
extra_link_args.append("/DEBUG")

use_cutlass = False
if use_cuda and not IS_WINDOWS:
use_cutlass = True
this_dir = os.path.abspath(os.path.curdir)
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass")
cutlass_include_dir = os.path.join(cutlass_dir, "include")
if use_cutlass:
extra_compile_args["nvcc"].extend([
"-DTORCHAO_USE_CUTLASS",
"-I" + cutlass_include_dir,
])

this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
Expand Down
11 changes: 10 additions & 1 deletion test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
run_tests,
)

from torchao.dtypes import Int4CPULayout, SemiSparseLayout
from torchao.dtypes import Int4CPULayout, CutlassInt4PackedLayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
int4_weight_only,
Expand Down Expand Up @@ -48,6 +48,15 @@ def get_quantization_functions(
)
else:
base_functions.append(int4_weight_only(group_size=32))
if device == "cuda":
base_functions.append(
int8_dynamic_activation_int4_weight(
group_size=None,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=CutlassInt4PackedLayout(),
)
)

if do_sparse:
base_functions.append(
Expand Down
80 changes: 80 additions & 0 deletions test/test_s8s4_linear_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import itertools

import torch

import torchao
from torchao.ops import s8s4_linear_cutlass
from torchao.quantization.utils import group_quantize_tensor_symmetric
from torchao.utils import compute_max_diff

import pytest


S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
S8S4_LINEAR_CUTLASS_SIZE_MNK = [
(2, 512, 128),
(3, 2048, 2048),
(4, 3584, 640),
(13, 8704, 8576),
(26, 18944, 1664),
(67, 6656, 1408),
]
S8S4_LINEAR_CUTLASS_USE_BIAS = [False, True]
S8S4_LINEAR_CUTLASS_TEST_PARAMS = list(
itertools.product(
S8S4_LINEAR_CUTLASS_DTYPE,
S8S4_LINEAR_CUTLASS_BATCH_SIZE,
S8S4_LINEAR_CUTLASS_SIZE_MNK,
S8S4_LINEAR_CUTLASS_USE_BIAS,
)
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"dtype, batch_size, size_mnk, use_bias", S8S4_LINEAR_CUTLASS_TEST_PARAMS
)
def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias):
size_m, size_n, size_k = size_mnk

input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None

input_2d = input.view(-1, input.shape[-1])
input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
input_2d, 8, size_k, dtype
)
assert torch.all(input_2d_zeros == 0)
input_s8 = input_2d_s8.reshape(input.shape)
input_scales = input_2d_scales.reshape(input.shape[:-1])

weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
weight, 4, size_n, dtype
)
assert torch.all(weight_zeros == 0)
weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)

# If torch.nn.functional.linear(input, weight, bias) used as
# reference, the error would be too big. The calculation below is
# approximately what s8s4_linear_cutlass kernel is doing (except
# that matrrix multiplication is over integers there)).
size_m_2d = input_2d.shape[0]
output_ref = (
(input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
* input_2d_scales.view(size_m_2d, 1)
* weight_scales.view(1, size_n)
)
if bias is not None:
output_ref += bias
output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))

fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias)
try:
output = s8s4_linear_cutlass(*fn_inputs)
except NotImplementedError as e:
pytest.xfail("s8s4_linear_cutlass() op not implemented")

max_diff = compute_max_diff(output, output_ref)
assert max_diff < 5e-3
1 change: 1 addition & 0 deletions third_party/cutlass
Submodule cutlass added at bf9da7
13 changes: 12 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,17 @@ def ffn_or_attn_only(mod, fqn):
]
), f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size))
elif "int8adq-int4w-symm" in quantization:
from torchao.dtypes import CutlassInt4PackedLayout
quantize_(
model,
int8_dynamic_activation_int4_weight(
group_size=None,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=CutlassInt4PackedLayout(),
)
)
if "marlin" in quantization:
if "qqq" in quantization:
from torchao.dtypes import MarlinQQQLayout
Expand Down Expand Up @@ -1058,7 +1069,7 @@ def callback(x):
help=(
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, "
+ "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>"
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, int8adq-int4w-symm"
),
)
parser.add_argument(
Expand Down
Loading

0 comments on commit 9f824fe

Please sign in to comment.