From e6500666390bc18991174313ba3c6f95c057cd58 Mon Sep 17 00:00:00 2001 From: adarshxs Date: Sun, 30 Mar 2025 00:16:38 +0530 Subject: [PATCH 01/11] Parameterize --- .../tests/speculative/test_eagle_utils.py | 11 +- .../speculative/test_speculative_sampling.py | 36 ++- sgl-kernel/tests/test_awq_dequant.py | 1 - sgl-kernel/tests/test_cublas_grouped_gemm.py | 82 +++---- sgl-kernel/tests/test_deep_gemm.py | 232 ++++++++++++++++++ sgl-kernel/tests/test_fp8_blockwise_gemm.py | 98 +++----- sgl-kernel/tests/test_fp8_gemm.py | 80 +++--- sgl-kernel/tests/test_int8_gemm.py | 61 ++--- sgl-kernel/tests/test_per_token_quant_fp8.py | 1 - 9 files changed, 385 insertions(+), 217 deletions(-) create mode 100644 sgl-kernel/tests/test_deep_gemm.py diff --git a/sgl-kernel/tests/speculative/test_eagle_utils.py b/sgl-kernel/tests/speculative/test_eagle_utils.py index 1514029ecf1..12aa2e4981a 100644 --- a/sgl-kernel/tests/speculative/test_eagle_utils.py +++ b/sgl-kernel/tests/speculative/test_eagle_utils.py @@ -1,3 +1,4 @@ +import pytest import torch import torch.nn.functional as F from sgl_kernel import verify_tree_greedy @@ -85,14 +86,14 @@ def test_verify_tree_greedy(): print(f"{accept_index=}") print(f"{accept_token_num=}") - return predicts, accept_index, accept_token_num - - -if __name__ == "__main__": - predicts, accept_index, accept_token_num = test_verify_tree_greedy() + # Check the expected output. assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] assert accept_index.tolist() == [ [0, 3, 4, 5], [6, 10, 11, -1], ] assert accept_token_num.tolist() == [3, 2] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/speculative/test_speculative_sampling.py b/sgl-kernel/tests/speculative/test_speculative_sampling.py index 2d45db2d04b..93f3f509357 100644 --- a/sgl-kernel/tests/speculative/test_speculative_sampling.py +++ b/sgl-kernel/tests/speculative/test_speculative_sampling.py @@ -1,3 +1,4 @@ +import pytest import torch import torch.nn.functional as F from sgl_kernel import tree_speculative_sampling_target_only @@ -97,26 +98,21 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc print(f"{accept_index=}") print(f"{accept_token_num=}") - return predicts, accept_index, accept_token_num + if threshold_single == 1 and threshold_acc == 1: + assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] + assert accept_index.tolist() == [ + [0, 3, 4, 5], + [6, 10, 11, -1], + ] + assert accept_token_num.tolist() == [3, 2] + elif threshold_single == 0 and threshold_acc == 0: + assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18] + assert accept_index.tolist() == [ + [0, 1, 2, -1], + [6, 10, 11, -1], + ] + assert accept_token_num.tolist() == [2, 2] if __name__ == "__main__": - predicts, accept_index, accept_token_num = ( - test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1) - ) - assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] - assert accept_index.tolist() == [ - [0, 3, 4, 5], - [6, 10, 11, -1], - ] - assert accept_token_num.tolist() == [3, 2] - - predicts, accept_index, accept_token_num = ( - test_tree_speculative_sampling_target_only(threshold_single=0, threshold_acc=0) - ) - assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18] - assert accept_index.tolist() == [ - [0, 1, 2, -1], - [6, 10, 11, -1], - ] - assert accept_token_num.tolist() == [2, 2] + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_awq_dequant.py b/sgl-kernel/tests/test_awq_dequant.py index bad3e2c10ce..33380180b0f 100644 --- a/sgl-kernel/tests/test_awq_dequant.py +++ b/sgl-kernel/tests/test_awq_dequant.py @@ -128,5 +128,4 @@ def test_awq_dequant_compare_implementations( if __name__ == "__main__": - # Run the specific test function directly pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_cublas_grouped_gemm.py b/sgl-kernel/tests/test_cublas_grouped_gemm.py index 9aac569f2dd..744fb17fd9e 100644 --- a/sgl-kernel/tests/test_cublas_grouped_gemm.py +++ b/sgl-kernel/tests/test_cublas_grouped_gemm.py @@ -1,49 +1,49 @@ -import unittest - +import pytest import torch from sgl_kernel import cublas_grouped_gemm def torch_grouped_gemm(a_array, b_array, out_dtype): - c_array = [] - for a, b in zip(a_array, b_array): - c_array.append(torch.matmul(a, b.t()).to(out_dtype)) - return c_array - - -class TestGroupedGemm(unittest.TestCase): - def _test_accuracy(self, Ms, Ns, Ks, out_dtype): - group_count = len(Ms) - a_array = [] - b_array = [] - c_array_cublas = [] - for i in range(group_count): - M, N, K = Ms[i], Ns[i], Ks[i] - a_array.append(torch.randn((M, K), device="cuda", dtype=out_dtype) * 5) - b_array.append(torch.randn((N, K), device="cuda", dtype=out_dtype) * 5) - c_array_cublas.append(torch.empty((M, N), device="cuda", dtype=out_dtype)) - - c_array_torch = torch_grouped_gemm(a_array, b_array, out_dtype) - cublas_grouped_gemm(a_array, b_array, c_array_cublas, out_dtype) - - for i in range(group_count): - M, N, K = Ms[i], Ns[i], Ks[i] - torch.testing.assert_close(c_array_torch[i], c_array_cublas[i]) - print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK") - - def test_accuracy(self): - Ms = [1, 16, 32, 256, 1024] - Ns = [2, 16, 128, 256, 4096] - Ks = [3, 16, 32, 512, 8192] - out_dtypes = [torch.float16, torch.bfloat16] - for out_dtype in out_dtypes: - self._test_accuracy(Ms, Ns, Ks, out_dtype) + return [torch.matmul(a, b.t()).to(out_dtype) for a, b in zip(a_array, b_array)] + + +# Skip if CUDA is not available or CUDA version is lower than 12.5 +skip_condition = not torch.cuda.is_available() or ( + torch.version.cuda is None + or tuple(map(int, torch.version.cuda.split("."))) < (12, 5) +) + +shape_params = [ + (1, 2, 3), + (16, 16, 16), + (32, 128, 32), + (256, 256, 512), + (1024, 4096, 8192), +] + + +@pytest.mark.skipif( + skip_condition, reason="CUDA not available or CUDA version lower than 12.5" +) +@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("M, N, K", shape_params) +def test_grouped_gemm_accuracy(out_dtype, M, N, K): + # Create input matrices for a single GEMM test + a = torch.randn((M, K), device="cuda", dtype=out_dtype) * 5 + b = torch.randn((N, K), device="cuda", dtype=out_dtype) * 5 + expected = torch.matmul(a, b.t()).to(out_dtype) + + a_array = [a] + b_array = [b] + c_array = [torch.empty((M, N), device="cuda", dtype=out_dtype)] + + result_torch = torch_grouped_gemm(a_array, b_array, out_dtype)[0] + cublas_grouped_gemm(a_array, b_array, c_array, out_dtype) + + torch.testing.assert_close(result_torch, expected) + torch.testing.assert_close(c_array[0], expected) + print(f"Test passed for M={M}, N={N}, K={K}, out_dtype={out_dtype}") if __name__ == "__main__": - if torch.cuda.is_available(): - cuda_version = tuple(map(int, torch.version.cuda.split("."))) - if cuda_version >= (12, 5): - unittest.main() - else: - print(f"Cuda version {cuda_version} lower than 12.5, not executing tests.") + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_deep_gemm.py b/sgl-kernel/tests/test_deep_gemm.py new file mode 100644 index 00000000000..e0dbe331ec4 --- /dev/null +++ b/sgl-kernel/tests/test_deep_gemm.py @@ -0,0 +1,232 @@ +import os +import random +from typing import Any, Tuple + +import deep_gemm +import pytest +import torch +from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor, jit + + +@pytest.fixture(autouse=True, scope="module") +def setup_deepgemm(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.manual_seed(0) + random.seed(0) + print("Library path:") + print(f" > {deep_gemm.__path__}\n") + + +def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( + m, n + ), (x_amax / 448.0).view(m, -1) + + +def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device + ) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( + x_view.size(0), x_view.size(2) + ) + + +def construct(m: int, k: int, n: int) -> Tuple[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor], + torch.Tensor, + torch.Tensor, +]: + x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) + out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) + ref_out = x @ y.t() + x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) + x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) + return x_fp8, y_fp8, out, ref_out + + +def construct_grouped( + num_groups: int, m: int, k: int, n: int, is_masked: bool +) -> Tuple[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor], + torch.Tensor, + torch.Tensor, +]: + x = torch.randn((num_groups, m, k), device="cuda", dtype=torch.bfloat16) + y = torch.randn((num_groups, n, k), device="cuda", dtype=torch.bfloat16) + out = torch.empty((num_groups, m, n), device="cuda", dtype=torch.bfloat16) + ref_out = torch.einsum("gmk,gnk->gmn", x, y) + assert m % 4 == 0, f"TMA alignment error: {m}" + x_fp8 = ( + torch.empty_like(x, dtype=torch.float8_e4m3fn), + torch.empty((num_groups, m, k // 128), device="cuda", dtype=torch.float), + ) + y_fp8 = ( + torch.empty_like(y, dtype=torch.float8_e4m3fn), + torch.empty( + (num_groups, (n + 127) // 128, k // 128), device="cuda", dtype=torch.float + ), + ) + for i in range(num_groups): + x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + if not is_masked: + x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1]) + out, ref_out = out.view(-1, n), ref_out.view(-1, n) + x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) + return x_fp8, y_fp8, out, ref_out + + +@pytest.mark.parametrize( + "m,k,n", + [ + (64, 7168, 2112), + (64, 1536, 24576), + (64, 512, 32768), + (64, 16384, 7168), + (64, 7168, 4096), + (64, 2048, 7168), + (128, 7168, 2112), + (128, 1536, 24576), + (128, 512, 32768), + (128, 16384, 7168), + (128, 7168, 4096), + (128, 2048, 7168), + (4096, 7168, 2112), + (4096, 1536, 24576), + (4096, 512, 32768), + (4096, 16384, 7168), + (4096, 7168, 4096), + (4096, 2048, 7168), + ], +) +def test_gemm(m: int, k: int, n: int): + x_fp8, y_fp8, out, ref_out = construct(m, k, n) + deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) + diff = calc_diff(out, ref_out) + assert diff < 0.001, f"{m=}, {k=}, {n=}, {diff:.5f}" + + +@pytest.mark.parametrize( + "num_groups,m,k,n", + [ + (4, 8192, 7168, 4096), + (4, 8192, 2048, 7168), + (8, 4096, 7168, 4096), + (8, 4096, 2048, 7168), + ], +) +def test_m_grouped_gemm_contiguous(num_groups: int, m: int, k: int, n: int): + x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False) + m_indices = ( + torch.arange(0, num_groups, device="cuda", dtype=torch.int) + .unsqueeze(-1) + .expand(num_groups, m) + .contiguous() + .view(-1) + ) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) + diff = calc_diff(out, ref_out) + assert diff < 0.001, f"m={m * num_groups}, {k=}, {n=}, {diff:.5f}" + + +@pytest.mark.parametrize("num_groups,m", [(1, 1024), (2, 512), (4, 256)]) +@pytest.mark.parametrize("k,n", [(7168, 4096), (2048, 7168)]) +@pytest.mark.parametrize("trial", range(10)) +def test_m_grouped_gemm_masked(num_groups: int, m: int, k: int, n: int, trial: int): + masked_m_candidates = [c for c in (64, 128, 192, 256, 320, 384) if c <= m] + x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True) + masked_m = torch.empty((num_groups,), device="cuda", dtype=torch.int) + for j in range(num_groups): + masked_m[j] = random.choice(masked_m_candidates) + expected_m = min(int(masked_m.float().mean()) + 1, m) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( + x_fp8, y_fp8, out, masked_m, expected_m + ) + for j in range(num_groups): + diff = calc_diff(out[j, : masked_m[j].item()], ref_out[j, : masked_m[j].item()]) + assert ( + diff < 0.001 + ), f"{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}" + + +class Capture: + def __init__(self) -> None: + self.read_fd = None + self.write_fd = None + self.saved_stdout = None + self.captured = None + + def __enter__(self) -> Any: + self.read_fd, self.write_fd = os.pipe() + self.saved_stdout = os.dup(1) + os.dup2(self.write_fd, 1) + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + os.dup2(self.saved_stdout, 1) + os.close(self.write_fd) + with os.fdopen(self.read_fd, "r") as f: + self.captured = f.read() + + def capture(self) -> str: + return self.captured + + +def test_jit(): + print(f"NVCC compiler: {jit.get_nvcc_compiler()}\n") + print("Generated code:") + args = ( + ("lhs", torch.float8_e4m3fn), + ("rhs", torch.float8_e4m3fn), + ("scale", torch.float), + ("out", torch.bfloat16), + ("enable_double_streams", bool), + ("stream", torch.cuda.Stream), + ) + body = "\n" + body += "std::cout << reinterpret_cast(lhs) << std::endl;\n" + body += "std::cout << reinterpret_cast(rhs) << std::endl;\n" + body += "std::cout << reinterpret_cast(scale) << std::endl;\n" + body += "std::cout << reinterpret_cast(out) << std::endl;\n" + body += "std::cout << enable_double_streams << std::endl;\n" + body += "std::cout << reinterpret_cast(stream) << std::endl;\n" + code = jit.generate((), args, body) + print(code) + print("Building ...") + func = jit.build("test_func", args, code) + print("Running ...") + fp8_tensor = torch.empty((1,), dtype=torch.float8_e4m3fn, device="cuda") + fp32_tensor = torch.empty((1,), dtype=torch.float, device="cuda") + bf16_tensor = torch.empty((1,), dtype=torch.bfloat16, device="cuda") + with Capture() as capture: + ret_val = func( + fp8_tensor, + fp8_tensor, + fp32_tensor, + bf16_tensor, + True, + torch.cuda.current_stream(), + ) + assert ret_val == 0, "Function did not return 0" + output = capture.capture() + ref_output = f"{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n" + assert output == ref_output, f"{output=}, {ref_output=}" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_fp8_blockwise_gemm.py b/sgl-kernel/tests/test_fp8_blockwise_gemm.py index 4ae7ae0355d..d432d974e4e 100644 --- a/sgl-kernel/tests/test_fp8_blockwise_gemm.py +++ b/sgl-kernel/tests/test_fp8_blockwise_gemm.py @@ -1,12 +1,13 @@ -import unittest +import os +import random from typing import Optional, Type +import pytest import torch from sgl_kernel import fp8_blockwise_scaled_mm def cdiv(a: int, b: int) -> int: - """Ceiling division.""" return -(a // -b) @@ -23,21 +24,6 @@ def baseline_scaled_mm( out_dtype: Type[torch.dtype], bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - # We treat N-dimensional group scaling as extended numpy-style broadcasting - # in numpy simply stretches dimensions with an extent of 1 to match the - # the target shape by repeating the data along that dimension (broadcasting) - # , we extend these semantics to say if the extent of a dimension in the - # source shape is not 1 and does not match the target shape we repeat each - # element along that dimension src_shape[dim] // target_shape[dim] times - # example if we have: - # a = [[1, 2], and target_shape = (2, 4) - # [3, 4]] - # then we would expand a to: - # a = [[1, 1, 2, 2], - # [3, 3, 4, 4]] - # NOTE this function this function does not explicitly broadcast dimensions - # with an extent of 1, since this can be done implicitly by pytorch def group_broadcast(t, shape): for i, s in enumerate(shape): if t.shape[i] != s and t.shape[i] != 1: @@ -51,62 +37,44 @@ def group_broadcast(t, shape): scale_a = group_broadcast(scale_a, a.shape) scale_b = group_broadcast(scale_b, b.shape) - output = torch.mm( (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32)) ).to(out_dtype) - if bias is not None: output = output + bias - return output -class TestFp8Gemm(unittest.TestCase): - def _test_accuracy_once(self, M, N, K, out_dtype, device): - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - a_fp32 = ( - (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max - ) - a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - b_fp32 = ( - (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max - ) - b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t() - - scale_a_group_shape = (1, 128) - scale_b_group_shape = (128, 128) - scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape) - scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape) - - scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001 - scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001 - scale_a = scale_a.t().contiguous().t() - scale_b = scale_b.t().contiguous().t() - - o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16) - o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype) - o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype) - - rtol = 0.02 - atol = 1 - torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) - print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK") - - def test_accuracy(self): - Ms = [1, 128, 512, 1024, 4096] - Ns = [128, 512, 1024, 4096] - Ks = [512, 1024, 4096, 8192, 16384] - out_dtypes = [torch.bfloat16, torch.float16] - for M in Ms: - for N in Ns: - for K in Ks: - for out_dtype in out_dtypes: - self._test_accuracy_once(M, N, K, out_dtype, "cuda") +def _test_accuracy_once(M, N, K, out_dtype, device): + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t() + scale_a_group_shape = (1, 128) + scale_b_group_shape = (128, 128) + scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape) + scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape) + scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001 + scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001 + scale_a = scale_a.t().contiguous().t() + scale_b = scale_b.t().contiguous().t() + o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype) + o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype) + rtol = 0.02 + atol = 1 + torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) + print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK") + + +@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096]) +@pytest.mark.parametrize("N", [128, 512, 1024, 4096]) +@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +def test_accuracy(M, N, K, out_dtype): + _test_accuracy_once(M, N, K, out_dtype, "cuda") if __name__ == "__main__": - unittest.main() + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_fp8_gemm.py b/sgl-kernel/tests/test_fp8_gemm.py index 1a731865944..e70e62af26c 100644 --- a/sgl-kernel/tests/test_fp8_gemm.py +++ b/sgl-kernel/tests/test_fp8_gemm.py @@ -1,67 +1,49 @@ -import unittest - +import pytest import torch from sgl_kernel import fp8_scaled_mm def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) - o = o.to(torch.float32) temp1 = o * scale_a.view(-1, 1) temp2 = temp1 * scale_b.view(1, -1) final = temp2.to(out_dtype) if bias is not None: final = final + bias.view(1, -1) - return final -class TestFp8Gemm(unittest.TestCase): - def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - a_fp32 = ( - (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max - ) - a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - b_fp32 = ( - (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max - ) - b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001 - scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001 - if with_bias: - bias = torch.randn((N,), device=device, dtype=out_dtype) - else: - bias = None - o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16) - b_fp8 = b_fp8.t() - o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) - o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) - rtol = 0.02 - atol = 1 - torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) - print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") - - def test_accuracy(self): - Ms = [1, 128, 512, 1024, 4096] - Ns = [16, 128, 512, 1024, 4096] - Ks = [512, 1024, 4096, 8192, 16384] - bias_opts = [True, False] - out_dtypes = [torch.bfloat16, torch.float16] - for M in Ms: - for N in Ns: - for K in Ks: - for with_bias in bias_opts: - for out_dtype in out_dtypes: - self._test_accuracy_once( - M, N, K, with_bias, out_dtype, "cuda" - ) +def _test_accuracy_once(M, N, K, with_bias, out_dtype, device): + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001 + scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001 + if with_bias: + bias = torch.randn((N,), device=device, dtype=out_dtype) + else: + bias = None + b_fp8 = b_fp8.t() + o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + rtol = 0.02 + atol = 1 + torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) + print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") + + +@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096]) +@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096]) +@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384]) +@pytest.mark.parametrize("with_bias", [True, False]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +def test_accuracy(M, N, K, with_bias, out_dtype): + _test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda") if __name__ == "__main__": - unittest.main() + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_int8_gemm.py b/sgl-kernel/tests/test_int8_gemm.py index 951de314e03..d87a9a5aacf 100644 --- a/sgl-kernel/tests/test_int8_gemm.py +++ b/sgl-kernel/tests/test_int8_gemm.py @@ -1,5 +1,4 @@ -import unittest - +import pytest import torch from sgl_kernel import int8_scaled_mm from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm @@ -18,39 +17,31 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): return o.to(out_dtype) -class TestInt8Gemm(unittest.TestCase): - def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): - a = to_int8(torch.randn((M, K), device=device) * 5) - b = to_int8(torch.randn((N, K), device=device).t() * 5) - scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) - scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) - if with_bias: - bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10 - else: - bias = None - - o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - torch.testing.assert_close(o, o1) - torch.testing.assert_close(o, o2) - print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") - - def test_accuracy(self): - Ms = [1, 16, 32, 64, 128, 512, 1024, 4096, 8192] - Ns = [16, 128, 512, 1024, 4096, 8192, 16384] - Ks = [512, 1024, 4096, 8192, 16384] - bias_opts = [True, False] - out_dtypes = [torch.float16, torch.bfloat16] - for M in Ms: - for N in Ns: - for K in Ks: - for with_bias in bias_opts: - for out_dtype in out_dtypes: - self._test_accuracy_once( - M, N, K, with_bias, out_dtype, "cuda" - ) +def _test_accuracy_once(M, N, K, with_bias, out_dtype, device): + a = to_int8(torch.randn((M, K), device=device) * 5) + b = to_int8(torch.randn((N, K), device=device).t() * 5) + scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) + scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) + if with_bias: + bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10 + else: + bias = None + o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + torch.testing.assert_close(o, o1) + torch.testing.assert_close(o, o2) + print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") + + +@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]) +@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096, 8192, 16384]) +@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384]) +@pytest.mark.parametrize("with_bias", [True, False]) +@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16]) +def test_accuracy(M, N, K, with_bias, out_dtype): + _test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda") if __name__ == "__main__": - unittest.main() + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_per_token_quant_fp8.py b/sgl-kernel/tests/test_per_token_quant_fp8.py index 20b2722fce0..fe1e0afe3dd 100644 --- a/sgl-kernel/tests/test_per_token_quant_fp8.py +++ b/sgl-kernel/tests/test_per_token_quant_fp8.py @@ -51,5 +51,4 @@ def test_per_token_quant_compare_implementations( if __name__ == "__main__": - # Run the specific test function directly pytest.main([__file__]) From a47cc35bbdce2732c79e173498df0247c9862c55 Mon Sep 17 00:00:00 2001 From: Adarsh Shirawalmath <114558126+adarshxs@users.noreply.github.com> Date: Sun, 30 Mar 2025 01:10:40 +0530 Subject: [PATCH 02/11] use pytest --- .github/workflows/pr-test-sgl-kernel.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml index 6944f9a4412..91fa78fc000 100644 --- a/.github/workflows/pr-test-sgl-kernel.yml +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -89,7 +89,7 @@ jobs: timeout-minutes: 30 run: | cd sgl-kernel - find tests -name "test_*.py" | xargs -n 1 python3 + pytest tests/ - name: Uninstall dependencies run: | From 15790f8d643f5d0b3d84ba94330fe90106700312 Mon Sep 17 00:00:00 2001 From: adarshxs Date: Sun, 30 Mar 2025 10:20:20 +0530 Subject: [PATCH 03/11] fix all_reduce_test --- sgl-kernel/tests/test_trt_allreduce.py | 444 +++++++++++++++---------- 1 file changed, 264 insertions(+), 180 deletions(-) diff --git a/sgl-kernel/tests/test_trt_allreduce.py b/sgl-kernel/tests/test_trt_allreduce.py index 9bbc4e76fa8..79e106337ed 100644 --- a/sgl-kernel/tests/test_trt_allreduce.py +++ b/sgl-kernel/tests/test_trt_allreduce.py @@ -3,9 +3,9 @@ import random import socket import time -import unittest -from typing import Any, List, Optional +from typing import Any, Callable, List, Optional +import pytest import ray import sgl_kernel.allreduce as custom_ops import torch @@ -18,227 +18,311 @@ logger = logging.getLogger(__name__) +TEST_SIZES = [512, 4096, 32768, 262144, 524288, 1048576, 2097152] +WORLD_SIZES = [2, 4] +BUFFER_MAX_SIZE = 8 * 1024 * 1024 +BARRIER_MAX_SIZE = 8 * (24 + 2) * 8 +VLLM_MAX_SIZE = 8 * 1024 * 1024 + + def get_open_port() -> int: - # try ipv4 try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) return s.getsockname()[1] except OSError: - # try ipv6 with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: - s.bind(("127.0.0.1", 0)) + s.bind(("::1", 0)) return s.getsockname()[1] -def multi_process_parallel( - world_size: int, - cls: Any, - test_target: Any, -) -> None: - # Using ray helps debugging the error when it failed - # as compared to multiprocessing. - # NOTE: We need to set working_dir for distributed tests, - # otherwise we may get import errors on ray workers - ray.init(log_to_driver=True) - - distributed_init_port = get_open_port() - refs = [] - for rank in range(world_size): - refs.append(test_target.remote(cls, world_size, rank, distributed_init_port)) - ray.get(refs) - - ray.shutdown() - - -class TestCustomAllReduce(unittest.TestCase): - @classmethod - def setUpClass(cls): - random.seed(42) - cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152] - cls.world_sizes = [2, 4, 8] - - @staticmethod - def create_shared_buffer( - size_in_bytes: int, group: Optional[ProcessGroup] = None - ) -> List[int]: - """ - Creates a shared buffer and returns a list of pointers - representing the buffer on all processes in the group. - """ - lib = CudaRTLibrary() - pointer = lib.cudaMalloc(size_in_bytes) - handle = lib.cudaIpcGetMemHandle(pointer) - world_size = dist.get_world_size(group=group) - rank = dist.get_rank(group=group) - handles = [None] * world_size - dist.all_gather_object(handles, handle, group=group) - - pointers: List[int] = [] - for i, h in enumerate(handles): - if i == rank: - pointers.append(pointer.value) # type: ignore - else: - pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore - - return pointers - - @staticmethod - def free_shared_buffer( - pointers: List[int], group: Optional[ProcessGroup] = None - ) -> None: - rank = dist.get_rank(group=group) - lib = CudaRTLibrary() - lib.cudaFree(ctypes.c_void_p(pointers[rank])) - - def test_correctness(self): - for world_size in self.world_sizes: - if world_size > torch.cuda.device_count(): - continue - multi_process_parallel(world_size, self, self.correctness) +def init_distributed_env(world_size, rank, distributed_init_port): + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + ranks = [i for i in range(world_size)] + distributed_init_method = f"tcp://127.0.0.1:{distributed_init_port}" + dist.init_process_group( + backend="nccl", + init_method=distributed_init_method, + rank=rank, + world_size=world_size, + ) + return None + + +def create_shared_buffer( + size_in_bytes: int, group: Optional[ProcessGroup] = None +) -> List[int]: + lib = CudaRTLibrary() + if not torch.cuda.is_initialized(): + torch.cuda.init() + + pointer = lib.cudaMalloc(size_in_bytes) + handle = lib.cudaIpcGetMemHandle(pointer) + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + + gathered_handles = [None] * world_size + dist.all_gather_object(gathered_handles, handle, group=group) + + pointers: List[int] = [0] * world_size + for i, h in enumerate(gathered_handles): + if h is None: + raise RuntimeError( + f"Rank {i} did not receive a valid handle from rank {rank}." + ) + if i == rank: + pointers[i] = pointer.value + else: + opened_ptr = lib.cudaIpcOpenMemHandle(h) + pointers[i] = opened_ptr.value - def test_performance(self): - for world_size in self.world_sizes: - if world_size > torch.cuda.device_count(): - continue - multi_process_parallel(world_size, self, self.performance) + dist.barrier(group=group) + return pointers - def init_custom_allreduce(self, rank, world_size, group): - buffer_max_size = 8 * 1024 * 1024 - barrier_max_size = 8 * (24 + 2) * 8 - self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group) - self.tmp_result_buffer_ptrs = self.create_shared_buffer( - buffer_max_size, group=group - ) - self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group) - self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group) - self.rank_data = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0") - ) +def free_shared_buffer( + pointers: List[int], group: Optional[ProcessGroup] = None +) -> None: + if not pointers: + return + rank = dist.get_rank(group=group) + lib = CudaRTLibrary() + if rank < len(pointers) and pointers[rank] != 0: + try: + lib.cudaFree(ctypes.c_void_p(pointers[rank])) + except Exception as e: + logger.error( + f"Rank {rank}: Error freeing shared buffer pointer {pointers[rank]}: {e}" + ) + dist.barrier(group=group) - self.custom_ptr = custom_ops.init_custom_reduce( - rank, - world_size, - self.rank_data, - self.buffer_ptrs, - self.tmp_result_buffer_ptrs, - self.barrier_in_ptrs, - self.barrier_out_ptrs, - ) - def custom_allreduce(self, inp, out): - custom_ops.custom_reduce(self.custom_ptr, inp, out) - - def free_custom_allreduce(self, group): - self.free_shared_buffer(self.buffer_ptrs, group) - self.free_shared_buffer(self.tmp_result_buffer_ptrs, group) - self.free_shared_buffer(self.barrier_in_ptrs, group) - self.free_shared_buffer(self.barrier_out_ptrs, group) - custom_ops.custom_dispose(self.custom_ptr) - - def init_vllm_allreduce(self, rank, group): - self.vllm_rank = rank - self.vllm_max_size = 8 * 1024 * 1024 - self.vllm_meta_ptrs = self.create_shared_buffer( - vllm_ops.meta_size() + self.vllm_max_size, group=group - ) - self.vllm_buffer_ptrs = self.create_shared_buffer( - self.vllm_max_size, group=group +@ray.remote(num_gpus=1, max_calls=1) +def correctness_worker(rank, world_size, distributed_init_port): + group = init_distributed_env(world_size, rank, distributed_init_port) + worker_state = {} + + try: + worker_state["buffer_ptrs"] = create_shared_buffer(BUFFER_MAX_SIZE, group=group) + worker_state["tmp_result_buffer_ptrs"] = create_shared_buffer( + BUFFER_MAX_SIZE, group=group ) - self.vllm_rank_data = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0") + worker_state["barrier_in_ptrs"] = create_shared_buffer( + BARRIER_MAX_SIZE, group=group ) - self.vllm_ptr = vllm_ops.init_custom_ar( - self.vllm_meta_ptrs, self.vllm_rank_data, rank, True + worker_state["barrier_out_ptrs"] = create_shared_buffer( + BARRIER_MAX_SIZE, group=group ) - vllm_ops.register_buffer(self.vllm_ptr, self.vllm_buffer_ptrs) - - def vllm_allreduce(self, inp, out): - vllm_ops.all_reduce( - self.vllm_ptr, - inp, - out, - self.vllm_buffer_ptrs[self.vllm_rank], - self.vllm_max_size, + worker_state["rank_data"] = torch.empty( + BUFFER_MAX_SIZE, dtype=torch.uint8, device=torch.cuda.current_device() ) - - def free_vllm_allreduce(self, group): - vllm_ops.dispose(self.vllm_ptr) - self.free_shared_buffer(self.vllm_meta_ptrs, group) - self.free_shared_buffer(self.vllm_buffer_ptrs, group) - - @staticmethod - def init_distributed_env(world_size, rank, distributed_init_port): - device = torch.device("cuda:0") - torch.cuda.set_device(device) - ranks = [i for i in range(world_size)] - distributed_init_method = f"tcp://localhost:{distributed_init_port}" - dist.init_process_group( - backend="nccl", - init_method=distributed_init_method, - rank=rank, - world_size=world_size, + worker_state["custom_ptr"] = custom_ops.init_custom_reduce( + rank, + world_size, + worker_state["rank_data"], + worker_state["buffer_ptrs"], + worker_state["tmp_result_buffer_ptrs"], + worker_state["barrier_in_ptrs"], + worker_state["barrier_out_ptrs"], ) - group = torch.distributed.new_group(ranks, backend="gloo") - return group - - # compare result with torch.distributed - @ray.remote(num_gpus=1, max_calls=1) - def correctness(self, world_size, rank, distributed_init_port): - group = self.init_distributed_env(world_size, rank, distributed_init_port) - - self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) test_loop = 10 - for sz in self.test_sizes: + for sz in TEST_SIZES: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - for _ in range(test_loop): + for i in range(test_loop): + dist.barrier(group=group) inp1 = torch.randint( 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() ) + inp_clone = inp1.clone() out1 = torch.empty_like(inp1) - self.custom_allreduce(inp1, out1) - dist.all_reduce(inp1, group=group) - torch.testing.assert_close(out1, inp1) + custom_ops.custom_reduce(worker_state["custom_ptr"], inp1, out1) + + dist.barrier(group=group) + + dist.all_reduce(inp_clone, group=group) + + dist.barrier(group=group) + torch.testing.assert_close(out1, inp_clone, rtol=1e-3, atol=1e-3) + + finally: + if "custom_ptr" in worker_state and worker_state["custom_ptr"]: + custom_ops.custom_dispose(worker_state["custom_ptr"]) + free_shared_buffer(worker_state.get("buffer_ptrs", []), group) + free_shared_buffer(worker_state.get("tmp_result_buffer_ptrs", []), group) + free_shared_buffer(worker_state.get("barrier_in_ptrs", []), group) + free_shared_buffer(worker_state.get("barrier_out_ptrs", []), group) + if dist.is_initialized(): + dist.destroy_process_group() + - self.free_custom_allreduce(group) +@ray.remote(num_gpus=1, max_calls=1) +def performance_worker(rank, world_size, distributed_init_port): + group = init_distributed_env(world_size, rank, distributed_init_port) + worker_state = {} - # compare performance with vllm - @ray.remote(num_gpus=1, max_calls=1) - def performance(self, world_size, rank, distributed_init_port): - group = self.init_distributed_env(world_size, rank, distributed_init_port) + try: + worker_state["vllm_meta_ptrs"] = create_shared_buffer( + vllm_ops.meta_size() + VLLM_MAX_SIZE, group=group + ) + worker_state["vllm_buffer_ptrs"] = create_shared_buffer( + VLLM_MAX_SIZE, group=group + ) + worker_state["vllm_rank_data"] = torch.empty( + VLLM_MAX_SIZE, dtype=torch.uint8, device=torch.cuda.current_device() + ) + worker_state["vllm_ptr"] = vllm_ops.init_custom_ar( + worker_state["vllm_meta_ptrs"], worker_state["vllm_rank_data"], rank, True + ) + vllm_ops.register_buffer( + worker_state["vllm_ptr"], worker_state["vllm_buffer_ptrs"] + ) + + worker_state["buffer_ptrs"] = create_shared_buffer(BUFFER_MAX_SIZE, group=group) + worker_state["tmp_result_buffer_ptrs"] = create_shared_buffer( + BUFFER_MAX_SIZE, group=group + ) + worker_state["barrier_in_ptrs"] = create_shared_buffer( + BARRIER_MAX_SIZE, group=group + ) + worker_state["barrier_out_ptrs"] = create_shared_buffer( + BARRIER_MAX_SIZE, group=group + ) + worker_state["rank_data"] = torch.empty( + BUFFER_MAX_SIZE, dtype=torch.uint8, device=torch.cuda.current_device() + ) + worker_state["custom_ptr"] = custom_ops.init_custom_reduce( + rank, + world_size, + worker_state["rank_data"], + worker_state["buffer_ptrs"], + worker_state["tmp_result_buffer_ptrs"], + worker_state["barrier_in_ptrs"], + worker_state["barrier_out_ptrs"], + ) - self.init_vllm_allreduce(rank, group) - self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) + dist.barrier(group=group) - for sz in self.test_sizes: + for sz in TEST_SIZES: inp1 = torch.randint( 1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device() ) out1 = torch.empty_like(inp1) - test_loop = 5000 - start = time.time() - for _ in range(test_loop): - self.custom_allreduce(inp1, out1) - elapse_custom = time.time() - start + test_loop = 100 - start = time.time() + torch.cuda.synchronize() + dist.barrier(group=group) + start_custom = time.time() + for _ in range(test_loop): + custom_ops.custom_reduce(worker_state["custom_ptr"], inp1, out1) + torch.cuda.synchronize() + dist.barrier(group=group) + elapse_custom = time.time() - start_custom + + torch.cuda.synchronize() + dist.barrier(group=group) + start_vllm = time.time() for _ in range(test_loop): - self.vllm_allreduce(inp1, out1) - elapse_vllm = time.time() - start + vllm_ops.all_reduce( + worker_state["vllm_ptr"], + inp1, + out1, + worker_state["vllm_buffer_ptrs"][rank], + VLLM_MAX_SIZE, + ) + torch.cuda.synchronize() + dist.barrier(group=group) + elapse_vllm = time.time() - start_vllm if rank == 0: logger.warning( - f"test_size = {sz}, world_size = {world_size}, " - f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms, " - f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms " + f"PERF: sz={sz}, world={world_size}, " + f"vllm={elapse_vllm * 1000 / test_loop:.4f}ms, " + f"custom={elapse_custom * 1000 / test_loop:.4f}ms" ) - self.free_custom_allreduce(group) - self.free_vllm_allreduce(group) + finally: + if "vllm_ptr" in worker_state and worker_state["vllm_ptr"]: + vllm_ops.dispose(worker_state["vllm_ptr"]) + free_shared_buffer(worker_state.get("vllm_meta_ptrs", []), group) + free_shared_buffer(worker_state.get("vllm_buffer_ptrs", []), group) + + if "custom_ptr" in worker_state and worker_state["custom_ptr"]: + custom_ops.custom_dispose(worker_state["custom_ptr"]) + free_shared_buffer(worker_state.get("buffer_ptrs", []), group) + free_shared_buffer(worker_state.get("tmp_result_buffer_ptrs", []), group) + free_shared_buffer(worker_state.get("barrier_in_ptrs", []), group) + free_shared_buffer(worker_state.get("barrier_out_ptrs", []), group) + + if dist.is_initialized(): + dist.destroy_process_group() + + +class TestCustomAllReduce: + + @pytest.fixture(scope="class", autouse=True) + def ray_controller(self): + if not ray.is_initialized(): + ray.init(log_to_driver=False, ignore_reinit_error=True) + yield + ray.shutdown() + + def test_correctness(self, request): + node_world_sizes = WORLD_SIZES + num_gpus = torch.cuda.device_count() + logger.info(f"Detected {num_gpus} GPUs for correctness test.") + + for world_size in node_world_sizes: + if world_size > num_gpus: + pytest.skip( + f"Skipping world_size={world_size}, requires {world_size} GPUs, found {num_gpus}" + ) + continue + + logger.info(f"Running correctness test with world_size={world_size}") + distributed_init_port = get_open_port() + refs = [ + correctness_worker.remote(rank, world_size, distributed_init_port) + for rank in range(world_size) + ] + try: + ray.get(refs) + logger.info(f"Correctness test PASSED for world_size={world_size}") + except Exception as e: + logger.error( + f"Correctness test FAILED for world_size={world_size}: {e}" + ) + pytest.fail(f"Correctness test failed for world_size={world_size}: {e}") + + def test_performance(self, request): + node_world_sizes = WORLD_SIZES + num_gpus = torch.cuda.device_count() + logger.info(f"Detected {num_gpus} GPUs for performance test.") + + for world_size in node_world_sizes: + if world_size > num_gpus: + pytest.skip( + f"Skipping world_size={world_size}, requires {world_size} GPUs, found {num_gpus}" + ) + continue + + logger.info(f"Running performance test with world_size={world_size}") + distributed_init_port = get_open_port() + refs = [ + performance_worker.remote(rank, world_size, distributed_init_port) + for rank in range(world_size) + ] + try: + ray.get(refs) + logger.info(f"Performance test COMPLETED for world_size={world_size}") + except Exception as e: + logger.error( + f"Performance test FAILED for world_size={world_size}: {e}" + ) + pytest.fail(f"Performance test failed for world_size={world_size}: {e}") if __name__ == "__main__": - unittest.main() + pytest.main([__file__]) From fbf82db7cfd05ba56f95121e8b2992fcc602def6 Mon Sep 17 00:00:00 2001 From: adarshxs Date: Sun, 30 Mar 2025 11:49:08 +0530 Subject: [PATCH 04/11] fix trt_allreduce --- sgl-kernel/tests/test_trt_allreduce.py | 426 +++++++++++++------------ 1 file changed, 217 insertions(+), 209 deletions(-) diff --git a/sgl-kernel/tests/test_trt_allreduce.py b/sgl-kernel/tests/test_trt_allreduce.py index 79e106337ed..f371cdeea67 100644 --- a/sgl-kernel/tests/test_trt_allreduce.py +++ b/sgl-kernel/tests/test_trt_allreduce.py @@ -3,9 +3,9 @@ import random import socket import time -from typing import Any, Callable, List, Optional +import unittest +from typing import Any, Dict, List, Optional, Tuple -import pytest import ray import sgl_kernel.allreduce as custom_ops import torch @@ -18,13 +18,6 @@ logger = logging.getLogger(__name__) -TEST_SIZES = [512, 4096, 32768, 262144, 524288, 1048576, 2097152] -WORLD_SIZES = [2, 4] -BUFFER_MAX_SIZE = 8 * 1024 * 1024 -BARRIER_MAX_SIZE = 8 * (24 + 2) * 8 -VLLM_MAX_SIZE = 8 * 1024 * 1024 - - def get_open_port() -> int: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -32,297 +25,312 @@ def get_open_port() -> int: return s.getsockname()[1] except OSError: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: - s.bind(("::1", 0)) + s.bind(("", 0, 0, 0)) return s.getsockname()[1] -def init_distributed_env(world_size, rank, distributed_init_port): - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - ranks = [i for i in range(world_size)] - distributed_init_method = f"tcp://127.0.0.1:{distributed_init_port}" - dist.init_process_group( - backend="nccl", - init_method=distributed_init_method, - rank=rank, - world_size=world_size, - ) - return None - - -def create_shared_buffer( +def _create_shared_buffer( size_in_bytes: int, group: Optional[ProcessGroup] = None ) -> List[int]: + rank = dist.get_rank(group=group) + torch.cuda.set_device(rank) lib = CudaRTLibrary() - if not torch.cuda.is_initialized(): - torch.cuda.init() - pointer = lib.cudaMalloc(size_in_bytes) handle = lib.cudaIpcGetMemHandle(pointer) world_size = dist.get_world_size(group=group) - rank = dist.get_rank(group=group) - - gathered_handles = [None] * world_size - dist.all_gather_object(gathered_handles, handle, group=group) - - pointers: List[int] = [0] * world_size - for i, h in enumerate(gathered_handles): - if h is None: - raise RuntimeError( - f"Rank {i} did not receive a valid handle from rank {rank}." - ) + object_list = [None] * world_size + dist.all_gather_object(object_list, handle, group=group) + handles = object_list + pointers: List[int] = [] + for i, h in enumerate(handles): if i == rank: - pointers[i] = pointer.value + pointers.append(pointer.value) else: - opened_ptr = lib.cudaIpcOpenMemHandle(h) - pointers[i] = opened_ptr.value - + if h is None: + raise RuntimeError(f"Rank {rank} received None handle from rank {i}") + try: + opened_ptr = lib.cudaIpcOpenMemHandle(h) + pointers.append(opened_ptr.value) + except Exception as e: + raise RuntimeError( + f"Rank {rank} failed cudaIpcOpenMemHandle from rank {i}" + ) from e dist.barrier(group=group) return pointers -def free_shared_buffer( +def _free_shared_buffer( pointers: List[int], group: Optional[ProcessGroup] = None ) -> None: - if not pointers: - return rank = dist.get_rank(group=group) + torch.cuda.set_device(rank) lib = CudaRTLibrary() - if rank < len(pointers) and pointers[rank] != 0: + if pointers and rank < len(pointers): try: lib.cudaFree(ctypes.c_void_p(pointers[rank])) except Exception as e: - logger.error( - f"Rank {rank}: Error freeing shared buffer pointer {pointers[rank]}: {e}" - ) + logger.error(f"Rank {rank} failed to free shared buffer: {e}") dist.barrier(group=group) -@ray.remote(num_gpus=1, max_calls=1) -def correctness_worker(rank, world_size, distributed_init_port): - group = init_distributed_env(world_size, rank, distributed_init_port) - worker_state = {} +def _init_distributed_env(world_size, rank, distributed_init_port): + torch.cuda.set_device(rank) + distributed_init_method = f"tcp://127.0.0.1:{distributed_init_port}" + dist.init_process_group( + backend="nccl", + init_method=distributed_init_method, + rank=rank, + world_size=world_size, + ) + dist.barrier() + return dist.group.WORLD # Use default group + +@ray.remote(num_gpus=1, max_calls=1) +def run_correctness_task(world_size, rank, distributed_init_port, test_sizes): + group = _init_distributed_env(world_size, rank, distributed_init_port) + state = {} try: - worker_state["buffer_ptrs"] = create_shared_buffer(BUFFER_MAX_SIZE, group=group) - worker_state["tmp_result_buffer_ptrs"] = create_shared_buffer( - BUFFER_MAX_SIZE, group=group - ) - worker_state["barrier_in_ptrs"] = create_shared_buffer( - BARRIER_MAX_SIZE, group=group - ) - worker_state["barrier_out_ptrs"] = create_shared_buffer( - BARRIER_MAX_SIZE, group=group + buffer_max_size = 8 * 1024 * 1024 + barrier_max_size = 8 * (24 + 2) * 8 + + state["buffer_ptrs"] = _create_shared_buffer(buffer_max_size, group=group) + state["tmp_result_buffer_ptrs"] = _create_shared_buffer( + buffer_max_size, group=group ) - worker_state["rank_data"] = torch.empty( - BUFFER_MAX_SIZE, dtype=torch.uint8, device=torch.cuda.current_device() + state["barrier_in_ptrs"] = _create_shared_buffer(barrier_max_size, group=group) + state["barrier_out_ptrs"] = _create_shared_buffer(barrier_max_size, group=group) + state["rank_data"] = torch.empty( + buffer_max_size, dtype=torch.uint8, device=f"cuda:{rank}" ) - worker_state["custom_ptr"] = custom_ops.init_custom_reduce( + + state["custom_ptr"] = custom_ops.init_custom_reduce( rank, world_size, - worker_state["rank_data"], - worker_state["buffer_ptrs"], - worker_state["tmp_result_buffer_ptrs"], - worker_state["barrier_in_ptrs"], - worker_state["barrier_out_ptrs"], + state["rank_data"], + state["buffer_ptrs"], + state["tmp_result_buffer_ptrs"], + state["barrier_in_ptrs"], + state["barrier_out_ptrs"], ) + dist.barrier(group=group) test_loop = 10 - for sz in TEST_SIZES: + for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - for i in range(test_loop): - dist.barrier(group=group) - inp1 = torch.randint( - 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() + for _ in range(test_loop): + inp_torch = torch.randint( + 1, 16, (sz,), dtype=dtype, device=f"cuda:{rank}" ) - inp_clone = inp1.clone() - out1 = torch.empty_like(inp1) - - custom_ops.custom_reduce(worker_state["custom_ptr"], inp1, out1) + inp_custom = inp_torch.clone() + out_custom = torch.empty_like(inp_custom) + custom_ops.custom_reduce( + state["custom_ptr"], inp_custom, out_custom + ) dist.barrier(group=group) - dist.all_reduce(inp_clone, group=group) - + dist.all_reduce(inp_torch, group=group) dist.barrier(group=group) - torch.testing.assert_close(out1, inp_clone, rtol=1e-3, atol=1e-3) + + torch.testing.assert_close( + out_custom, inp_torch, rtol=1e-3, atol=1e-3 + ) finally: - if "custom_ptr" in worker_state and worker_state["custom_ptr"]: - custom_ops.custom_dispose(worker_state["custom_ptr"]) - free_shared_buffer(worker_state.get("buffer_ptrs", []), group) - free_shared_buffer(worker_state.get("tmp_result_buffer_ptrs", []), group) - free_shared_buffer(worker_state.get("barrier_in_ptrs", []), group) - free_shared_buffer(worker_state.get("barrier_out_ptrs", []), group) + if "custom_ptr" in state and state["custom_ptr"] is not None: + custom_ops.custom_dispose(state["custom_ptr"]) + if "buffer_ptrs" in state: + _free_shared_buffer(state["buffer_ptrs"], group) + if "tmp_result_buffer_ptrs" in state: + _free_shared_buffer(state["tmp_result_buffer_ptrs"], group) + if "barrier_in_ptrs" in state: + _free_shared_buffer(state["barrier_in_ptrs"], group) + if "barrier_out_ptrs" in state: + _free_shared_buffer(state["barrier_out_ptrs"], group) + dist.barrier(group=group) if dist.is_initialized(): - dist.destroy_process_group() + dist.destroy_process_group(group=group) @ray.remote(num_gpus=1, max_calls=1) -def performance_worker(rank, world_size, distributed_init_port): - group = init_distributed_env(world_size, rank, distributed_init_port) - worker_state = {} - +def run_performance_task(world_size, rank, distributed_init_port, test_sizes): + group = _init_distributed_env(world_size, rank, distributed_init_port) + custom_state = {} + vllm_state = {} try: - worker_state["vllm_meta_ptrs"] = create_shared_buffer( - vllm_ops.meta_size() + VLLM_MAX_SIZE, group=group + buffer_max_size = 8 * 1024 * 1024 + barrier_max_size = 8 * (24 + 2) * 8 + + custom_state["buffer_ptrs"] = _create_shared_buffer( + buffer_max_size, group=group ) - worker_state["vllm_buffer_ptrs"] = create_shared_buffer( - VLLM_MAX_SIZE, group=group + custom_state["tmp_result_buffer_ptrs"] = _create_shared_buffer( + buffer_max_size, group=group ) - worker_state["vllm_rank_data"] = torch.empty( - VLLM_MAX_SIZE, dtype=torch.uint8, device=torch.cuda.current_device() + custom_state["barrier_in_ptrs"] = _create_shared_buffer( + barrier_max_size, group=group ) - worker_state["vllm_ptr"] = vllm_ops.init_custom_ar( - worker_state["vllm_meta_ptrs"], worker_state["vllm_rank_data"], rank, True + custom_state["barrier_out_ptrs"] = _create_shared_buffer( + barrier_max_size, group=group ) - vllm_ops.register_buffer( - worker_state["vllm_ptr"], worker_state["vllm_buffer_ptrs"] + custom_state["rank_data"] = torch.empty( + buffer_max_size, dtype=torch.uint8, device=f"cuda:{rank}" ) - - worker_state["buffer_ptrs"] = create_shared_buffer(BUFFER_MAX_SIZE, group=group) - worker_state["tmp_result_buffer_ptrs"] = create_shared_buffer( - BUFFER_MAX_SIZE, group=group + custom_state["custom_ptr"] = custom_ops.init_custom_reduce( + rank, + world_size, + custom_state["rank_data"], + custom_state["buffer_ptrs"], + custom_state["tmp_result_buffer_ptrs"], + custom_state["barrier_in_ptrs"], + custom_state["barrier_out_ptrs"], ) - worker_state["barrier_in_ptrs"] = create_shared_buffer( - BARRIER_MAX_SIZE, group=group + dist.barrier(group=group) + + vllm_state["vllm_max_size"] = buffer_max_size + vllm_meta_buffer_size = vllm_ops.meta_size() + vllm_state["vllm_max_size"] + vllm_state["vllm_meta_ptrs"] = _create_shared_buffer( + vllm_meta_buffer_size, group=group ) - worker_state["barrier_out_ptrs"] = create_shared_buffer( - BARRIER_MAX_SIZE, group=group + vllm_state["vllm_buffer_ptrs"] = _create_shared_buffer( + vllm_state["vllm_max_size"], group=group ) - worker_state["rank_data"] = torch.empty( - BUFFER_MAX_SIZE, dtype=torch.uint8, device=torch.cuda.current_device() + vllm_state["vllm_rank_data"] = torch.empty( + buffer_max_size, dtype=torch.uint8, device=f"cuda:{rank}" ) - worker_state["custom_ptr"] = custom_ops.init_custom_reduce( - rank, - world_size, - worker_state["rank_data"], - worker_state["buffer_ptrs"], - worker_state["tmp_result_buffer_ptrs"], - worker_state["barrier_in_ptrs"], - worker_state["barrier_out_ptrs"], + vllm_state["vllm_ptr"] = vllm_ops.init_custom_ar( + vllm_state["vllm_meta_ptrs"], vllm_state["vllm_rank_data"], rank, True ) - + vllm_ops.register_buffer(vllm_state["vllm_ptr"], vllm_state["vllm_buffer_ptrs"]) dist.barrier(group=group) - for sz in TEST_SIZES: - inp1 = torch.randint( - 1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device() + for sz in test_sizes: + inp = torch.randint( + 1, 16, (sz,), dtype=torch.float32, device=f"cuda:{rank}" ) - out1 = torch.empty_like(inp1) - test_loop = 100 + out = torch.empty_like(inp) + warmup_loop = 10 + test_loop = 50 # Reduced for CI speed - torch.cuda.synchronize() + for _ in range(warmup_loop): + custom_ops.custom_reduce(custom_state["custom_ptr"], inp, out) dist.barrier(group=group) + torch.cuda.synchronize(device=f"cuda:{rank}") start_custom = time.time() for _ in range(test_loop): - custom_ops.custom_reduce(worker_state["custom_ptr"], inp1, out1) - torch.cuda.synchronize() + custom_ops.custom_reduce(custom_state["custom_ptr"], inp, out) dist.barrier(group=group) + torch.cuda.synchronize(device=f"cuda:{rank}") elapse_custom = time.time() - start_custom - torch.cuda.synchronize() + for _ in range(warmup_loop): + vllm_ops.all_reduce( + vllm_state["vllm_ptr"], + inp, + out, + vllm_state["vllm_buffer_ptrs"][rank], + vllm_state["vllm_max_size"], + ) dist.barrier(group=group) + torch.cuda.synchronize(device=f"cuda:{rank}") start_vllm = time.time() for _ in range(test_loop): vllm_ops.all_reduce( - worker_state["vllm_ptr"], - inp1, - out1, - worker_state["vllm_buffer_ptrs"][rank], - VLLM_MAX_SIZE, + vllm_state["vllm_ptr"], + inp, + out, + vllm_state["vllm_buffer_ptrs"][rank], + vllm_state["vllm_max_size"], ) - torch.cuda.synchronize() dist.barrier(group=group) + torch.cuda.synchronize(device=f"cuda:{rank}") elapse_vllm = time.time() - start_vllm if rank == 0: logger.warning( - f"PERF: sz={sz}, world={world_size}, " - f"vllm={elapse_vllm * 1000 / test_loop:.4f}ms, " - f"custom={elapse_custom * 1000 / test_loop:.4f}ms" + f"Perf Test: size={sz}, world={world_size}, " + f"vLLM avg={(elapse_vllm * 1000 / test_loop):.4f}ms, " + f"Custom avg={(elapse_custom * 1000 / test_loop):.4f}ms" ) + dist.barrier(group=group) finally: - if "vllm_ptr" in worker_state and worker_state["vllm_ptr"]: - vllm_ops.dispose(worker_state["vllm_ptr"]) - free_shared_buffer(worker_state.get("vllm_meta_ptrs", []), group) - free_shared_buffer(worker_state.get("vllm_buffer_ptrs", []), group) - - if "custom_ptr" in worker_state and worker_state["custom_ptr"]: - custom_ops.custom_dispose(worker_state["custom_ptr"]) - free_shared_buffer(worker_state.get("buffer_ptrs", []), group) - free_shared_buffer(worker_state.get("tmp_result_buffer_ptrs", []), group) - free_shared_buffer(worker_state.get("barrier_in_ptrs", []), group) - free_shared_buffer(worker_state.get("barrier_out_ptrs", []), group) + if "custom_ptr" in custom_state and custom_state["custom_ptr"] is not None: + custom_ops.custom_dispose(custom_state["custom_ptr"]) + if "buffer_ptrs" in custom_state: + _free_shared_buffer(custom_state["buffer_ptrs"], group) + if "tmp_result_buffer_ptrs" in custom_state: + _free_shared_buffer(custom_state["tmp_result_buffer_ptrs"], group) + if "barrier_in_ptrs" in custom_state: + _free_shared_buffer(custom_state["barrier_in_ptrs"], group) + if "barrier_out_ptrs" in custom_state: + _free_shared_buffer(custom_state["barrier_out_ptrs"], group) + + if "vllm_ptr" in vllm_state and vllm_state["vllm_ptr"] is not None: + vllm_ops.dispose(vllm_state["vllm_ptr"]) + if "vllm_meta_ptrs" in vllm_state: + _free_shared_buffer(vllm_state["vllm_meta_ptrs"], group) + if "vllm_buffer_ptrs" in vllm_state: + _free_shared_buffer(vllm_state["vllm_buffer_ptrs"], group) + dist.barrier(group=group) if dist.is_initialized(): - dist.destroy_process_group() - + dist.destroy_process_group(group=group) -class TestCustomAllReduce: - @pytest.fixture(scope="class", autouse=True) - def ray_controller(self): - if not ray.is_initialized(): - ray.init(log_to_driver=False, ignore_reinit_error=True) - yield +def _multi_process_parallel(world_size: int, target_func: Any, args: Tuple): + distributed_init_port = get_open_port() + refs = [] + for rank in range(world_size): + task_args = (world_size, rank, distributed_init_port) + args + refs.append(target_func.remote(*task_args)) + try: + results = ray.get(refs) + return results + except ray.exceptions.RayTaskError as e: + logger.error(f"Ray task failed: {e}") + raise e + except Exception as e: + logger.error(f"An unexpected error occurred during ray.get: {e}") + raise e + + +class TestCustomAllReduce(unittest.TestCase): + + @classmethod + def setUpClass(cls): + random.seed(42) + cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152] + cls.world_sizes = [2, 4, 8] + ray.init(log_to_driver=True, num_cpus=1, include_dashboard=False) + + @classmethod + def tearDownClass(cls): ray.shutdown() - def test_correctness(self, request): - node_world_sizes = WORLD_SIZES - num_gpus = torch.cuda.device_count() - logger.info(f"Detected {num_gpus} GPUs for correctness test.") - - for world_size in node_world_sizes: - if world_size > num_gpus: - pytest.skip( - f"Skipping world_size={world_size}, requires {world_size} GPUs, found {num_gpus}" + def test_correctness(self): + for world_size in self.world_sizes: + if world_size > torch.cuda.device_count(): + logger.warning( + f"Skipping correctness test for world_size={world_size} due to insufficient GPUs ({torch.cuda.device_count()})." ) continue + _multi_process_parallel( + world_size, run_correctness_task, (self.test_sizes,) + ) - logger.info(f"Running correctness test with world_size={world_size}") - distributed_init_port = get_open_port() - refs = [ - correctness_worker.remote(rank, world_size, distributed_init_port) - for rank in range(world_size) - ] - try: - ray.get(refs) - logger.info(f"Correctness test PASSED for world_size={world_size}") - except Exception as e: - logger.error( - f"Correctness test FAILED for world_size={world_size}: {e}" - ) - pytest.fail(f"Correctness test failed for world_size={world_size}: {e}") - - def test_performance(self, request): - node_world_sizes = WORLD_SIZES - num_gpus = torch.cuda.device_count() - logger.info(f"Detected {num_gpus} GPUs for performance test.") - - for world_size in node_world_sizes: - if world_size > num_gpus: - pytest.skip( - f"Skipping world_size={world_size}, requires {world_size} GPUs, found {num_gpus}" + def test_performance(self): + for world_size in self.world_sizes: + if world_size > torch.cuda.device_count(): + logger.warning( + f"Skipping performance test for world_size={world_size} due to insufficient GPUs ({torch.cuda.device_count()})." ) continue - - logger.info(f"Running performance test with world_size={world_size}") - distributed_init_port = get_open_port() - refs = [ - performance_worker.remote(rank, world_size, distributed_init_port) - for rank in range(world_size) - ] - try: - ray.get(refs) - logger.info(f"Performance test COMPLETED for world_size={world_size}") - except Exception as e: - logger.error( - f"Performance test FAILED for world_size={world_size}: {e}" - ) - pytest.fail(f"Performance test failed for world_size={world_size}: {e}") + _multi_process_parallel( + world_size, run_performance_task, (self.test_sizes,) + ) if __name__ == "__main__": - pytest.main([__file__]) + unittest.main() From b63fde944e38da8b14152a21b17443a211c9a0a8 Mon Sep 17 00:00:00 2001 From: adarshxs Date: Sun, 30 Mar 2025 11:58:00 +0530 Subject: [PATCH 05/11] fix lint --- sgl-kernel/tests/test_trt_allreduce.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sgl-kernel/tests/test_trt_allreduce.py b/sgl-kernel/tests/test_trt_allreduce.py index 87b97fb2da9..910bcb253a6 100644 --- a/sgl-kernel/tests/test_trt_allreduce.py +++ b/sgl-kernel/tests/test_trt_allreduce.py @@ -165,4 +165,3 @@ def correctness(self, world_size, rank, distributed_init_port): if __name__ == "__main__": unittest.main() - \ No newline at end of file From d6767f8f813079bbb37b9f6055ef57a7edf98d68 Mon Sep 17 00:00:00 2001 From: adarshxs Date: Sun, 30 Mar 2025 12:19:14 +0530 Subject: [PATCH 06/11] fix --- sgl-kernel/tests/test_cublas_grouped_gemm.py | 42 ++-- sgl-kernel/tests/test_deep_gemm.py | 232 ------------------- sgl-kernel/tests/test_fp8_blockwise_gemm.py | 14 ++ 3 files changed, 35 insertions(+), 253 deletions(-) delete mode 100644 sgl-kernel/tests/test_deep_gemm.py diff --git a/sgl-kernel/tests/test_cublas_grouped_gemm.py b/sgl-kernel/tests/test_cublas_grouped_gemm.py index 744fb17fd9e..a20b6e33648 100644 --- a/sgl-kernel/tests/test_cublas_grouped_gemm.py +++ b/sgl-kernel/tests/test_cublas_grouped_gemm.py @@ -7,42 +7,42 @@ def torch_grouped_gemm(a_array, b_array, out_dtype): return [torch.matmul(a, b.t()).to(out_dtype) for a, b in zip(a_array, b_array)] -# Skip if CUDA is not available or CUDA version is lower than 12.5 skip_condition = not torch.cuda.is_available() or ( torch.version.cuda is None or tuple(map(int, torch.version.cuda.split("."))) < (12, 5) ) -shape_params = [ - (1, 2, 3), - (16, 16, 16), - (32, 128, 32), - (256, 256, 512), - (1024, 4096, 8192), -] + +m_values = [1, 16, 32, 256, 1024] +n_values = [2, 16, 128, 256, 4096] +k_values = [3, 16, 32, 512, 8192] @pytest.mark.skipif( skip_condition, reason="CUDA not available or CUDA version lower than 12.5" ) @pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("M, N, K", shape_params) +@pytest.mark.parametrize("M", m_values) +@pytest.mark.parametrize("N", n_values) +@pytest.mark.parametrize("K", k_values) def test_grouped_gemm_accuracy(out_dtype, M, N, K): - # Create input matrices for a single GEMM test - a = torch.randn((M, K), device="cuda", dtype=out_dtype) * 5 - b = torch.randn((N, K), device="cuda", dtype=out_dtype) * 5 - expected = torch.matmul(a, b.t()).to(out_dtype) + try: + a = torch.randn((M, K), device="cuda", dtype=out_dtype) * 5 + b = torch.randn((N, K), device="cuda", dtype=out_dtype) * 5 + expected = torch.matmul(a, b.t()).to(out_dtype) + + a_array = [a] + b_array = [b] + c_array = [torch.empty((M, N), device="cuda", dtype=out_dtype)] - a_array = [a] - b_array = [b] - c_array = [torch.empty((M, N), device="cuda", dtype=out_dtype)] + result_torch = torch_grouped_gemm(a_array, b_array, out_dtype)[0] + cublas_grouped_gemm(a_array, b_array, c_array, out_dtype) - result_torch = torch_grouped_gemm(a_array, b_array, out_dtype)[0] - cublas_grouped_gemm(a_array, b_array, c_array, out_dtype) + torch.testing.assert_close(result_torch, expected) + torch.testing.assert_close(c_array[0], expected) - torch.testing.assert_close(result_torch, expected) - torch.testing.assert_close(c_array[0], expected) - print(f"Test passed for M={M}, N={N}, K={K}, out_dtype={out_dtype}") + except torch.cuda.OutOfMemoryError: + pytest.skip(f"Skipping M={M}, N={N}, K={K} due to OOM") if __name__ == "__main__": diff --git a/sgl-kernel/tests/test_deep_gemm.py b/sgl-kernel/tests/test_deep_gemm.py deleted file mode 100644 index e0dbe331ec4..00000000000 --- a/sgl-kernel/tests/test_deep_gemm.py +++ /dev/null @@ -1,232 +0,0 @@ -import os -import random -from typing import Any, Tuple - -import deep_gemm -import pytest -import torch -from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor, jit - - -@pytest.fixture(autouse=True, scope="module") -def setup_deepgemm(): - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - torch.manual_seed(0) - random.seed(0) - print("Library path:") - print(f" > {deep_gemm.__path__}\n") - - -def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 and x.size(1) % 128 == 0 - m, n = x.shape - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( - m, n - ), (x_amax / 448.0).view(m, -1) - - -def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros( - (ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device - ) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( - x_view.size(0), x_view.size(2) - ) - - -def construct(m: int, k: int, n: int) -> Tuple[ - Tuple[torch.Tensor, torch.Tensor], - Tuple[torch.Tensor, torch.Tensor], - torch.Tensor, - torch.Tensor, -]: - x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) - y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) - out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) - ref_out = x @ y.t() - x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) - x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) - return x_fp8, y_fp8, out, ref_out - - -def construct_grouped( - num_groups: int, m: int, k: int, n: int, is_masked: bool -) -> Tuple[ - Tuple[torch.Tensor, torch.Tensor], - Tuple[torch.Tensor, torch.Tensor], - torch.Tensor, - torch.Tensor, -]: - x = torch.randn((num_groups, m, k), device="cuda", dtype=torch.bfloat16) - y = torch.randn((num_groups, n, k), device="cuda", dtype=torch.bfloat16) - out = torch.empty((num_groups, m, n), device="cuda", dtype=torch.bfloat16) - ref_out = torch.einsum("gmk,gnk->gmn", x, y) - assert m % 4 == 0, f"TMA alignment error: {m}" - x_fp8 = ( - torch.empty_like(x, dtype=torch.float8_e4m3fn), - torch.empty((num_groups, m, k // 128), device="cuda", dtype=torch.float), - ) - y_fp8 = ( - torch.empty_like(y, dtype=torch.float8_e4m3fn), - torch.empty( - (num_groups, (n + 127) // 128, k // 128), device="cuda", dtype=torch.float - ), - ) - for i in range(num_groups): - x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) - y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) - if not is_masked: - x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1]) - out, ref_out = out.view(-1, n), ref_out.view(-1, n) - x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) - return x_fp8, y_fp8, out, ref_out - - -@pytest.mark.parametrize( - "m,k,n", - [ - (64, 7168, 2112), - (64, 1536, 24576), - (64, 512, 32768), - (64, 16384, 7168), - (64, 7168, 4096), - (64, 2048, 7168), - (128, 7168, 2112), - (128, 1536, 24576), - (128, 512, 32768), - (128, 16384, 7168), - (128, 7168, 4096), - (128, 2048, 7168), - (4096, 7168, 2112), - (4096, 1536, 24576), - (4096, 512, 32768), - (4096, 16384, 7168), - (4096, 7168, 4096), - (4096, 2048, 7168), - ], -) -def test_gemm(m: int, k: int, n: int): - x_fp8, y_fp8, out, ref_out = construct(m, k, n) - deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) - diff = calc_diff(out, ref_out) - assert diff < 0.001, f"{m=}, {k=}, {n=}, {diff:.5f}" - - -@pytest.mark.parametrize( - "num_groups,m,k,n", - [ - (4, 8192, 7168, 4096), - (4, 8192, 2048, 7168), - (8, 4096, 7168, 4096), - (8, 4096, 2048, 7168), - ], -) -def test_m_grouped_gemm_contiguous(num_groups: int, m: int, k: int, n: int): - x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False) - m_indices = ( - torch.arange(0, num_groups, device="cuda", dtype=torch.int) - .unsqueeze(-1) - .expand(num_groups, m) - .contiguous() - .view(-1) - ) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) - diff = calc_diff(out, ref_out) - assert diff < 0.001, f"m={m * num_groups}, {k=}, {n=}, {diff:.5f}" - - -@pytest.mark.parametrize("num_groups,m", [(1, 1024), (2, 512), (4, 256)]) -@pytest.mark.parametrize("k,n", [(7168, 4096), (2048, 7168)]) -@pytest.mark.parametrize("trial", range(10)) -def test_m_grouped_gemm_masked(num_groups: int, m: int, k: int, n: int, trial: int): - masked_m_candidates = [c for c in (64, 128, 192, 256, 320, 384) if c <= m] - x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True) - masked_m = torch.empty((num_groups,), device="cuda", dtype=torch.int) - for j in range(num_groups): - masked_m[j] = random.choice(masked_m_candidates) - expected_m = min(int(masked_m.float().mean()) + 1, m) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( - x_fp8, y_fp8, out, masked_m, expected_m - ) - for j in range(num_groups): - diff = calc_diff(out[j, : masked_m[j].item()], ref_out[j, : masked_m[j].item()]) - assert ( - diff < 0.001 - ), f"{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}" - - -class Capture: - def __init__(self) -> None: - self.read_fd = None - self.write_fd = None - self.saved_stdout = None - self.captured = None - - def __enter__(self) -> Any: - self.read_fd, self.write_fd = os.pipe() - self.saved_stdout = os.dup(1) - os.dup2(self.write_fd, 1) - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - os.dup2(self.saved_stdout, 1) - os.close(self.write_fd) - with os.fdopen(self.read_fd, "r") as f: - self.captured = f.read() - - def capture(self) -> str: - return self.captured - - -def test_jit(): - print(f"NVCC compiler: {jit.get_nvcc_compiler()}\n") - print("Generated code:") - args = ( - ("lhs", torch.float8_e4m3fn), - ("rhs", torch.float8_e4m3fn), - ("scale", torch.float), - ("out", torch.bfloat16), - ("enable_double_streams", bool), - ("stream", torch.cuda.Stream), - ) - body = "\n" - body += "std::cout << reinterpret_cast(lhs) << std::endl;\n" - body += "std::cout << reinterpret_cast(rhs) << std::endl;\n" - body += "std::cout << reinterpret_cast(scale) << std::endl;\n" - body += "std::cout << reinterpret_cast(out) << std::endl;\n" - body += "std::cout << enable_double_streams << std::endl;\n" - body += "std::cout << reinterpret_cast(stream) << std::endl;\n" - code = jit.generate((), args, body) - print(code) - print("Building ...") - func = jit.build("test_func", args, code) - print("Running ...") - fp8_tensor = torch.empty((1,), dtype=torch.float8_e4m3fn, device="cuda") - fp32_tensor = torch.empty((1,), dtype=torch.float, device="cuda") - bf16_tensor = torch.empty((1,), dtype=torch.bfloat16, device="cuda") - with Capture() as capture: - ret_val = func( - fp8_tensor, - fp8_tensor, - fp32_tensor, - bf16_tensor, - True, - torch.cuda.current_stream(), - ) - assert ret_val == 0, "Function did not return 0" - output = capture.capture() - ref_output = f"{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n" - assert output == ref_output, f"{output=}, {ref_output=}" - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_fp8_blockwise_gemm.py b/sgl-kernel/tests/test_fp8_blockwise_gemm.py index d432d974e4e..c9ca0135055 100644 --- a/sgl-kernel/tests/test_fp8_blockwise_gemm.py +++ b/sgl-kernel/tests/test_fp8_blockwise_gemm.py @@ -24,6 +24,20 @@ def baseline_scaled_mm( out_dtype: Type[torch.dtype], bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + # We treat N-dimensional group scaling as extended numpy-style broadcasting + # in numpy simply stretches dimensions with an extent of 1 to match the + # the target shape by repeating the data along that dimension (broadcasting) + # , we extend these semantics to say if the extent of a dimension in the + # source shape is not 1 and does not match the target shape we repeat each + # element along that dimension src_shape[dim] // target_shape[dim] times + # example if we have: + # a = [[1, 2], and target_shape = (2, 4) + # [3, 4]] + # then we would expand a to: + # a = [[1, 1, 2, 2], + # [3, 3, 4, 4]] + # NOTE this function this function does not explicitly broadcast dimensions + # with an extent of 1, since this can be done implicitly by pytorch def group_broadcast(t, shape): for i, s in enumerate(shape): if t.shape[i] != s and t.shape[i] != 1: From f6d99314bbf5b76acd91761d729e99b1d1788fde Mon Sep 17 00:00:00 2001 From: adarshxs Date: Sun, 30 Mar 2025 12:29:05 +0530 Subject: [PATCH 07/11] fix --- sgl-kernel/tests/test_cublas_grouped_gemm.py | 35 ++++++++------------ 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/sgl-kernel/tests/test_cublas_grouped_gemm.py b/sgl-kernel/tests/test_cublas_grouped_gemm.py index a20b6e33648..70b3dc5cf25 100644 --- a/sgl-kernel/tests/test_cublas_grouped_gemm.py +++ b/sgl-kernel/tests/test_cublas_grouped_gemm.py @@ -13,36 +13,27 @@ def torch_grouped_gemm(a_array, b_array, out_dtype): ) -m_values = [1, 16, 32, 256, 1024] -n_values = [2, 16, 128, 256, 4096] -k_values = [3, 16, 32, 512, 8192] - - @pytest.mark.skipif( skip_condition, reason="CUDA not available or CUDA version lower than 12.5" ) @pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("M", m_values) -@pytest.mark.parametrize("N", n_values) -@pytest.mark.parametrize("K", k_values) +@pytest.mark.parametrize("M", [1, 16, 32, 256, 1024]) +@pytest.mark.parametrize("N", [2, 16, 128, 256, 4096]) +@pytest.mark.parametrize("K", [3, 16, 32, 512, 8192]) def test_grouped_gemm_accuracy(out_dtype, M, N, K): - try: - a = torch.randn((M, K), device="cuda", dtype=out_dtype) * 5 - b = torch.randn((N, K), device="cuda", dtype=out_dtype) * 5 - expected = torch.matmul(a, b.t()).to(out_dtype) - - a_array = [a] - b_array = [b] - c_array = [torch.empty((M, N), device="cuda", dtype=out_dtype)] + a = torch.randn((M, K), device="cuda", dtype=out_dtype) * 5 + b = torch.randn((N, K), device="cuda", dtype=out_dtype) * 5 + expected = torch.matmul(a, b.t()).to(out_dtype) - result_torch = torch_grouped_gemm(a_array, b_array, out_dtype)[0] - cublas_grouped_gemm(a_array, b_array, c_array, out_dtype) + a_array = [a] + b_array = [b] + c_array = [torch.empty((M, N), device="cuda", dtype=out_dtype)] - torch.testing.assert_close(result_torch, expected) - torch.testing.assert_close(c_array[0], expected) + result_torch = torch_grouped_gemm(a_array, b_array, out_dtype)[0] + cublas_grouped_gemm(a_array, b_array, c_array, out_dtype) - except torch.cuda.OutOfMemoryError: - pytest.skip(f"Skipping M={M}, N={N}, K={K} due to OOM") + torch.testing.assert_close(result_torch, expected) + torch.testing.assert_close(c_array[0], expected) if __name__ == "__main__": From 7e53a513323343e069beed56cf5368346e6d48bf Mon Sep 17 00:00:00 2001 From: adarshxs Date: Sun, 30 Mar 2025 13:05:41 +0530 Subject: [PATCH 08/11] pass CI --- sgl-kernel/tests/test_trt_allreduce.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sgl-kernel/tests/test_trt_allreduce.py b/sgl-kernel/tests/test_trt_allreduce.py index 910bcb253a6..2f0b808d5d9 100644 --- a/sgl-kernel/tests/test_trt_allreduce.py +++ b/sgl-kernel/tests/test_trt_allreduce.py @@ -30,6 +30,7 @@ def multi_process_parallel( world_size: int, test_target: Any, ) -> None: + mp.set_start_method("spawn", force=True) procs = [] distributed_init_port = get_open_port() for i in range(world_size): From c1de3405850cd51de1a5eee8aeb8f2cbd74a46d7 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 30 Mar 2025 07:37:35 +0000 Subject: [PATCH 09/11] upd --- .github/workflows/pr-test-sgl-kernel.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml index 91fa78fc000..c87f8d548b2 100644 --- a/.github/workflows/pr-test-sgl-kernel.yml +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -80,7 +80,8 @@ jobs: - name: Install run: | - pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.6.4.post1 + bash scripts/ci_install_dependency.sh + pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.7.2 pip3 uninstall sgl-kernel -y || true pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps pip3 list | grep sgl-kernel From 11d0d4b0e2904458ece365f35f12a0d563210e7c Mon Sep 17 00:00:00 2001 From: adarshxs Date: Sun, 30 Mar 2025 13:05:41 +0530 Subject: [PATCH 10/11] pass CI --- sgl-kernel/tests/test_trt_allreduce.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sgl-kernel/tests/test_trt_allreduce.py b/sgl-kernel/tests/test_trt_allreduce.py index 910bcb253a6..2f0b808d5d9 100644 --- a/sgl-kernel/tests/test_trt_allreduce.py +++ b/sgl-kernel/tests/test_trt_allreduce.py @@ -30,6 +30,7 @@ def multi_process_parallel( world_size: int, test_target: Any, ) -> None: + mp.set_start_method("spawn", force=True) procs = [] distributed_init_port = get_open_port() for i in range(world_size): From 3c69014d95f8446a5d8aeebba7421b7ef009387b Mon Sep 17 00:00:00 2001 From: adarshxs Date: Sun, 30 Mar 2025 13:44:49 +0530 Subject: [PATCH 11/11] add top level func --- sgl-kernel/tests/test_trt_allreduce.py | 224 ++++++++++++++----------- 1 file changed, 127 insertions(+), 97 deletions(-) diff --git a/sgl-kernel/tests/test_trt_allreduce.py b/sgl-kernel/tests/test_trt_allreduce.py index 2f0b808d5d9..242f226be56 100644 --- a/sgl-kernel/tests/test_trt_allreduce.py +++ b/sgl-kernel/tests/test_trt_allreduce.py @@ -13,155 +13,185 @@ from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary +def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes): + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + ranks = list(range(world_size)) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + dist.init_process_group( + backend="nccl", + init_method=distributed_init_method, + rank=rank, + world_size=world_size, + ) + group = dist.group.WORLD + + buffer_max_size = 8 * 1024 * 1024 + barrier_max_size = 8 * (24 + 2) * 8 + buffer_ptrs = None + tmp_result_buffer_ptrs = None + barrier_in_ptrs = None + barrier_out_ptrs = None + custom_ptr = None + + try: + buffer_ptrs = TestCustomAllReduce.create_shared_buffer( + buffer_max_size, group=group + ) + tmp_result_buffer_ptrs = TestCustomAllReduce.create_shared_buffer( + buffer_max_size, group=group + ) + barrier_in_ptrs = TestCustomAllReduce.create_shared_buffer( + barrier_max_size, group=group + ) + barrier_out_ptrs = TestCustomAllReduce.create_shared_buffer( + barrier_max_size, group=group + ) + + rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device) + + custom_ptr = custom_ops.init_custom_reduce( + rank, + world_size, + rank_data, + buffer_ptrs, + tmp_result_buffer_ptrs, + barrier_in_ptrs, + barrier_out_ptrs, + ) + + test_loop = 10 + for sz in test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + for _ in range(test_loop): + inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device) + inp1_ref = inp1.clone() + out1 = torch.empty_like(inp1) + + custom_ops.custom_reduce(custom_ptr, inp1, out1) + + dist.all_reduce(inp1_ref, group=group) + + torch.testing.assert_close(out1, inp1_ref) + + finally: + dist.barrier(group=group) + if custom_ptr is not None: + custom_ops.custom_dispose(custom_ptr) + if buffer_ptrs: + TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group) + if tmp_result_buffer_ptrs: + TestCustomAllReduce.free_shared_buffer(tmp_result_buffer_ptrs, group) + if barrier_in_ptrs: + TestCustomAllReduce.free_shared_buffer(barrier_in_ptrs, group) + if barrier_out_ptrs: + TestCustomAllReduce.free_shared_buffer(barrier_out_ptrs, group) + + dist.destroy_process_group(group=group) + + def get_open_port() -> int: - # try ipv4 try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) return s.getsockname()[1] except OSError: - # try ipv6 with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: - s.bind(("127.0.0.1", 0)) + s.bind(("::1", 0)) return s.getsockname()[1] def multi_process_parallel( - world_size: int, - test_target: Any, + world_size: int, test_target: Any, target_args: tuple = () ) -> None: mp.set_start_method("spawn", force=True) + procs = [] distributed_init_port = get_open_port() for i in range(world_size): - proc = mp.Process( - target=test_target, - args=(world_size, i, distributed_init_port), - ) + proc_args = (world_size, i, distributed_init_port) + target_args + proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") proc.start() procs.append(proc) for i in range(world_size): procs[i].join() - assert procs[i].exitcode == 0 + assert ( + procs[i].exitcode == 0 + ), f"Process {i} failed with exit code {procs[i].exitcode}" class TestCustomAllReduce(unittest.TestCase): - @classmethod - def setUpClass(cls): - random.seed(42) - cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152] - cls.world_sizes = [2, 4, 8] + test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152] + world_sizes = [2, 4, 8] @staticmethod def create_shared_buffer( size_in_bytes: int, group: Optional[ProcessGroup] = None ) -> List[int]: - """ - Creates a shared buffer and returns a list of pointers - representing the buffer on all processes in the group. - """ lib = CudaRTLibrary() pointer = lib.cudaMalloc(size_in_bytes) handle = lib.cudaIpcGetMemHandle(pointer) + if group is None: + group = dist.group.WORLD world_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) - handles = [None] * world_size - dist.all_gather_object(handles, handle, group=group) + + handle_bytes = ctypes.string_at(ctypes.addressof(handle), ctypes.sizeof(handle)) + input_tensor = torch.ByteTensor(list(handle_bytes)).to(f"cuda:{rank}") + gathered_tensors = [torch.empty_like(input_tensor) for _ in range(world_size)] + dist.all_gather(gathered_tensors, input_tensor, group=group) + + handles = [] + handle_type = type(handle) + for tensor in gathered_tensors: + bytes_list = tensor.cpu().tolist() + bytes_data = bytes(bytes_list) + handle_obj = handle_type() + ctypes.memmove(ctypes.addressof(handle_obj), bytes_data, len(bytes_data)) + handles.append(handle_obj) pointers: List[int] = [] for i, h in enumerate(handles): if i == rank: - pointers.append(pointer.value) # type: ignore + pointers.append(pointer.value) else: - pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore - + try: + opened_ptr = lib.cudaIpcOpenMemHandle(h) + pointers.append(opened_ptr.value) + except Exception as e: + print(f"Rank {rank}: Failed to open IPC handle from rank {i}: {e}") + raise + + dist.barrier(group=group) return pointers @staticmethod def free_shared_buffer( pointers: List[int], group: Optional[ProcessGroup] = None ) -> None: + if group is None: + group = dist.group.WORLD rank = dist.get_rank(group=group) lib = CudaRTLibrary() - lib.cudaFree(ctypes.c_void_p(pointers[rank])) + if pointers and len(pointers) > rank and pointers[rank] is not None: + lib.cudaFree(ctypes.c_void_p(pointers[rank])) + dist.barrier(group=group) def test_correctness(self): for world_size in self.world_sizes: - if world_size > torch.cuda.device_count(): + available_gpus = torch.cuda.device_count() + if world_size > available_gpus: + print( + f"Skipping world_size={world_size}, requires {world_size} GPUs, found {available_gpus}" + ) continue - multi_process_parallel(world_size, self.correctness) - print(f"custom allreduce tp = {world_size}: OK") - - def init_custom_allreduce(self, rank, world_size, group): - buffer_max_size = 8 * 1024 * 1024 - barrier_max_size = 8 * (24 + 2) * 8 - - self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group) - self.tmp_result_buffer_ptrs = self.create_shared_buffer( - buffer_max_size, group=group - ) - self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group) - self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group) - self.rank_data = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0") - ) - - self.custom_ptr = custom_ops.init_custom_reduce( - rank, - world_size, - self.rank_data, - self.buffer_ptrs, - self.tmp_result_buffer_ptrs, - self.barrier_in_ptrs, - self.barrier_out_ptrs, - ) - - def custom_allreduce(self, inp, out): - custom_ops.custom_reduce(self.custom_ptr, inp, out) - def free_custom_allreduce(self, group): - self.free_shared_buffer(self.buffer_ptrs, group) - self.free_shared_buffer(self.tmp_result_buffer_ptrs, group) - self.free_shared_buffer(self.barrier_in_ptrs, group) - self.free_shared_buffer(self.barrier_out_ptrs, group) - custom_ops.custom_dispose(self.custom_ptr) - - @staticmethod - def init_distributed_env(world_size, rank, distributed_init_port): - device = torch.device("cuda:0") - torch.cuda.set_device(device) - ranks = [i for i in range(world_size)] - distributed_init_method = f"tcp://localhost:{distributed_init_port}" - dist.init_process_group( - backend="nccl", - init_method=distributed_init_method, - rank=rank, - world_size=world_size, - ) - group = torch.distributed.new_group(ranks, backend="gloo") - return group - - # compare result with torch.distributed - def correctness(self, world_size, rank, distributed_init_port): - group = self.init_distributed_env(world_size, rank, distributed_init_port) - - self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) - - test_loop = 10 - for sz in self.test_sizes: - for dtype in [torch.float32, torch.float16, torch.bfloat16]: - for _ in range(test_loop): - inp1 = torch.randint( - 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() - ) - out1 = torch.empty_like(inp1) - self.custom_allreduce(inp1, out1) - - dist.all_reduce(inp1, group=group) - torch.testing.assert_close(out1, inp1) - - self.free_custom_allreduce(group) + print(f"Running test for world_size={world_size}") + multi_process_parallel( + world_size, _run_correctness_worker, target_args=(self.test_sizes,) + ) + print(f"custom allreduce tp = {world_size}: OK") if __name__ == "__main__":