Skip to content
Merged
2 changes: 1 addition & 1 deletion .github/workflows/pr-test-sgl-kernel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ jobs:
timeout-minutes: 30
run: |
cd sgl-kernel
find tests -name "test_*.py" | xargs -n 1 python3
pytest tests/

- name: Uninstall dependencies
run: |
Expand Down
11 changes: 6 additions & 5 deletions sgl-kernel/tests/speculative/test_eagle_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import verify_tree_greedy
Expand Down Expand Up @@ -85,14 +86,14 @@ def test_verify_tree_greedy():
print(f"{accept_index=}")
print(f"{accept_token_num=}")

return predicts, accept_index, accept_token_num


if __name__ == "__main__":
predicts, accept_index, accept_token_num = test_verify_tree_greedy()
# Check the expected output.
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 3, 4, 5],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [3, 2]


if __name__ == "__main__":
pytest.main([__file__])
36 changes: 16 additions & 20 deletions sgl-kernel/tests/speculative/test_speculative_sampling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import tree_speculative_sampling_target_only
Expand Down Expand Up @@ -97,26 +98,21 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
print(f"{accept_index=}")
print(f"{accept_token_num=}")

return predicts, accept_index, accept_token_num
if threshold_single == 1 and threshold_acc == 1:
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 3, 4, 5],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [3, 2]
elif threshold_single == 0 and threshold_acc == 0:
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 1, 2, -1],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [2, 2]


if __name__ == "__main__":
predicts, accept_index, accept_token_num = (
test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1)
)
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 3, 4, 5],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [3, 2]

predicts, accept_index, accept_token_num = (
test_tree_speculative_sampling_target_only(threshold_single=0, threshold_acc=0)
)
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 1, 2, -1],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [2, 2]
pytest.main([__file__])
1 change: 0 additions & 1 deletion sgl-kernel/tests/test_awq_dequant.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,4 @@ def test_awq_dequant_compare_implementations(


if __name__ == "__main__":
# Run the specific test function directly
pytest.main([__file__])
82 changes: 41 additions & 41 deletions sgl-kernel/tests/test_cublas_grouped_gemm.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,49 @@
import unittest

import pytest
import torch
from sgl_kernel import cublas_grouped_gemm


def torch_grouped_gemm(a_array, b_array, out_dtype):
c_array = []
for a, b in zip(a_array, b_array):
c_array.append(torch.matmul(a, b.t()).to(out_dtype))
return c_array


class TestGroupedGemm(unittest.TestCase):
def _test_accuracy(self, Ms, Ns, Ks, out_dtype):
group_count = len(Ms)
a_array = []
b_array = []
c_array_cublas = []
for i in range(group_count):
M, N, K = Ms[i], Ns[i], Ks[i]
a_array.append(torch.randn((M, K), device="cuda", dtype=out_dtype) * 5)
b_array.append(torch.randn((N, K), device="cuda", dtype=out_dtype) * 5)
c_array_cublas.append(torch.empty((M, N), device="cuda", dtype=out_dtype))

c_array_torch = torch_grouped_gemm(a_array, b_array, out_dtype)
cublas_grouped_gemm(a_array, b_array, c_array_cublas, out_dtype)

for i in range(group_count):
M, N, K = Ms[i], Ns[i], Ks[i]
torch.testing.assert_close(c_array_torch[i], c_array_cublas[i])
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")

def test_accuracy(self):
Ms = [1, 16, 32, 256, 1024]
Ns = [2, 16, 128, 256, 4096]
Ks = [3, 16, 32, 512, 8192]
out_dtypes = [torch.float16, torch.bfloat16]
for out_dtype in out_dtypes:
self._test_accuracy(Ms, Ns, Ks, out_dtype)
return [torch.matmul(a, b.t()).to(out_dtype) for a, b in zip(a_array, b_array)]


skip_condition = not torch.cuda.is_available() or (
torch.version.cuda is None
or tuple(map(int, torch.version.cuda.split("."))) < (12, 5)
)


m_values = [1, 16, 32, 256, 1024]
n_values = [2, 16, 128, 256, 4096]
k_values = [3, 16, 32, 512, 8192]


@pytest.mark.skipif(
skip_condition, reason="CUDA not available or CUDA version lower than 12.5"
)
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("M", m_values)
@pytest.mark.parametrize("N", n_values)
@pytest.mark.parametrize("K", k_values)
def test_grouped_gemm_accuracy(out_dtype, M, N, K):
try:
a = torch.randn((M, K), device="cuda", dtype=out_dtype) * 5
b = torch.randn((N, K), device="cuda", dtype=out_dtype) * 5
expected = torch.matmul(a, b.t()).to(out_dtype)

a_array = [a]
b_array = [b]
c_array = [torch.empty((M, N), device="cuda", dtype=out_dtype)]

result_torch = torch_grouped_gemm(a_array, b_array, out_dtype)[0]
cublas_grouped_gemm(a_array, b_array, c_array, out_dtype)

torch.testing.assert_close(result_torch, expected)
torch.testing.assert_close(c_array[0], expected)

except torch.cuda.OutOfMemoryError:
pytest.skip(f"Skipping M={M}, N={N}, K={K} due to OOM")


if __name__ == "__main__":
if torch.cuda.is_available():
cuda_version = tuple(map(int, torch.version.cuda.split(".")))
if cuda_version >= (12, 5):
unittest.main()
else:
print(f"Cuda version {cuda_version} lower than 12.5, not executing tests.")
pytest.main([__file__])
84 changes: 33 additions & 51 deletions sgl-kernel/tests/test_fp8_blockwise_gemm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import unittest
import os
import random
from typing import Optional, Type

import pytest
import torch
from sgl_kernel import fp8_blockwise_scaled_mm


def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)


