diff --git a/python/sglang/jit_kernel/all_reduce.py b/python/sglang/jit_kernel/all_reduce.py index dd02100822d3..48f763259beb 100644 --- a/python/sglang/jit_kernel/all_reduce.py +++ b/python/sglang/jit_kernel/all_reduce.py @@ -95,7 +95,7 @@ def config_pull( def _jit_custom_all_reduce_pull_module(dtype: torch.dtype, world_size: int): args = make_cpp_args(dtype, world_size, is_arch_support_pdl()) return load_jit( - "custom_all_reduce", + "custom_all_reduce_pull", *args, extra_ldflags=["-lcuda"], cuda_files=["distributed/custom_all_reduce_pull.cuh"], @@ -107,7 +107,7 @@ def _jit_custom_all_reduce_pull_module(dtype: torch.dtype, world_size: int): def _jit_custom_all_reduce_push_module(dtype: torch.dtype, world_size: int): args = make_cpp_args(dtype, world_size, is_arch_support_pdl()) return load_jit( - "custom_all_reduce", + "custom_all_reduce_push", *args, extra_ldflags=["-lcuda"], cuda_files=["distributed/custom_all_reduce_push.cuh"], diff --git a/python/sglang/jit_kernel/tests/test_custom_all_reduce.py b/python/sglang/jit_kernel/tests/test_custom_all_reduce.py index 365761ddfe88..2d7e0253eb43 100644 --- a/python/sglang/jit_kernel/tests/test_custom_all_reduce.py +++ b/python/sglang/jit_kernel/tests/test_custom_all_reduce.py @@ -16,29 +16,33 @@ import itertools import logging +import multiprocessing as mp import os import subprocess import sys -from typing import Optional +from typing import Dict, Optional, Tuple import pytest import torch import torch.distributed as dist -from tqdm import tqdm import sglang.srt.distributed.parallel_state as ps -from sglang.jit_kernel.all_reduce import AllReduceAlgo +from sglang.jit_kernel.all_reduce import ( + AllReduceAlgo, + _jit_custom_all_reduce_pull_module, + _jit_custom_all_reduce_push_module, +) from sglang.srt.distributed.device_communicators.custom_all_reduce_v2 import ( CustomAllReduceV2, ) from sglang.test.ci.ci_register import register_cuda_ci register_cuda_ci( - est_time=500, + est_time=300, suite="stage-b-kernel-unit-8-gpu-h200", ) register_cuda_ci( - est_time=500, + est_time=300, suite="nightly-kernel-8-gpu-h200", nightly=True, ) @@ -67,7 +71,7 @@ ] USE_GRAPH_OPTIONS = [True, False] TEST_CONFIG = itertools.product(TEST_SIZES, TEST_DTYPES, SHOTS, USE_GRAPH_OPTIONS) -TEST_LAYERS = 2 +TEST_LAYERS = 4 TEST_LOOP = 16 # --------------------------------------------------------------------------- @@ -75,14 +79,13 @@ # --------------------------------------------------------------------------- -def run_torchrun(nproc: int, timeout: int = 300) -> None: +def _run_torchrun(nproc: int, timeout: int = 300) -> None: """Launch this script as a torchrun worker and assert success.""" cmd = [ "torchrun", f"--nproc_per_node={nproc}", __file__, ] - os.environ["DISABLE_PBAR"] = "1" result = subprocess.run( cmd, stdout=subprocess.PIPE, @@ -96,14 +99,37 @@ def run_torchrun(nproc: int, timeout: int = 300) -> None: ) -@pytest.mark.parametrize("nproc", [2, 3, 4, 5, 6, 7, 8]) +def _compile_one(dtype: torch.dtype, world_size: int): + _jit_custom_all_reduce_push_module(dtype, world_size) + _jit_custom_all_reduce_pull_module(dtype, world_size) + + +def _precompile_kernels() -> None: + # NOTE: even when device count < 8, we should be able to compile all + process_map: Dict[Tuple[torch.dtype, int], mp.Process] = {} + COMPILE_SPACE = itertools.product(TEST_DTYPES, [2, 3, 4, 5, 6, 7, 8]) + mp.set_start_method("spawn") + for config in COMPILE_SPACE: + process_map[config] = mp.Process(target=_compile_one, args=config) + for process in process_map.values(): + process.start() + for (dtype, world_size), process in process_map.items(): + process.join() + if process.exitcode != 0: + raise RuntimeError(f"Custom All Reduce {world_size=} {dtype=} failed") + + +@pytest.mark.parametrize("nproc", [1, 2, 3, 4, 5, 6, 7, 8]) def test_custom_allreduce(nproc: int) -> None: + if nproc == 1: # NOTE: special case to speed up tests + return _precompile_kernels() + device_count = torch.cuda.device_count() if device_count < nproc: pytest.skip( f"Requires at least {nproc} GPUs, but only {device_count} available" ) - run_torchrun(nproc) + _run_torchrun(nproc) # --------------------------------------------------------------------------- @@ -192,8 +218,6 @@ def run_eager(x: torch.Tensor) -> torch.Tensor: dist.all_reduce(out_ref, group=nccl_group) out_jit = run_fn(inp) num_errors += not torch.all(out_jit == out_ref) - torch.cuda.synchronize() - nccl_group.barrier().wait() if num_errors > 0: return RuntimeError( f"Test failed for {size=}, {dtype=}, {algo=}, " @@ -211,9 +235,7 @@ def worker_main() -> None: logging.disable(logging.INFO) # Suppress internal logging for cleaner test output items = list(enumerate(TEST_CONFIG)) - disable_pbar = os.environ.get("DISABLE_PBAR", "0") == "1" or rank != 0 - pbar = tqdm(items, desc=f"Testing {world_size} GPUs", disable=disable_pbar) - for i, (size, dtype, algo, use_graph) in pbar: + for i, (size, dtype, algo, use_graph) in items: error = worker_test(device, nccl_group, comm, size, dtype, use_graph, algo) if error is not None: print( @@ -222,7 +244,7 @@ def worker_main() -> None: f"Error: {error}" ) # communicate the result to rank 0 for logging - result = torch.tensor([int(error is not None)], device=device) + result = torch.tensor([int(error is not None)]) dist.all_reduce(result, group=cpu_group) failed = bool(result.item()) if failed: @@ -239,4 +261,4 @@ def worker_main() -> None: if "LOCAL_RANK" in os.environ: worker_main() else: - sys.exit(pytest.main([__file__, "-v", "-s"])) + sys.exit(pytest.main([__file__, "-x", "-vv", "-s"]))