Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/sglang/jit_kernel/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def config_pull(
def _jit_custom_all_reduce_pull_module(dtype: torch.dtype, world_size: int):
args = make_cpp_args(dtype, world_size, is_arch_support_pdl())
return load_jit(
"custom_all_reduce",
"custom_all_reduce_pull",
*args,
extra_ldflags=["-lcuda"],
cuda_files=["distributed/custom_all_reduce_pull.cuh"],
Expand All @@ -107,7 +107,7 @@ def _jit_custom_all_reduce_pull_module(dtype: torch.dtype, world_size: int):
def _jit_custom_all_reduce_push_module(dtype: torch.dtype, world_size: int):
args = make_cpp_args(dtype, world_size, is_arch_support_pdl())
return load_jit(
"custom_all_reduce",
"custom_all_reduce_push",
*args,
extra_ldflags=["-lcuda"],
cuda_files=["distributed/custom_all_reduce_push.cuh"],
Expand Down
56 changes: 39 additions & 17 deletions python/sglang/jit_kernel/tests/test_custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,33 @@

import itertools
import logging
import multiprocessing as mp
import os
import subprocess
import sys
from typing import Optional
from typing import Dict, Optional, Tuple

import pytest
import torch
import torch.distributed as dist
from tqdm import tqdm

import sglang.srt.distributed.parallel_state as ps
from sglang.jit_kernel.all_reduce import AllReduceAlgo
from sglang.jit_kernel.all_reduce import (
AllReduceAlgo,
_jit_custom_all_reduce_pull_module,
_jit_custom_all_reduce_push_module,
)
from sglang.srt.distributed.device_communicators.custom_all_reduce_v2 import (
CustomAllReduceV2,
)
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(
est_time=500,
est_time=300,
suite="stage-b-kernel-unit-8-gpu-h200",
)
register_cuda_ci(
est_time=500,
est_time=300,
suite="nightly-kernel-8-gpu-h200",
nightly=True,
)
Expand Down Expand Up @@ -67,22 +71,21 @@
]
USE_GRAPH_OPTIONS = [True, False]
TEST_CONFIG = itertools.product(TEST_SIZES, TEST_DTYPES, SHOTS, USE_GRAPH_OPTIONS)
TEST_LAYERS = 2
TEST_LAYERS = 4
TEST_LOOP = 16

# ---------------------------------------------------------------------------
# Test class (runs via pytest, launches torchrun subprocesses)
# ---------------------------------------------------------------------------


def run_torchrun(nproc: int, timeout: int = 300) -> None:
def _run_torchrun(nproc: int, timeout: int = 300) -> None:
"""Launch this script as a torchrun worker and assert success."""
cmd = [
"torchrun",
f"--nproc_per_node={nproc}",
__file__,
]
os.environ["DISABLE_PBAR"] = "1"
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
Expand All @@ -96,14 +99,37 @@ def run_torchrun(nproc: int, timeout: int = 300) -> None:
)


@pytest.mark.parametrize("nproc", [2, 3, 4, 5, 6, 7, 8])
def _compile_one(dtype: torch.dtype, world_size: int):
_jit_custom_all_reduce_push_module(dtype, world_size)
_jit_custom_all_reduce_pull_module(dtype, world_size)


def _precompile_kernels() -> None:
# NOTE: even when device count < 8, we should be able to compile all
process_map: Dict[Tuple[torch.dtype, int], mp.Process] = {}
COMPILE_SPACE = itertools.product(TEST_DTYPES, [2, 3, 4, 5, 6, 7, 8])
mp.set_start_method("spawn")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling mp.set_start_method("spawn") can raise a RuntimeError if the start method has already been set. Since this function may be called multiple times within the same test process (due to pytest parameterization), this could lead to test failures depending on the execution order. To make this more robust, you should handle the case where the start method is already set.

Suggested change
mp.set_start_method("spawn")
try:
mp.set_start_method("spawn")
except RuntimeError:
# The start method has already been set, which is fine.
pass

for config in COMPILE_SPACE:
process_map[config] = mp.Process(target=_compile_one, args=config)
for process in process_map.values():
process.start()
for (dtype, world_size), process in process_map.items():
process.join()
if process.exitcode != 0:
raise RuntimeError(f"Custom All Reduce {world_size=} {dtype=} failed")


@pytest.mark.parametrize("nproc", [1, 2, 3, 4, 5, 6, 7, 8])
def test_custom_allreduce(nproc: int) -> None:
if nproc == 1: # NOTE: special case to speed up tests
return _precompile_kernels()

device_count = torch.cuda.device_count()
if device_count < nproc:
pytest.skip(
f"Requires at least {nproc} GPUs, but only {device_count} available"
)
run_torchrun(nproc)
_run_torchrun(nproc)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -192,8 +218,6 @@ def run_eager(x: torch.Tensor) -> torch.Tensor:
dist.all_reduce(out_ref, group=nccl_group)
out_jit = run_fn(inp)
num_errors += not torch.all(out_jit == out_ref)
torch.cuda.synchronize()
nccl_group.barrier().wait()
if num_errors > 0:
return RuntimeError(
f"Test failed for {size=}, {dtype=}, {algo=}, "
Expand All @@ -211,9 +235,7 @@ def worker_main() -> None:

logging.disable(logging.INFO) # Suppress internal logging for cleaner test output
items = list(enumerate(TEST_CONFIG))
disable_pbar = os.environ.get("DISABLE_PBAR", "0") == "1" or rank != 0
pbar = tqdm(items, desc=f"Testing {world_size} GPUs", disable=disable_pbar)
for i, (size, dtype, algo, use_graph) in pbar:
for i, (size, dtype, algo, use_graph) in items:
error = worker_test(device, nccl_group, comm, size, dtype, use_graph, algo)
if error is not None:
print(
Expand All @@ -222,7 +244,7 @@ def worker_main() -> None:
f"Error: {error}"
)
# communicate the result to rank 0 for logging
result = torch.tensor([int(error is not None)], device=device)
result = torch.tensor([int(error is not None)])
dist.all_reduce(result, group=cpu_group)
failed = bool(result.item())
if failed:
Expand All @@ -239,4 +261,4 @@ def worker_main() -> None:
if "LOCAL_RANK" in os.environ:
worker_main()
else:
sys.exit(pytest.main([__file__, "-v", "-s"]))
sys.exit(pytest.main([__file__, "-x", "-vv", "-s"]))
Loading