Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sgl-kernel/tests/test_awq_dequant.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,4 @@ def test_awq_dequant_compare_implementations(


if __name__ == "__main__":
# Run the specific test function directly
pytest.main([__file__])
64 changes: 32 additions & 32 deletions sgl-kernel/tests/test_cublas_grouped_gemm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import unittest

import pytest
import torch
from sgl_kernel import cublas_grouped_gemm

Expand All @@ -11,39 +10,40 @@ def torch_grouped_gemm(a_array, b_array, out_dtype):
return c_array


class TestGroupedGemm(unittest.TestCase):
def _test_accuracy(self, Ms, Ns, Ks, out_dtype):
group_count = len(Ms)
a_array = []
b_array = []
c_array_cublas = []
for i in range(group_count):
M, N, K = Ms[i], Ns[i], Ks[i]
a_array.append(torch.randn((M, K), device="cuda", dtype=out_dtype) * 5)
b_array.append(torch.randn((N, K), device="cuda", dtype=out_dtype) * 5)
c_array_cublas.append(torch.empty((M, N), device="cuda", dtype=out_dtype))
# skip if CUDA is not available or CUDA < 12.5
skip_condition = not torch.cuda.is_available() or (
torch.version.cuda is None
or tuple(map(int, torch.version.cuda.split("."))) < (12, 5)
)


@pytest.mark.skipif(
skip_condition, reason="CUDA not available or CUDA version lower than 12.5"
)
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
def test_grouped_gemm_accuracy(out_dtype):
Ms = [1, 16, 32, 256, 1024]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For different shapes you should use @pytest.mark.parametrize instead of for loop

Ns = [2, 16, 128, 256, 4096]
Ks = [3, 16, 32, 512, 8192]
group_count = len(Ms)

a_array = []
b_array = []
c_array_cublas = []

c_array_torch = torch_grouped_gemm(a_array, b_array, out_dtype)
cublas_grouped_gemm(a_array, b_array, c_array_cublas, out_dtype)
for i in range(group_count):
M, N, K = Ms[i], Ns[i], Ks[i]
a_array.append(torch.randn((M, K), device="cuda", dtype=out_dtype) * 5)
b_array.append(torch.randn((N, K), device="cuda", dtype=out_dtype) * 5)
c_array_cublas.append(torch.empty((M, N), device="cuda", dtype=out_dtype))

for i in range(group_count):
M, N, K = Ms[i], Ns[i], Ks[i]
torch.testing.assert_close(c_array_torch[i], c_array_cublas[i])
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
c_array_torch = torch_grouped_gemm(a_array, b_array, out_dtype)
cublas_grouped_gemm(a_array, b_array, c_array_cublas, out_dtype)

def test_accuracy(self):
Ms = [1, 16, 32, 256, 1024]
Ns = [2, 16, 128, 256, 4096]
Ks = [3, 16, 32, 512, 8192]
out_dtypes = [torch.float16, torch.bfloat16]
for out_dtype in out_dtypes:
self._test_accuracy(Ms, Ns, Ks, out_dtype)
for i in range(group_count):
torch.testing.assert_close(c_array_torch[i], c_array_cublas[i])
print(f"M={Ms[i]}, N={Ns[i]}, K={Ks[i]}, out_dtype={out_dtype}: OK")


if __name__ == "__main__":
if torch.cuda.is_available():
cuda_version = tuple(map(int, torch.version.cuda.split(".")))
if cuda_version >= (12, 5):
unittest.main()
else:
print(f"Cuda version {cuda_version} lower than 12.5, not executing tests.")
pytest.main([__file__])
Loading
Loading