Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/nightly-test-nvidia.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ on:
- 'nightly-test-multimodal-server-2-gpu'
- 'nightly-test-perf-4-gpu-b200'
- 'nightly-test-perf-8-gpu-b200'
- 'nightly-test-kernel-1-gpu-h100'
workflow_call:
inputs:
ref:
Expand Down Expand Up @@ -76,6 +77,42 @@ jobs:
- uses: ./.github/actions/upload-cuda-coredumps
if: always()

# JIT kernel full unit tests (expanded parameter ranges via SGLANG_JIT_KERNEL_RUN_FULL_TESTS)
nightly-test-kernel-1-gpu-h100:
if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-kernel-1-gpu-h100')
runs-on: 1-gpu-h100
timeout-minutes: 240
env:
# Full jit_kernel test grids (see sglang.jit_kernel.utils.should_run_full_tests)
SGLANG_JIT_KERNEL_RUN_FULL_TESTS: "1"
# Match pr-test-jit-kernel workflow for consistent JIT warmup behavior
SGLANG_JIT_DEEPGEMM_FAST_WARMUP: true
# Allow maintenance bypass on default branch (same semantics as PR JIT workflow)
SGLANG_PR_TEST_BYPASS_MAINTENANCE_ON_MAIN: ${{ github.ref == 'refs/heads/main' && 'true' || 'false' }}
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
ref: ${{ inputs.ref || github.ref }}

- uses: ./.github/actions/check-maintenance
with:
github-token: ${{ github.token }}

- name: Install dependencies
timeout-minutes: 20
run: |
bash scripts/ci/cuda/ci_install_dependency.sh

- name: Run jit kernel nightly suite
timeout-minutes: 60
run: |
cd test
python3 run_suite.py --hw cuda --suite nightly-kernel-1-gpu --nightly --continue-on-error

- uses: ./.github/actions/upload-cuda-coredumps
if: always()

# General tests - 4 GPU H100
nightly-test-general-4-gpu-h100:
if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-general-4-gpu-h100')
Expand Down
54 changes: 4 additions & 50 deletions .github/workflows/pr-test-jit-kernel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,36 +56,8 @@ jobs:
- name: Run test
timeout-minutes: 30
run: |
cd python/sglang/jit_kernel
pytest tests/

jit-kernel-unit-test-nightly:
if: |
github.event_name == 'schedule' &&
inputs.jit_kernel == 'true'
runs-on: 1-gpu-h100
timeout-minutes: 240
env:
SGLANG_JIT_KERNEL_RUN_FULL_TESTS: "1"
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }}

- uses: ./.github/actions/check-maintenance
with:
github-token: ${{ github.token }}

- name: Install dependencies
timeout-minutes: 20
run: |
bash scripts/ci/cuda/ci_install_dependency.sh

- name: Run full nightly test
timeout-minutes: 60
run: |
cd python/sglang/jit_kernel
pytest tests/
cd test/
python3 run_suite.py --hw cuda --suite stage-b-kernel-unit-1-gpu-large

jit-kernel-benchmark-test:
if: |
Expand All @@ -111,23 +83,5 @@ jobs:
- name: Run benchmark tests
timeout-minutes: 45
run: |
cd python/sglang/jit_kernel/benchmark
echo "Running jit-kernel benchmark tests in CI mode..."

failures=()

for bench_file in bench_*.py; do
echo "Testing $bench_file..."
if ! timeout 120 python3 "$bench_file"; then
failures+=("$bench_file")
fi
echo "Completed $bench_file"
echo "---"
done

