-
Notifications
You must be signed in to change notification settings - Fork 179
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
26648c2
commit f4bffcc
Showing
13 changed files
with
904 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torch | ||
import pandas as pd | ||
from torchao.utils import benchmark_torch_function_in_microseconds | ||
from torchao.ops import s8s4_linear_cutlass | ||
from tqdm import tqdm | ||
|
||
|
||
def get_problem(m, n, k): | ||
groupsize = k | ||
|
||
dev = torch.device("cuda") | ||
A_ref = torch.randn((m, k), dtype=torch.half, device=dev) | ||
B_ref = torch.randn((k, n), dtype=torch.half, device=dev) | ||
|
||
A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev) | ||
A_scale = torch.randn((m,), dtype=torch.half, device=dev) | ||
B = torch.randint(-128, 127, size=(n, k // 2), dtype=torch.int8, device=dev) | ||
B_scale = torch.randn((n,), dtype=torch.half, device=dev) | ||
C = None | ||
|
||
return A_ref, B_ref, A, A_scale, B, B_scale, C | ||
|
||
|
||
def benchmark(m: int, k: int, n: int): | ||
A_ref, B_ref, A, A_scale, B, B_scale, C = get_problem(m, n, k) | ||
|
||
fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref) | ||
s8s4_linear_cutlass_time = benchmark_torch_function_in_microseconds( | ||
s8s4_linear_cutlass, A, A_scale, B, B_scale, C | ||
) | ||
|
||
return { | ||
"m": m, | ||
"k": k, | ||
"n": n, | ||
"fp16_latency (ms)": fp16_time, | ||
"s8s4_linear_cutlass latency (ms)": s8s4_linear_cutlass_time, | ||
"speedup (d/s)": fp16_time / s8s4_linear_cutlass_time, | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
k_vals = (8192, 8192, 8192, 28672) | ||
n_vals = (8192, 10240, 57344, 8192) | ||
|
||
results = [] | ||
for m in tqdm([1 << i for i in range(10)]): | ||
for n, k in zip(n_vals, k_vals): | ||
results.append(benchmark(m, k, n)) | ||
|
||
df = pd.DataFrame(results) | ||
df.to_csv("s8s4_linear_cutlass_time_results.csv", index=False) | ||
print(df.to_markdown(index=False)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import itertools | ||
|
||
import torch | ||
|
||
import torchao | ||
from torchao.quantization.utils import group_quantize_tensor_symmetric | ||
from torchao.utils import compute_max_diff | ||
|
||
import pytest | ||
|
||
|
||
S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] | ||
S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] | ||
S8S4_LINEAR_CUTLASS_SIZE_MNK = [ | ||
(2, 512, 128), | ||
(3, 2048, 2048), | ||
(4, 3584, 640), | ||
(13, 8704, 8576), | ||
(26, 18944, 1664), | ||
(67, 6656, 1408), | ||
] | ||
S8S4_LINEAR_CUTLASS_USE_BIAS = [False, True] | ||
S8S4_LINEAR_CUTLASS_TEST_PARAMS = list( | ||
itertools.product( | ||
S8S4_LINEAR_CUTLASS_DTYPE, | ||
S8S4_LINEAR_CUTLASS_BATCH_SIZE, | ||
S8S4_LINEAR_CUTLASS_SIZE_MNK, | ||
S8S4_LINEAR_CUTLASS_USE_BIAS, | ||
) | ||
) | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.parametrize( | ||
"dtype, batch_size, size_mnk, use_bias", S8S4_LINEAR_CUTLASS_TEST_PARAMS | ||
) | ||
def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias): | ||
size_m, size_n, size_k = size_mnk | ||
|
||
input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda") | ||
weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda") | ||
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None | ||
|
||
input_2d = input.view(-1, input.shape[-1]) | ||
input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric( | ||
input_2d, 8, size_k, dtype | ||
) | ||
assert torch.all(input_2d_zeros == 0) | ||
input_s8 = input_2d_s8.reshape(input.shape) | ||
input_scales = input_2d_scales.reshape(input.shape[:-1]) | ||
|
||
weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric( | ||
weight, 4, size_n, dtype | ||
) | ||
assert torch.all(weight_zeros == 0) | ||
weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF) | ||
|
||
# If torch.nn.functional.linear(input, weight, bias) used as | ||
# reference, the error would be too big. The calculation below is | ||
# approximately what s8s4_linear_cutlass kernel is doing (except | ||
# that matrrix multiplication is over integers there)). | ||
size_m_2d = input_2d.shape[0] | ||
output_ref = ( | ||
(input_2d_s8.to(dtype) @ weight_s8.to(dtype).T) | ||
* input_2d_scales.view(size_m_2d, 1) | ||
* weight_scales.view(1, size_n) | ||
) | ||
if bias is not None: | ||
output_ref += bias | ||
output_ref = output_ref.reshape(input.shape[:-1] + (size_n,)) | ||
|
||
fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias) | ||
try: | ||
output = torchao.ops.s8s4_linear_cutlass(*fn_inputs) | ||
except NotImplementedError as e: | ||
pytest.xfail("torchao.ops.s8s4_linear_cutlass() op not implemented") | ||
|
||
max_diff = compute_max_diff(output, output_ref) | ||
assert max_diff < 5e-3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.