Skip to content

Commit

Permalink
W4A8 based on CUTLASS
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic committed Nov 19, 2024
1 parent 26648c2 commit f4bffcc
Show file tree
Hide file tree
Showing 13 changed files with 904 additions and 6 deletions.
53 changes: 53 additions & 0 deletions benchmarks/benchmark_s8s4_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import pandas as pd
from torchao.utils import benchmark_torch_function_in_microseconds
from torchao.ops import s8s4_linear_cutlass
from tqdm import tqdm


def get_problem(m, n, k):
groupsize = k

dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((k, n), dtype=torch.half, device=dev)

A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev)
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
B = torch.randint(-128, 127, size=(n, k // 2), dtype=torch.int8, device=dev)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
C = None

return A_ref, B_ref, A, A_scale, B, B_scale, C


def benchmark(m: int, k: int, n: int):
A_ref, B_ref, A, A_scale, B, B_scale, C = get_problem(m, n, k)

fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref)
s8s4_linear_cutlass_time = benchmark_torch_function_in_microseconds(
s8s4_linear_cutlass, A, A_scale, B, B_scale, C
)

return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"s8s4_linear_cutlass latency (ms)": s8s4_linear_cutlass_time,
"speedup (d/s)": fp16_time / s8s4_linear_cutlass_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)

results = []
for m in tqdm([1 << i for i in range(10)]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n))

df = pd.DataFrame(results)
df.to_csv("s8s4_linear_cutlass_time_results.csv", index=False)
print(df.to_markdown(index=False))
14 changes: 14 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ def get_extensions():
extension = CUDAExtension if use_cuda else CppExtension

if not IS_WINDOWS:
try:
import cutlass_library
except:
use_cutlass = False
else:
use_cutlass = True
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
cutlass_include_dir = os.path.join(cutlass_library_dir, "source", "include")

extra_link_args = []
extra_compile_args = {
"cxx": [
Expand All @@ -76,6 +85,11 @@ def get_extensions():
"-t=0",
]
}
if use_cutlass:
extra_compile_args["nvcc"].extend([
"-DTORCHAO_USE_CUTLASS",
"-I" + cutlass_include_dir,
])

if debug_mode:
extra_compile_args["cxx"].append("-g")
Expand Down
79 changes: 79 additions & 0 deletions test/test_s8s4_linear_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import itertools

import torch

import torchao
from torchao.quantization.utils import group_quantize_tensor_symmetric
from torchao.utils import compute_max_diff

import pytest


S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
S8S4_LINEAR_CUTLASS_SIZE_MNK = [
(2, 512, 128),
(3, 2048, 2048),
(4, 3584, 640),
(13, 8704, 8576),
(26, 18944, 1664),
(67, 6656, 1408),
]
S8S4_LINEAR_CUTLASS_USE_BIAS = [False, True]
S8S4_LINEAR_CUTLASS_TEST_PARAMS = list(
itertools.product(
S8S4_LINEAR_CUTLASS_DTYPE,
S8S4_LINEAR_CUTLASS_BATCH_SIZE,
S8S4_LINEAR_CUTLASS_SIZE_MNK,
S8S4_LINEAR_CUTLASS_USE_BIAS,
)
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"dtype, batch_size, size_mnk, use_bias", S8S4_LINEAR_CUTLASS_TEST_PARAMS
)
def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias):
size_m, size_n, size_k = size_mnk

input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None

input_2d = input.view(-1, input.shape[-1])
input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
input_2d, 8, size_k, dtype
)
assert torch.all(input_2d_zeros == 0)
input_s8 = input_2d_s8.reshape(input.shape)
input_scales = input_2d_scales.reshape(input.shape[:-1])

weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
weight, 4, size_n, dtype
)
assert torch.all(weight_zeros == 0)
weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)

# If torch.nn.functional.linear(input, weight, bias) used as
# reference, the error would be too big. The calculation below is
# approximately what s8s4_linear_cutlass kernel is doing (except
# that matrrix multiplication is over integers there)).
size_m_2d = input_2d.shape[0]
output_ref = (
(input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
* input_2d_scales.view(size_m_2d, 1)
* weight_scales.view(1, size_n)
)
if bias is not None:
output_ref += bias
output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))

fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias)
try:
output = torchao.ops.s8s4_linear_cutlass(*fn_inputs)
except NotImplementedError as e:
pytest.xfail("torchao.ops.s8s4_linear_cutlass() op not implemented")

max_diff = compute_max_diff(output, output_ref)
assert max_diff < 5e-3
12 changes: 11 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torchao
import torch._dynamo.config
import torch._inductor.config
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
Expand Down Expand Up @@ -228,6 +229,15 @@ def main(
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
quantize_(model, int8_dynamic_activation_int8_weight())
if "int8adq-int4w-symm" in quantization:
quantize_(
model,
int8_dynamic_activation_int4_weight(
group_size=None,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
)
)
if "int4wo" in quantization:
if "hqq" in quantization:
use_hqq=True
Expand Down Expand Up @@ -486,7 +496,7 @@ def callback(x):
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
'Which quantization techniques to apply: int8dq, int8adq-int4w-symm, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '
+'embed-int8wo, marlin_qqq'
)
Expand Down
Loading

0 comments on commit f4bffcc

Please sign in to comment.