if [ ${#failures[@]} -ne 0 ]; then
echo "The following benchmark tests failed: ${failures[*]}"
exit 1
fi

echo "All jit-kernel benchmark tests completed successfully!"
cd test/
python3 run_suite.py --hw cuda --suite stage-b-kernel-benchmark-1-gpu-large
3 changes: 3 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_awq_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

from sglang.jit_kernel.awq_dequantize import awq_dequantize as jit_awq_dequantize
from sglang.jit_kernel.benchmark.utils import run_benchmark
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.utils import is_in_ci

register_cuda_ci(est_time=5, suite="stage-b-kernel-benchmark-1-gpu-large")

try:
from sgl_kernel import awq_dequantize as aot_awq_dequantize

Expand Down
3 changes: 3 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_clamp_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
)
from sglang.jit_kernel.clamp_position import clamp_position_cuda
from sglang.srt.utils import get_compiler_backend
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=13, suite="stage-b-kernel-benchmark-1-gpu-large")

SIZE_LIST = get_benchmark_range(
full_range=[2**n for n in range(4, 16)],
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_concat_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
from sglang.jit_kernel.benchmark.utils import run_benchmark
from sglang.jit_kernel.concat_mla import concat_mla_absorb_q as jit_absorb_q
from sglang.jit_kernel.concat_mla import concat_mla_k as jit_k
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.utils import is_in_ci

register_cuda_ci(est_time=6, suite="stage-b-kernel-benchmark-1-gpu-large")

IS_CI = is_in_ci()

NUM_LOCAL_HEADS = 128
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
import torch.distributed as dist

from sglang.jit_kernel.benchmark.utils import is_in_ci
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(
est_time=120,
suite="stage-b-kernel-benchmark-1-gpu-large",
disabled="requires multi-GPU, self-skips in CI",
)

DTYPE_MAP = {
"float16": torch.float16,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

from sglang.jit_kernel.benchmark.utils import run_benchmark
from sglang.jit_kernel.norm import fused_add_rmsnorm as jit_fused_add_rmsnorm
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.utils import is_in_ci

register_cuda_ci(est_time=6, suite="stage-b-kernel-benchmark-1-gpu-large")

IS_CI = is_in_ci()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
ScaleResidualLayerNormScaleShift,
ScaleResidualRMSNormScaleShift,
)
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.utils import is_in_ci

register_cuda_ci(est_time=17, suite="stage-b-kernel-benchmark-1-gpu-large")

if is_in_ci():
B_RANGE, S_RANGE, D_RANGE = [1], [128], [1024]
else:
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
run_benchmark,
)
from sglang.jit_kernel.hadamard import hadamard_transform
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=5, suite="stage-b-kernel-benchmark-1-gpu-large")

# AOT kernel: might not be available in all environments.
# This is used for performance baseline comparison.
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_hicache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
transfer_hicache_all_layer,
transfer_hicache_one_layer,
)
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=29, suite="stage-b-kernel-benchmark-1-gpu-large")

DISABLE_TORCH = os.environ.get("DISABLE_TORCH", "0") == "1"
PAGE_SIZE = 1
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
from sglang.jit_kernel.benchmark.utils import run_benchmark
from sglang.jit_kernel.norm import fused_add_rmsnorm as jit_fused_add_rmsnorm
from sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.utils import is_in_ci

register_cuda_ci(est_time=5, suite="stage-b-kernel-benchmark-1-gpu-large")

IS_CI = is_in_ci()

DTYPE = torch.bfloat16
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_norm_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,15 @@
from sglang.jit_kernel.norm import fused_add_rmsnorm as jit_fused_add_rmsnorm
from sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm
from sglang.jit_kernel.utils import KERNEL_PATH
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.utils import is_in_ci

register_cuda_ci(
est_time=120,
suite="stage-b-kernel-benchmark-1-gpu-large",
disabled="self-skips in CI, standalone tool",
)

os.environ.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1")

REPO_ROOT = KERNEL_PATH.parents[2]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
scaled_fp4_quant,
)
from sglang.srt.utils import is_sm100_supported
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=5, suite="stage-b-kernel-benchmark-1-gpu-large")

FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_nvfp4_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark
from sglang.jit_kernel.nvfp4 import scaled_fp4_quant
from sglang.srt.utils import is_sm100_supported
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=5, suite="stage-b-kernel-benchmark-1-gpu-large")

FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark
from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm, scaled_fp4_quant
from sglang.srt.utils import is_sm100_supported
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=5, suite="stage-b-kernel-benchmark-1-gpu-large")

FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark
from sglang.jit_kernel.per_tensor_quant_fp8 import per_tensor_quant_fp8
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=5, suite="stage-b-kernel-benchmark-1-gpu-large")

try:
from vllm import _custom_ops as ops
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
)
from sglang.srt.utils import is_hip
from sglang.srt.utils.bench_utils import bench_kineto
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.utils import is_in_ci

register_cuda_ci(est_time=13, suite="stage-b-kernel-benchmark-1-gpu-large")

IS_CI = is_in_ci()

_is_hip = is_hip()
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_qknorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
)
from sglang.jit_kernel.norm import fused_inplace_qknorm
from sglang.srt.utils import get_current_device_stream_fast
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=10, suite="stage-b-kernel-benchmark-1-gpu-large")

alt_stream = torch.cuda.Stream()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
from sglang.jit_kernel.benchmark.utils import run_benchmark
from sglang.jit_kernel.norm import fused_inplace_qknorm_across_heads
from sglang.srt.utils import get_current_device_stream_fast
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.utils import is_in_ci

register_cuda_ci(est_time=12, suite="stage-b-kernel-benchmark-1-gpu-large")

IS_CI = is_in_ci()

alt_stream = torch.cuda.Stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
fuse_layernorm_scale_shift_gate_select01_kernel,
fuse_residual_layernorm_scale_shift_gate_select01_kernel,
)
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.utils import is_in_ci

register_cuda_ci(est_time=13, suite="stage-b-kernel-benchmark-1-gpu-large")

if is_in_ci():
B_RANGE, S_RANGE, D_RANGE = [1], [128], [3072]
else:
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_renorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import triton.testing

from sglang.jit_kernel.benchmark.utils import run_benchmark_no_cudagraph
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.utils import is_in_ci

register_cuda_ci(est_time=5, suite="stage-b-kernel-benchmark-1-gpu-large")


def torch_top_k_renorm_probs(probs, top_k):
"""Vectorized PyTorch implementation of top-k renormalization."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
)
from sglang.jit_kernel.resolve_future_token_ids import resolve_future_token_ids_cuda
from sglang.srt.utils import get_compiler_backend
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=10, suite="stage-b-kernel-benchmark-1-gpu-large")

SIZE_LIST = get_benchmark_range(
full_range=[2**n for n in range(4, 16)], # 16 … 32K elements
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
run_benchmark,
)
from sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=21, suite="stage-b-kernel-benchmark-1-gpu-large")


def sglang_aot_rmsnorm(
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
get_benchmark_range,
run_benchmark,
)
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=6, suite="stage-b-kernel-benchmark-1-gpu-large")

MAX_SEQ_LEN = 131072
ROPE_BASE = 10000.0
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_store_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
get_benchmark_range,
)
from sglang.jit_kernel.kvcache import store_cache
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=9, suite="stage-b-kernel-benchmark-1-gpu-large")


def sglang_jit_store_cache(
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/jit_kernel/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
import torch

from sglang.jit_kernel.debug_utils import maybe_wrap_jit_kernel_debug

logger = logging.getLogger(__name__)

from sglang.jit_kernel.utils import (
cache_once,
is_arch_support_pdl,
Expand All @@ -20,6 +17,9 @@
from tvm_ffi.module import Module


logger = logging.getLogger(__name__)


@cache_once
def _jit_qknorm_module(head_dim: int, dtype: torch.dtype) -> Module:
args = make_cpp_args(head_dim, is_arch_support_pdl(), dtype)
Expand Down
Loading
Loading