Expand All @@ -23,7 +24,6 @@ def baseline_scaled_mm(
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match the
# the target shape by repeating the data along that dimension (broadcasting)
Expand Down Expand Up @@ -51,62 +51,44 @@ def group_broadcast(t, shape):

scale_a = group_broadcast(scale_a, a.shape)
scale_b = group_broadcast(scale_b, b.shape)

output = torch.mm(
(scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
).to(out_dtype)

if bias is not None:
output = output + bias

return output


class TestFp8Gemm(unittest.TestCase):
def _test_accuracy_once(self, M, N, K, out_dtype, device):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min

a_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

b_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()

scale_a_group_shape = (1, 128)
scale_b_group_shape = (128, 128)
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)

scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
scale_a = scale_a.t().contiguous().t()
scale_b = scale_b.t().contiguous().t()

o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)

rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")

def test_accuracy(self):
Ms = [1, 128, 512, 1024, 4096]
Ns = [128, 512, 1024, 4096]
Ks = [512, 1024, 4096, 8192, 16384]
out_dtypes = [torch.bfloat16, torch.float16]
for M in Ms:
for N in Ns:
for K in Ks:
for out_dtype in out_dtypes:
self._test_accuracy_once(M, N, K, out_dtype, "cuda")
def _test_accuracy_once(M, N, K, out_dtype, device):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
scale_a_group_shape = (1, 128)
scale_b_group_shape = (128, 128)
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
scale_a = scale_a.t().contiguous().t()
scale_b = scale_b.t().contiguous().t()
o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")


@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096])
@pytest.mark.parametrize("N", [128, 512, 1024, 4096])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
def test_accuracy(M, N, K, out_dtype):
_test_accuracy_once(M, N, K, out_dtype, "cuda")


if __name__ == "__main__":
unittest.main()
pytest.main([__file__])
80 changes: 31 additions & 49 deletions sgl-kernel/tests/test_fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,49 @@
import unittest

import pytest
import torch
from sgl_kernel import fp8_scaled_mm


def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))

o = o.to(torch.float32)
temp1 = o * scale_a.view(-1, 1)
temp2 = temp1 * scale_b.view(1, -1)
final = temp2.to(out_dtype)
if bias is not None:
final = final + bias.view(1, -1)

return final


class TestFp8Gemm(unittest.TestCase):
def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min

a_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

b_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001
if with_bias:
bias = torch.randn((N,), device=device, dtype=out_dtype)
else:
bias = None
o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
b_fp8 = b_fp8.t()
o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")

def test_accuracy(self):
Ms = [1, 128, 512, 1024, 4096]
Ns = [16, 128, 512, 1024, 4096]
Ks = [512, 1024, 4096, 8192, 16384]
bias_opts = [True, False]
out_dtypes = [torch.bfloat16, torch.float16]
for M in Ms:
for N in Ns:
for K in Ks:
for with_bias in bias_opts:
for out_dtype in out_dtypes:
self._test_accuracy_once(
M, N, K, with_bias, out_dtype, "cuda"
)
def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001
if with_bias:
bias = torch.randn((N,), device=device, dtype=out_dtype)
else:
bias = None
b_fp8 = b_fp8.t()
o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")


@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096])
@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("with_bias", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
def test_accuracy(M, N, K, with_bias, out_dtype):
_test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda")


if __name__ == "__main__":
unittest.main()
pytest.main([__file__])
Loading
Loading