diff --git a/.github/workflows/diffusion-ci-gt-gen.yml b/.github/workflows/diffusion-ci-gt-gen.yml index 92844245bde4..9dad8ed00c16 100644 --- a/.github/workflows/diffusion-ci-gt-gen.yml +++ b/.github/workflows/diffusion-ci-gt-gen.yml @@ -22,6 +22,10 @@ permissions: contents: write actions: read +env: + SGLANG_IS_IN_CI: true + SGLANG_CUDA_COREDUMP: "1" + jobs: multimodal-diffusion-gen-1gpu: if: github.repository == 'sgl-project/sglang' @@ -40,6 +44,8 @@ jobs: run: bash scripts/ci/cuda/ci_install_dependency.sh diffusion - name: Generate outputs + env: + RUNAI_STREAMER_MEMORY_LIMIT: 0 run: | cd python python -m sglang.multimodal_gen.test.scripts.gen_diffusion_ci_outputs \ @@ -56,6 +62,11 @@ jobs: path: python/diffusion-ci-outputs retention-days: 7 + - name: Publish GT images to sglang-bot/sglang-ci-data + env: + GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }} + run: python scripts/ci/utils/diffusion/publish_diffusion_gt.py --source-dir python/diffusion-ci-outputs + multimodal-diffusion-gen-2gpu: if: github.repository == 'sgl-project/sglang' runs-on: 2-gpu-h100 @@ -73,6 +84,8 @@ jobs: run: bash scripts/ci/cuda/ci_install_dependency.sh diffusion - name: Generate outputs + env: + RUNAI_STREAMER_MEMORY_LIMIT: 0 run: | cd python python -m sglang.multimodal_gen.test.scripts.gen_diffusion_ci_outputs \ @@ -89,27 +102,42 @@ jobs: path: python/diffusion-ci-outputs retention-days: 7 - diffusion-ci-push: - needs: [multimodal-diffusion-gen-1gpu, multimodal-diffusion-gen-2gpu] + - name: Publish GT images to sglang-bot/sglang-ci-data + env: + GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }} + run: python scripts/ci/utils/diffusion/publish_diffusion_gt.py --source-dir python/diffusion-ci-outputs + + multimodal-diffusion-gen-b200: if: github.repository == 'sgl-project/sglang' - runs-on: ubuntu-latest + runs-on: 4-gpu-b200 + timeout-minutes: 240 steps: - name: Checkout code uses: actions/checkout@v4 - - - name: Download artifacts - uses: actions/download-artifact@v4 with: - pattern: diffusion-gen-* - path: combined - merge-multiple: true + ref: ${{ inputs.ref || github.ref }} + + - name: Install dependencies + run: bash scripts/ci/cuda/ci_install_dependency.sh diffusion - - name: Collect image files + - name: Generate outputs + env: + RUNAI_STREAMER_MEMORY_LIMIT: 0 run: | - mkdir -p gt_images - find combined \( -name "*.png" -o -name "*.jpg" -o -name "*.jpeg" -o -name "*.webp" \) -type f -exec cp -f {} gt_images/ \; + cd python + python -m sglang.multimodal_gen.test.scripts.gen_diffusion_ci_outputs \ + --suite 1-gpu-b200 \ + --out-dir ./diffusion-ci-outputs \ + ${{ inputs.case_ids != '' && format('--case-ids {0}', inputs.case_ids) || '' }} + + - name: Upload artifact + uses: actions/upload-artifact@v4 + with: + name: diffusion-gen-b200 + path: python/diffusion-ci-outputs + retention-days: 7 - name: Publish GT images to sglang-bot/sglang-ci-data env: GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }} - run: python scripts/ci/utils/diffusion/publish_diffusion_gt.py --source-dir gt_images + run: python scripts/ci/utils/diffusion/publish_diffusion_gt.py --source-dir python/diffusion-ci-outputs diff --git a/.github/workflows/nightly-test-nvidia.yml b/.github/workflows/nightly-test-nvidia.yml index f3b33b8cbc9f..f523f4c8dbbe 100644 --- a/.github/workflows/nightly-test-nvidia.yml +++ b/.github/workflows/nightly-test-nvidia.yml @@ -76,7 +76,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-1-gpu --nightly --continue-on-error - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # JIT kernel full unit tests (expanded parameter ranges via SGLANG_JIT_KERNEL_RUN_FULL_TESTS) nightly-test-kernel-1-gpu-h100: @@ -110,7 +110,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-kernel-1-gpu --nightly --continue-on-error - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() nightly-test-kernel-8-gpu-h200: if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-kernel-8-gpu-h200') @@ -140,7 +140,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-kernel-8-gpu-h200 --nightly --continue-on-error - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # General tests - 4 GPU H100 nightly-test-general-4-gpu-h100: @@ -165,7 +165,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-4-gpu --nightly --continue-on-error - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # General tests - 8 GPU H200 nightly-test-general-8-gpu-h200: @@ -249,7 +249,7 @@ jobs: if-no-files-found: ignore - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.partition }} @@ -280,7 +280,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-8-gpu-h20 --nightly --continue-on-error - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # General tests - 8 GPU B200 nightly-test-general-8-gpu-b200: @@ -353,7 +353,7 @@ jobs: if-no-files-found: ignore - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.partition }} @@ -380,7 +380,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-eval-text-2-gpu --nightly --continue-on-error --timeout-per-file 4500 - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # Text model performance tests nightly-test-text-perf-2-gpu-h100: @@ -418,7 +418,7 @@ jobs: python3 scripts/ci/utils/publish_traces.py --traces-dir test/performance_profiles_text_models - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # VLM accuracy tests nightly-test-vlm-accuracy-2-gpu-h100: @@ -443,7 +443,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-eval-vlm-2-gpu --nightly --continue-on-error --timeout-per-file 9000 - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # VLM performance tests nightly-test-vlm-perf-2-gpu-h100: @@ -481,7 +481,7 @@ jobs: python3 scripts/ci/utils/publish_traces.py --traces-dir test/performance_profiles_vlms - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # diffusion performance tests nightly-test-multimodal-server-1-gpu: @@ -538,7 +538,7 @@ jobs: if-no-files-found: ignore - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -596,7 +596,7 @@ jobs: if-no-files-found: ignore - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -623,7 +623,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-4-gpu-b200 --nightly --continue-on-error --timeout-per-file 12000 - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # Specialized B200 tests - 8 GPU, for specific backends and configs nightly-test-specialized-8-gpu-b200: @@ -652,7 +652,7 @@ jobs: python3 run_suite.py --hw cuda --suite nightly-8-gpu-b200 --nightly --continue-on-error --timeout-per-file 2400 - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # Diffusion cross-framework comparison nightly-test-diffusion-comparison: @@ -716,7 +716,7 @@ jobs: if-no-files-found: ignore - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() # Consolidate performance metrics from all jobs consolidate-metrics: diff --git a/.github/workflows/pr-test-multimodal-gen.yml b/.github/workflows/pr-test-multimodal-gen.yml index a91b6c2e927a..1fd8ed24eb0e 100644 --- a/.github/workflows/pr-test-multimodal-gen.yml +++ b/.github/workflows/pr-test-multimodal-gen.yml @@ -100,7 +100,7 @@ jobs: $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -155,7 +155,7 @@ jobs: $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -175,6 +175,7 @@ jobs: with: ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }} + - uses: ./.github/actions/check-stage-health - uses: ./.github/actions/check-maintenance @@ -203,7 +204,7 @@ jobs: $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() multimodal-gen-unit-test: if: | diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 4be332e5928e..ff64a9c3d4d2 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -602,7 +602,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-a-test-1-gpu-small $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() stage-a-test-cpu: needs: [check-changes, call-gate] @@ -711,7 +711,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-b-test-1-gpu-small --auto-partition-id ${{ matrix.partition }} --auto-partition-size 8 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.partition }} @@ -767,7 +767,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-b-test-1-gpu-large --auto-partition-id ${{ matrix.partition }} --auto-partition-size 14 --timeout-per-file 1800 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.partition }} @@ -822,7 +822,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-b-test-2-gpu-large --auto-partition-id ${{ matrix.partition }} --auto-partition-size 4 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.partition }} @@ -880,7 +880,7 @@ jobs: python3 -m pytest -q python/sglang/jit_kernel/tests/test_flash_attention_4.py - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() call-multimodal-gen-tests: needs: [check-changes, call-gate, sgl-kernel-build-wheels] @@ -962,7 +962,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-c-test-4-gpu-h100 --auto-partition-id ${{ matrix.part }} --auto-partition-size 3 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -1030,7 +1030,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-c-test-8-gpu-h200 --auto-partition-id ${{ matrix.part }} --auto-partition-size 4 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -1086,7 +1086,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-c-test-8-gpu-h20 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -1148,7 +1148,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-c-test-deepep-4-gpu-h100 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() stage-c-test-deepep-8-gpu-h200: needs: [check-changes, call-gate, wait-for-stage-b] @@ -1209,7 +1209,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-c-test-deepep-8-gpu-h200 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() stage-c-test-4-gpu-b200: needs: [check-changes, call-gate, wait-for-stage-b] @@ -1262,7 +1262,7 @@ jobs: python3 run_suite.py --hw cuda --suite stage-c-test-4-gpu-b200 --auto-partition-id ${{ matrix.part }} --auto-partition-size 4 --timeout-per-file 1800 $CONTINUE_ON_ERROR_FLAG - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() with: artifact-suffix: ${{ matrix.part }} @@ -1316,7 +1316,7 @@ jobs: # python3 run_suite.py --hw cuda --suite stage-c-test-4-gpu-gb200 --timeout-per-file 3600 $CONTINUE_ON_ERROR_FLAG # # - uses: ./.github/actions/upload-cuda-coredumps - # if: always() + # if: failure() pr-test-finish: needs: diff --git a/.github/workflows/rerun-test.yml b/.github/workflows/rerun-test.yml index 431b69474c1c..64e930a527d8 100644 --- a/.github/workflows/rerun-test.yml +++ b/.github/workflows/rerun-test.yml @@ -111,7 +111,7 @@ jobs: echo "All $total test(s) passed in ${total_elapsed}s" - uses: ./.github/actions/upload-cuda-coredumps - if: always() + if: failure() rerun-test-cpu: if: inputs.is_cpu == 'true' @@ -173,4 +173,4 @@ jobs: echo "" done total_elapsed=$(( SECONDS - suite_start )) - echo "All $total test(s) passed in ${total_elapsed}s" + echo "All $total test(s) passed in ${total_elapsed}s" \ No newline at end of file diff --git a/.gitignore b/.gitignore index b5917c299ecf..a8aa903e28f7 100644 --- a/.gitignore +++ b/.gitignore @@ -258,6 +258,7 @@ inputs/ # setuptools-scm generated version file python/sglang/_version.py +python/kernel.lock # MUSA section # Generated source files by torchada diff --git a/benchmark/kernels/bench_fused_temperature_softmax.py b/benchmark/kernels/bench_fused_temperature_softmax.py new file mode 100644 index 000000000000..fc624b721ecf --- /dev/null +++ b/benchmark/kernels/bench_fused_temperature_softmax.py @@ -0,0 +1,108 @@ +"""Benchmark: fused_temperature_softmax vs separate div_ + softmax vs flashinfer.sampling.softmax. + +Each path clones logits every iteration so timing is not skewed by in-place reuse. +Uses torch.cuda.Event timing; default 50 warmup, 200 timed iterations. + +Columns tri/base and fi/base are speedup vs PyTorch baseline; tri/fi is t_flashinfer/t_triton +(>1 means Triton is faster). +""" + +import argparse + +import torch + + +def benchmark_fn(fn, warmup=50, iters=200): + """Time a zero-arg callable using CUDA events.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / iters * 1000 # microseconds + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--warmup", type=int, default=50) + parser.add_argument("--iters", type=int, default=200) + args = parser.parse_args() + + from flashinfer.sampling import softmax as flashinfer_softmax + + from sglang.srt.layers.fused_sampling import ( + fused_temperature_softmax, + fused_temperature_softmax_inplace, + ) + + configs = [ + # (batch_size, vocab_size, dtype) + (1, 32000, torch.bfloat16), + (1, 128256, torch.bfloat16), + (32, 32000, torch.bfloat16), + (32, 128256, torch.bfloat16), + (128, 32000, torch.bfloat16), + (128, 128256, torch.bfloat16), + (512, 32000, torch.bfloat16), + (512, 128256, torch.bfloat16), + ] + + header = ( + f"{'bs':>5} {'vocab':>7} {'dtype':>8} " + f"{'baseline (us)':>14} {'triton (us)':>12} {'inplace (us)':>13} {'flashinfer (us)':>16} " + f"{'tri/base':>9} {'fi/base':>8} {'tri/fi':>7}" + ) + print(header) + print("-" * len(header)) + + for bs, vocab, dtype in configs: + temps = torch.rand(bs, 1, dtype=torch.float32, device="cuda") * 1.5 + 0.1 + temps_1d = temps.view(-1) + logits_src = torch.randn(bs, vocab, dtype=dtype, device="cuda") + + # --- Baseline: div_ + softmax --- + def run_baseline(src=logits_src, t=temps): + l = src.clone() + l.div_(t) + l[:] = torch.softmax(l, dim=-1) + + t_base = benchmark_fn(run_baseline, args.warmup, args.iters) + + # --- Triton fused (out-of-place) --- + def run_triton(src=logits_src, t=temps): + fused_temperature_softmax(src.clone(), t) + + t_triton = benchmark_fn(run_triton, args.warmup, args.iters) + + # --- Triton fused (in-place) --- + def run_inplace(src=logits_src, t=temps): + l = src.clone() + fused_temperature_softmax_inplace(l, t) + + t_ip = benchmark_fn(run_inplace, args.warmup, args.iters) + + # --- FlashInfer (clone each iter, same as other paths) --- + def run_flashinfer(src=logits_src, t=temps_1d): + l = src.clone() + flashinfer_softmax(l, temperature=t) + + t_fi = benchmark_fn(run_flashinfer, args.warmup, args.iters) + + sp_triton = t_base / t_triton + sp_fi = t_base / t_fi + tri_vs_fi = t_fi / t_triton + print( + f"{bs:>5} {vocab:>7} {str(dtype):>8} " + f"{t_base:>14.1f} {t_triton:>12.1f} {t_ip:>13.1f} {t_fi:>16.1f} " + f"{sp_triton:>8.2f}x {sp_fi:>7.2f}x {tri_vs_fi:>6.2f}x" + ) + + +if __name__ == "__main__": + main() diff --git a/docker/Dockerfile b/docker/Dockerfile index d7f4ead4579c..57842c53564b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -219,6 +219,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install flashinfer-jit-cache==${FLASHINFER_VERSION} --index-url https://flashinfer.ai/whl/cu${CUINDEX} ; \ fi \ && FLASHINFER_CUBIN_DOWNLOAD_THREADS=${BUILD_AND_DOWNLOAD_PARALLEL} FLASHINFER_LOGGING_LEVEL=warning python3 -m flashinfer --download-cubin + && kernels download python + && kernels lock python + && mv python/kernels.lock /root/.cache/sglang # DeepEP # We use Tom's DeepEP fork for GB200 for now; the 1fd57b0276311d035d16176bb0076426166e52f3 commit is https://github.com/fzyzcjy/DeepEP/tree/gb200_blog_part_2 @@ -561,6 +564,10 @@ COPY --from=framework /usr/local/lib/python3.12/dist-packages /usr/local/lib/pyt # Copy SGLang workspace COPY --from=framework /sgl-workspace /sgl-workspace +# Copy cache for kernels from kernels community +COPY --from=framework /root/.cache/huggingface /root/.cache/huggingface +COPY --from=framework /root/.cache/sglang /root/.cache/sglang + # Fix Triton to use system ptxas for Blackwell (sm_103a) support (CUDA 13+ only) RUN if [ "${CUDA_VERSION%%.*}" = "13" ] && [ -d /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin ]; then \ rm -f /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas && \ diff --git a/docs/platforms/ascend_npu_glm4_7_flash_examples.md b/docs/platforms/ascend_npu_glm4_7_flash_examples.md new file mode 100644 index 000000000000..8604efb78d47 --- /dev/null +++ b/docs/platforms/ascend_npu_glm4_7_flash_examples.md @@ -0,0 +1,177 @@ +# GLM-4.7-Flash examples + +## Environment Preparation + +### Model Weight + +- `GLM-4.7-Flash`(BF16 version): [Download model weight](https://www.modelscope.cn/models/ZhipuAI/GLM-4.7-Flash). + +### Installation + +The dependencies required for the NPU runtime environment have been integrated into a Docker image and uploaded to the quay.io platform. You can directly pull it. + +```bash +#Atlas 800 A3 +docker pull quay.io/ascend/sglang:main-cann8.5.0-a3 +#Atlas 800 A2 +docker pull quay.io/ascend/sglang:main-cann8.5.0-910b + +#start container +docker run -itd --shm-size=16g --privileged=true --name ${NAME} \ +--privileged=true --net=host \ +-v /var/queue_schedule:/var/queue_schedule \ +-v /etc/ascend_install.info:/etc/ascend_install.info \ +-v /usr/local/sbin:/usr/local/sbin \ +-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ +-v /usr/local/Ascend/firmware:/usr/local/Ascend/firmware \ +--device=/dev/davinci0:/dev/davinci0 \ +--device=/dev/davinci1:/dev/davinci1 \ +--device=/dev/davinci2:/dev/davinci2 \ +--device=/dev/davinci3:/dev/davinci3 \ +--device=/dev/davinci4:/dev/davinci4 \ +--device=/dev/davinci5:/dev/davinci5 \ +--device=/dev/davinci6:/dev/davinci6 \ +--device=/dev/davinci7:/dev/davinci7 \ +--device=/dev/davinci8:/dev/davinci8 \ +--device=/dev/davinci9:/dev/davinci9 \ +--device=/dev/davinci10:/dev/davinci10 \ +--device=/dev/davinci11:/dev/davinci11 \ +--device=/dev/davinci12:/dev/davinci12 \ +--device=/dev/davinci13:/dev/davinci13 \ +--device=/dev/davinci14:/dev/davinci14 \ +--device=/dev/davinci15:/dev/davinci15 \ +--device=/dev/davinci_manager:/dev/davinci_manager \ +--device=/dev/hisi_hdc:/dev/hisi_hdc \ +--entrypoint=bash \ +quay.io/ascend/sglang:${tag} +``` + +Note: When using this image, you need to update Transformers to version 5.3.0. + +``` shell +# reinstall transformers +pip install transformers==5.3.0 +``` + +## Running GLM-4.7-Flash + +### Running GLM-4.7-Flash on 1 x Atlas 800I A3. + +Run the following script to execute online inference. + +```shell +# high performance cpu +echo performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor +sysctl -w vm.swappiness=0 +sysctl -w kernel.numa_balancing=0 +sysctl -w kernel.sched_migration_cost_ns=50000 +# bind cpu +export SGLANG_SET_CPU_AFFINITY=1 + +unset https_proxy +unset http_proxy +unset HTTPS_PROXY +unset HTTP_PROXY +unset ASCEND_LAUNCH_BLOCKING +# cann +source /usr/local/Ascend/ascend-toolkit/set_env.sh +source /usr/local/Ascend/nnal/atb/set_env.sh + +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export STREAMS_PER_DEVICE=32 +export HCCL_BUFFSIZE=1000 +export HCCL_OP_EXPANSION_MODE=AIV +export HCCL_SOCKET_IFNAME=lo +export GLOO_SOCKET_IFNAME=lo + +python3 -m sglang.launch_server \ + --model-path $MODEL_PATH \ + --tp-size 2 \ + --attention-backend ascend \ + --device npu \ + --chunked-prefill-size 16384 \ + --max-prefill-tokens 150000 \ + --dtype bfloat16 \ + --max-running-requests 32 \ + --trust-remote-code \ + --host 127.0.0.1 \ + --mem-fraction-static 0.75 \ + --port 8000 \ + --cuda-graph-bs 1 2 4 8 16 32 \ + --watchdog-timeout 9000 +``` + +Note: TP size is currently limited to 2 or 4. + +### Running GLM-4.7-Flash on 1 x Atlas 800I A3 in slime-ascend. + +#### Preparation + +- [slime-ascend](https://gitcode.com/Ascend/slime-ascend) code + +#### Installation + +Run the following commands to install sglang. (Please replace '' with the path to the root directory of the slime codebase.') + +```bash +git clone -b v0.5.8 https://github.com/sgl-project/sglang.git +cd sglang +mv python/pyproject_other.toml python/pyproject.toml +pip install -e python[srt_npu] +git checkout . && git checkout sglang-slime +git am /docker/npu_patch/v0.2.2/sglang/* +``` + +Note: Make sure you are using Transformers 5.3.0. + +#### Execution + +Run the following script to execute online **inference**. + +```shell +# high performance cpu +echo performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor +sysctl -w vm.swappiness=0 +sysctl -w kernel.numa_balancing=0 +sysctl -w kernel.sched_migration_cost_ns=50000 +# bind cpu +export SGLANG_SET_CPU_AFFINITY=1 + +unset https_proxy +unset http_proxy +unset HTTPS_PROXY +unset HTTP_PROXY +unset ASCEND_LAUNCH_BLOCKING +# cann +source /usr/local/Ascend/ascend-toolkit/set_env.sh +source /usr/local/Ascend/nnal/atb/set_env.sh + +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export STREAMS_PER_DEVICE=32 +export HCCL_BUFFSIZE=1000 +export HCCL_OP_EXPANSION_MODE=AIV +export HCCL_SOCKET_IFNAME=lo +export GLOO_SOCKET_IFNAME=lo + +python3 -m sglang.launch_server \ + --model-path $MODEL_PATH \ + --tp-size 2 \ + --attention-backend ascend \ + --device npu \ + --chunked-prefill-size 16384 \ + --max-prefill-tokens 150000 \ + --dtype bfloat16 \ + --max-running-requests 32 \ + --trust-remote-code \ + --host 127.0.0.1 \ + --mem-fraction-static 0.75 \ + --port 8000 \ + --cuda-graph-bs 1 2 4 8 16 32 \ + --watchdog-timeout 9000 +``` + +Refer to [Training and Deployment Example](https://gitcode.com/Ascend/slime-ascend/blob/main/docs/ascend_tutorial/examples/glm4.7-30B-A3B.md) for training and deployment. + +### Using Benchmark + +Refer to [Benchmark and Profiling](../developer_guide/benchmark_and_profiling.md) for details. diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index e2e93b177b9c..b7ac94a71245 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -19,6 +19,7 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_FORWARD_UNKNOWN_TOOLS` | Forward unknown tool calls to clients instead of dropping them | `false` (drop unknown tools) | | `SGLANG_REQ_WAITING_TIMEOUT` | Timeout (in seconds) for requests waiting in the queue before being scheduled | `-1` | | `SGLANG_REQ_RUNNING_TIMEOUT` | Timeout (in seconds) for requests running in the decode batch | `-1` | +| `SGLANG_CACHE_DIR` | Cache directory for model weights and other data | `~/.cache/sglang` | ## Performance Tuning @@ -47,6 +48,7 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_CUSTOM_ALLREDUCE_ALGO` | The algorithm of custom all-reduce. Set to `oneshot` or `1stage` to force use one-shot. Set to `twoshot` or `2stage` to force use two-shot. | `` | | `SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR` | Skip-softmax threshold scale factor for TRT-LLM prefill attention in flashinfer. `None` means standard attention. See https://arxiv.org/abs/2512.12087 | `None` | | `SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR` | Skip-softmax threshold scale factor for TRT-LLM decode attention in flashinfer. `None` means standard attention. See https://arxiv.org/abs/2512.12087 | `None` | +| `SGLANG_USE_SGL_FA3_KERNEL` | Use sgl-kernel implementation for FlashAttention v3 | `true` | ## DeepGEMM Configuration (Advanced Optimization) diff --git a/python/pyproject.toml b/python/pyproject.toml index 8e96b44afe3c..9fab1de2e1e2 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -77,6 +77,7 @@ dependencies = [ "watchfiles", "xgrammar==0.1.32", "smg-grpc-servicer>=0.5.0", + "kernels", ] [[tool.uv.index]] @@ -129,6 +130,10 @@ tracing = [ "opentelemetry-sdk", ] +http2 = [ + "granian>=2.6.0", +] + test = [ "accelerate", "addict", @@ -146,6 +151,7 @@ test = [ "diff-cover", "sentence_transformers", "tabulate", + "granian>=2.6.0", ] dev = ["sglang[test]"] @@ -153,6 +159,7 @@ dev = ["sglang[test]"] all = [ "sglang[diffusion]", "sglang[tracing]", + "sglang[http2]", ] [tool.uv.extra-build-dependencies] @@ -201,3 +208,6 @@ version_file = "sglang/_version.py" git_describe_command = ["python3", "python/tools/get_version_tag.py", "--tag-only"] # Allow editable installs even when .git metadata is not available. fallback_version = "0.0.0.dev0" + +[tool.kernels.dependencies] +"kernels-community/sgl-flash-attn3" = 1 diff --git a/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h b/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h index 6c4112e633fd..651710a963f7 100644 --- a/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h +++ b/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h @@ -484,11 +484,11 @@ __global__ void Marlin( constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) - : 1; + // FP4 (kFE2M1f) uses FP8 scales (1 byte/element), others use FP16 (2 bytes) + int s_gl_stride = prob_n / (w_type == host::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (w_type == host::kFE2M1f ? 16 : 8); + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -540,8 +540,7 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + - s_sh_stride * slice_col + threadIdx.x; + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; } } auto s_sh_wr = threadIdx.x; @@ -563,15 +562,7 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + warp_row % 2; - - } else if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; @@ -876,7 +867,7 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + int cur_group_id = k_blocks / group_blocks; int4* sh_s_stage = sh_s + s_sh_stage * pipe; diff --git a/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h b/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h index bf7dcb202301..566fa5f59606 100644 --- a/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h +++ b/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h @@ -626,11 +626,10 @@ __global__ void Marlin( constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) - : 1; + int s_gl_stride = prob_n / (w_type == host::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (w_type == host::kFE2M1f ? 16 : 8); + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -682,8 +681,7 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + - s_sh_stride * slice_col + threadIdx.x; + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; } } auto s_sh_wr = threadIdx.x; @@ -705,15 +703,7 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; - - } else if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; @@ -1038,18 +1028,15 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + int cur_group_id = k_blocks / group_blocks; int4* sh_s_stage = sh_s + s_sh_stage * pipe; if constexpr (w_type_id != host::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { - reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } else { reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + k % 2]; + reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } } } @@ -1243,17 +1230,19 @@ __global__ void Marlin( } } - // Commented out FP4/FP8 scale dequantization since we don't generate - // kFE2M1f kernels to reduce compilation time - // if constexpr (w_type == host::kFE2M1f) { - // int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; - // int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - // - // dequant_fp8_scales( - // s_quant_0, reinterpret_cast(&frag_s[k2])); - // dequant_fp8_scales( - // s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); - // } +#ifdef SGL_MOE_MARLIN_FP4 + // Convert FP8 per-group scales to BF16/FP16 before applying them. + // Required for kFE2M1f (NVFP4): frag_s holds raw float8_e4m3fn bytes; + // without this conversion scale would misinterpret them as + // BF16/FP16, producing NaN/Inf multipliers. + if constexpr (w_type == host::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } +#endif // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. diff --git a/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh b/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh index 81c021dc8ecc..d89954200c88 100644 --- a/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh +++ b/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh @@ -453,7 +453,9 @@ MarlinFuncPtr get_marlin_kernel( COMMON_GET_IF(host::kU4B8) COMMON_GET_IF(host::kU8B128) +#ifdef SGL_MOE_MARLIN_FP4 NVFP4_GET_IF(host::kFE2M1f) +#endif BIGGROUP_GET_IF(host::kFE4M3fn) diff --git a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.h b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.h index ee77d7d7c24a..e2202f54692d 100644 --- a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.h +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.h @@ -83,4 +83,4 @@ class Ngram { void insertWorker(); }; -} // namespace ngram +} // namespace ngram \ No newline at end of file diff --git a/python/sglang/jit_kernel/csrc/ngram_corpus/result.cpp b/python/sglang/jit_kernel/csrc/ngram_corpus/result.cpp index 07138bf8d17e..404b7a3f22d3 100644 --- a/python/sglang/jit_kernel/csrc/ngram_corpus/result.cpp +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/result.cpp @@ -1,6 +1,5 @@ #include "result.h" -#include #include #include #include @@ -48,81 +47,6 @@ Result fillResult(int last_token, int draft_token_num, std::vector& tree, return info; } -std::vector> extractLeafPaths_(const Result& result) { - const auto n = static_cast(result.token.size()); - if (n <= 1) { - return {}; - } - - std::vector parent(n, -1); - std::vector has_child(n, false); - for (int i = 1; i < n; ++i) { - for (int j = i - 1; j >= 0; --j) { - if (result.mask[i * n + j]) { - parent[i] = j; - has_child[j] = true; - break; - } - } - } - - std::vector> paths; - for (int leaf = 1; leaf < n; ++leaf) { - if (has_child[leaf]) { - continue; - } - std::vector path; - for (int cursor = leaf; cursor > 0; cursor = parent[cursor]) { - path.emplace_back(result.token[cursor]); - } - std::reverse(path.begin(), path.end()); - if (path.size() == 1 && path.front() == 0) { - continue; - } - paths.emplace_back(std::move(path)); - } - return paths; -} - -Result buildResultFromLeafPaths_(int last_token, int draft_token_num, const std::vector>& paths) { - std::vector tree(draft_token_num); - const int root = 0; - int cursor = 1; - for (const auto& path : paths) { - int parent = root; - for (const auto token : path) { - auto iter = tree[parent].next.find(token); - if (iter == tree[parent].next.end()) { - if (cursor >= draft_token_num) { - parent = -1; - break; - } - iter = tree[parent].next.insert({token, cursor++}).first; - } - parent = iter->second; - } - if (cursor >= draft_token_num) { - break; - } - } - return fillResult(last_token, draft_token_num, tree, root); -} - -Result combineRootResults_(int last_token, int draft_token_num, const Result& primary, const Result& secondary) { - auto primary_paths = extractLeafPaths_(primary); - auto secondary_paths = extractLeafPaths_(secondary); - std::vector> merged_paths = std::move(primary_paths); - merged_paths.reserve(merged_paths.size() + secondary_paths.size()); - for (const auto& path : secondary_paths) { - if (path.empty()) { - continue; - } - merged_paths.emplace_back(path); - } - - return buildResultFromLeafPaths_(last_token, draft_token_num, merged_paths); -} - void Result::truncate(size_t n) { if (n < token.size()) { int full_n = token.size(); diff --git a/python/sglang/jit_kernel/csrc/ngram_corpus/trie.h b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.h index 76707eea1e89..bd555597dd46 100644 --- a/python/sglang/jit_kernel/csrc/ngram_corpus/trie.h +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.h @@ -22,11 +22,6 @@ struct TrieNode { TrieNode* parent; std::list lru; int32_t freq = 0; - // Logical generation of this TrieNode. retireNode() bumps it before the node - // goes back to the pool so stale NodeRefs fail validation after reuse. - // Starts at 1 so that a default-constructed NodeRef (version=0) never - // accidentally resolves to a live node. - uint64_t version = 1; struct CompareByFreq { bool operator()(TrieNode* a, TrieNode* b) const { @@ -36,23 +31,6 @@ struct TrieNode { std::multiset sorted_children; }; -// By-value handle to a logical trie location, cached in MatchState. -// We cannot cache TrieNode* alone across decode steps: squeeze() may evict a -// node, and getNode() may later recycle the same address for a different node. -struct NodeRef { - TrieNode* ptr = nullptr; - uint64_t version = 0; -}; - -// Per-request cached anchors. anchors[d - 1] caches the trie match for the -// length-d suffix ending at the current last token; processed_total_len records -// the full request length covered by those cached anchors. -struct MatchState { - uint64_t trie_epoch = 0; - size_t processed_total_len = 0; - std::vector anchors; -}; - class Trie { public: Trie(size_t capacity, const Param& param); @@ -60,72 +38,22 @@ class Trie { void insert(const int32_t* tokens, size_t len); Result buildRecency( - const int32_t* context, - size_t len, - int32_t last_token, - size_t draft_token_num, - const Param& param, - MatchState& state, - size_t total_len) const; + const int32_t* context, size_t len, int32_t last_token, size_t draft_token_num, const Param& param) const; Result buildFrequency( - const int32_t* context, - size_t len, - int32_t last_token, - size_t draft_token_num, - const Param& param, - MatchState& state, - size_t total_len) const; + const int32_t* context, size_t len, int32_t last_token, size_t draft_token_num, const Param& param) const; void squeeze(size_t count); void reset(); private: - // Stateful suffix matcher. If `state` still represents the previous step for - // this request, infer the newly appended suffix from (`context`, `total_len`) - // and advance anchors incrementally; otherwise rebuild the cached anchors from - // `context`. Returns only the suffix matches that are currently expandable. - std::vector> - match(const int32_t* context, size_t len, MatchState& state, size_t total_len) const; - // Recompute all cached anchors from the current tail. After this, for every - // d in [1, min(len, max_trie_depth)], anchors[d - 1] represents the suffix of - // length d ending at context[len - 1]. - void rebuildMatchState_(const int32_t* context, size_t len, MatchState& state, size_t total_len) const; - // Advance the cached anchors by consuming the newly appended suffix one - // token at a time, without re-walking all suffixes from root. - bool advanceMatchState_(MatchState& state, const int32_t* tokens, size_t len, size_t total_len) const; - // Check that every non-empty cached NodeRef in MatchState still resolves to - // the same logical trie node under the current trie_epoch_. - bool validateMatchState_(const MatchState& state) const; - // MatchState keeps all live suffix matches, including leaves. This helper - // filters the cached anchors down to the suffixes that currently have children and - // therefore can seed BFS / PROB draft construction. - std::vector> getExpandableAnchors_(const MatchState& state) const; - // Resolve a cached NodeRef back to a live trie node. nullptr means the - // cached location went stale and the caller should rebuild from context. - const TrieNode* resolve(const MatchState& state, const NodeRef& ref) const; - NodeRef rootRef() const { - return NodeRef{root_, root_->version}; - } - NodeRef capture(TrieNode* node) const { - if (node == nullptr) { - return {}; - } - return NodeRef{node, node->version}; - } - void retireNode(TrieNode* node) { - if (node != nullptr) { - ++node->version; - } - } + std::vector> match(const int32_t* context, size_t len) const; TrieNode* getNode() { auto node = node_pool_[--free_node_count_]; - auto version = node->version; node->~TrieNode(); new (node) TrieNode(); - node->version = version; return node; } @@ -136,7 +64,6 @@ class Trie { TrieNode* root_; std::vector path_; Param param_; - uint64_t trie_epoch_ = 1; }; } // namespace ngram diff --git a/python/sglang/jit_kernel/flash_attention.py b/python/sglang/jit_kernel/flash_attention.py new file mode 100644 index 000000000000..633863d0a648 --- /dev/null +++ b/python/sglang/jit_kernel/flash_attention.py @@ -0,0 +1,286 @@ +from typing import Optional, Union + +import torch + +from .flash_attention_v3 import flash_attn_varlen_func as fa3_flash_attn_varlen_func +from .flash_attention_v3 import flash_attn_with_kvcache as fa3_flash_attn_with_kvcache +from .flash_attention_v4 import flash_attn_varlen_func as fa4_flash_attn_varlen_func +from .flash_attention_v4 import flash_attn_with_kvcache as fa4_flash_attn_with_kvcache + + +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + qv=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[int, torch.Tensor]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + attention_chunk: Optional[int] = None, + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + return_softmax_lse=False, + sinks=None, + score_mod=None, + aux_tensors=None, + ver=3, +): + """ + If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from + k and v. This is useful for incremental decoding: you can pass in the cached keys/values from + the previous step, and update them with the new keys/values from the current step, and do + attention with the updated cache, all in 1 kernel. + + If you pass in k / v, you must make sure that the cache is large enough to hold the new values. + For example, the KV cache could be pre-allocated with the max sequence length, and you can use + cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. + + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be + rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos + and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at + indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + + See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. + + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Note: Does not support backward pass. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) + page_block_size must be a multiple of 256. + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) + k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate + k with k_cache, starting at the indices specified by cache_seqlens. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. + qv [optional]: (batch_size, seqlen, nheads, headdim_v) + rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding + to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. + rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. + cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the + KV cache. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. + cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. + page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + attention_chunk: Optional[int]. If not None, splits the query into chunks of this size to save memory. + softcap: float. Anything > 0 activates softcapping attention. + rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. + If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, + rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 + (i.e. GPT-NeoX style). + num_splits: int. If > 1, split the key/value into this many chunks along the sequence. + If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic + to automatically determine the number of splits. + Don't change this unless you know what you are doing. + return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. + score_mod [optional]: A callable that takes the attention scores and applies a modification. + aux_tensors [optional]: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. + + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + + if ver == 3: + return fa3_flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=k, + v=v, + qv=qv, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + pack_gqa=pack_gqa, + sm_margin=sm_margin, + return_softmax_lse=return_softmax_lse, + sinks=sinks, + ) + elif ver == 4: + return fa4_flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=k, + v=v, + qv=qv, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + sinks=sinks, + score_mod=score_mod, + aux_tensors=aux_tensors, + return_softmax_lse=return_softmax_lse, + ) + else: + raise RuntimeError(f"Unknown flash attention version {ver}") + + +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=None, + max_seqlen_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + num_splits=1, + pack_gqa=None, + sm_margin=0, + return_softmax_lse=False, + sinks=None, + score_mod=None, + aux_tensors=None, + ver=3, +): + + if ver == 3: + return fa3_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + page_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + qv=qv, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + sm_margin=sm_margin, + return_softmax_lse=return_softmax_lse, + sinks=sinks, + ) + elif ver == 4: + return fa4_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + page_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + softcap=softcap, + window_size=window_size, + sinks=sinks, + num_splits=num_splits, + pack_gqa=pack_gqa, + score_mod=score_mod, + aux_tensors=aux_tensors, + return_softmax_lse=return_softmax_lse, + ) + else: + raise RuntimeError(f"Unknown flash attention version {ver}") diff --git a/python/sglang/jit_kernel/flash_attention_v3.py b/python/sglang/jit_kernel/flash_attention_v3.py new file mode 100644 index 000000000000..23018961d998 --- /dev/null +++ b/python/sglang/jit_kernel/flash_attention_v3.py @@ -0,0 +1,222 @@ +import logging +import os +from typing import Optional, Union + +import torch + +from sglang.jit_kernel.utils import cache_once +from sglang.kernel_api_logging import debug_kernel_api +from sglang.srt.environ import envs + +logger = logging.getLogger(__name__) + +SGL_FA3_KERNEL_REPO = "kernels-community/sgl-flash-attn3" +SGL_FA3_KERNEL_REVISION = "v1" +DEFAULT_FA3_KERNEL_LOCKFILE = "kernels.lock" + + +@cache_once +def _load_fa3_kernels(): + # By default, we use the implementation from sgl-kernel, + # which is expected to be more stable and compatible + if envs.SGLANG_USE_SGL_FA3_KERNEL.get(): + logger.debug( + f"SGLANG_USE_SGL_FA3_KERNEL=True, use sgl-kernel implementation for FlashAttention v3 " + ) + return _load_fa3_kernel_from_sgl() + + # Otherwise, we try to load the kernels from the kernels community cache directory or kernels community repo + lockfile_path = os.path.join( + envs.SGLANG_CACHE_DIR.get(), DEFAULT_FA3_KERNEL_LOCKFILE + ) + + try: + from kernels import get_kernel, load_kernel + + # When the lock file provided, load from the kernel cache directory, + # otherwise, load from the repo, which require download from huggingface hub + # but always works as long as the repo is accessible. + if os.path.exists(lockfile_path): + ops = load_kernel(SGL_FA3_KERNEL_REPO, lockfile_path) + else: + ops = get_kernel(SGL_FA3_KERNEL_REPO, revision=SGL_FA3_KERNEL_REVISION) + + return { + "flash_attn_with_kvcache": ops.flash_attn_with_kvcache, + "flash_attn_varlen_func": ops.flash_attn_varlen_func, + } + except Exception as e: + # When the kernels from the repo or the cache directory cannot be loaded + # we catch the exception and log a warning, and then fallback to the implementation + # from sgl-kernel, which is expected to be less efficient but more compatible. + logger.warning( + f"Rollback to implementation from sgl-kernel since loading FlashAttention v3 " + f"kernels from {SGL_FA3_KERNEL_REPO} with lockfile {lockfile_path} failed: {e}" + ) + return _load_fa3_kernel_from_sgl() + + +def _load_fa3_kernel_from_sgl(): + from sgl_kernel.flash_attn import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, + ) + + return { + "flash_attn_with_kvcache": flash_attn_with_kvcache, + "flash_attn_varlen_func": flash_attn_varlen_func, + } + + +@cache_once +def _is_fa3_supported(device=None) -> bool: + # There some fa3 FYI + # FA3 can fail without a enough shared memory for a some shapes, such as higher + # hidden_dim or some special cases. + # Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different + # Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x + # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. + # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. + return (torch.version.cuda >= "12.3") and ( + torch.cuda.get_device_capability(device)[0] == 9 + or torch.cuda.get_device_capability(device)[0] == 8 + ) + + +@debug_kernel_api +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + qv=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[int, torch.Tensor]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + attention_chunk: Optional[int] = None, + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + return_softmax_lse=False, + sinks=None, +): + if not _is_fa3_supported(): + raise NotImplementedError( + "flash_attn at sgl-kernel is only supported on sm90 and above" + ) + + assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" + assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" + + return _load_fa3_kernels()["flash_attn_with_kvcache"]( + q, + k_cache, + v_cache, + k, + v, + qv, + rotary_cos, + rotary_sin, + cache_seqlens, + cache_batch_idx, + cache_leftpad, + page_table, + cu_seqlens_q, + cu_seqlens_k_new, + max_seqlen_q, + rotary_seqlens, + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + window_size, + attention_chunk, + softcap, + rotary_interleaved, + scheduler_metadata, + num_splits, + pack_gqa, + sm_margin, + return_softmax_lse, + sinks, + ) + + +@debug_kernel_api +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=None, + max_seqlen_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + num_splits=1, + pack_gqa=None, + sm_margin=0, + return_softmax_lse=False, + sinks=None, +): + + if not _is_fa3_supported(): + raise NotImplementedError( + "flash_attn at sgl-kernel is only supported on sm90 and above" + ) + + return _load_fa3_kernels()["flash_attn_varlen_func"]( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q, + seqused_k, + page_table, + softmax_scale, + causal, + qv, + q_descale, + k_descale, + v_descale, + window_size, + attention_chunk, + softcap, + num_splits, + pack_gqa, + sm_margin, + return_softmax_lse, + sinks, + ) diff --git a/python/sglang/jit_kernel/flash_attention_v4.py b/python/sglang/jit_kernel/flash_attention_v4.py index 0a79614ee075..46b49d177388 100644 --- a/python/sglang/jit_kernel/flash_attention_v4.py +++ b/python/sglang/jit_kernel/flash_attention_v4.py @@ -42,7 +42,6 @@ def flash_attn_varlen_func( score_mod: Optional[Callable] = None, aux_tensors: Optional[list] = None, return_softmax_lse: bool = False, - **_: object, ): if _flash_attn_varlen_func is None: # pragma: no cover raise ImportError( diff --git a/python/sglang/jit_kernel/moe_wna16_marlin.py b/python/sglang/jit_kernel/moe_wna16_marlin.py index e9a8cd25372b..0ddd6ef717d5 100644 --- a/python/sglang/jit_kernel/moe_wna16_marlin.py +++ b/python/sglang/jit_kernel/moe_wna16_marlin.py @@ -31,6 +31,24 @@ def _jit_moe_wna16_marlin_module(dtype: torch.dtype) -> Module: ) +@cache_once +def _jit_moe_wna16_marlin_fp4_module(dtype: torch.dtype) -> Module: + """Separate JIT module with NVFP4 (kFE2M1f) kernel instantiations enabled.""" + args = make_cpp_args(dtype) + return load_jit( + "moe_wna16_marlin_fp4", + *args, + cuda_files=["gemm/marlin_moe/moe_wna16_marlin.cuh"], + extra_cuda_cflags=["-DSGL_MOE_MARLIN_FP4"], + cuda_wrappers=[ + ( + "moe_wna16_marlin_gemm", + f"moe_wna16_marlin_gemm<{args}>", + ) + ], + ) + + def _or_empty( t: Optional[torch.Tensor], device: torch.device, dtype: torch.dtype ) -> torch.Tensor: @@ -134,7 +152,11 @@ def moe_wna16_marlin_gemm( b_bias_t = _or_empty(b_bias_or_none, device, a.dtype) global_scale_t = _or_empty(global_scale_or_none, device, a.dtype) - module = _jit_moe_wna16_marlin_module(a.dtype) + is_fp4 = global_scale_or_none is not None and global_scale_or_none.numel() > 0 + if is_fp4: + module = _jit_moe_wna16_marlin_fp4_module(a.dtype) + else: + module = _jit_moe_wna16_marlin_module(a.dtype) module.moe_wna16_marlin_gemm( a, c, diff --git a/python/sglang/jit_kernel/tests/test_flash_attention_3.py b/python/sglang/jit_kernel/tests/test_flash_attention_3.py new file mode 100644 index 000000000000..e4687da9c827 --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_flash_attention_3.py @@ -0,0 +1,1373 @@ +# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py +import itertools +import math +from typing import Optional + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +apply_rotary_emb = None + +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=120, suite="stage-b-kernel-unit-1-gpu-large") +register_cuda_ci(est_time=900, suite="nightly-kernel-1-gpu", nightly=True) + + +def is_hopper(): + # Only Hopper supports different V headdim + return torch.cuda.get_device_properties(0).major == 9 + + +def is_fa3_supported(device=None) -> bool: + # There some fa3 FYI + # FA3 can fail without a enough shared memory for a some shapes, such as higher + # hidden_dim or some special cases. + # Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different + # Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x + # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. + # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. + return (torch.version.cuda >= "12.3") and ( + torch.cuda.get_device_capability(device)[0] == 9 + or torch.cuda.get_device_capability(device)[0] == 8 + ) + + +DISABLE_BACKWARD = True +# For CI test, we close them to True. +# DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" +# DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" +# DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" +# DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" +# DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" +# DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" +# DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" +# DISABLE_FP8 = ( +# os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" +# or torch.cuda.get_device_capability("cuda")[0] < 9 +# ) + +DISABLE_SPLIT = False +DISABLE_PAGEDKV = True +DISABLE_APPENDKV = False +DISABLE_LOCAL = False +DISABLE_SOFTCAP = True +DISABLE_PACKGQA = False +DISABLE_FP16 = True +DISABLE_FP8 = True + + +# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/padding.py +def unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = ( + (attention_mask + unused_mask) if unused_mask is not None else attention_mask + ) + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. + return ( + rearrange(hidden_states, "b s ... -> (b s) ...")[indices], + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def generate_random_padding_mask( + max_seqlen, batch_size, device, mode="random", zero_lengths=False +): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full( + (batch_size, 1), max_seqlen, device=device, dtype=torch.int32 + ) + elif mode == "random": + lengths = torch.randint( + max(0 if zero_lengths else 1, max_seqlen - 20), + max_seqlen + 1, + (batch_size, 1), + device=device, + ) + elif mode == "third": + lengths = torch.randint( + max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device + ) + + if zero_lengths: + # Generate zero-lengths every 5 batches and the last batch. + for i in range(batch_size): + if i % 5 == 0: + lengths[i] = 0 + lengths[-1] = 0 + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) + < lengths + ) + return padding_mask + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros( + (batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype + ) + output[indices] = hidden_states + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + sink_token_length=0, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange( + torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1" + ) + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + torch.logical_and( + col_idx < row_idx + sk - sq - window_size[0], + col_idx >= sink_token_length, + ), + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(-1, -1), # -1 means infinite window size + sink_token_length=0, + sinks: Optional[torch.Tensor] = None, + softcap=0.0, + upcast=True, + reorder_ops=False, + intermediate_dtype=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads, head_dim) + v: (batch_size, seqlen_k, nheads, head_dim_v) + qv: (batch_size, seqlen_q, nheads, head_dim_v) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim_v) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + qv = qv.float() if qv is not None else None + if q_descale is not None: + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) + q = (q.float() * q_descale).to(q.dtype) + qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None + if k_descale is not None: + k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) + if v_descale is not None: + v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if qv is not None: + scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) + if softcap > 0: + scores = torch.tanh(scores / softcap) * softcap + if key_padding_mask is not None: + scores.masked_fill_( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf") + ) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + sink_token_length, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + if sinks is None: + attention = torch.softmax(scores, dim=-1).to(v.dtype) + else: + scores_fp32 = scores.to(torch.float32) + logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) + sinks = rearrange(sinks, "h -> h 1 1") + logits_or_sinks_max = torch.maximum(sinks, logits_max) + unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp( + sinks - logits_or_sinks_max + ) + attention = (unnormalized_scores / normalizer).to(v.dtype) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill( + rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0 + ) + # Without this we might get NaN in dv + if key_padding_mask is not None: + attention = attention.masked_fill( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0 + ) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill( + torch.all(local_mask, dim=-1, keepdim=True), 0.0 + ) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + if intermediate_dtype is not None: + attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +def generate_qkv( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + kvpacked=False, + qkvpacked=False, + add_unused_qkv=False, + query_unused_mask=None, + key_unused_mask=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) + if query_unused_mask is not None or key_unused_mask is not None: + assert not kvpacked + assert not qkvpacked + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( + q, + query_padding_mask, + query_unused_mask, + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=q_unpad.device, + ) + seqused_q = None + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( + k, key_padding_mask, key_unused_mask + ) + v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * seqlen_k, + step=seqlen_k, + dtype=torch.int32, + device=k_unpad.device, + ) + seqused_k = None + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input( + dqkv_unpad, indices_q, batch_size, seqlen_q + ) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange( + dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input( + dkv_unpad, indices_k, batch_size, seqlen_k + ) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange( + dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input( + dk_unpad, indices_k, batch_size, seqlen_k + ) + else: + dk_pad_fn = lambda dk_unpad: rearrange( + dk_unpad, "(b s) h d -> b s h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", +) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize( + "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) +) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_sink", [False, True]) +# @pytest.mark.parametrize("has_sink", [False]) +@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) +# @pytest.mark.parametrize("new_kv", [True]) +# @pytest.mark.parametrize( +# "causal,local", +# [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else []), +# ) +# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) +@pytest.mark.parametrize("causal,local", [(False, False)]) +@pytest.mark.parametrize( + "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True] +) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) +@pytest.mark.parametrize( + "rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False] +) +# @pytest.mark.parametrize("rotary_interleaved", [True]) +@pytest.mark.parametrize( + "rotary_fraction", + ( + [0.0, 0.5, 1.0] + if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) + else [0.0] + ), +) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize( + "page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else []) +) +# @pytest.mark.parametrize("page_size", [None]) +# @pytest.mark.parametrize("has_leftpad", [False, True]) +@pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +@pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("varlen_q", [False, True]) +@pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +@pytest.mark.parametrize("d", [64]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + # (1, 128 * 1024), + # (16, 128 * 1024), + (128, 128), + (256, 512), # To test appending KV with more than 1 block + (2048, 3577), # Enough tile to test persistent scheduler + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + varlen_q, + has_batch_idx, + has_leftpad, + page_size, + rotary_fraction, + rotary_interleaved, + has_rotary_seqlens, + seqlen_new_eq_seqlen_q, + causal, + local, + new_kv, + mha_type, + dtype, + has_sink, +): + from sgl_kernel.flash_attn import flash_attn_with_kvcache + + if page_size is not None and seqlen_k % page_size != 0: + pytest.skip() + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + # batch_size = 1 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # nheads = 1 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + + if has_sink: + sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + sinks = None + + if dtype == torch.float8_e4m3fn or not is_hopper(): + # for fp8 and ampere arch, we not support v head dim != qk head dim + dv_vals = [d] + for dv in dv_vals: + has_qv = d == 64 and dv >= 256 + q = ( + torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + if has_qv: + qv = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random" + ) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input( + q, query_padding_mask + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = ( + rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + ) + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + + seqlen_new = ( + seqlen_q + if seqlen_new_eq_seqlen_q + else torch.randint(1, seqlen_q + 1, (1,)).item() + ) + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = ( + torch.randn( + batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + v = ( + torch.randn( + batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask( + seqlen_new, batch_size, device, mode="random" + ) + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input( + k, key_new_padding_mask + ) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + d, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + v_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + dv, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, + page_size, + batch_size_cache, + nheads_k, + d, + dv, + device, + dtype, + dtype_ref, + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + ( + seqlen_k + - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + + 1 + ) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat( + [ + ( + torch.randint( + 0, + cache_seqlens[i].item(), + (1,), + dtype=torch.int32, + device=device, + ) + if cache_seqlens[i].item() > 0 + else torch.zeros(1, dtype=torch.int32, device=device) + ) + for i in range(batch_size) + ] + ) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm( + batch_size_cache, dtype=torch.int32, device=device + )[:batch_size] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = ( + key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + ) + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, + arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k), + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = ( + k_cache if not has_batch_idx else k_cache[cache_batch_idx] + ).clone() + v_cache_ref = ( + v_cache if not has_batch_idx else v_cache[cache_batch_idx] + ).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, + arange < cache_seqlens_expanded + k_new_seqlens, + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat( + k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + v_cache_rep = repeat( + v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + key_leftpad=cache_leftpad, + sinks=sinks, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + sinks=sinks, + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + precompute_metadata_vals = [False] + for num_splits, precompute_metadata in itertools.product( + num_splits_vals, precompute_metadata_vals + ): + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + return_softmax_lse=True, + sinks=sinks, + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) + if not has_batch_idx + else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) + if not has_batch_idx + else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose( + v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 + ) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 + ) + else: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 + ) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * ( + out_pt - out_ref + ).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * ( + out_pt - out_ref + ).abs().mean().item() + + +def _generate_block_kvcache( + seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref +): + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + v_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", +) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize( + "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) +) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_sink", [False, True]) +# @pytest.mark.parametrize("has_sink", [False]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("add_unused_qkv", [False, True]) +# @pytest.mark.parametrize("add_unused_qkv", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +# @pytest.mark.parametrize("d", COMPILED_HDIMS) +@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (1, 3), + (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, + seqlen_k, + d, + add_unused_qkv, + causal, + local, + softcap, + deterministic, + has_qv, + mha_type, + dtype, + has_sink, +): + from sglang.jit_kernel.flash_attention import flash_attn_varlen_func + + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_q <= 2048 else 2 + nheads = 6 + # batch_size = 2 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + for dv in dv_vals: + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + + if has_sink: + sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + sinks = None + + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) + + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv( + q, + k, + v, + query_padding_mask, + key_padding_mask, + kvpacked=False, + query_unused_mask=query_unused_mask, + key_unused_mask=key_unused_mask, + ) + q_unpad, k_unpad, v_unpad = [ + x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) + ] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + softcap=softcap, + sinks=sinks, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + sinks=sinks, + ) + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out_unpad, lse, *rest = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + causal=causal, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + softcap=softcap, + return_softmax_lse=True, + sinks=sinks, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( + out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad + ) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/python/sglang/jit_kernel/tests/test_flash_attention_4.py b/python/sglang/jit_kernel/tests/test_flash_attention_4.py index e1453b8f2323..81b0f0b23d62 100644 --- a/python/sglang/jit_kernel/tests/test_flash_attention_4.py +++ b/python/sglang/jit_kernel/tests/test_flash_attention_4.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from einops import rearrange, repeat -from sglang.jit_kernel.flash_attention_v4 import flash_attn_varlen_func +from sglang.jit_kernel.flash_attention import flash_attn_varlen_func from sglang.test.ci.ci_register import register_cuda_ci register_cuda_ci(est_time=120, suite="stage-b-kernel-unit-1-gpu-large") @@ -826,6 +826,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): sinks=learnable_sink, # FA4 uses learnable_sink, not sinks pack_gqa=pack_gqa, return_softmax_lse=True, + ver=4, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: @@ -1384,6 +1385,7 @@ def test_flash_attn_kvcache( softcap=0.0, pack_gqa=None, return_softmax_lse=True, + ver=4, ) if varlen_q: out = output_pad_fn(out) diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py index 9c30a9798283..31372e2e16ce 100644 --- a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py @@ -5,27 +5,13 @@ import torch +from sglang.jit_kernel.flash_attention import flash_attn_varlen_func from sglang.multimodal_gen.runtime.layers.utils import register_custom_op from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context from sglang.multimodal_gen.runtime.platforms import ( AttentionBackendEnum, ) -try: - from sgl_kernel.flash_attn import flash_attn_varlen_func - - from sglang.jit_kernel.flash_attention_v4 import ( - flash_attn_varlen_func as flash_attn_varlen_func_fa4, - ) - - def flash_attn_func(*args, ver: int = 3, **kwargs): - if ver == 4: - return flash_attn_varlen_func_fa4(*args, **kwargs) - return flash_attn_varlen_func(*args, **kwargs) - -except ImportError as e: - raise e - def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -207,7 +193,7 @@ def flash_attn_varlen_func_op( "flash_attn_varlen_func_op is out-only op; return_softmax_lse must be False. " "Use flash_attn_varlen_func_op_lse for (out, lse)." ) - return flash_attn_func( + return flash_attn_varlen_func( q, k, v, @@ -271,7 +257,7 @@ def flash_attn_varlen_func_op_lse( "flash_attn_varlen_func_op_lse is out+lse op; return_softmax_lse must be True. " "Use flash_attn_varlen_func_op for out-only." ) - return flash_attn_func( + return flash_attn_varlen_func( q, k, v, @@ -409,7 +395,7 @@ def forward( # - fa_ver == 3: call python function (can return Tensor or (Tensor, Tensor) depending on flag) # - fa_ver == 4: call custom ops with FIXED return schema if fa_ver == 3: - flash_attn_op = flash_attn_func + flash_attn_op = flash_attn_varlen_func output = flash_attn_op( q=query, k=key, diff --git a/python/sglang/multimodal_gen/test/run_suite.py b/python/sglang/multimodal_gen/test/run_suite.py index 700d4d6b8b18..a6fef42e8a69 100644 --- a/python/sglang/multimodal_gen/test/run_suite.py +++ b/python/sglang/multimodal_gen/test/run_suite.py @@ -174,12 +174,14 @@ def collect_test_items(files, filter_expr=None): return test_items -def run_pytest(files, filter_expr=None): +def run_pytest(files, filter_expr=None, exitfirst=False): if not files: print("No files to run.") return 0 base_cmd = [sys.executable, "-m", "pytest", "-s", "-v"] + if exitfirst: + base_cmd.append("-x") # Add pytest -k filter if provided if filter_expr: @@ -349,7 +351,8 @@ def main(): print(f"Running {len(my_items)} items in this shard: {', '.join(my_items)}") # 4. execute with the specific test items - exit_code = run_pytest(my_items) + # Fast-fail: stop on first failure unless --continue-on-error is set + exit_code = run_pytest(my_items, exitfirst=not args.continue_on_error) # Print tests again at the end for visibility msg = "\n" + tabulate.tabulate(rows, headers=headers, tablefmt="psql") + "\n" diff --git a/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py b/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py index 645a9cac5486..f36e803dd11e 100755 --- a/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py +++ b/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py @@ -17,7 +17,12 @@ from pathlib import Path from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger -from sglang.multimodal_gen.test.run_suite import SUITES, collect_test_items, run_pytest +from sglang.multimodal_gen.test.run_suite import ( + SUITES, + _maybe_pin_update_weights_model_pair, + collect_test_items, + run_pytest, +) logger = init_logger(__name__) @@ -95,6 +100,7 @@ def main(): # Get files from suite (same as run_suite.py) suite_files_rel = SUITES[args.suite] + _maybe_pin_update_weights_model_pair(suite_files_rel) suite_files_abs = [] for f_rel in suite_files_rel: f_abs = target_dir / f_rel diff --git a/python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py b/python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py index 4086dcd74d90..32dabef50a9d 100644 --- a/python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py +++ b/python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py @@ -58,6 +58,18 @@ T2I_sampling_params, run_consistency_check=False, ), + DiffusionTestCase( + "qwen_image_t2i_2npu", + DiffusionServerArgs( + model_path="/root/.cache/modelscope/hub/models/Qwen/Qwen-Image", + modality="image", + num_gpus=2, + # test ring attn + ulysses_degree=1, + ring_degree=2, + ), + T2I_sampling_params, + ), ] EIGHT_NPU_CASES: list[DiffusionTestCase] = [ diff --git a/python/sglang/multimodal_gen/test/server/consistency_threshold.json b/python/sglang/multimodal_gen/test/server/consistency_threshold.json index 3795a9f6e28a..596e98166ee0 100644 --- a/python/sglang/multimodal_gen/test/server/consistency_threshold.json +++ b/python/sglang/multimodal_gen/test/server/consistency_threshold.json @@ -49,12 +49,6 @@ "psnr_threshold": 19.0, "mean_abs_diff_threshold": 10.0 }, - "sana_image_t2i": { - "clip_threshold": 0.91, - "ssim_threshold": 0.88, - "psnr_threshold": 21.0, - "mean_abs_diff_threshold": 8.4 - }, "qwen_image_edit_2509_ti2i": { "clip_threshold": 0.92, "ssim_threshold": 0.65, diff --git a/python/sglang/multimodal_gen/test/server/test_server_common.py b/python/sglang/multimodal_gen/test/server/test_server_common.py index dd48e7e0c7b8..f8ac02c2c761 100644 --- a/python/sglang/multimodal_gen/test/server/test_server_common.py +++ b/python/sglang/multimodal_gen/test/server/test_server_common.py @@ -51,14 +51,6 @@ logger = init_logger(__name__) -def _is_lora_case(case: DiffusionTestCase) -> bool: - return bool( - case.server_args.lora_path - or case.server_args.dynamic_lora_path - or case.server_args.second_lora_path - ) - - @pytest.fixture def diffusion_server(case: DiffusionTestCase) -> ServerContext: """Start a diffusion server for a single case and tear it down afterwards.""" @@ -81,11 +73,6 @@ def diffusion_server(case: DiffusionTestCase) -> ServerContext: sampling_params = case.sampling_params extra_args = os.environ.get("SGLANG_TEST_SERVE_ARGS", "") - # Keep LoRA GT on the normal backend path so adapter state matches CI. - if os.environ.get("SGLANG_GEN_GT", "0") == "1": - if not _is_lora_case(case) and "--backend" not in extra_args: - extra_args = "--backend diffusers " + extra_args.strip() - extra_args += f" --num-gpus {server_args.num_gpus}" if server_args.tp_size is not None: @@ -235,18 +222,21 @@ def run_and_collect( ctx: ServerContext, case_id: str, generate_fn: Callable[[str, openai.Client], tuple[str, bytes]], - ) -> tuple[RequestPerfRecord, bytes]: - """Run generation and collect performance records. + collect_perf: bool = True, + ) -> tuple[RequestPerfRecord | None, bytes]: + """Run generation and optionally collect performance records. Returns: Tuple of (performance_record, content_bytes) """ - log_path = ctx.perf_log_path - log_wait_timeout = 30 - client = self._client(ctx) rid, content = generate_fn(case_id, client) + if not collect_perf: + return None, content + + log_path = ctx.perf_log_path + log_wait_timeout = 30 req_perf_record = wait_for_req_perf_record( rid, log_path, @@ -1024,6 +1014,7 @@ def test_diffusion_generation( diffusion_server, case.id, generate_fn, + collect_perf=not is_gt_gen_mode, ) if is_gt_gen_mode: diff --git a/python/sglang/multimodal_gen/test/server/testcase_configs.py b/python/sglang/multimodal_gen/test/server/testcase_configs.py index d879adce616b..e1c837691fb9 100644 --- a/python/sglang/multimodal_gen/test/server/testcase_configs.py +++ b/python/sglang/multimodal_gen/test/server/testcase_configs.py @@ -500,15 +500,6 @@ def from_req_perf_record( run_lora_dynamic_switch_check=True, run_multi_lora_api_check=True, ), - DiffusionTestCase( - "sana_image_t2i", - DiffusionServerArgs( - model_path="Efficient-Large-Model/Sana_600M_1024px_diffusers", - modality="image", - ), - T2I_sampling_params, - run_perf_check=False, - ), # === Text and Image to Image (TI2I) === DiffusionTestCase( "qwen_image_edit_ti2i", @@ -804,7 +795,6 @@ def from_req_perf_record( modality="image", ), T2I_sampling_params, - run_consistency_check=False, ) ] @@ -945,7 +935,6 @@ def from_req_perf_record( extras=["--pipeline-class-name LTX2TwoStagePipeline"], ), T2V_sampling_params, - run_consistency_check=False, ), ] diff --git a/python/sglang/srt/compilation/backend.py b/python/sglang/srt/compilation/backend.py index f9d376e959be..201123324068 100644 --- a/python/sglang/srt/compilation/backend.py +++ b/python/sglang/srt/compilation/backend.py @@ -21,6 +21,7 @@ from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend from sglang.srt.compilation.npu_piecewise_backend import NPUPiecewiseBackend from sglang.srt.compilation.pass_manager import PostGradPassManager +from sglang.srt.environ import envs from sglang.srt.utils.common import is_npu logger = logging.getLogger(__name__) @@ -393,9 +394,7 @@ def configure_post_pass(self): self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: - base_cache_dir = os.path.expanduser( - os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/") - ) + base_cache_dir = envs.SGLANG_CACHE_DIR.get() cache_hash = self.compiler_manager.compute_hash() cache_dir = os.path.join( diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index f84353f6dc12..005d5b05c286 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -84,6 +84,8 @@ class KVArgsRegisterInfo: decode_tp_size: int decode_tp_rank: int dst_kv_item_len: int + dst_state_item_lens: list[int] = dataclasses.field(default_factory=list) + dst_state_dim_per_tensor: list[int] = dataclasses.field(default_factory=list) @classmethod def from_zmq(cls, msg: List[bytes]): @@ -93,6 +95,15 @@ def from_zmq(cls, msg: List[bytes]): else: dst_state_data_ptrs = [] + dst_state_item_lens = [] + dst_state_dim_per_tensor = [] + if len(msg) > 12 and len(msg[12]) > 0: + dst_state_item_lens = list(struct.unpack(f"{len(msg[12]) // 4}I", msg[12])) + if len(msg) > 13 and len(msg[13]) > 0: + dst_state_dim_per_tensor = list( + struct.unpack(f"{len(msg[13]) // 4}I", msg[13]) + ) + return cls( room=str(msg[0].decode("ascii")), endpoint=msg[1].decode("ascii"), @@ -106,6 +117,8 @@ def from_zmq(cls, msg: List[bytes]): decode_tp_size=int(msg[9].decode("ascii")), decode_tp_rank=int(msg[10].decode("ascii")), dst_kv_item_len=int(msg[11].decode("ascii")), + dst_state_item_lens=dst_state_item_lens, + dst_state_dim_per_tensor=dst_state_dim_per_tensor, ) @@ -681,6 +694,106 @@ def _send_mamba_state( raise Exception("Failed to post Mamba state transfer") return xfer_handle + def _send_mamba_state_slice( + self, + peer_name: str, + prefill_state_indices: List[int], + dst_state_data_ptrs: list[int], + dst_state_indices: List[int], + dst_gpu_id: int, + notif: str, + dst_state_item_lens: list[int], + dst_state_dim_per_tensor: list[int], + decode_tp_size: int, + decode_tp_rank: int, + ): + """Transfer Mamba states with TP slice support via RDMA. + + When prefill and decode have different attn_tp_size, we slice the + TP-sharded dimension (3rd dim) of conv_state and temporal_state + accordingly, mirroring Mooncake's _send_mamba_state_slice. + """ + logger.warning_once( + "Using Mamba state slice transfer for different TP sizes. " + f"Prefill attn_tp_size={self.attn_tp_size}, " + f"Decode attn_tp_size={decode_tp_size}." + ) + assert len(prefill_state_indices) == 1, "Mamba should have single state index" + + prefill_state_data_ptrs = self.kv_args.state_data_ptrs + prefill_state_item_lens = self.kv_args.state_item_lens + src_state_dim_per_tensor = getattr(self.kv_args, "state_dim_per_tensor", []) + + if not src_state_dim_per_tensor or not dst_state_dim_per_tensor: + return self._send_mamba_state( + peer_name, + prefill_state_indices, + dst_state_data_ptrs, + dst_state_indices, + dst_gpu_id, + notif, + ) + + local_tp_rank_in_group = self.kv_args.engine_rank % self.attn_tp_size + dst_tp_rank_in_group = decode_tp_rank % decode_tp_size + + src_addrs = [] + dst_addrs = [] + + for i, dst_state_ptr in enumerate(dst_state_data_ptrs): + src_item_len = prefill_state_item_lens[i] + dst_item_len = dst_state_item_lens[i] + src_dim = src_state_dim_per_tensor[i] + dst_dim = dst_state_dim_per_tensor[i] + + src_bytes_per_dim = src_item_len // src_dim + dst_bytes_per_dim = dst_item_len // dst_dim + + if self.attn_tp_size > decode_tp_size: + src_dim_start = 0 + num_dims_to_send = src_dim + writers_per_decode = self.attn_tp_size // decode_tp_size + local_writer_idx = local_tp_rank_in_group % writers_per_decode + dst_dim_start = local_writer_idx * src_dim + else: + src_dim_start = (dst_tp_rank_in_group * dst_dim) % src_dim + num_dims_to_send = dst_dim + dst_dim_start = 0 + + src_dim_offset = src_dim_start * src_bytes_per_dim + dst_dim_offset = dst_dim_start * dst_bytes_per_dim + bytes_to_send = num_dims_to_send * src_bytes_per_dim + + src_addr = ( + prefill_state_data_ptrs[i] + + src_item_len * int(prefill_state_indices[0]) + + src_dim_offset + ) + dst_addr = ( + dst_state_ptr + + dst_item_len * int(dst_state_indices[0]) + + dst_dim_offset + ) + src_addrs.append((src_addr, bytes_to_send, self.kv_args.gpu_id)) + dst_addrs.append((dst_addr, bytes_to_send, dst_gpu_id)) + + src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM") + dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM") + + xfer_handle = self.agent.initialize_xfer( + "WRITE", + src_descs, + dst_descs, + peer_name, + notif.encode("ascii"), + ) + if not xfer_handle: + raise Exception("Failed to create Mamba state slice transfer") + state = self.agent.transfer(xfer_handle) + if state == "ERR": + raise Exception("Failed to post Mamba state slice transfer") + return xfer_handle + def maybe_send_extra( self, peer_name: str, @@ -690,14 +803,26 @@ def maybe_send_extra( dst_gpu_id: int, notif: str, decode_tp_size: int, + decode_tp_rank: int = 0, + dst_state_item_lens: list[int] | None = None, + dst_state_dim_per_tensor: list[int] | None = None, ): """Send state or extra pool data with type-specific handling.""" state_type = getattr(self.kv_args, "state_type", "none") if state_type == "mamba": if self.attn_tp_size != decode_tp_size: - raise RuntimeError( - "PD Disaggregation does NOT support PD different TP sizes for hybrid mamba models yet." + return self._send_mamba_state_slice( + peer_name, + prefill_state_indices, + dst_state_data_ptrs, + dst_state_indices, + dst_gpu_id, + notif, + dst_state_item_lens or [], + dst_state_dim_per_tensor or [], + decode_tp_size, + decode_tp_rank, ) return self._send_mamba_state( peer_name, @@ -803,6 +928,9 @@ def add_transfer_request( dst_info.gpu_id, f"{req.room}_state_{self.kv_args.engine_rank}", decode_tp_size, + decode_tp_rank=dst_info.decode_tp_rank, + dst_state_item_lens=dst_info.dst_state_item_lens, + dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor, ) if state_xfer_handle is not None: handles.append(state_xfer_handle) @@ -1080,6 +1208,17 @@ def _register_kv_args(self): struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs ) + packed_state_item_lens = b"".join( + struct.pack("I", item_len) + for item_len in self.kv_mgr.kv_args.state_item_lens + ) + state_dim_per_tensor = getattr( + self.kv_mgr.kv_args, "state_dim_per_tensor", [] + ) + packed_state_dim_per_tensor = b"".join( + struct.pack("I", dim) for dim in state_dim_per_tensor + ) + with lock: sock.send_multipart( [ @@ -1096,6 +1235,8 @@ def _register_kv_args(self): str(self.kv_mgr.attn_tp_size).encode("ascii"), str(self.kv_mgr.kv_args.engine_rank).encode("ascii"), str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"), + packed_state_item_lens, + packed_state_dim_per_tensor, ] ) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 61f7513f567f..43afc1577a56 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -206,6 +206,32 @@ def get_global_state() -> _GlobalState: return _global_state +async def _init_granian_worker() -> ServerArgs: + main_pid = get_main_process_id() + port_args, server_args, scheduler_info = read_from_shared_memory( + f"multi_tokenizer_args_{main_pid}" + ) + + tokenizer_manager = TokenizerManager(server_args, port_args) + template_manager = TemplateManager() + template_manager.initialize_templates( + tokenizer_manager=tokenizer_manager, + model_path=server_args.model_path, + chat_template=server_args.chat_template, + completion_template=server_args.completion_template, + ) + tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] + + set_global_state( + _GlobalState( + tokenizer_manager=tokenizer_manager, + template_manager=template_manager, + scheduler_info=scheduler_info, + ) + ) + return server_args + + async def init_multi_tokenizer() -> ServerArgs: """ Initialization function for multi-process tokenizer mode. @@ -263,6 +289,10 @@ async def lifespan(fast_api_app: FastAPI): server_args = fast_api_app.server_args warmup_thread_kwargs = fast_api_app.warmup_thread_kwargs thread_label = "Tokenizer" + elif envs.SGLANG_GRANIAN_PARENT_PID.get() is not None: + server_args = await _init_granian_worker() + warmup_thread_kwargs = dict(server_args=server_args) + thread_label = "Tokenizer" else: # Initialize multi-tokenizer support for worker processes server_args = await init_multi_tokenizer() @@ -2017,6 +2047,53 @@ def _wait_weights_ready(): ) +def _close_main_process_sockets(): + """Close the main process's ZMQ sockets before spawning Granian workers. + + Granian workers create their own TokenizerManager with fresh ZMQ sockets. + The main process must release its sockets first to avoid binding conflicts + on the same IPC addresses. + """ + if _global_state is None or _global_state.tokenizer_manager is None: + return + tm = _global_state.tokenizer_manager + for attr in ("recv_from_detokenizer", "send_to_scheduler"): + sock = getattr(tm, attr, None) + if sock is None: + continue + inner = getattr(sock, "socket", None) + if inner is not None: + inner.close() + elif hasattr(sock, "close"): + sock.close() + setattr(tm, attr, None) + + +def _run_granian_server(server_args: ServerArgs): + """Launch Granian with HTTP/2 support""" + from granian import Granian + from granian.constants import HTTPModes, Interfaces, Loops + + granian_kwargs = dict( + target="sglang.srt.entrypoints.http_server:app", + address=server_args.host, + port=server_args.port, + interface=Interfaces.ASGI, + http=HTTPModes.auto, + loop=Loops.uvloop, + log_level=server_args.log_level_http or server_args.log_level or "info", + workers=1, + ) + + ssl_enabled = server_args.ssl_certfile and server_args.ssl_keyfile + if ssl_enabled: + granian_kwargs["ssl_cert"] = server_args.ssl_certfile + granian_kwargs["ssl_key"] = server_args.ssl_keyfile + + server = Granian(**granian_kwargs) + server.serve() + + def _setup_and_run_http_server( server_args: ServerArgs, tokenizer_manager, @@ -2047,6 +2124,35 @@ def _setup_and_run_http_server( if server_args.enable_metrics: add_prometheus_track_response_middleware(app) + # Use Granian for HTTP/2 server + if server_args.enable_http2: + # Reuse the multi-tokenizer shared memory mechanism to pass + # init args (port_args, server_args, scheduler_info) to + # Granian workers, which are independent processes. + multi_tokenizer_args_shm = write_data_for_multi_tokenizer( + port_args, server_args, scheduler_infos[0] + ) + try: + if server_args.ssl_certfile: + logger.info( + f"SSL enabled: certfile={server_args.ssl_certfile}, " + f"keyfile={server_args.ssl_keyfile}" + ) + logger.info( + f"Starting Granian HTTP/2 server on " + f"{server_args.host}:{server_args.port}" + ) + # Propagate the main process PID via os.environ so Granian + # workers (forked or spawned) can locate the shared memory + # segment created above. + envs.SGLANG_GRANIAN_PARENT_PID.set(os.getpid()) + _close_main_process_sockets() + _run_granian_server(server_args) + finally: + if multi_tokenizer_args_shm is not None: + multi_tokenizer_args_shm.unlink() + return + # Pass additional arguments to the lifespan function. # They will be used for additional initialization setups. if server_args.tokenizer_worker_num == 1: diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index dfc5507de0ba..e9716f1afd3f 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -336,6 +336,7 @@ class Envs: SGLANG_CPU_QUANTIZATION = EnvBool(False) SGLANG_USE_DYNAMIC_MXFP4_LINEAR = EnvBool(False) SGLANG_FORCE_FP8_MARLIN = EnvBool(False) + SGLANG_FORCE_NVFP4_MARLIN = EnvBool(False) SGLANG_MOE_NVFP4_DISPATCH = EnvBool(False) SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN = EnvBool(False) SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2 = EnvBool(False) @@ -406,6 +407,9 @@ class Envs: # sgl-kernel SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK = EnvBool(False) + # Flash Attention + SGLANG_USE_SGL_FA3_KERNEL = EnvBool(True) + # vLLM dependencies (TODO: they have been deprecated, we can remove them safely) USE_VLLM_CUTLASS_W8A8_FP8_KERNEL = EnvBool(False) @@ -489,6 +493,9 @@ class Envs: # HTTP Server SGLANG_TIMEOUT_KEEP_ALIVE = EnvInt(5) + # HTTP/2 Server + SGLANG_GRANIAN_PARENT_PID = EnvInt(None) + # Health Check SGLANG_ENABLE_HEALTH_ENDPOINT_GENERATION = EnvBool(True) @@ -531,6 +538,9 @@ class Envs: # Elastic EP Backup Port SGLANG_BACKUP_PORT_BASE = EnvInt(10000) + # Sglang Cache Dir + SGLANG_CACHE_DIR = EnvStr(os.path.expanduser("~/.cache/sglang")) + envs = Envs() EnvField._allow_set_name = False diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index d93cc5f62f1c..e7acdee86e90 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -860,7 +860,7 @@ def forward_extend( sinks: Optional[torch.Tensor] = None, slopes: Optional[torch.Tensor] = None, ): - if is_mla_preprocess_enabled(): + if is_mla_preprocess_enabled() and self.use_mla: # MLAPO and MLAPROLOG do save kv_cache save_kv_cache = False if self.is_dllm_model: @@ -1773,7 +1773,7 @@ def forward_decode( sinks: Optional[torch.Tensor] = None, slopes: Optional[torch.Tensor] = None, ): - if is_mla_preprocess_enabled(): + if is_mla_preprocess_enabled() and self.use_mla: # MLAPO does saving kv_cache save_kv_cache = False if topk_indices is not None: diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py new file mode 100644 index 000000000000..00158319fb46 --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py @@ -0,0 +1,429 @@ +from typing import Optional, Tuple, Union + +import torch +from sgl_kernel_npu.fla.fused_gdn_gating import fused_gdn_gating_npu +from sgl_kernel_npu.mamba.causal_conv1d import ( + causal_conv1d_fn_npu, + causal_conv1d_update_npu, +) + +from sglang.srt.layers.attention.fla.fused_gdn_gating import ( + fused_gdn_gating_kernel_without_sigmoid, +) +from sglang.srt.layers.attention.linear.gdn_backend import ( + GDNAttnBackend, + GDNKernelDispatcher, +) +from sglang.srt.layers.attention.linear.utils import ( + get_linear_attn_decode_backend, + get_linear_attn_prefill_backend, +) +from sglang.srt.layers.radix_linear_attention import RadixLinearAttention +from sglang.srt.mem_cache.memory_pool import MambaPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput +from sglang.srt.utils import is_cpu + +fused_gdn_gating = fused_gdn_gating_npu +causal_conv1d_fn = causal_conv1d_fn_npu +causal_conv1d_update = causal_conv1d_update_npu + + +class AscendGDNKernelDispatcher(GDNKernelDispatcher): + pass + + +class AscendGDNAttnBackend(GDNAttnBackend): + + def __init__(self, model_runner: ModelRunner): + super().__init__(model_runner) + # transpose last two dim for _init_npu_conv_state + self.conv_states_shape = torch.Size( + (*self.conv_states_shape[:-2], self.conv_states_shape[-1], self.conv_states_shape[-2]) + ) + decode_backend = get_linear_attn_decode_backend() + prefill_backend = get_linear_attn_prefill_backend() + self.kernel_dispatcher = AscendGDNKernelDispatcher( + decode_backend, prefill_backend + ) + + def prepare_gdn_inputs( + self, + bs: int, + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + cache_indices = self.forward_metadata.mamba_cache_indices + self.num_accepted_tokens = torch.ones( + [bs], dtype=torch.int32, device=cache_indices.device + ) + self.actual_seq_lengths = torch.ones( + [bs], dtype=torch.int32, device=cache_indices.device + ) + if forward_mode.is_target_verify(): + seq_len = spec_info.draft_token_num + self.actual_seq_lengths = self.actual_seq_lengths * seq_len + # indices + self.ssm_state_indices = torch.arange( + cache_indices.shape[0] * seq_len, dtype=torch.int32, device=cache_indices.device + ) + else: + self.ssm_state_indices = cache_indices + + def init_forward_metadata(self, forward_batch: ForwardBatch): + if forward_batch.forward_mode.is_draft_extend(True): + return + super().init_forward_metadata(forward_batch) + self.prepare_gdn_inputs( + forward_batch.batch_size, + forward_batch.forward_mode, + forward_batch.spec_info, + ) + self.graph_mode = False + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + if forward_mode.is_draft_extend(True): + return + super().init_forward_metadata_capture_cuda_graph( + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_mode, + spec_info, + ) + self.prepare_gdn_inputs(bs, forward_mode, spec_info) + self.graph_mode = True + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], + ): + if forward_mode.is_draft_extend(True): + return + super().init_forward_metadata_replay_cuda_graph( + bs, + req_pool_indices, + seq_lens, + seq_lens_sum, + encoder_lens, + forward_mode, + spec_info, + seq_lens_cpu, + ) + self.prepare_gdn_inputs(bs, forward_mode, spec_info) + self.graph_mode = True + + def forward_decode( + self, + layer: RadixLinearAttention, + forward_batch: ForwardBatch, + mixed_qkv: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + a: torch.Tensor, + b: torch.Tensor, + **kwargs, + ): + layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer.layer_id) + conv_states = layer_cache.conv[0] + ssm_states = layer_cache.temporal + query_start_loc = self.forward_metadata.query_start_loc + cache_indices = self.forward_metadata.mamba_cache_indices + + assert isinstance(mixed_qkv, torch.Tensor) + conv_states_tmp = conv_states.transpose(1, 2).clone() + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_states_tmp, + layer.conv_weights, + layer.bias, + layer.activation, + conv_state_indices=cache_indices, + ) + conv_states[:] = conv_states_tmp.transpose(1, 2) + + query, key, value = torch.split( + mixed_qkv, + [layer.q_dim, layer.k_dim, layer.v_dim], + dim=-1, + ) + bs = forward_batch.batch_size + query = query.view(1, bs, layer.num_q_heads, layer.head_q_dim) + key = key.view(1, bs, layer.num_k_heads, layer.head_k_dim) + value = value.view(1, bs, layer.num_v_heads, layer.head_v_dim) + + core_attn_out = self.kernel_dispatcher.decode( + q=query, + k=key, + v=value, + a=a, + b=b, + A_log=layer.A_log, + dt_bias=layer.dt_bias, + ssm_states=ssm_states, + cache_indices=cache_indices, + query_start_loc=query_start_loc, + ) + + self._track_mamba_state_decode( + forward_batch, conv_states, ssm_states, cache_indices + ) + return core_attn_out + + def forward_extend( + self, + layer: RadixLinearAttention, + forward_batch: ForwardBatch, + mixed_qkv: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + a: torch.Tensor, + b: torch.Tensor, + **kwargs, + ): + assert isinstance(mixed_qkv, torch.Tensor) + seq_len = mixed_qkv.shape[0] + is_target_verify = forward_batch.forward_mode.is_target_verify() + forward_metadata = self.forward_metadata + + query_start_loc = forward_metadata.query_start_loc + cache_indices = forward_metadata.mamba_cache_indices + retrieve_next_token = forward_metadata.retrieve_next_token + retrieve_next_sibling = forward_metadata.retrieve_next_sibling + retrieve_parent_token = forward_metadata.retrieve_parent_token + + mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer.layer_id) + conv_states = mamba_cache_params.conv[0] + ssm_states = mamba_cache_params.temporal + if is_target_verify: + assert isinstance(mamba_cache_params, MambaPool.SpeculativeState) + intermediate_state_cache = mamba_cache_params.intermediate_ssm + intermediate_conv_window_cache = ( + mamba_cache_params.intermediate_conv_window[0] + ) + has_initial_states = torch.ones( + seq_len // forward_batch.spec_info.draft_token_num, + dtype=torch.bool, + device=forward_batch.input_ids.device, + ) + else: + has_initial_states = forward_batch.extend_prefix_lens > 0 + if is_target_verify: + batch_size = seq_len // forward_batch.spec_info.draft_token_num + draft_token_num = forward_batch.spec_info.draft_token_num + num_token_padding = mixed_qkv.shape[0] + batch_size = cache_indices.shape[0] + if ( + not self.graph_mode + and forward_batch.num_token_non_padded_cpu != num_token_padding + ): + mixed_qkv = mixed_qkv[: forward_batch.num_token_non_padded_cpu] + a = a[: forward_batch.num_token_non_padded_cpu] + b = b[: forward_batch.num_token_non_padded_cpu] + seq_len = forward_batch.num_token_non_padded_cpu + + mixed_qkv_reshaped = mixed_qkv.view(batch_size, draft_token_num, -1) + num_accepted_tokens = torch.full( + (batch_size,), + draft_token_num, + dtype=torch.int32, + device=mixed_qkv.device, + ) + mixed_qkv = torch.ops.npu.causal_conv1d_update( + mixed_qkv_reshaped, + layer.conv_weights.transpose(0, 1).contiguous(), + conv_states, + cache_indices, + layer.bias, + num_accepted_tokens, + None, + layer.activation == "silu", + self.pad_slot_id, + ).view(seq_len, -1) + else: + mixed_qkv = mixed_qkv.transpose(0, 1) + if ( + forward_batch.mamba_track_mask is not None + and forward_batch.mamba_track_mask.any() + ): + conv_dst = forward_batch.mamba_track_indices + mixed_qkv_to_track = mixed_qkv[ + :, forward_metadata.track_conv_indices + ].transpose(0, 1) + mask_indices = forward_batch.mamba_track_mask.nonzero(as_tuple=True)[0] + conv_states.transpose(1, 2)[conv_dst[mask_indices]] = mixed_qkv_to_track + kernel_size = layer.conv_weights.shape[-1] + conv_states_for_prefill = conv_states[:, -(kernel_size - 1) :, :] + conv_states_tmp = conv_states_for_prefill.contiguous() + + x = mixed_qkv.transpose(0, 1).contiguous() + weight = layer.conv_weights.transpose(0, 1).contiguous() + activation_mode = layer.activation == "silu" + + mixed_qkv = torch.ops.npu.causal_conv1d( + x, + weight, + conv_states_tmp, + query_start_loc, + cache_indices, + has_initial_states, + layer.bias, + activation_mode, + self.pad_slot_id, + )[:seq_len] + + conv_states[:, -(kernel_size - 1) :, :] = conv_states_tmp + if is_target_verify: + g, beta = fused_gdn_gating_kernel_without_sigmoid( + layer.A_log, a, b, layer.dt_bias + ) + beta = beta.unsqueeze(0) + num_heads, head_k_dim = layer.num_q_heads, layer.head_q_dim + num_value_heads, head_v_dim = layer.num_v_heads, layer.head_v_dim + + mixed_qkv_last_dim = mixed_qkv.shape[-1] + + mixed_qkv = mixed_qkv.view(batch_size, -1, mixed_qkv_last_dim) + beta = beta.view(batch_size, -1, num_value_heads) + g = g.view(batch_size, -1, num_value_heads) + + core_attn_out = self.fused_recurrent_gated_delta_rule_update( + mixed_qkv, + num_heads, + num_value_heads, + head_k_dim, + head_v_dim, + recurrent_state=ssm_states, + beta=beta, + g=g, + cache_indices=cache_indices, + intermediate_state=intermediate_state_cache, + ) + core_attn_out = core_attn_out.view(-1, num_value_heads, head_v_dim) + if (not self.graph_mode) and core_attn_out.shape[0] < num_token_padding: + core_attn_out = torch.cat( + [ + core_attn_out, + core_attn_out.new_zeros( + num_token_padding - core_attn_out.shape[0], + *core_attn_out.shape[1:], + ), + ], + dim=0, + ) + else: + query, key, value = torch.split( + mixed_qkv, + [layer.q_dim, layer.k_dim, layer.v_dim], + dim=-1, + ) + + actual_seq_len = query.shape[0] + query = query.view(1, actual_seq_len, layer.num_q_heads, layer.head_q_dim) + key = key.view(1, actual_seq_len, layer.num_k_heads, layer.head_k_dim) + value = value.view(1, actual_seq_len, layer.num_v_heads, layer.head_v_dim) + + g, beta = fused_gdn_gating(layer.A_log, a, b, layer.dt_bias) + core_attn_out, last_recurrent_state, h = self.kernel_dispatcher.extend( + q=query, + k=key, + v=value, + g=g, + beta=beta, + ssm_states=ssm_states, + cache_indices=cache_indices, + query_start_loc=query_start_loc, + ) + if is_cpu() and last_recurrent_state is not None: + last_recurrent_state = last_recurrent_state.to( + ssm_states.dtype, copy=False + ) + ssm_states[cache_indices] = last_recurrent_state + if not forward_batch.spec_algorithm.is_none(): + last_recurrent_state = last_recurrent_state.transpose(-1, -2).to( + ssm_states.dtype, copy=False + ) + else: + last_recurrent_state = last_recurrent_state.to( + ssm_states.dtype, copy=False + ) + ssm_states[cache_indices] = last_recurrent_state + if h is not None: + self._track_mamba_state_extend( + forward_batch, h, ssm_states, forward_metadata + ) + + return core_attn_out + + def fused_recurrent_gated_delta_rule_update( + self, + mix_qkv: torch.Tensor, + num_heads, + num_value_heads, + head_k_dim, + head_v_dim, + recurrent_state: torch.Tensor, + beta: torch.Tensor, + g: torch.Tensor, + cache_indices: torch.Tensor, + intermediate_state: Optional[torch.Tensor] = None, + ): + beta = beta.to(torch.bfloat16) + g = g.to(torch.float32) + batch_size = mix_qkv.shape[0] + seq_len = mix_qkv.shape[1] + scale = 1 / (head_k_dim**0.5) + + if intermediate_state is not None: + intermediate_state = intermediate_state.view( + -1, num_value_heads, head_k_dim, head_v_dim + ) + + if self.graph_mode: + num_accepted_tokens = torch.full( + [batch_size], 1, dtype=torch.int32, device=cache_indices.device + ) + actual_seq_lengths = torch.full( + [batch_size], seq_len, dtype=torch.int32, device=cache_indices.device + ) + ssm_state_indices = self.forward_metadata.mamba_cache_indices_gdn + else: + num_accepted_tokens = self.num_accepted_tokens + actual_seq_lengths = self.actual_seq_lengths + ssm_state_indices = self.ssm_state_indices + + attn_core_out = torch.ops.npu.recurrent_gated_delta_rule( + mix_qkv, + recurrent_state, + beta=beta, + scale=scale, + actual_seq_lengths=actual_seq_lengths, + ssm_state_indices=ssm_state_indices.view(batch_size, seq_len), + nk=num_heads, + nv=num_value_heads, + intermediate_state=intermediate_state, + cache_indices=cache_indices, + num_accepted_tokens=num_accepted_tokens, + g=g, + ) + + if intermediate_state is not None: + intermediate_state = intermediate_state.view( + -1, seq_len, num_value_heads, head_k_dim, head_v_dim + ) + return attn_core_out diff --git a/python/sglang/srt/hardware_backend/npu/memory_pool_npu.py b/python/sglang/srt/hardware_backend/npu/memory_pool_npu.py index ea81f4e589e6..e4f319fa5859 100644 --- a/python/sglang/srt/hardware_backend/npu/memory_pool_npu.py +++ b/python/sglang/srt/hardware_backend/npu/memory_pool_npu.py @@ -15,6 +15,30 @@ from sglang.srt.layers.radix_attention import RadixAttention +def _init_npu_conv_state( + conv_state_in, conv_state_shape, speculative_num_draft_tokens: Optional[int] = None +): + extra_conv_len = 0 + if speculative_num_draft_tokens is not None: + extra_conv_len = speculative_num_draft_tokens - 1 + + # conv_state shape (layers, pool_size, conv_wind + draft_step, dim) for conv1d ascendc ops require dim as last dim + conv_state = [ + torch.zeros( + size=( + conv_state_in.shape[0], + conv_state_in.shape[1], + conv_shape[1] + extra_conv_len, + conv_shape[0], + ), + dtype=conv_state_in.dtype, + device=conv_state_in.device, + ) + for conv_shape in conv_state_shape + ] + return conv_state + + class NPUMHATokenToKVPool(MHATokenToKVPool): def __init__( diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index 0a5920575c0f..8a9a06e9590c 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -197,7 +197,6 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac HybridLinearAttnBackend, Mamba2AttnBackend, ) - from sglang.srt.layers.attention.linear.gdn_backend import GDNAttnBackend from sglang.srt.layers.attention.linear.kda_backend import KDAAttnBackend from sglang.srt.layers.attention.linear.lightning_backend import ( LightningAttentionBackend, @@ -207,6 +206,13 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac ) from sglang.srt.utils import is_blackwell, is_npu + if is_npu(): + from sglang.srt.hardware_backend.npu.attention.ascend_gdn_backend import ( + AscendGDNAttnBackend as GDNAttnBackend, + ) + else: + from sglang.srt.layers.attention.linear.gdn_backend import GDNAttnBackend + check_environments() initialize_linear_attn_config(runner.server_args) if runner.hybrid_gdn_config is not None: diff --git a/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py b/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py index e522fbe4a934..a84015a803f8 100644 --- a/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +++ b/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py @@ -9,13 +9,16 @@ import torch import torch.nn.functional as F -from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from sgl_kernel.sparse_flash_attn import ( convert_vertical_slash_indexes, convert_vertical_slash_indexes_mergehead, sparse_attn_func, ) +from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, +) from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.flashattention_backend import FlashAttentionMetadata diff --git a/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py b/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py index 6e92208ec130..a82c18ad9abb 100644 --- a/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py +++ b/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py @@ -67,3 +67,69 @@ def fused_gdn_gating( num_warps=1, ) return g, beta_output + + +@triton.jit +def fused_gdn_gating_kernel_without_sigmoid_kernel( + g, + A_log, + a, + dt_bias, + batch, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_BATCHES: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + batch_off = i_b * BLK_BATCHES + tl.arange(0, BLK_BATCHES) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + head_mask = head_off < NUM_HEADS + a_off = ( + batch_off[:, None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[None, :] + ) + a_mask = (batch_off[:, None] < batch) & head_mask[None, :] + blk_A_log = tl.load(A_log + head_off, mask=head_mask) + blk_bias = tl.load(dt_bias + head_off, mask=head_mask) + blk_a = tl.load(a + a_off, mask=a_mask) + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where( + beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x + ) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + a_off, blk_g.to(g.dtype.element_ty), mask=a_mask) + + +def fused_gdn_gating_kernel_without_sigmoid( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + batch, num_heads = a.shape + seq_len = 1 + g = torch.empty_like(a, dtype=torch.float32) + num_cores = 48 # num_vectorcore of NPU + NUM_BLK_BATCHES = triton.cdiv(num_cores, triton.cdiv(num_heads, 8)) + BLK_BATCHES = triton.cdiv(batch, NUM_BLK_BATCHES) + grid = (NUM_BLK_BATCHES, seq_len, triton.cdiv(num_heads, 8)) + fused_gdn_gating_kernel_without_sigmoid_kernel[grid]( + g, + A_log, + a, + dt_bias, + batch, + seq_len, + num_heads, + beta, + threshold, + BLK_BATCHES=BLK_BATCHES, + BLK_HEADS=8, + num_warps=1, + ) + g = g.unsqueeze(0) + return g, b diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index ff170c390838..ad7f59c0d539 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -27,6 +27,11 @@ from sgl_kernel import merge_state_v2 +from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, +) + @dataclass class FlashAttentionMetadata: @@ -616,9 +621,6 @@ def forward_extend( and not is_swa_layer ) - flash_attn_varlen_func = self.flash_attn_varlen_func - flash_attn_with_kvcache = self.flash_attn_with_kvcache - kwargs = {} if sinks is not None: kwargs["sinks"] = sinks @@ -696,6 +698,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) @@ -723,6 +726,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) @@ -750,6 +754,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=True, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) o, _ = merge_state_v2_wrapper( @@ -789,6 +794,7 @@ def _fa_cp_attn( softmax_scale=layer.scaling, causal=False, return_softmax_lse=True, + ver=self.fa_impl_ver, **kwargs, ) else: @@ -814,6 +820,7 @@ def _fa_cp_attn( softmax_scale=layer.scaling, causal=True, return_softmax_lse=forward_batch.mha_return_lse, + ver=self.fa_impl_ver, **kwargs, ) if forward_batch.mha_return_lse: @@ -822,7 +829,7 @@ def _fa_cp_attn( return output, lse return output else: - assert self.fa_impl_ver in [3], "Only FA3 support here" + assert self.fa_impl_ver == 3, "Only FA3 support here" # Do absorbed multi-latent attention kv_cache = forward_batch.token_to_kv_pool.get_key_buffer( layer.layer_id @@ -865,6 +872,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=self.num_splits, + ver=self.fa_impl_ver, ) if use_cascade_attn: o, softmax_lse, *rest = result @@ -887,6 +895,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=True, num_splits=self.num_splits, + ver=self.fa_impl_ver, ) ) o, _ = merge_state_v2_wrapper( @@ -964,8 +973,6 @@ def forward_decode( if sinks is not None: kwargs["sinks"] = sinks - flash_attn_with_kvcache = self.flash_attn_with_kvcache - k_descale, v_descale = None, None # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # has corresponding quantization method so that layer.k_scale is not None, @@ -1009,6 +1016,7 @@ def forward_decode( k_descale=k_descale, v_descale=v_descale, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) elif use_local_attn: @@ -1029,6 +1037,7 @@ def forward_decode( k_descale=k_descale, v_descale=v_descale, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) else: @@ -1066,6 +1075,7 @@ def forward_decode( v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) if use_cascade_attn: @@ -1088,6 +1098,7 @@ def forward_decode( v_descale=v_descale, return_softmax_lse=True, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) ) @@ -1144,6 +1155,7 @@ def forward_decode( v_descale=v_descale, return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states num_splits=self.num_splits, + ver=self.fa_impl_ver, ) if use_cascade_attn: o, softmax_lse, *rest = result @@ -1165,6 +1177,7 @@ def forward_decode( v_descale=v_descale, return_softmax_lse=True, num_splits=self.num_splits, + ver=self.fa_impl_ver, ) o, _ = merge_state_v2( o, diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 4fe8aec31301..c1e2ea4fcdab 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -596,8 +596,24 @@ def init_forward_metadata_capture_cuda_graph( fast_decode_plan, decode_wrappers[i] ) elif forward_mode.is_target_verify(): + # FlashInfer's prefill wrapper decides mask mode based on whether + # `custom_mask_buf` is initialized (not whether a custom mask is provided). + # For cases like DFLASH draft (ENCODER_ONLY / non-causal) we do NOT use a + # custom mask, so we must avoid initializing `custom_mask_buf`, otherwise + # FlashInfer will treat the (zero) buffer as a real mask and block attention. + use_custom_mask = ( + spec_info is not None + and getattr(spec_info, "custom_mask", None) is not None + ) prefill_wrappers = [] for i in range(self.num_wrappers): + wrapper_kwargs = {} + if use_custom_mask: + wrapper_kwargs = { + "custom_mask_buf": self.cuda_graph_custom_mask, + "mask_indptr_buf": self.cuda_graph_qk_indptr[i][: bs + 1], + } + prefill_wrappers.append( BatchPrefillWithPagedKVCacheWrapper( self.workspace_buffer, @@ -608,8 +624,7 @@ def init_forward_metadata_capture_cuda_graph( paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1], paged_kv_indices_buf=self.cuda_graph_kv_indices[i], paged_kv_last_page_len_buf=self.kv_last_page_len[:bs], - custom_mask_buf=self.cuda_graph_custom_mask, - mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1], + **wrapper_kwargs, ) ) seq_lens_sum = seq_lens.sum().item() @@ -783,10 +798,14 @@ def forward_extend( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) + causal = ( + not layer.is_cross_attention + and layer.attn_type != AttentionType.ENCODER_ONLY + ) o = prefill_wrapper_paged.forward( q.view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=not layer.is_cross_attention, + causal=causal, sm_scale=layer.scaling, # Disable sliding window attention for multi-item scoring: # - Sliding window could cut across item boundaries, breaking semantic coherence @@ -838,11 +857,6 @@ def forward_extend( ) else: - if not self.is_dllm_model: - # TODO: design a better interface - # For other models, use causal attention for the ragged part as previously - causal = True - o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( q.view(-1, layer.tp_q_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim), diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 91194c494396..9d8b5f439561 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -22,13 +22,19 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput -from sglang.srt.utils import is_cpu +from sglang.srt.utils import is_cpu, is_npu if not is_cpu(): from sglang.srt.layers.attention.fla.chunk_delta_h import ( CHUNK_SIZE as FLA_CHUNK_SIZE, ) +if is_npu(): + from sgl_kernel_npu.mamba.mamba_state_update_triton import ( + conv_state_rollback, + move_intermediate_cache, + ) + logger = logging.getLogger(__name__) @@ -142,6 +148,7 @@ def __init__(self, model_runner: ModelRunner): self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool self.forward_metadata: ForwardMetadata = None self.state_indices_list = [] + self.state_indices_list_gdn = [] self.query_start_loc_list = [] self.retrieve_next_token_list = [] self.retrieve_next_sibling_list = [] @@ -217,6 +224,11 @@ def _forward_metadata(self, forward_batch: ForwardBatch): else: raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode=}") + has_mamba_track_mask = bool( + forward_batch.mamba_track_mask is not None + and forward_batch.mamba_track_mask.any() + ) + return ForwardMetadata( query_start_loc=query_start_loc, mamba_cache_indices=mamba_cache_indices, @@ -228,6 +240,7 @@ def _forward_metadata(self, forward_batch: ForwardBatch): track_ssm_h_dst=track_ssm_h_dst, track_ssm_final_src=track_ssm_final_src, track_ssm_final_dst=track_ssm_final_dst, + has_mamba_track_mask=has_mamba_track_mask, ) def init_forward_metadata(self, forward_batch: ForwardBatch): @@ -409,6 +422,14 @@ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): (i + 1,), self.pad_slot_id, dtype=torch.int32, device=self.device ) ) + self.state_indices_list_gdn.append( + torch.full( + ((i + 1) * draft_token_num,), + self.pad_slot_id, + dtype=torch.int32, + device=self.device, + ) + ) self.query_start_loc_list.append( torch.zeros((i + 2,), dtype=torch.int32, device=self.device) ) @@ -462,6 +483,8 @@ def _capture_metadata( forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): + mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices) + self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices) if forward_mode.is_decode_or_idle(): self.query_start_loc_list[bs - 1].copy_( self.cached_cuda_graph_decode_query_start_loc[: bs + 1] @@ -470,10 +493,14 @@ def _capture_metadata( self.query_start_loc_list[bs - 1].copy_( self.cached_cuda_graph_verify_query_start_loc[: bs + 1] ) + ssm_state_indices = torch.arange( + mamba_indices.shape[0] * spec_info.draft_token_num, dtype=torch.int32, device=mamba_indices.device + ) + self.state_indices_list_gdn[bs - 1][ + : len(mamba_indices) * spec_info.draft_token_num + ].copy_(ssm_state_indices) else: raise ValueError(f"Invalid forward mode: {forward_mode=}") - mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices) - self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices) # If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask if forward_mode.is_target_verify() and spec_info.topk > 1: @@ -491,6 +518,7 @@ def _capture_metadata( return ForwardMetadata( query_start_loc=self.query_start_loc_list[bs - 1], mamba_cache_indices=self.state_indices_list[bs - 1], + mamba_cache_indices_gdn=self.state_indices_list_gdn[bs - 1], ) def _replay_metadata( @@ -507,7 +535,7 @@ def _replay_metadata( # Make sure forward metadata is correctly handled for padding reqs req_pool_indices[bs - num_padding :] = 0 mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices) - mamba_indices[bs - num_padding :] = -1 + mamba_indices[bs - num_padding :] = 0 self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices) if forward_mode.is_decode_or_idle(): if num_padding == 0: @@ -522,6 +550,16 @@ def _replay_metadata( bs - num_padding ) elif forward_mode.is_target_verify(): + ssm_state_indices = torch.arange( + len(mamba_indices[:bs - num_padding]) * spec_info.draft_token_num, + dtype=torch.int32, device=mamba_indices.device + ) + self.state_indices_list_gdn[bs - 1][ + : len(mamba_indices[: bs - num_padding]) * spec_info.draft_token_num + ].copy_(ssm_state_indices) + self.state_indices_list_gdn[bs - 1][ + len(mamba_indices[: bs - num_padding]) * spec_info.draft_token_num : + ] = 0 if num_padding == 0: self.query_start_loc_list[bs - 1].copy_( self.cached_cuda_graph_verify_query_start_loc[: bs + 1] @@ -556,10 +594,11 @@ def _replay_metadata( return ForwardMetadata( query_start_loc=self.query_start_loc_list[bs - 1], mamba_cache_indices=self.state_indices_list[bs - 1], + mamba_cache_indices_gdn=self.state_indices_list_gdn[bs - 1], ) def get_cuda_graph_seq_len_fill_value(self): - return 1 # Mamba attn does not use seq lens to index kv cache + return 0 # Mamba attn does not use seq lens to index kv cache def get_cpu_graph_seq_len_fill_value(self): return 1 @@ -613,10 +652,7 @@ def _track_mamba_state_extend( Note: Conv state tracking for extend is handled separately via gather operations using indices computed by `_init_track_conv_indices`. """ - if ( - forward_batch.mamba_track_mask is not None - and forward_batch.mamba_track_mask.any() - ): + if forward_metadata.has_mamba_track_mask: h = h.squeeze(0) if forward_metadata.track_ssm_h_src.numel() > 0: @@ -960,6 +996,26 @@ def update_mamba_state_after_mtp_verify( ssm_states = mamba_caches.temporal intermediate_state_cache = mamba_caches.intermediate_ssm intermediate_conv_window_cache = mamba_caches.intermediate_conv_window[0] + if is_npu(): + dst_indices_tensor = state_indices_tensor.to(torch.int64) # [N] + src_indices_tensor = torch.arange(dst_indices_tensor.shape[0], + device=dst_indices_tensor.device, + dtype=torch.int64) + last_steps = accepted_steps.to(torch.int64) # [N] + + move_intermediate_cache( + ssm_states, intermediate_state_cache, dst_indices_tensor, src_indices_tensor, last_steps + ) + + draft_token_num = intermediate_state_cache.shape[2] + if dst_indices_tensor.numel() > 0: + conv_state_rollback( + conv_states, + dst_indices_tensor, + last_steps, + draft_token_num, + ) + return # Use fully fused kernel that handles masking internally # This avoids separate nonzero() and index_select() calls @@ -992,3 +1048,8 @@ def update_mamba_state_after_mtp_verify( mamba_track_indices, mamba_steps_to_track, ) + + def update_verify_buffers_to_fill_after_draft( + self, spec_info: SpecInput, cuda_graph_bs: Optional[int] + ): + pass diff --git a/python/sglang/srt/layers/attention/linear/gdn_backend.py b/python/sglang/srt/layers/attention/linear/gdn_backend.py index 700ccfdf6aa3..4dad415b1c31 100644 --- a/python/sglang/srt/layers/attention/linear/gdn_backend.py +++ b/python/sglang/srt/layers/attention/linear/gdn_backend.py @@ -256,6 +256,18 @@ def __init__(self, model_runner: ModelRunner): prefill_backend = get_linear_attn_prefill_backend() self.kernel_dispatcher = GDNKernelDispatcher(decode_backend, prefill_backend) + def init_forward_metadata(self, forward_batch: ForwardBatch): + super().init_forward_metadata(forward_batch) + if self.forward_metadata.has_mamba_track_mask: + self.forward_metadata.mamba_track_mask_indices = ( + forward_batch.mamba_track_mask.nonzero(as_tuple=True)[0] + ) + self.forward_metadata.conv_states_mask_indices = ( + forward_batch.mamba_track_indices[ + self.forward_metadata.mamba_track_mask_indices + ] + ) + def forward_decode( self, layer: RadixLinearAttention, @@ -394,16 +406,13 @@ def forward_extend( mixed_qkv = mixed_qkv_processed.transpose(1, 2).view(seq_len, -1) else: mixed_qkv = mixed_qkv.transpose(0, 1) - if ( - forward_batch.mamba_track_mask is not None - and forward_batch.mamba_track_mask.any() - ): - conv_dst = forward_batch.mamba_track_indices + if forward_metadata.has_mamba_track_mask: mixed_qkv_to_track = mixed_qkv[ :, forward_metadata.track_conv_indices ].transpose(0, 1) - mask_indices = forward_batch.mamba_track_mask.nonzero(as_tuple=True)[0] - conv_states[conv_dst[mask_indices]] = mixed_qkv_to_track + conv_states[forward_metadata.conv_states_mask_indices] = ( + mixed_qkv_to_track + ) mixed_qkv = causal_conv1d_fn( mixed_qkv, diff --git a/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py b/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py index 5eeb2b65e307..35d8abaa826c 100644 --- a/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py +++ b/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py @@ -27,6 +27,7 @@ class ForwardMetadata: query_start_loc: torch.Tensor mamba_cache_indices: torch.Tensor + mamba_cache_indices_gdn: Optional[torch.Tensor] = None # For topk > 1 eagle retrieve_next_token: Optional[torch.Tensor] = None retrieve_next_sibling: Optional[torch.Tensor] = None @@ -41,6 +42,10 @@ class ForwardMetadata: is_target_verify: bool = False draft_token_num: int = 1 + has_mamba_track_mask: bool = False + mamba_track_mask_indices: Optional[torch.Tensor] = None + conv_states_mask_indices: Optional[torch.Tensor] = None + @dataclass(kw_only=True) class Mamba2Metadata(ForwardMetadata): diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 02ef4e2440cd..6bfcb3f66852 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -257,13 +257,13 @@ def _weights_proj_bf16_in_fp32_out(self, x: torch.Tensor) -> torch.Tensor: weights, _ = self.weights_proj(x) return weights.float() - @torch.compile(dynamic=True) if not _is_hip else lambda f: f + @torch.compile(dynamic=True) def _project_and_scale_head_gates(self, x: torch.Tensor): weights = self._weights_proj_bf16_in_fp32_out(x) weights = weights * self.n_heads**-0.5 return weights - @torch.compile(dynamic=True) if not _is_hip else lambda f: f + @torch.compile(dynamic=True) def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor): weights = self._weights_proj_bf16_in_fp32_out(x) weights = weights * self.n_heads**-0.5 @@ -318,8 +318,8 @@ def _get_q_k_bf16( q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope) - query[..., : self.rope_head_dim] = q_rope.clone() - key[..., : self.rope_head_dim] = k_rope.clone() + self._update_rope_guarded(query[..., : self.rope_head_dim], q_rope) + self._update_rope_guarded(key[..., : self.rope_head_dim], k_rope) if enable_dual_stream: current_stream = torch.cuda.current_stream() @@ -376,11 +376,19 @@ def _get_k_bf16( ) _, k_rope = self.rotary_emb(positions, k_rope, k_rope) - key[..., : self.rope_head_dim] = k_rope.clone() + self._update_rope_guarded(key[..., : self.rope_head_dim], k_rope) key = rotate_activation(key) return key + @staticmethod + def _update_rope_guarded(dst: torch.Tensor, src: torch.Tensor) -> None: + # On AMD with in-place RoPE kernels, self-aliasing can occur; + # skip write-back when src/dst tensors point to a single memory. + if src.data_ptr() == dst.data_ptr(): + return + dst.copy_(src) + def _get_topk_paged( self, forward_batch: ForwardBatch, diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 862488e5f918..314c897ab313 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -61,7 +61,10 @@ "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." ) else: - from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, + ) # Reuse this workspace buffer across all NSA backend instances diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 23dba24584e9..a624ad06e022 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -38,21 +38,9 @@ if _is_cuda: from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache - try: - from sgl_kernel.flash_attn import flash_attn_varlen_func - - def flash_attn_func(*args, ver: int = 3, **kwargs): - if ver == 4: - from sglang.jit_kernel.flash_attention_v4 import ( - flash_attn_varlen_func as flash_attn_varlen_func_fa4, - ) - - return flash_attn_varlen_func_fa4(*args, **kwargs) - return flash_attn_varlen_func(*args, **kwargs) - - except ImportError as e: - raise e - + from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + ) if _is_npu: import torch_npu @@ -420,7 +408,7 @@ def forward( """ if envs.SGLANG_VIT_ENABLE_CUDA_GRAPH.get(): max_seqlen = cu_seqlens[1] - output = flash_attn_func( + output = flash_attn_varlen_func( q, k, v, @@ -436,7 +424,7 @@ def forward( seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] max_seqlen = seq_lens.max().item() - output = flash_attn_func( + output = flash_attn_varlen_func( q, k, v, @@ -489,7 +477,7 @@ def forward( seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] max_seqlen = seq_lens.max().item() - output = flash_attn_func( + output = flash_attn_varlen_func( q, k, v, diff --git a/python/sglang/srt/layers/attention/xpu_backend.py b/python/sglang/srt/layers/attention/xpu_backend.py index 4a40d25ee8c9..77e773d88d0c 100644 --- a/python/sglang/srt/layers/attention/xpu_backend.py +++ b/python/sglang/srt/layers/attention/xpu_backend.py @@ -20,7 +20,11 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sgl_kernel import merge_state_v2 -from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + +from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, +) class XPUAttentionBackend(AttentionBackend): diff --git a/python/sglang/srt/layers/fused_sampling.py b/python/sglang/srt/layers/fused_sampling.py new file mode 100644 index 000000000000..55852c2f0f34 --- /dev/null +++ b/python/sglang/srt/layers/fused_sampling.py @@ -0,0 +1,371 @@ +"""Fused Triton kernels for the sampling pipeline. + +Fuses temperature scaling + softmax into a single kernel to reduce +kernel launch overhead and global memory traffic during decode. + +Two kernel variants: + - Single-pass: vocab fits in one tile (1 read + 1 write). Used when + next_power_of_2(vocab) <= 32768. + - Multi-pass: 2-pass online softmax with autotune (2 reads + 1 write). + Used for large vocabs (e.g. 128K+). +""" + +import logging + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + +_MAX_SINGLE_PASS_BLOCK = 32768 + +# --------------------------------------------------------------------------- +# Single-pass kernel: entire vocab fits in one BLOCK_SIZE tile. +# Data stays in registers — only 1 global memory read + 1 write. +# --------------------------------------------------------------------------- + + +@triton.jit +def _single_pass_temperature_softmax_kernel( + logits_ptr, + temperatures_ptr, + output_ptr, + vocab_size, + logits_stride, + output_stride, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + temp = tl.load(temperatures_ptr + row_idx) + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + + x = tl.load( + logits_ptr + row_idx * logits_stride + offsets, + mask=mask, + other=float("-inf"), + ) + x = (x / temp).to(tl.float32) + + x_max = tl.max(x, axis=0) + exp_x = tl.exp(x - x_max) + prob = exp_x / tl.sum(exp_x, axis=0) + + tl.store(output_ptr + row_idx * output_stride + offsets, prob, mask=mask) + + +@triton.jit +def _single_pass_temperature_softmax_inplace_kernel( + logits_ptr, + temperatures_ptr, + vocab_size, + stride, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + temp = tl.load(temperatures_ptr + row_idx) + + row_start = logits_ptr + row_idx * stride + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + + x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + + x_max = tl.max(x, axis=0) + exp_x = tl.exp(x - x_max) + prob = exp_x / tl.sum(exp_x, axis=0) + + tl.store(row_start + offsets, prob, mask=mask) + + +# --------------------------------------------------------------------------- +# Multi-pass kernel: vocab too large for one tile. +# 2-pass online softmax with autotune over (BLOCK_SIZE, num_warps). +# --------------------------------------------------------------------------- + +_MULTI_PASS_AUTOTUNE_CONFIGS = [ + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=16), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=16), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=16, num_stages=4), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=32), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4), + triton.Config({"BLOCK_SIZE": 16384}, num_warps=16), + triton.Config({"BLOCK_SIZE": 16384}, num_warps=32), + triton.Config({"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4), + triton.Config({"BLOCK_SIZE": 32768}, num_warps=32), + triton.Config({"BLOCK_SIZE": 32768}, num_warps=32, num_stages=4), +] + + +@triton.autotune(configs=_MULTI_PASS_AUTOTUNE_CONFIGS, key=["vocab_size"]) +@triton.jit +def _multi_pass_temperature_softmax_kernel( + logits_ptr, + temperatures_ptr, + output_ptr, + vocab_size, + logits_stride, + output_stride, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + temp = tl.load(temperatures_ptr + row_idx) + + logits_row = logits_ptr + row_idx * logits_stride + output_row = output_ptr + row_idx * output_stride + + # Pass 1: find global max (matches PyTorch's first reduction pass) + global_max = tl.full([], value=float("-inf"), dtype=tl.float32) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(logits_row + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + global_max = tl.maximum(global_max, tl.max(x, axis=0)) + + # Pass 2: compute sum of exp(x - max) (matches PyTorch's second pass) + sum_exp = tl.full([], value=0.0, dtype=tl.float32) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(logits_row + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + sum_exp += tl.sum(tl.exp(x - global_max), axis=0) + + # Pass 3: normalize (matches PyTorch's exp(x-max)/sum) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(logits_row + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + + prob = tl.exp(x - global_max) / sum_exp + tl.store(output_row + offsets, prob, mask=mask) + + +@triton.jit +def _multi_pass_temperature_softmax_inplace_kernel( + logits_ptr, + temperatures_ptr, + vocab_size, + stride, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + temp = tl.load(temperatures_ptr + row_idx) + + row_start = logits_ptr + row_idx * stride + + # Pass 1: find global max (matches PyTorch's first reduction pass) + global_max = tl.full([], value=float("-inf"), dtype=tl.float32) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + global_max = tl.maximum(global_max, tl.max(x, axis=0)) + + # Pass 2: compute sum of exp(x - max) (matches PyTorch's second pass) + sum_exp = tl.full([], value=0.0, dtype=tl.float32) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + sum_exp += tl.sum(tl.exp(x - global_max), axis=0) + + # Pass 3: normalize (matches PyTorch's exp(x-max)/sum) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + + prob = tl.exp(x - global_max) / sum_exp + tl.store(row_start + offsets, prob, mask=mask) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +_DEFAULT_MULTI_PASS_CONFIG = {"BLOCK_SIZE": 4096, "num_warps": 16} + +# Populated by warmup from the out-of-place kernel's autotune result. +_multi_pass_inplace_config: dict | None = None + + +def _single_pass_num_warps(block_size: int) -> int: + return max(4, min(32, block_size // 256)) + + +def _get_multi_pass_inplace_config() -> dict: + """Return the launch config for the multi-pass in-place kernel.""" + if _multi_pass_inplace_config is not None: + return _multi_pass_inplace_config + return _DEFAULT_MULTI_PASS_CONFIG + + +def _dispatch_kernel( + logits: torch.Tensor, + temperatures_flat: torch.Tensor, + vocab_size: int, + batch_size: int, + output: torch.Tensor = None, +) -> None: + """Dispatch to single-pass or multi-pass kernel. output=None means in-place.""" + grid = (batch_size,) + block_size = triton.next_power_of_2(vocab_size) + inplace = output is None + + if block_size <= _MAX_SINGLE_PASS_BLOCK: + if inplace: + _single_pass_temperature_softmax_inplace_kernel[grid]( + logits, + temperatures_flat, + vocab_size, + logits.stride(0), + BLOCK_SIZE=block_size, + num_warps=_single_pass_num_warps(block_size), + ) + else: + _single_pass_temperature_softmax_kernel[grid]( + logits, + temperatures_flat, + output, + vocab_size, + logits.stride(0), + output.stride(0), + BLOCK_SIZE=block_size, + num_warps=_single_pass_num_warps(block_size), + ) + else: + if inplace: + cfg = _get_multi_pass_inplace_config() + _multi_pass_temperature_softmax_inplace_kernel[grid]( + logits, + temperatures_flat, + vocab_size, + logits.stride(0), + **cfg, + ) + else: + _multi_pass_temperature_softmax_kernel[grid]( + logits, + temperatures_flat, + output, + vocab_size, + logits.stride(0), + output.stride(0), + ) + + +def fused_temperature_softmax( + logits: torch.Tensor, + temperatures: torch.Tensor, +) -> torch.Tensor: + """Fused temperature scaling + softmax. Returns float32 probabilities.""" + batch_size, vocab_size = logits.shape + if batch_size == 0: + return torch.empty(0, vocab_size, dtype=torch.float32, device=logits.device) + + if not logits.is_contiguous(): + logits = logits.contiguous() + + output = torch.empty( + batch_size, vocab_size, dtype=torch.float32, device=logits.device + ) + temperatures_flat = temperatures.contiguous().view(-1) + _dispatch_kernel(logits, temperatures_flat, vocab_size, batch_size, output) + return output + + +def fused_temperature_softmax_inplace( + logits: torch.Tensor, + temperatures: torch.Tensor, +) -> None: + """In-place fused temperature scaling + softmax. Overwrites logits with probabilities.""" + batch_size, vocab_size = logits.shape + if batch_size == 0: + return + + if not logits.is_contiguous(): + work = logits.contiguous() + fused_temperature_softmax_inplace(work, temperatures) + logits.copy_(work) + return + + temperatures_flat = temperatures.contiguous().view(-1) + _dispatch_kernel(logits, temperatures_flat, vocab_size, batch_size) + + +def warmup_fused_temperature_softmax( + vocab_size: int, + device: torch.device | int | None = None, + logits_dtype: torch.dtype = torch.float32, +) -> None: + """Pre-compile and autotune kernels at startup so first request has no latency spike. + + For multi-pass kernels the out-of-place variant is autotuned (safe — separate + input/output buffers), and its winning config is reused for the in-place + variant so that no autotune ever runs on a live logits buffer. + + ``logits_dtype`` should match ``next_token_logits`` at inference (usually + ``model_config.dtype``) so Triton specializes the same way as in production. + """ + global _multi_pass_inplace_config + + if device is None: + device = torch.cuda.current_device() + + block_size = triton.next_power_of_2(vocab_size) + is_multi_pass = block_size > _MAX_SINGLE_PASS_BLOCK + label = "multi-pass autotune" if is_multi_pass else "single-pass JIT" + logger.info( + "Warming up fused_temperature_softmax (%s, vocab_size=%d, logits_dtype=%s) ...", + label, + vocab_size, + logits_dtype, + ) + + dummy_logits = torch.randn(1, vocab_size, dtype=logits_dtype, device=device) + dummy_temps = torch.ones(1, 1, dtype=torch.float32, device=device) + + # 1. Out-of-place kernel: autotune runs here (safe, separate buffers). + fused_temperature_softmax(dummy_logits, dummy_temps) + + # 2. Propagate best config to the in-place kernel (no autotune needed). + if is_multi_pass: + best = getattr(_multi_pass_temperature_softmax_kernel, "best_config", None) + if best is not None: + _multi_pass_inplace_config = { + "BLOCK_SIZE": best.kwargs["BLOCK_SIZE"], + "num_warps": best.num_warps, + } + if best.num_stages is not None: + _multi_pass_inplace_config["num_stages"] = best.num_stages + ns = _multi_pass_inplace_config.get("num_stages", "default") + logger.info( + "Multi-pass autotune result: BLOCK_SIZE=%d, num_warps=%d, num_stages=%s", + _multi_pass_inplace_config["BLOCK_SIZE"], + _multi_pass_inplace_config["num_warps"], + ns, + ) + else: + _multi_pass_inplace_config = None + logger.warning( + "Multi-pass fused softmax: autotune did not set best_config; " + "using default launch config for in-place kernel." + ) + + # 3. In-place kernel: JIT compile only (uses the config from step 2). + fused_temperature_softmax_inplace(dummy_logits.clone(), dummy_temps) + torch.cuda.synchronize(device) + + logger.info("fused_temperature_softmax warmup done (vocab_size=%d).", vocab_size) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 0db6675e648f..db831f5805bf 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -84,6 +84,7 @@ if _is_npu: import torch_npu + from sgl_kernel_npu.norm.add_rmsnorm_bias import add_gemma_rms_norm def _forward_with_allreduce_fusion( @@ -580,11 +581,13 @@ def forward_npu( if residual is not None: if post_residual_addition is not None: residual = residual + post_residual_addition - x = x + residual - residual = x + norm_out, residual = add_gemma_rms_norm( + x, self.weight, residual, self.variance_epsilon + ) + return norm_out, residual x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon) - return x if residual is None else (x, residual) + return x def forward_xpu( self, diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 7af9eb004008..ff959ada9a65 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -10,7 +10,6 @@ from torch import nn from torch.nn.parameter import Parameter, UninitializedParameter -from sglang.kernel_api_logging import wrap_method_with_debug_kernel_once from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, @@ -177,13 +176,6 @@ def __init__( else: self.quant_method = quant_config.get_quant_method(self, prefix=prefix) - if self.quant_method is not None: - wrap_method_with_debug_kernel_once( - self.quant_method, - "apply", - op_name=f"sglang.quant_method.{self.quant_method.__class__.__name__}.apply", - ) - def forward(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -539,15 +531,8 @@ def weight_loader( self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: tuple[int, ...] | int | None = None, + loaded_shard_id: Optional[int] = None, ): - if isinstance(loaded_shard_id, tuple): - if hasattr(param, "load_merged_column_weight"): - return self.weight_loader_v2(param, loaded_weight, loaded_shard_id) - raise NotImplementedError( - "Shard id with multiple indices is not supported in weight_loader, " - "please use weight_loader_v2 instead." - ) # Special case for GGUF # initialize GGUF param after we know the quantize type @@ -714,10 +699,7 @@ def weight_loader( param_data.copy_(loaded_weight) def _load_fused_module_from_checkpoint( - self, - param: BasevLLMParameter, - loaded_weight: torch.Tensor, - output_sizes: list[int] | None = None, + self, param: BasevLLMParameter, loaded_weight: torch.Tensor ): """ Handle special case for models where MLP layers are already @@ -731,8 +713,7 @@ def _load_fused_module_from_checkpoint( current_shard_offset = 0 shard_offsets: List[Tuple[int, int, int]] = [] - output_sizes = output_sizes or self.output_sizes - for i, output_size in enumerate(output_sizes): + for i, output_size in enumerate(self.output_sizes): shard_offsets.append((i, current_shard_offset, output_size)) current_shard_offset += output_size @@ -802,9 +783,9 @@ def weight_loader_v2( self, param: BasevLLMParameter, loaded_weight: torch.Tensor, - loaded_shard_id: tuple[int, ...] | int | None = None, + loaded_shard_id: Optional[int] = None, ): - if loaded_shard_id is None or isinstance(loaded_shard_id, tuple): + if loaded_shard_id is None: if isinstance(param, PerTensorScaleParameter): param.load_merged_column_weight( loaded_weight=loaded_weight, @@ -823,15 +804,8 @@ def weight_loader_v2( tp_size=self.tp_size, ) return - output_sizes = ( - [self.output_sizes[idx] for idx in loaded_shard_id] - if loaded_shard_id - else None - ) # TODO: @dsikka - move to parameter.py - self._load_fused_module_from_checkpoint( - param, loaded_weight, output_sizes=output_sizes - ) + self._load_fused_module_from_checkpoint(param, loaded_weight) return assert loaded_shard_id < len(self.output_sizes) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py index 4410f07f327e..b1bb618ce5ce 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py @@ -23,6 +23,13 @@ def get_scalar_type(num_bits: int, has_zp: bool): return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 +def _get_fp4_scalar_type(): + from sglang.srt.layers.quantization.utils import get_scalar_types + + _, scalar_types = get_scalar_types() + return scalar_types.float4_e2m1f + + @register_custom_op(out_shape="hidden_states") def fused_marlin_moe( hidden_states: torch.Tensor, @@ -46,6 +53,8 @@ def fused_marlin_moe( is_k_full: bool = True, inplace: bool = False, routed_scaling_factor: Optional[float] = None, + w1_global_scale: Optional[torch.Tensor] = None, + w2_global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -76,6 +85,13 @@ def fused_marlin_moe( """ from sglang.srt.layers.moe.fused_moe_triton import moe_align_block_size + # Detect FP4 Marlin mode (when global scales are provided) + _is_fp4_marlin = w1_global_scale is not None + if _is_fp4_marlin: + assert ( + w2_global_scale is not None + ), "Both w1_global_scale and w2_global_scale must be provided for FP4 Marlin mode" + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" assert hidden_states.shape[1] == w2.shape[2] // ( @@ -85,12 +101,14 @@ def fused_marlin_moe( assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float16, torch.bfloat16] - assert ( - hidden_states.dtype == w1_scale.dtype - ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})" - assert ( - hidden_states.dtype == w2_scale.dtype - ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})" + # For FP4 Marlin, scales are in special float8_e4m3fn format (not input dtype) + if not _is_fp4_marlin: + assert ( + hidden_states.dtype == w1_scale.dtype + ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})" + assert ( + hidden_states.dtype == w2_scale.dtype + ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})" assert num_bits in [4, 8] M, K = hidden_states.shape @@ -121,8 +139,13 @@ def fused_marlin_moe( max_workspace_size, dtype=torch.int, device=device, requires_grad=False ) - scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) - scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) + # FP4 Marlin uses float4_e2m1f scalar type (not uint4b8/uint8b128) + if _is_fp4_marlin: + scalar_type1 = _get_fp4_scalar_type() + scalar_type2 = _get_fp4_scalar_type() + else: + scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) + scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -150,7 +173,7 @@ def fused_marlin_moe( w1, None, # b_bias_or_none w1_scale, - None, # global_scale_or_none + w1_global_scale, # None for INT4/INT8, tensor for FP4 Marlin w1_zeros, g_idx1, sort_indices1, @@ -184,7 +207,7 @@ def fused_marlin_moe( w2, None, # b_bias_or_none w2_scale, - None, # global_scale_or_none + w2_global_scale, # None for INT4/INT8, tensor for FP4 Marlin w2_zeros, g_idx2, sort_indices2, diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index e568579bb7e8..f4add35b391d 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -616,13 +616,17 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( dispatch_output: StandardDispatchOutput, quant_info: FlashInferTrtllmFp4MoeQuantInfo, runner_config: MoeRunnerConfig, + use_routed_topk: bool = False, ) -> StandardCombineInput: """FlashInfer TRTLLM FP4 MoE forward pass. This function handles the FP4 TRTLLM MoE path that was previously in ModelOptNvFp4FusedMoEMethod.apply. """ - from flashinfer.fused_moe import trtllm_fp4_block_scale_moe + from flashinfer.fused_moe import ( + trtllm_fp4_block_scale_moe, + trtllm_fp4_block_scale_routed_moe, + ) from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput from sglang.srt.layers.moe.topk import TopKOutputChecker @@ -633,25 +637,13 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( hidden_states = dispatch_output.hidden_states topk_output = dispatch_output.topk_output - assert TopKOutputChecker.format_is_bypassed(topk_output) - - router_logits = topk_output.router_logits - topk_config = topk_output.topk_config - routing_method_type = quant_info.routing_method_type # Quantize hidden states to FP4 hs_fp4, hs_scale_linear = quantize_hidden_states_fp4( hidden_states, quant_info.w13_input_scale_quant ) - - # DeepSeekV3 style routing requires float32 router logits - if routing_method_type == RoutingMethodType.DeepSeekV3: - router_logits = router_logits.to(torch.float32) - - correction_bias = ( - None - if topk_config.correction_bias is None - else topk_config.correction_bias.to(hidden_states.dtype) + hs_scale = hs_scale_linear.view(torch.float8_e4m3fn).reshape( + *hs_scale_linear.shape[:-1], -1 ) with use_symmetric_memory(get_tp_group(), disabled=not is_allocation_symmetric()): @@ -660,49 +652,103 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( hs_fp4.shape[-1] * 2 if hs_fp4.dtype == torch.uint8 else hs_fp4.shape[-1] ) symm_output = torch.empty( - num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device + num_tokens, hidden_size, dtype=hidden_states.dtype, device=hs_fp4.device ) - result = trtllm_fp4_block_scale_moe( - routing_logits=router_logits, - routing_bias=correction_bias, - hidden_states=hs_fp4, - hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).reshape( - *hs_scale_linear.shape[:-1], -1 - ), - gemm1_weights=quant_info.gemm1_weights_fp4_shuffled, - gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), - gemm1_bias=None, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=quant_info.gemm2_weights_fp4_shuffled, - gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), - gemm2_bias=None, - output1_scale_scalar=quant_info.g1_scale_c, - output1_scale_gate_scalar=quant_info.g1_alphas, - output2_scale_scalar=quant_info.g2_alphas, - num_experts=quant_info.global_num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=quant_info.intermediate_size_per_partition, - local_expert_offset=quant_info.local_expert_offset, - local_num_experts=quant_info.local_num_experts, - routed_scaling_factor=runner_config.routed_scaling_factor, - routing_method_type=( - routing_method_type - if routing_method_type is not None - else RoutingMethodType.Default - ), - do_finalize=True, - tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), - output=symm_output, - )[0] + if use_routed_topk: + assert TopKOutputChecker.format_is_standard(topk_output) + + packed_topk_ids = _pack_topk_for_flashinfer_routed( + topk_output.topk_ids, topk_output.topk_weights + ) + result = trtllm_fp4_block_scale_routed_moe( + topk_ids=packed_topk_ids, + routing_bias=None, + hidden_states=hs_fp4, + hidden_states_scale=hs_scale, + gemm1_weights=quant_info.gemm1_weights_fp4_shuffled, + gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view( + torch.float8_e4m3fn + ), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=quant_info.gemm2_weights_fp4_shuffled, + gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view( + torch.float8_e4m3fn + ), + gemm2_bias=None, + output1_scale_scalar=quant_info.g1_scale_c, + output1_scale_gate_scalar=quant_info.g1_alphas, + output2_scale_scalar=quant_info.g2_alphas, + num_experts=quant_info.global_num_experts, + top_k=topk_output.topk_ids.shape[1], + n_group=0, + topk_group=0, + intermediate_size=quant_info.intermediate_size_per_partition, + local_expert_offset=quant_info.local_expert_offset, + local_num_experts=quant_info.local_num_experts, + routed_scaling_factor=None, + routing_method_type=1, # Unused, but must be 1 to pass validation. + do_finalize=True, + tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), + output=symm_output, + )[0] + else: + assert TopKOutputChecker.format_is_bypassed(topk_output) + + router_logits = topk_output.router_logits + topk_config = topk_output.topk_config + routing_method_type = quant_info.routing_method_type + + # DeepSeekV3 style routing requires float32 router logits + if routing_method_type == RoutingMethodType.DeepSeekV3: + router_logits = router_logits.to(torch.float32) + + correction_bias = ( + None + if topk_config.correction_bias is None + else topk_config.correction_bias.to(hidden_states.dtype) + ) + result = trtllm_fp4_block_scale_moe( + routing_logits=router_logits, + routing_bias=correction_bias, + hidden_states=hs_fp4, + hidden_states_scale=hs_scale, + gemm1_weights=quant_info.gemm1_weights_fp4_shuffled, + gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view( + torch.float8_e4m3fn + ), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=quant_info.gemm2_weights_fp4_shuffled, + gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view( + torch.float8_e4m3fn + ), + gemm2_bias=None, + output1_scale_scalar=quant_info.g1_scale_c, + output1_scale_gate_scalar=quant_info.g1_alphas, + output2_scale_scalar=quant_info.g2_alphas, + num_experts=quant_info.global_num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=quant_info.intermediate_size_per_partition, + local_expert_offset=quant_info.local_expert_offset, + local_num_experts=quant_info.local_num_experts, + routed_scaling_factor=runner_config.routed_scaling_factor, + routing_method_type=( + routing_method_type + if routing_method_type is not None + else RoutingMethodType.Default + ), + do_finalize=True, + tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), + output=symm_output, + )[0] return StandardCombineInput(hidden_states=result) @@ -858,6 +904,13 @@ def fused_experts_none_to_flashinfer_trtllm_routed( quant_info: MoeQuantInfo, runner_config: MoeRunnerConfig, ) -> StandardCombineInput: + if isinstance(quant_info, FlashInferTrtllmFp4MoeQuantInfo): + return fused_experts_none_to_flashinfer_trtllm_fp4( + dispatch_output, + quant_info, + runner_config, + use_routed_topk=True, + ) if isinstance(quant_info, FlashInferTrtllmFp8MoeQuantInfo): return fused_experts_none_to_flashinfer_trtllm_fp8( dispatch_output, diff --git a/python/sglang/srt/layers/moe/moe_runner/marlin.py b/python/sglang/srt/layers/moe/moe_runner/marlin.py index 45104dd27805..429b28697d23 100644 --- a/python/sglang/srt/layers/moe/moe_runner/marlin.py +++ b/python/sglang/srt/layers/moe/moe_runner/marlin.py @@ -69,8 +69,13 @@ class MarlinMoeQuantInfo(MoeQuantInfo): w13_qzeros: Optional[torch.Tensor] = None w2_qzeros: Optional[torch.Tensor] = None - # Optional + # FP4 Marlin specific (Optional) + w13_global_scale: Optional[torch.Tensor] = None + w2_global_scale: Optional[torch.Tensor] = None + + # EP support (Optional) expert_map: Optional[torch.Tensor] = None + global_num_experts: int = -1 @register_fused_func("none", "marlin") @@ -106,6 +111,7 @@ def fused_experts_none_to_marlin( gating_output=topk_output.router_logits, topk_weights=topk_output.topk_weights, topk_ids=topk_output.topk_ids, + global_num_experts=quant_info.global_num_experts, expert_map=quant_info.expert_map, g_idx1=quant_info.w13_g_idx, g_idx2=quant_info.w2_g_idx, @@ -118,6 +124,8 @@ def fused_experts_none_to_marlin( is_k_full=quant_info.is_k_full, inplace=runner_config.inplace, routed_scaling_factor=runner_config.routed_scaling_factor, + w1_global_scale=quant_info.w13_global_scale, + w2_global_scale=quant_info.w2_global_scale, ).to(hidden_states.dtype) return StandardCombineInput( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 477339a54fa6..872daa191b18 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -16,6 +16,10 @@ CompressedTensorsLinearScheme, ) from sglang.srt.layers.quantization.fp4_utils import get_fp4_gemm_runner_backend +from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + should_use_fp4_marlin_fallback, +) from sglang.srt.layers.quantization.modelopt_quant import ( enable_flashinfer_fp4_gemm, fp4_gemm, @@ -34,7 +38,7 @@ def __init__(self): @classmethod def get_min_capability(cls) -> int: - return 100 + return 75 # SM75+ (Turing) supports Marlin FP4 fallback; SM100 for native FP4 def create_weights( self, @@ -47,6 +51,7 @@ def create_weights( ): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes + layer.params_dtype = params_dtype layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition @@ -91,6 +96,20 @@ def create_weights( layer.register_parameter("input_global_scale", input_global_scale) def process_weights_after_loading(self, layer) -> None: + if should_use_fp4_marlin_fallback(): + # Marlin FP4 fallback: consolidate global scale then repack weights + global_scale = layer.weight_global_scale.max().to(torch.float32) + layer.weight_global_scale = Parameter(global_scale, requires_grad=False) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight_packed", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_global_scale", + ) + layer.use_marlin_fallback = True + return + + layer.use_marlin_fallback = False global_input_scale = layer.input_global_scale.max().to(torch.float32) layer.input_global_scale = Parameter(global_input_scale, requires_grad=False) @@ -136,6 +155,18 @@ def apply_weights( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if layer.use_marlin_fallback: + return torch.ops.sglang.apply_fp4_marlin_linear( + input=x, + weight=layer.weight_packed, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_global_scale, + workspace=layer.marlin_workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + output_dtype = x.dtype w_n, _ = layer.weight_packed.shape output_shape = [x.shape[0], w_n] diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py index 5898a078dbba..7824a3bcce60 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py @@ -17,6 +17,10 @@ CompressedTensorsMoEScheme, ) from sglang.srt.layers.quantization.fp8_utils import is_blackwell_supported +from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin, + should_use_fp4_marlin_fallback, +) from sglang.srt.layers.quantization.utils import ( prepare_static_weights_for_trtllm_fp4_moe, reorder_w1w3_to_w3w1, @@ -38,19 +42,27 @@ class CompressedTensorsW4A4Nvfp4MoE(CompressedTensorsMoEScheme): def __init__(self): - if not is_blackwell_supported(): + self.group_size = 16 + + if should_use_fp4_marlin_fallback(): + logger.warning_once( + "GPU is not Blackwell (SM100+). Using Marlin FP4 fallback kernel " + "for MoE layers. Weights remain compressed in FP4 format." + ) + self.use_marlin_fallback = True + self.use_flashinfer_trtllm = False + elif not is_blackwell_supported(): raise ValueError( "Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above." + " quantization. Please use SM75+ (Turing or newer)." ) - self.group_size = 16 - self.use_flashinfer_trtllm = get_moe_runner_backend().is_flashinfer_trtllm() + else: + self.use_marlin_fallback = False + self.use_flashinfer_trtllm = get_moe_runner_backend().is_flashinfer_trtllm() @classmethod def get_min_capability(cls) -> int: - # Requires sm100(blackwell) architecture - return 100 + return 75 # SM75+ (Turing) supports Marlin FP4 fallback; SM100 for native FP4 def create_weights( self, @@ -64,6 +76,7 @@ def create_weights( from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported layer.params_dtype = params_dtype + layer.intermediate_size_per_partition = intermediate_size_per_partition w13_weight = torch.nn.Parameter( torch.empty( @@ -175,6 +188,21 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) delattr(layer, "w2_weight_packed") + if self.use_marlin_fallback: + # CompressedTensors checkpoint: global_scale is stored as the inverse. + # Actual dequant scale = 1 / stored_value. We create w*_weight_scale_2 + # with the actual scale before calling prepare_moe_fp4_layer_for_marlin(). + layer.w13_weight_scale_2 = torch.nn.Parameter( + (1.0 / layer.w13_weight_global_scale).to(layer.params_dtype), + requires_grad=False, + ) # [E, 2] + layer.w2_weight_scale_2 = torch.nn.Parameter( + (1.0 / layer.w2_weight_global_scale).to(layer.params_dtype), + requires_grad=False, + ) # [E] + prepare_moe_fp4_layer_for_marlin(layer) + return + if self.use_flashinfer_trtllm: w, s = reorder_w1w3_to_w3w1( layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2 @@ -303,7 +331,10 @@ def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): self.moe_runner_config = moe_runner_config - self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + if self.use_marlin_fallback: + self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config) + else: + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) def apply_weights( self, @@ -313,6 +344,33 @@ def apply_weights( from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + if self.use_marlin_fallback: + from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo + + expert_map = None + global_num_experts = -1 + if hasattr(layer, "dispatcher") and hasattr( + layer.dispatcher, "local_expert_mapping" + ): + expert_map = layer.dispatcher.local_expert_mapping + if expert_map is not None: + global_num_experts = self.moe_runner_config.num_experts + + quant_info = MarlinMoeQuantInfo( + w13_qweight=layer.w13_weight, + w2_qweight=layer.w2_weight, + w13_scales=layer.w13_weight_scale, + w2_scales=layer.w2_weight_scale, + w13_g_idx_sort_indices=None, + w2_g_idx_sort_indices=None, + weight_bits=4, + w13_global_scale=layer.w13_weight_scale_2, + w2_global_scale=layer.w2_weight_scale_2, + expert_map=expert_map, + global_num_experts=global_num_experts, + ) + return self.runner.run(dispatch_output, quant_info) + x = dispatch_output.hidden_states topk_output = dispatch_output.topk_output diff --git a/python/sglang/srt/layers/quantization/marlin_utils_fp4.py b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py new file mode 100644 index 000000000000..5a9fb3cef84b --- /dev/null +++ b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py @@ -0,0 +1,320 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py + +"""NVFP4 Marlin fallback: run FP4-quantized models on non-Blackwell GPUs via Marlin kernel.""" + +import logging +from typing import Optional + +import torch + +from sglang.srt.layers.quantization.marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, + marlin_make_workspace, + marlin_permute_bias, + marlin_permute_scales, + should_use_atomic_add_reduce, +) +from sglang.srt.layers.quantization.utils import get_scalar_types +from sglang.srt.utils import direct_register_custom_op, get_device_capability, is_cuda + +_is_cuda = is_cuda() +if _is_cuda: + from sglang.jit_kernel.gptq_marlin import gptq_marlin_gemm + from sglang.jit_kernel.gptq_marlin_repack import gptq_marlin_repack + +ScalarType, scalar_types = get_scalar_types() + +logger = logging.getLogger(__name__) + +# NVFP4 always uses group_size=16 +FP4_MARLIN_GROUP_SIZE = 16 + + +def is_fp4_marlin_supported() -> bool: + """Check if the current GPU supports FP4 Marlin fallback (CUDA SM >= 75).""" + if not _is_cuda: + return False + if torch.version.hip is not None: + return False + major, minor = get_device_capability() + if major is None or minor is None: + return False + return (major * 10 + minor) >= 75 + + +def should_use_fp4_marlin_fallback() -> bool: + """True if non-Blackwell (or forced) AND Marlin kernel available (SM >= 75).""" + from sglang.srt.environ import envs + from sglang.srt.layers.quantization.fp8_utils import is_blackwell_supported + + force = envs.SGLANG_FORCE_NVFP4_MARLIN.get() + return (force or not is_blackwell_supported()) and is_fp4_marlin_supported() + + +def nvfp4_marlin_process_scales(marlin_scales: torch.Tensor) -> torch.Tensor: + """Convert NVFP4 scales from FP8-S1E4M3 to FP8-S0E5M3 format for Marlin. + + The int16 <<1 may wrap for large scales (e.g. 448*128=57344), but the BIT + PATTERN is preserved correctly — the kernel reads raw bytes, not int16 values. + """ + marlin_scales = marlin_scales.to(torch.half) + + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes scales >= 0, but encountered negative scales. " + "Accuracy may be degraded. The scales are converted from FP8-S1E4M3 " + "to a special FP8-S0E5M3 format to speed up dequantization." + ) + + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1 + ) + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def nvfp4_marlin_process_global_scale(global_scale: torch.Tensor) -> torch.Tensor: + """Pre-adjust global scale with FP4/FP16/BF16 exponent bias for Marlin kernel.""" + assert global_scale.dtype in [ + torch.half, + torch.bfloat16, + ], f"global_scale dtype must be half or bfloat16, got {global_scale.dtype}" + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp4_exponent - 1) + return global_scale * (2.0 ** (exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_global_scale: Optional[torch.Tensor], + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + """Apply FP4-quantized linear via Marlin kernel (non-Blackwell fallback).""" + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype, + ) + + output = gptq_marlin_gemm( + a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_global_scale.reshape(-1), + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + ) + + if bias is not None: + output.add_(bias) + + return output.reshape(out_shape) + + +def fake_apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_global_scale: Optional[torch.Tensor], + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + out_shape = input.shape[:-1] + (size_n,) + return torch.empty(out_shape, dtype=input.dtype, device=input.device) + + +direct_register_custom_op( + op_name="apply_fp4_marlin_linear", + op_func=apply_fp4_marlin_linear, + mutates_args=[], + fake_impl=fake_apply_fp4_marlin_linear, +) + + +def prepare_fp4_layer_for_marlin( + layer: torch.nn.Module, + weight_attr: str = "weight", + weight_scale_attr: str = "weight_scale", + weight_global_scale_attr: str = "weight_global_scale", +) -> None: + """Repack NVFP4 linear layer weights into Marlin format in-place.""" + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads." + ) + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + weight = getattr(layer, weight_attr) + assert weight.shape == (part_size_n, part_size_k // 2), ( + f"Expected {weight_attr} shape ({part_size_n}, {part_size_k // 2}), " + f"got {weight.shape}" + ) + + device = weight.device + + # WORKSPACE + layer.marlin_workspace = marlin_make_workspace(device) + + # WEIGHT: repack from NVFP4 native layout to Marlin tile layout + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = weight.data.view(torch.int32).T.contiguous() + del weight + marlin_qweight = gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4, + ) + del qweight + setattr(layer, weight_attr, torch.nn.Parameter(marlin_qweight, requires_grad=False)) + + # WEIGHT SCALES: transpose, permute, convert to FP8-S0E5M3 + weight_scale = getattr(layer, weight_scale_attr) + weight_scale = weight_scale.data.T.contiguous().to(param_dtype) + weight_scale = marlin_permute_scales( + s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=FP4_MARLIN_GROUP_SIZE, + ) + weight_scale = nvfp4_marlin_process_scales(weight_scale) + setattr( + layer, weight_scale_attr, torch.nn.Parameter(weight_scale, requires_grad=False) + ) + + # GLOBAL SCALE: Pre-adjust exponent bias for Marlin kernel. + weight_global_scale = getattr(layer, weight_global_scale_attr) + weight_global_scale = weight_global_scale.to(param_dtype) + weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale) + setattr( + layer, + weight_global_scale_attr, + torch.nn.Parameter(weight_global_scale, requires_grad=False), + ) + + # BIAS (if present): Permute for Marlin's fast access pattern + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n,) + bias = marlin_permute_bias(layer.bias) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + """Repack NVFP4 MoE weights into Marlin format in-place (per-expert).""" + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel for MoE layers. This may " + "degrade performance for compute-heavy workloads." + ) + + e = layer.num_local_experts + k = layer.w13_weight.shape[2] * 2 # hidden_size (packed: K//2 per uint8) + n = layer.intermediate_size_per_partition + param_dtype = layer.params_dtype + num_shards = 2 if layer.moe_runner_config.is_gated else 1 + + device = layer.w13_weight.device + perm = torch.empty(0, dtype=torch.int, device=device) + + # (size_n, size_k) for each projection + sizes = {"w13": (n * num_shards, k), "w2": (k, n)} + + # --- WEIGHT REPACKING --- + for name in ["w13_weight", "w2_weight"]: + prefix = name.split("_")[0] # "w13" or "w2" + size_n, size_k = sizes[prefix] + weight = getattr(layer, name) + + assert weight.shape == (e, size_n, size_k // 2), ( + f"Expected {name} shape ({e}, {size_n}, {size_k // 2}), " + f"got {weight.shape}" + ) + + repacked = [] + for i in range(e): + qweight = weight.data[i].view(torch.int32).T.contiguous() + repacked.append( + gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + ) + + del weight + setattr( + layer, name, torch.nn.Parameter(torch.stack(repacked), requires_grad=False) + ) + + # --- WEIGHT SCALE PROCESSING --- + for prefix in ["w13", "w2"]: + size_n, size_k = sizes[prefix] + scales = getattr(layer, prefix + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, prefix + "_weight_scale_2").to(param_dtype) + + processed = [] + for i in range(e): + s = marlin_permute_scales( + s=scales.data[i].T, + size_k=size_k, + size_n=size_n, + group_size=FP4_MARLIN_GROUP_SIZE, + ) + processed.append(nvfp4_marlin_process_scales(s)) + + del scales + setattr( + layer, + prefix + "_weight_scale", + torch.nn.Parameter(torch.stack(processed), requires_grad=False), + ) + + if global_scale.dim() > 1: + global_scale = global_scale.max(dim=-1).values + global_scale = nvfp4_marlin_process_global_scale(global_scale) + setattr( + layer, + prefix + "_weight_scale_2", + torch.nn.Parameter(global_scale, requires_grad=False), + ) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index c0d9958e45ee..d26454778852 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -40,6 +40,11 @@ is_blackwell_supported, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod +from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + prepare_moe_fp4_layer_for_marlin, + should_use_fp4_marlin_fallback, +) from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import ( convert_to_channelwise, @@ -1142,7 +1147,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 100 + return 75 # SM75+ (Turing) supports Marlin FP4 fallback; SM100 for native FP4 @staticmethod def common_group_size(cfg: dict) -> int: @@ -1316,6 +1321,7 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes + layer.params_dtype = params_dtype layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition @@ -1370,6 +1376,20 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if should_use_fp4_marlin_fallback(): + # Marlin FP4 fallback: consolidate global scale then repack weights + weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) + layer.weight_scale_2_marlin = Parameter(weight_scale_2, requires_grad=False) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + layer.use_marlin_fallback = True + return + + layer.use_marlin_fallback = False input_scale_2 = layer.input_scale.max().to(torch.float32) weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) @@ -1470,6 +1490,18 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if layer.use_marlin_fallback: + return torch.ops.sglang.apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + output_dtype = x.dtype x_m, _ = x.shape @@ -1526,14 +1558,27 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: ModelOptFp4Config): self.quant_config = quant_config - if not is_blackwell_supported(): + + if should_use_fp4_marlin_fallback(): + logger.warning_once( + "GPU is not Blackwell (SM100+). Using Marlin FP4 fallback kernel " + "for MoE layers. Weights remain compressed in FP4 format." + ) + self.use_marlin_fallback = True + self.enable_flashinfer_trtllm_moe = False + elif not is_blackwell_supported(): raise ValueError( "Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above." + " quantization. Please use SM75+ (Turing or newer)." + ) + else: + self.use_marlin_fallback = False + self.enable_flashinfer_trtllm_moe = ( + get_moe_runner_backend().is_flashinfer_trtllm() ) self.enable_flashinfer_trtllm_moe = ( get_moe_runner_backend().is_flashinfer_trtllm() + or get_moe_runner_backend().is_flashinfer_trtllm_routed() ) self._cache_permute_indices = {} @@ -1688,6 +1733,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: Only supports pre-quantized checkpoints with FP8 weights and scales. """ + if self.use_marlin_fallback: + prepare_moe_fp4_layer_for_marlin(layer) + return # GEMM 1 scale processing if layer.moe_runner_config.is_gated: @@ -1900,10 +1948,17 @@ def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): self.moe_runner_config = moe_runner_config + if self.use_marlin_fallback: + self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config) + return if get_moe_runner_backend().is_flashinfer_trtllm(): self.runner = MoeRunner( MoeRunnerBackend.FLASHINFER_TRTLLM, moe_runner_config ) + elif get_moe_runner_backend().is_flashinfer_trtllm_routed(): + self.runner = MoeRunner( + MoeRunnerBackend.FLASHINFER_TRTLLM_ROUTED, moe_runner_config + ) def apply( self, @@ -1922,6 +1977,34 @@ def apply( ), f"{activation=} missing from {ACT_STR_TO_TYPE_MAP.keys()=}" moe_runner_config = self.moe_runner_config + # Marlin FP4 fallback path for non-Blackwell GPUs (SM75-SM89) + if self.use_marlin_fallback: + from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo + + expert_map = None + global_num_experts = -1 + if hasattr(layer, "dispatcher") and hasattr( + layer.dispatcher, "local_expert_mapping" + ): + expert_map = layer.dispatcher.local_expert_mapping + if expert_map is not None: + global_num_experts = moe_runner_config.num_experts + + quant_info = MarlinMoeQuantInfo( + w13_qweight=layer.w13_weight, + w2_qweight=layer.w2_weight, + w13_scales=layer.w13_weight_scale, + w2_scales=layer.w2_weight_scale, + w13_g_idx_sort_indices=None, + w2_g_idx_sort_indices=None, + weight_bits=4, + w13_global_scale=layer.w13_weight_scale_2, + w2_global_scale=layer.w2_weight_scale_2, + expert_map=expert_map, + global_num_experts=global_num_experts, + ) + return self.runner.run(dispatch_output, quant_info) + # FlashInfer TRTLLM FP4 path - layer has shuffled weights only when # backend is flashinfer_trtllm if hasattr(layer, "gemm1_weights_fp4_shuffled"): diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index e947a48cbde8..4196787820f4 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -18,6 +18,7 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.utils.common import crash_on_warnings, get_bool_env_var, is_cuda, is_npu +_use_fused_sampling = False if is_cuda(): from flashinfer.sampling import ( min_p_sampling_from_probs, @@ -27,6 +28,15 @@ top_k_renorm_prob, top_p_renorm_prob, ) + + from sglang.srt.layers.fused_sampling import fused_temperature_softmax_inplace + + _use_fused_sampling = True + +# Batch size threshold for fused Triton kernel vs PyTorch softmax. +# Below this threshold, PyTorch's native div+softmax is faster. +# At and above this threshold, the fused Triton kernel wins. +_FUSED_SAMPLING_BATCH_THRESHOLD = 128 if is_npu(): import torch_npu @@ -152,11 +162,20 @@ def forward( logprobs = logprobs_via_logsoftmax_kernel else: # Standard path: do softmax and sample from probs. - logits.div_(sampling_info.temperatures) - - # In-place op to save memory - logits[:] = torch.softmax(logits, dim=-1) - probs = logits + # Use fused Triton kernel for large batches where it excels; + # fall back to PyTorch for small batches where launch overhead dominates. + if ( + _use_fused_sampling + and logits.shape[0] >= _FUSED_SAMPLING_BATCH_THRESHOLD + ): + fused_temperature_softmax_inplace( + logits, sampling_info.temperatures + ) + probs = logits + else: + logits.div_(sampling_info.temperatures) + logits[:] = torch.softmax(logits, dim=-1) + probs = logits batch_next_token_ids = self._sample_from_probs( probs, sampling_info, positions, simple_sampling_case diff --git a/python/sglang/srt/managers/hisparse_coordinator.py b/python/sglang/srt/managers/hisparse_coordinator.py index 89740f73682e..9336571976ba 100644 --- a/python/sglang/srt/managers/hisparse_coordinator.py +++ b/python/sglang/srt/managers/hisparse_coordinator.py @@ -78,8 +78,11 @@ def __init__( ) self.write_staging_stream = device_module.Stream() + self.decode_backup_stream = device_module.Stream() self.ack_staging_queue: List[HiSparseAct] = [] self.decode_producer_stream = None + self._backup_done_event = device_module.Event() + self._has_pending_backup = False self.tp_group = tp_group self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) @@ -391,9 +394,6 @@ def _eager_backup_previous_token( The only exception is the first decode step right after staging: all prefill tokens were already backed up during staging, so there is nothing new to save yet. """ - if self.decode_producer_stream is not None: - device_module.current_stream().wait_stream(self.decode_producer_stream) - # Build the list of batch positions that need a host backup. # Skip the first decode step after staging (prefill already backed up). backup_indices = [] @@ -431,12 +431,36 @@ def _eager_backup_previous_token( host_locs = host_locs.to(device=self.device) self.req_to_host_pool[backup_req_indices, actual_token_pos] = host_locs - self.mem_pool_host.backup_from_device_all_layer( - self.mem_pool_device, - host_locs, - device_locs.contiguous(), - io_backend="kernel", - ) + if self._has_pending_backup: + self._backup_done_event.wait(device_module.current_stream()) + self._has_pending_backup = False + schedule_stream = device_module.current_stream() + with device_module.stream(self.decode_backup_stream): + self.decode_backup_stream.wait_stream(schedule_stream) + if self.decode_producer_stream is not None: + self.decode_backup_stream.wait_stream(self.decode_producer_stream) + self.mem_pool_host.backup_from_device_all_layer( + self.mem_pool_device, + host_locs, + device_locs, + io_backend="kernel", + ) + self._backup_done_event.record() + if host_locs.is_cuda: + host_locs.record_stream(self.decode_backup_stream) + if backup_req_indices.is_cuda: + backup_req_indices.record_stream(self.decode_backup_stream) + if actual_token_pos.is_cuda: + actual_token_pos.record_stream(self.decode_backup_stream) + if device_locs.is_cuda: + device_locs.record_stream(self.decode_backup_stream) + self._has_pending_backup = True + + def wait_for_pending_backup(self) -> None: + if not self._has_pending_backup: + return + self._backup_done_event.wait(device_module.current_stream()) + self._has_pending_backup = False def get_front_topk_tokens( self, @@ -569,6 +593,9 @@ def request_finished(self, req: Req): # release resources only after the execution of a potential overlapped batch if self.decode_producer_stream is not None: device_module.current_stream().wait_stream(self.decode_producer_stream) + if self._has_pending_backup: + self._backup_done_event.wait(device_module.current_stream()) + self._has_pending_backup = False # release memory — only free actually-allocated buffer indices current_cap = int(self.req_device_buffer_size[req.req_pool_idx]) diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index e0a1669fb3e6..8da3d3b0de60 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -433,7 +433,16 @@ async def print_exception_wrapper(func): def get_main_process_id() -> int: - """Get the main process ID""" + """Get the main process ID. + + Supports override via SGLANG_GRANIAN_PARENT_PID for workers whose + multiprocessing parent PID differs from the shared-memory owner. + """ + from sglang.srt.environ import envs + + override = envs.SGLANG_GRANIAN_PARENT_PID.get() + if override is not None: + return override return multiprocessing.current_process()._parent_pid diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 36c55826d821..377a7ec749fb 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -276,6 +276,24 @@ def copy_to_cpu(self): self.copy_done.record() +def validate_dflash_request(req: Req) -> Optional[str]: + if req.return_logprob: + return "DFLASH speculative decoding does not support return_logprob yet." + + if ( + req.sampling_params.json_schema is not None + or req.sampling_params.regex is not None + or req.sampling_params.ebnf is not None + or req.sampling_params.structural_tag is not None + ): + return ( + "DFLASH speculative decoding does not support " + "grammar-constrained decoding yet." + ) + + return None + + class Scheduler( SchedulerOutputProcessorMixin, SchedulerUpdateWeightsMixin, @@ -1861,6 +1879,14 @@ def handle_generate_request( self._add_request_to_queue(req) return + if self.spec_algorithm.is_dflash(): + error_msg = validate_dflash_request(req) + if error_msg is not None: + req.set_finish_with_abort(error_msg) + self.init_req_max_new_tokens(req) + self._add_request_to_queue(req) + return + # Handle multimodal inputs if recv_req.mm_inputs is not None: image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index e4c158cda9f6..439b35f1d3a9 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -259,6 +259,15 @@ def __init__( for conv_shape in conv_state_shape ] + if _is_npu: + from sglang.srt.hardware_backend.npu.memory_pool_npu import ( + _init_npu_conv_state, + ) + + conv_state = _init_npu_conv_state( + conv_state[0], conv_state_shape, speculative_num_draft_tokens + ) + if _is_cpu and _cpu_has_amx_support: from sglang.srt.layers.amx_utils import _init_amx_conv_state diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index c7c7d6b5ec0b..69cb176efbdc 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -547,18 +547,15 @@ def __init__(self, model_runner: ModelRunner): self.capture_forward_mode = ForwardMode.DECODE self.capture_hidden_mode = CaptureHiddenMode.NULL self.num_tokens_per_bs = 1 - if ( - model_runner.spec_algorithm.is_eagle() - or model_runner.spec_algorithm.is_standalone() - or model_runner.spec_algorithm.is_ngram() - ): + if model_runner.spec_algorithm.is_speculative(): if self.model_runner.is_draft_worker: - raise RuntimeError("This should not happen") - else: - self.capture_forward_mode = ForwardMode.TARGET_VERIFY - self.num_tokens_per_bs = ( - self.model_runner.server_args.speculative_num_draft_tokens - ) + # DFLASH draft workers reuse this runner for TARGET_VERIFY mode. + if not self.model_runner.spec_algorithm.is_dflash(): + raise RuntimeError("This should not happen") + self.capture_forward_mode = ForwardMode.TARGET_VERIFY + self.num_tokens_per_bs = ( + self.model_runner.server_args.speculative_num_draft_tokens + ) elif self.is_dllm: self.capture_forward_mode = ForwardMode.DLLM_EXTEND self.num_tokens_per_bs = self.dllm_config.block_size @@ -646,6 +643,18 @@ def __init__(self, model_runner: ModelRunner): and model_runner.eagle_use_aux_hidden_state ): self.model_runner.model.set_eagle3_layers_to_capture() + if ( + model_runner.spec_algorithm.is_dflash() + and model_runner.dflash_use_aux_hidden_state + ): + if not hasattr(self.model_runner.model, "set_dflash_layers_to_capture"): + raise ValueError( + f"Model {self.model_runner.model.__class__.__name__} does not implement set_dflash_layers_to_capture, " + "which is required for DFLASH aux hidden capture." + ) + self.model_runner.model.set_dflash_layers_to_capture( + self.model_runner.dflash_target_layer_ids + ) # Capture try: @@ -671,6 +680,7 @@ def can_run(self, forward_batch: ForwardBatch): max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() or self.model_runner.spec_algorithm.is_standalone() + or self.model_runner.spec_algorithm.is_dflash() else max(forward_batch.global_num_tokens_cpu) ) else: @@ -1007,6 +1017,12 @@ def run_once(): kwargs["pp_proxy_tensors"] = PPProxyTensors( {k: v.clone() for k, v in pp_proxy_tensors.tensors.items()} ) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and "input_embeds" in inspect.signature(forward).parameters + ): + kwargs["input_embeds"] = buffers.input_embeds[:num_tokens] logits_output_or_pp_proxy_tensors = forward( input_ids, @@ -1083,6 +1099,7 @@ def replay_prepare( max_num_tokens / self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() or self.model_runner.spec_algorithm.is_standalone() + or self.model_runner.spec_algorithm.is_dflash() else max_num_tokens ) index = bisect.bisect_left(self.capture_bs, max_batch_size) @@ -1104,6 +1121,13 @@ def replay_prepare( ), pp_proxy_tensors=pp_proxy_tensors, ) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and forward_batch.input_embeds is not None + ): + buffers.input_embeds[:raw_num_token].copy_(forward_batch.input_embeds) + # Padded tokens aren't read, so skip zeroing them. if self.enable_two_batch_overlap: self.tbo_plugin.replay_prepare( forward_mode=self.capture_forward_mode, @@ -1152,6 +1176,14 @@ def replay( # In speculative decoding, these two fields are still needed. self.buffers.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) self.buffers.positions[: self.raw_num_token].copy_(forward_batch.positions) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and forward_batch.input_embeds is not None + ): + self.buffers.input_embeds[: self.raw_num_token].copy_( + forward_batch.input_embeds + ) # Replay if self.enable_pdmux: @@ -1164,10 +1196,18 @@ def replay( if isinstance(output, LogitsProcessorOutput): if self.is_dllm: next_token_logits = None - full_logits = output.full_logits[: self.raw_num_token] + full_logits = ( + output.full_logits[: self.raw_num_token] + if output.full_logits is not None + else None + ) else: full_logits = None - next_token_logits = output.next_token_logits[: self.raw_num_token] + next_token_logits = ( + output.next_token_logits[: self.raw_num_token] + if output.next_token_logits is not None + else None + ) return LogitsProcessorOutput( next_token_logits=next_token_logits, @@ -1209,6 +1249,32 @@ def get_spec_info(self, num_tokens: int): seq_lens_sum=None, seq_lens_cpu=None, ) + elif self.model_runner.spec_algorithm.is_dflash(): + from sglang.srt.speculative.dflash_info import DFlashVerifyInput + from sglang.srt.speculative.dflash_utils import ( + resolve_dflash_verify_mask_policy, + ) + + # Avoid enabling custom-mask modes during graph capture for backends that + # can express DFLASH verify via their built-in causal path. + _, build_custom_mask = resolve_dflash_verify_mask_policy( + self.model_runner.attn_backend + ) + spec_info = DFlashVerifyInput( + draft_token=None, + positions=None, + draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, + custom_mask=( + None + if (self.model_runner.is_draft_worker or not build_custom_mask) + else self.buffers.custom_mask + ), + capture_hidden_mode=( + CaptureHiddenMode.NULL + if self.model_runner.is_draft_worker + else CaptureHiddenMode.FULL + ), + ) elif self.model_runner.spec_algorithm.is_ngram(): from sglang.srt.speculative.ngram_info import NgramVerifyInput diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 669cab133c49..74c2f4848e2d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -354,6 +354,9 @@ def __init__( self.remote_instance_transfer_engine_weight_info = None # auxiliary hidden capture mode. TODO: expose this to server args? self.eagle_use_aux_hidden_state = False + self.dflash_use_aux_hidden_state = False + self.dflash_target_layer_ids = None + self.dflash_draft_num_layers = None if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: # load draft config draft_model_config = ModelConfig.from_server_args( @@ -379,6 +382,52 @@ def __init__( # if there is no aux layer, set to None self.eagle_aux_hidden_state_layer_ids = None + if self.spec_algorithm.is_dflash() and not self.is_draft_worker: + from sglang.srt.speculative.dflash_utils import ( + parse_dflash_draft_config, + ) + + # Select target layers to capture for building DFlash context features. + draft_model_config = ModelConfig.from_server_args( + server_args, + model_path=(server_args.speculative_draft_model_path), + model_revision=server_args.speculative_draft_model_revision, + is_draft_model=True, + ) + dflash_draft_config = parse_dflash_draft_config( + draft_hf_config=draft_model_config.hf_config + ) + draft_num_layers = dflash_draft_config.require_num_layers() + trained_target_layers = dflash_draft_config.num_target_layers + + target_num_layers = getattr( + self.model_config.hf_text_config, "num_hidden_layers", None + ) + if target_num_layers is None: + raise ValueError( + "DFLASH requires target num_hidden_layers in config. " + f"Got target={target_num_layers}." + ) + target_num_layers = int(target_num_layers) + + if ( + trained_target_layers is not None + and trained_target_layers != target_num_layers + ): + logger.warning( + "DFLASH draft config num_target_layers=%s differs from runtime target num_hidden_layers=%s; " + "selecting capture layers based on the runtime target model.", + trained_target_layers, + target_num_layers, + ) + + self.dflash_use_aux_hidden_state = True + self.dflash_draft_num_layers = int(draft_num_layers) + self.dflash_target_layer_ids = dflash_draft_config.resolve_target_layer_ids( + target_num_layers=int(target_num_layers), + draft_num_layers=int(draft_num_layers), + ) + # Apply the rank zero filter to logger if server_args.show_time_cost: enable_show_time_cost() @@ -670,6 +719,14 @@ def initialize(self, pre_model_load_memory: float): self.eagle_aux_hidden_state_layer_ids ) + if self.dflash_use_aux_hidden_state: + if not hasattr(self.model, "set_dflash_layers_to_capture"): + raise ValueError( + f"Model {self.model.__class__.__name__} does not implement set_dflash_layers_to_capture, " + "which is required for DFLASH." + ) + self.model.set_dflash_layers_to_capture(self.dflash_target_layer_ids) + # Initialize piecewise CUDA graph self.init_piecewise_cuda_graphs() @@ -2077,6 +2134,22 @@ def kernel_warmup(self): if self._should_run_flashinfer_autotune(): self._flashinfer_autotune() + self._warmup_fused_sampling() + + def _warmup_fused_sampling(self): + """Pre-compile and autotune fused sampling Triton kernels.""" + if _is_hip: + return + from sglang.srt.layers.fused_sampling import warmup_fused_temperature_softmax + + logits_warmup_dtype = ( + torch.float32 if self.server_args.enable_fp32_lm_head else self.dtype + ) + warmup_fused_temperature_softmax( + self.model_config.vocab_size, + logits_dtype=logits_warmup_dtype, + ) + def _should_run_flashinfer_autotune(self) -> bool: """Check if flashinfer autotune should be run.""" if self.server_args.disable_flashinfer_autotune: @@ -2100,11 +2173,7 @@ def _should_run_flashinfer_autotune(self) -> bool: if major < 9: return False - if ( - self.spec_algorithm.is_eagle() - or self.spec_algorithm.is_standalone() - or self.spec_algorithm.is_ngram() - ): + if self.spec_algorithm.is_speculative(): return not self.is_draft_worker return True @@ -2134,16 +2203,12 @@ def _dummy_run(self, batch_size: int, run_ctx=None): capture_forward_mode = ForwardMode.EXTEND capture_hidden_mode = CaptureHiddenMode.NULL num_tokens_per_bs = 1 - if ( - self.spec_algorithm.is_eagle() - or self.spec_algorithm.is_standalone() - or self.spec_algorithm.is_ngram() - ): + if self.spec_algorithm.is_speculative(): if self.is_draft_worker: - raise RuntimeError("This should not happen") - else: - capture_forward_mode = ForwardMode.TARGET_VERIFY - num_tokens_per_bs = self.server_args.speculative_num_draft_tokens + if not self.spec_algorithm.is_dflash(): + raise RuntimeError("This should not happen") + capture_forward_mode = ForwardMode.TARGET_VERIFY + num_tokens_per_bs = self.server_args.speculative_num_draft_tokens if self.server_args.enable_return_hidden_states: capture_hidden_mode = CaptureHiddenMode.FULL @@ -2173,6 +2238,8 @@ def _dummy_run(self, batch_size: int, run_ctx=None): if self.eagle_use_aux_hidden_state: self.model.set_eagle3_layers_to_capture() + if self.dflash_use_aux_hidden_state: + self.model.set_dflash_layers_to_capture(self.dflash_target_layer_ids) require_mlp_tp_gather_ = require_mlp_tp_gather(self.server_args) if require_gathered_buffer(self.server_args): @@ -2286,6 +2353,21 @@ def get_spec_info(): seq_lens_sum=None, seq_lens_cpu=None, ) + elif self.spec_algorithm.is_dflash(): + from sglang.srt.speculative.dflash_info import DFlashVerifyInput + + # Dummy warmup only needs shape metadata; avoid forcing custom-mask mode. + spec_info = DFlashVerifyInput( + draft_token=None, + positions=None, + draft_token_num=self.server_args.speculative_num_draft_tokens, + custom_mask=None, + capture_hidden_mode=( + CaptureHiddenMode.NULL + if self.is_draft_worker + else CaptureHiddenMode.FULL + ), + ) elif self.spec_algorithm.is_ngram(): from sglang.srt.speculative.ngram_info import NgramVerifyInput @@ -2817,6 +2899,12 @@ def _forward_raw( and self.graph_runner.can_run(forward_batch) ) + if ( + self.hisparse_coordinator is not None + and forward_batch.forward_mode.is_decode() + ): + self.hisparse_coordinator.wait_for_pending_backup() + if can_run_graph: ret = self.graph_runner.replay( forward_batch, diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index a6baa4817ace..bca2baca64f9 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -167,6 +167,22 @@ def profile_max_num_token(self: ModelRunner, pre_model_load_memory: int): num_layers = self.num_effective_layers cell_size = self.get_cell_size_per_token(num_layers) + if self.spec_algorithm.is_dflash() and not self.is_draft_worker: + from sglang.srt.speculative.dflash_utils import ( + scale_kv_cell_size_per_token_for_dflash, + ) + + draft_num_layers = getattr(self, "dflash_draft_num_layers", None) + if ( + draft_num_layers is not None + and int(draft_num_layers) > 0 + and int(num_layers) > 0 + ): + cell_size = scale_kv_cell_size_per_token_for_dflash( + target_cell_size_per_token=cell_size, + target_num_layers=int(num_layers), + draft_num_layers=int(draft_num_layers), + ) rest_memory = post_model_load_memory - pre_model_load_memory * ( 1 - self.mem_fraction_static diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index ef672c5c0a7f..4c93176612a0 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -187,6 +187,16 @@ def get_quant_config( if not isinstance(hf_quant_config, dict): hf_quant_config = hf_quant_config.to_dict() hf_quant_config["packed_modules_mapping"] = packed_modules_mapping + # For modelopt, route to FP4 vs FP8 config based on quant_algo + if model_config.quantization.startswith("modelopt"): + quant_algo = hf_quant_config.get("quant_algo") + if quant_algo is None: + quant_algo = hf_quant_config.get("quantization", {}).get("quant_algo") + if quant_algo is not None: + if quant_algo == "FP8" or model_config.quantization == "modelopt_fp8": + return ModelOptFp8Config.from_config(hf_quant_config) + if "FP4" in quant_algo: + return ModelOptFp4Config.from_config(hf_quant_config) return quant_cls.from_config(hf_quant_config) # In case of bitsandbytes/QLoRA, get quant config from the adapter model. diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 8f89bfd219c1..baf876a9574b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -154,6 +154,9 @@ ) from sglang.srt.utils.custom_op import register_custom_op +if _use_aiter: + from sglang.srt.layers.rocm_linear_utils import aiter_dsv3_router_gemm + if _use_aiter: from sglang.srt.layers.rocm_linear_utils import aiter_dsv3_router_gemm @@ -339,7 +342,6 @@ def forward( logits = dsv3_router_gemm( hidden_states, self.weight, out_dtype=torch.float32 ) - elif _use_aiter: logits = aiter_dsv3_router_gemm(hidden_states, self.weight) else: diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py new file mode 100644 index 000000000000..27f5cdbf539d --- /dev/null +++ b/python/sglang/srt/models/dflash.py @@ -0,0 +1,399 @@ +# Adapted from the DFlash reference implementation (HF) but implemented with +# SGLang primitives (RadixAttention + SGLang KV cache). This model intentionally +# does not include token embeddings or an LM head; DFlash uses the target model's +# embedding/lm_head. + +from __future__ import annotations + +import logging +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.radix_attention import AttentionType, RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.utils import apply_qk_norm +from sglang.srt.speculative.dflash_utils import ( + can_dflash_slice_qkv_weight, + parse_dflash_draft_config, +) + +logger = logging.getLogger(__name__) + + +class DFlashAttention(nn.Module): + def __init__(self, config, layer_id: int) -> None: + super().__init__() + hidden_size = int(config.hidden_size) + tp_size = int(get_tensor_model_parallel_world_size()) + total_num_heads = int(config.num_attention_heads) + total_num_kv_heads = int( + getattr(config, "num_key_value_heads", total_num_heads) + ) + head_dim = int(getattr(config, "head_dim", hidden_size // total_num_heads)) + + self.hidden_size = hidden_size + self.total_num_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + assert self.total_num_heads % tp_size == 0, ( + f"DFlashAttention requires total_num_heads divisible by tp_size. " + f"total_num_heads={self.total_num_heads}, tp_size={tp_size}." + ) + self.num_heads = self.total_num_heads // tp_size + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0, ( + f"DFlashAttention requires total_num_kv_heads divisible by tp_size when >= tp_size. " + f"total_num_kv_heads={self.total_num_kv_heads}, tp_size={tp_size}." + ) + else: + assert tp_size % self.total_num_kv_heads == 0, ( + f"DFlashAttention requires tp_size divisible by total_num_kv_heads when total_num_kv_heads < tp_size. " + f"total_num_kv_heads={self.total_num_kv_heads}, tp_size={tp_size}." + ) + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * head_dim + self.kv_size = self.num_kv_heads * head_dim + + attention_bias = bool(getattr(config, "attention_bias", False)) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=attention_bias, + prefix="qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * head_dim, + hidden_size, + bias=attention_bias, + prefix="o_proj", + ) + + # Per-head Q/K RMSNorm, matching HF Qwen3. + self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps) + + rope_theta = float(getattr(config, "rope_theta", 1000000)) + rope_scaling = getattr(config, "rope_scaling", None) + rope_is_neox_style = bool( + getattr( + config, "rope_is_neox_style", getattr(config, "is_neox_style", True) + ) + ) + max_position_embeddings = int(getattr(config, "max_position_embeddings", 32768)) + self.rotary_emb = get_rope( + head_dim, + rotary_dim=head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=rope_is_neox_style, + ) + + self.scaling = head_dim**-0.5 + # DFlash uses non-causal attention over the draft block. + self.attn = RadixAttention( + num_heads=self.num_heads, + head_dim=head_dim, + scaling=self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + attn_type=AttentionType.ENCODER_ONLY, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = apply_qk_norm(q, k, self.q_norm, self.k_norm, self.head_dim) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + def kv_proj_only( + self, hidden_states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Project hidden_states to K/V only (skip Q). + + This is used by DFlash to materialize ctx tokens into the draft KV cache: + we only need K/V for the cached tokens; Q is never consumed. + """ + # Fast path for unquantized weights: slice the fused QKV weight and run one GEMM. + can_slice_qkv_weight, _ = can_dflash_slice_qkv_weight(self.qkv_proj) + if can_slice_qkv_weight: + kv_slice = slice(self.q_size, self.q_size + 2 * self.kv_size) + weight = self.qkv_proj.weight[kv_slice] + bias = ( + self.qkv_proj.bias[kv_slice] if self.qkv_proj.bias is not None else None + ) + kv = F.linear(hidden_states, weight, bias) + k, v = kv.split([self.kv_size, self.kv_size], dim=-1) + return k, v + + # Fallback: compute full QKV and discard Q (keeps compatibility with quantized weights). + qkv, _ = self.qkv_proj(hidden_states) + _, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + return k, v + + def apply_k_norm(self, k: torch.Tensor) -> torch.Tensor: + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + return k_by_head.view_as(k) + + def apply_k_rope(self, positions: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + # Use a minimal dummy query (1 head) to avoid doing full-Q work. + dummy_q = k.new_empty((k.shape[0], self.head_dim)) + _, k = self.rotary_emb(positions, dummy_q, k) + return k + + +class DFlashMLP(nn.Module): + def __init__(self, config, quant_config=None, prefix: str = "") -> None: + super().__init__() + hidden_size = int(config.hidden_size) + intermediate_size = int(getattr(config, "intermediate_size", 0)) + if intermediate_size <= 0: + raise ValueError( + f"Invalid intermediate_size={intermediate_size} for DFlash MLP." + ) + + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix="gate_up_proj" if not prefix else f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix="down_proj" if not prefix else f"{prefix}.down_proj", + ) + hidden_act = getattr(config, "hidden_act", "silu") + if hidden_act != "silu": + raise ValueError( + f"Unsupported DFlash activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DFlashDecoderLayer(nn.Module): + def __init__(self, config, layer_id: int) -> None: + super().__init__() + hidden_size = int(config.hidden_size) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.self_attn = DFlashAttention(config=config, layer_id=layer_id) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.mlp = DFlashMLP(config=config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if hidden_states.numel() == 0: + # Keep return types consistent for upstream callers. + if residual is None: + residual = hidden_states + return hidden_states, residual + + # Pre-norm attention with fused residual+norm when possible (Qwen3-style). + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + attn_out = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + hidden_states, residual = self.post_attention_layernorm(attn_out, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DFlashDraftModel(nn.Module): + """SGLang DFlash draft model (no embedding / lm_head weights). + + The checkpoint provides: + - transformer weights for `layers.*` + - `fc.weight`, `hidden_norm.weight` for projecting target context features + - `norm.weight` for final normalization + """ + + def __init__(self, config, quant_config=None, prefix: str = "") -> None: + super().__init__() + self.config = config + + hidden_size = int(config.hidden_size) + num_layers = int(config.num_hidden_layers) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.layers = nn.ModuleList( + [DFlashDecoderLayer(config=config, layer_id=i) for i in range(num_layers)] + ) + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + + # Project per-token target context features: + # concat(K * hidden_size) -> hidden_size, where K is the number of target-layer + # feature tensors concatenated per token (not necessarily equal to num_layers). + draft_config = parse_dflash_draft_config(draft_hf_config=config) + target_num_layers = ( + int(draft_config.num_target_layers) + if draft_config.num_target_layers is not None + else num_layers + ) + target_layer_ids = draft_config.resolve_target_layer_ids( + target_num_layers=target_num_layers, draft_num_layers=num_layers + ) + num_context_features = len(target_layer_ids) + + self.num_context_features = int(num_context_features) + self.fc = nn.Linear( + self.num_context_features * hidden_size, hidden_size, bias=False + ) + self.hidden_norm = RMSNorm(hidden_size, eps=rms_norm_eps) + + self.block_size = draft_config.resolve_block_size(default=16) + + def project_target_hidden(self, target_hidden: torch.Tensor) -> torch.Tensor: + """Project concatenated target-layer hidden states into draft hidden_size.""" + expected = int(self.fc.in_features) + if target_hidden.ndim != 2 or int(target_hidden.shape[-1]) != expected: + raise ValueError( + "DFLASH target_hidden feature dim mismatch. " + f"Expected shape [N, {expected}] " + f"(num_context_features={self.num_context_features}, hidden_size={int(self.config.hidden_size)}), " + f"but got shape={tuple(target_hidden.shape)}. " + "This usually means the target model is capturing a different number of layer features than " + "the draft checkpoint/config expects." + ) + return self.hidden_norm(self.fc(target_hidden)) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + get_embedding: bool = False, + pp_proxy_tensors=None, + ) -> LogitsProcessorOutput: + if input_embeds is None: + raise ValueError( + "DFlashDraftModel requires `input_embeds` (use the target embedding)." + ) + hidden_states = input_embeds + residual: Optional[torch.Tensor] = None + + for layer in self.layers: + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) + + if hidden_states.numel() != 0: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + + return LogitsProcessorOutput( + next_token_logits=None, + hidden_states=hidden_states, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, weight_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + + def resolve_param_name(name: str) -> Optional[str]: + if name in params_dict: + return name + if name.startswith("model."): + stripped_name = name[len("model.") :] + if stripped_name in params_dict: + return stripped_name + else: + prefixed_name = f"model.{name}" + if prefixed_name in params_dict: + return prefixed_name + return None + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if f".{weight_name}." not in name: + continue + mapped_name = name.replace(weight_name, param_name) + resolved_name = resolve_param_name(mapped_name) + if resolved_name is None: + continue + param = params_dict[resolved_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + break + else: + resolved_name = resolve_param_name(name) + if resolved_name is None: + # Ignore unexpected weights (e.g., HF rotary caches). + continue + param = params_dict[resolved_name] + if resolved_name.endswith("fc.weight") and tuple( + loaded_weight.shape + ) != tuple(param.shape): + raise ValueError( + "DFLASH fc.weight shape mismatch. This usually means the draft checkpoint's " + "number of context features (K) does not match this config. " + f"Expected fc.weight.shape={tuple(param.shape)} " + f"(num_context_features={self.num_context_features}, hidden_size={int(self.config.hidden_size)}), " + f"but got {tuple(loaded_weight.shape)} for weight '{name}'." + ) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = DFlashDraftModel diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index 6a38e7ebad9a..5be005df9f5b 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -42,9 +42,16 @@ default_weight_loader, maybe_remap_kv_scale_name, ) -from sglang.srt.utils import add_prefix, cpu_has_amx_support, is_cpu, make_layers +from sglang.srt.utils import ( + add_prefix, + cpu_has_amx_support, + is_cpu, + is_npu, + make_layers, +) _is_cpu = is_cpu() +_is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() @@ -574,10 +581,17 @@ def __init__( local_theta = getattr(config, "rope_local_base_freq", 10000.0) global_config = copy.deepcopy(config) - global_config.rope_parameters = { - "rope_type": "default", - "rope_theta": global_theta, - } + if not _is_npu: + global_config.rope_parameters = { + "rope_type": "default", + "rope_theta": global_theta, + } + else: + global_config.rope_parameters = { + "rope_theta": global_theta, + "factor": 8, + "rope_type": "linear", + } self.rotary_emb = Gemma3RotaryEmbedding(config=global_config) self.gradient_checkpointing = False diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index f955ac750d34..b8ad74015c6e 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -794,6 +794,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): # of the (i-1)th layer as aux hidden state self.model.layers_to_capture = [val + 1 for val in layer_ids] + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.layers_to_capture = [val + 1 for val in layer_ids] + class Phi3ForCausalLM(LlamaForCausalLM): pass diff --git a/python/sglang/srt/models/minimax_m2.py b/python/sglang/srt/models/minimax_m2.py index d0d2d6c76c5b..5ade336780bf 100644 --- a/python/sglang/srt/models/minimax_m2.py +++ b/python/sglang/srt/models/minimax_m2.py @@ -45,6 +45,7 @@ get_attention_tp_rank, get_attention_tp_size, is_dp_attention_enabled, + get_attention_tp_group, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -79,6 +80,12 @@ make_layers, ) from sglang.srt.utils.hf_transformers_utils import get_rope_config +from sglang.srt.utils import is_npu + +_is_npu = is_npu() + +if _is_npu: + from sgl_kernel_npu.norm.split_qkv_tp_rmsnorm_rope import split_qkv_tp_rmsnorm_rope logger = logging.getLogger(__name__) @@ -679,6 +686,42 @@ def forward_prepare( inner_state = q, k, v, forward_batch return None, forward_batch, inner_state + def forward_prepare_npu( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ): + qkv, _ = self.qkv_proj(hidden_states) + if self.use_qk_norm: + # q = self.q_norm(q.contiguous()) + # k = self.k_norm(k.contiguous()) + cos_sin = self.rotary_emb.cos_sin_cache.index_select( + 0, positions.flatten() + ) + cos, sin = cos_sin.chunk(2, dim=-1) + q, k, v = split_qkv_tp_rmsnorm_rope( + input=qkv, + cos=cos, + sin=sin, + q_weight=self.q_norm.weight, + k_weight=self.k_norm.weight, + q_hidden_size=self.q_size, + kv_hidden_size=self.kv_size, + head_dim=self.head_dim, + rotary_dim=self.rotary_dim, + eps=self.q_norm.variance_epsilon, + tp_world=self.q_norm.attn_tp_size, + tp_group=get_attention_tp_group().device_group, + ) + else: + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = q.contiguous(), k.contiguous() + q, k = self.rotary_emb(positions, q, k) + + inner_state = q, k, v, forward_batch + return None, forward_batch, inner_state + def forward_core(self, intermediate_state): _, _, inner_state = intermediate_state attn_output = self.attn(*inner_state) @@ -691,11 +734,18 @@ def forward( hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: - s = self.forward_prepare( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) + if _is_npu: + s = self.forward_prepare_npu( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + else: + s = self.forward_prepare( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) return self.forward_core(s) def op_prepare(self, state): diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py index 2f430c2b9b28..66a5a3b59732 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -102,7 +102,6 @@ _is_gfx95 = is_gfx95_supported() _is_amx_available = cpu_has_amx_support() - cached_get_processor = lru_cache(get_processor) @@ -132,6 +131,12 @@ def __init__( self.layer_id = layer_id self.activation = config.hidden_act self.layer_norm_epsilon = config.rms_norm_eps + packed_modules_mapping = { + "in_proj_qkvz": ["in_proj_qkv", "in_proj_z"], + "in_proj_ba": ["in_proj_b", "in_proj_a"], + } + if quant_config is not None and hasattr(quant_config, "packed_modules_mapping"): + quant_config.packed_modules_mapping["model"].update(packed_modules_mapping) # Conv1d layer self.conv_dim = self.key_dim * 2 + self.value_dim @@ -203,6 +208,7 @@ def __init__( conv_weights = self.conv1d.weight.view( self.conv1d.weight.size(0), self.conv1d.weight.size(2) ) + self.attn = RadixLinearAttention( layer_id=layer_id, num_q_heads=self.num_k_heads // self.attn_tp_size, @@ -439,7 +445,11 @@ def forward( hidden_states ) - if self.num_v_heads // self.num_k_heads in [1, 2, 4] and not _is_cpu: + if ( + self.num_v_heads // self.num_k_heads in [1, 2, 4] + and not _is_cpu + and not _is_npu + ): mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat_contiguous( projected_states_qkvz, projected_states_ba, @@ -448,6 +458,8 @@ def forward( self.head_k_dim, self.head_v_dim, ) + b = b.contiguous() + a = a.contiguous() elif _is_cpu and _is_amx_available: mixed_qkv, z, b, a = ( torch.ops.sgl_kernel.fused_qkvzba_split_reshape_cat_cpu( @@ -463,6 +475,8 @@ def forward( query, key, value, z, b, a = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba ) + b = b.contiguous() + a = a.contiguous() query, key, value = map( lambda x: x.reshape(x.shape[0], -1), (query, key, value) ) diff --git a/python/sglang/srt/models/qwen3_5_mtp.py b/python/sglang/srt/models/qwen3_5_mtp.py index 3fa89fcda0a9..037081431e95 100644 --- a/python/sglang/srt/models/qwen3_5_mtp.py +++ b/python/sglang/srt/models/qwen3_5_mtp.py @@ -15,6 +15,7 @@ """Inference-only Qwen3_5 MTP model.""" import logging +import os from typing import Iterable, Optional, Tuple import torch @@ -31,7 +32,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen3_5 import Qwen3_5ForCausalLM -from sglang.srt.utils import add_prefix +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import add_prefix, is_npu logger = logging.getLogger(__name__) @@ -53,6 +55,9 @@ def __init__( # The MTP model is unquantized in the nvfp4 checkpoint. if quant_config and quant_config.get_name() == "modelopt_fp4": quant_config = None + if get_global_server_args().speculative_draft_model_quantization is None: + quant_config = None + self.quant_config = quant_config self.config = config self.tp_size = get_tensor_model_parallel_world_size() @@ -118,6 +123,10 @@ def forward( input_embeds: Optional[torch.Tensor] = None, **kwargs, ): + if is_npu() and self.quant_config is None: + # ascend mtp unquant + os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "1" + os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "0" assert input_embeds is None input_embeds = forward_batch.mm_input_embeds if ( @@ -149,6 +158,10 @@ def forward( forward_batch, hidden_states, ) + if is_npu() and self.quant_config is None: + # ascend mtp unquant + os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "0" + os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "1" return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 7e3862a8a61b..6b92daab7cab 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -3,7 +3,6 @@ from typing import Any, Iterable, Optional, Set, Tuple import torch -import triton from torch import nn from sglang.srt.configs.qwen3_next import Qwen3NextConfig @@ -21,7 +20,6 @@ from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, - MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) @@ -56,7 +54,6 @@ logger = logging.getLogger(__name__) -from sglang.jit_kernel.triton.gdn_fused_proj import fused_qkvzba_split_reshape_cat from sglang.srt.layers.attention.fla.fused_norm_gate import FusedRMSNormGated _is_cuda = is_cuda() @@ -65,6 +62,147 @@ _is_amx_available = cpu_has_amx_support() +import triton +import triton.language as tl + + +@triton.jit +def fused_qkvzba_split_reshape_cat_kernel( + mixed_qkv, + z, + b, + a, + mixed_qkvz, + mixed_ba, + NUM_HEADS_QK: tl.constexpr, + NUM_HEADS_V: tl.constexpr, + HEAD_QK: tl.constexpr, + HEAD_V: tl.constexpr, +): + i_bs, i_qk = tl.program_id(0), tl.program_id(1) + QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2 + BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2 + QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + q_end: tl.constexpr = HEAD_QK + blk_q_ptr = ( + mixed_qkvz + + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + + tl.arange(0, q_end) + ) + k_end: tl.constexpr = q_end + HEAD_QK + blk_k_ptr = ( + mixed_qkvz + + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + + tl.arange(q_end, k_end) + ) + v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + blk_v_ptr = ( + mixed_qkvz + + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + + tl.arange(k_end, v_end) + ) + z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + blk_z_ptr = ( + mixed_qkvz + + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + + tl.arange(v_end, z_end) + ) + blk_q_st_ptr = ( + mixed_qkv + + i_bs * NUM_HEADS_QK * QKV_DIM_T + + i_qk * HEAD_QK + + tl.arange(0, HEAD_QK) + ) + blk_k_st_ptr = ( + mixed_qkv + + i_bs * NUM_HEADS_QK * QKV_DIM_T + + NUM_HEADS_QK * HEAD_QK + + i_qk * HEAD_QK + + tl.arange(0, HEAD_QK) + ) + blk_v_st_ptr = ( + mixed_qkv + + i_bs * NUM_HEADS_QK * QKV_DIM_T + + NUM_HEADS_QK * HEAD_QK * 2 + + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK + + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK) + ) + blk_z_st_ptr = ( + z + + i_bs * NUM_HEADS_V * HEAD_V + + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK + + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK) + ) + tl.store(blk_q_st_ptr, tl.load(blk_q_ptr)) + tl.store(blk_k_st_ptr, tl.load(blk_k_ptr)) + tl.store(blk_v_st_ptr, tl.load(blk_v_ptr)) + tl.store(blk_z_st_ptr, tl.load(blk_z_ptr)) + b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK + a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK + for i in tl.static_range(b_end): + blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i + blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i + tl.store(blk_b_st_ptr, tl.load(blk_b_ptr)) + for i in tl.static_range(b_end, a_end): + blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i + blk_a_st_ptr = ( + a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end) + ) + tl.store(blk_a_st_ptr, tl.load(blk_a_ptr)) + + +def fused_qkvzba_split_reshape_cat( + mixed_qkvz, + mixed_ba, + num_heads_qk, + num_heads_v, + head_qk, + head_v, +): + batch, seq_len = mixed_qkvz.shape[0], 1 + qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v + mixed_qkv = torch.empty( + [batch * seq_len, qkv_dim_t], + dtype=mixed_qkvz.dtype, + device=mixed_qkvz.device, + ) + z = torch.empty( + [batch * seq_len, num_heads_v, head_v], + dtype=mixed_qkvz.dtype, + device=mixed_qkvz.device, + ) + b = torch.empty( + [batch * seq_len, num_heads_v], + dtype=mixed_ba.dtype, + device=mixed_ba.device, + ) + a = torch.empty_like(b) + grid = (batch * seq_len, num_heads_qk) + fused_qkvzba_split_reshape_cat_kernel[grid]( + mixed_qkv, + z, + b, + a, + mixed_qkvz, + mixed_ba, + num_heads_qk, + num_heads_v, + head_qk, + head_v, + num_warps=1, + num_stages=3, + ) + return mixed_qkv, z, b, a + + +if _is_npu: + from sgl_kernel_npu.fla.utils import fused_qkvzba_split_reshape_cat as fused_qkvzba_split_reshape_cat_npu + fused_qkvzba_split_reshape_cat = fused_qkvzba_split_reshape_cat_npu + class Qwen3GatedDeltaNet(nn.Module): def __init__( self, @@ -111,26 +249,26 @@ def __init__( prefix=add_prefix("conv1d", prefix), ) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + projection_size_ba = self.num_v_heads * 2 - # projection of the input hidden states - self.in_proj_qkvz = self.create_qkvz_proj( - hidden_size=self.hidden_size, - key_dim=self.key_dim, - value_dim=self.value_dim, + self.in_proj_qkvz = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=projection_size_qkvz, + bias=False, quant_config=quant_config, - prefix=add_prefix("in_proj_qkvz", prefix), tp_rank=self.attn_tp_rank, tp_size=self.attn_tp_size, + prefix=add_prefix("in_proj_qkvz", prefix), ) - - self.in_proj_ba = MergedColumnParallelLinear( + self.in_proj_ba = ColumnParallelLinear( input_size=self.hidden_size, - output_sizes=[self.num_v_heads] * 2, + output_size=projection_size_ba, bias=False, quant_config=quant_config, - prefix=add_prefix("in_proj_ba", prefix), tp_rank=self.attn_tp_rank, tp_size=self.attn_tp_size, + prefix=add_prefix("in_proj_ba", prefix), ) # Override weight_loader for packed checkpoint format. @@ -975,18 +1113,11 @@ def load_weights( ) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) - # self attention ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - # mlp ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), - # GDN - ("in_proj_qkvz.", "in_proj_qkv.", (0, 1, 2)), - ("in_proj_qkvz.", "in_proj_z.", 3), - ("in_proj_ba.", "in_proj_b.", 0), - ("in_proj_ba.", "in_proj_a.", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales diff --git a/python/sglang/srt/models/qwen3_next_mtp.py b/python/sglang/srt/models/qwen3_next_mtp.py index b2bdbbbe8705..cc4f0f4715e9 100644 --- a/python/sglang/srt/models/qwen3_next_mtp.py +++ b/python/sglang/srt/models/qwen3_next_mtp.py @@ -15,6 +15,7 @@ """Inference-only Qwen3Next MTP Speculative Decoding.""" import logging +import os from typing import Iterable, Optional, Tuple import torch @@ -23,6 +24,7 @@ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.hardware_backend.npu.graph_runner.npu_graph_runner import is_npu from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -30,7 +32,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import add_prefix +from sglang.srt.utils import add_prefix, is_npu logger = logging.getLogger(__name__) @@ -46,6 +48,8 @@ def __init__( nn.Module.__init__(self) self.config = config self.tp_size = get_tensor_model_parallel_world_size() + if get_global_server_args().speculative_draft_model_quantization is None: + quant_config = None self.quant_config = quant_config # if not set, model load will be broken in Qwen3NextForCausalLM load_weights() self.pp_group = get_pp_group() @@ -86,6 +90,10 @@ def forward( input_embeds: Optional[torch.Tensor] = None, **kwargs, ): + if is_npu() and self.quant_config is None: + # ascend mtp unquant + os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "1" + os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "0" if input_embeds is None: input_embeds = self.model.embed_tokens(input_ids) @@ -103,6 +111,10 @@ def forward( forward_batch, hidden_states, ) + if is_npu() and self.quant_config is None: + # ascend mtp unquant + os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "0" + os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "1" return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 74445c9cd1f2..8ac38aaa0065 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -326,6 +326,7 @@ class ServerArgs: ssl_ca_certs: Optional[str] = None ssl_keyfile_password: Optional[str] = None enable_ssl_refresh: bool = False + enable_http2: bool = False # Quantization and data type dtype: str = "auto" @@ -498,6 +499,8 @@ class ServerArgs: speculative_num_steps: Optional[int] = None speculative_eagle_topk: Optional[int] = None speculative_num_draft_tokens: Optional[int] = None + speculative_dflash_block_size: Optional[int] = None + speculative_dflash_draft_window_size: Optional[int] = None speculative_accept_threshold_single: float = 1.0 speculative_accept_threshold_acc: float = 1.0 speculative_token_map: Optional[str] = None @@ -638,6 +641,7 @@ class ServerArgs: enable_single_batch_overlap: bool = False tbo_token_distribution_threshold: float = 0.48 enable_torch_compile: bool = False + enable_piecewise_cuda_graph: bool = False disable_piecewise_cuda_graph: bool = False enforce_piecewise_cuda_graph: bool = False enable_torch_compile_debug_mode: bool = False @@ -923,6 +927,26 @@ def _handle_ssl_validation(self): "to be specified." ) + if self.enable_http2: + try: + import granian # noqa: F401 + except ImportError: + raise ValueError( + "--enable-http2 requires the 'granian' package. " + 'Install it with: pip install "sglang[http2]"' + ) + if self.enable_ssl_refresh: + raise ValueError( + "--enable-ssl-refresh is not supported with --enable-http2. " + "Granian does not support SSL certificate hot-reloading. " + "Use Uvicorn (the default) or handle certificate rotation externally." + ) + if self.tokenizer_worker_num > 1: + raise ValueError( + "--enable-http2 does not yet support --tokenizer-worker-num > 1. " + "Multi-worker HTTP/2 support will be added in a future release." + ) + def _handle_deprecated_args(self): # Handle deprecated tool call parsers deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"} @@ -1142,6 +1166,10 @@ def _handle_piecewise_cuda_graph(self): if self.attn_cp_size > 1: self.disable_piecewise_cuda_graph = True + # NPU can use this function when the piece cuda graph is explicitly declared + if self.enable_piecewise_cuda_graph: + self.disable_piecewise_cuda_graph = False + def _handle_gpu_memory_settings(self, gpu_mem): """ Configure GPU memory-dependent settings including @@ -2187,7 +2215,7 @@ def _handle_mamba_radix_cache( ) assert ( - is_cuda() + is_cuda(), is_npu() ), "Mamba extra_buffer is only supported on CUDA devices with FLA backend" if self.speculative_num_draft_tokens is not None: assert ( @@ -2204,6 +2232,13 @@ def _handle_mamba_radix_cache( == 0 ), f"For SSM models with extra buffer, either FLA_CHUNK_SIZE or page_size must be divisible by the other, got {FLA_CHUNK_SIZE=}, {self.page_size=}" elif not self.disable_radix_cache: # no_buffer + if self.page_size is not None and self.page_size != 1: + logger.warning( + f"{model_arch} with radix cache requires page_size=1 in the current " + f"Mamba scheduling mode (no_buffer), but got {self.page_size}. " + "Automatically setting page_size=1." + ) + self.page_size = 1 if self.speculative_algorithm is None: logger.warning( "Disabling overlap schedule since mamba no_buffer is not compatible with " @@ -2999,6 +3034,134 @@ def _handle_speculative_decoding(self): if self.speculative_algorithm == "NEXTN": self.speculative_algorithm = "EAGLE" + if self.speculative_algorithm == "DFLASH": + if self.enable_dp_attention: + raise ValueError( + "Currently DFLASH speculative decoding does not support dp attention." + ) + + if self.pp_size != 1: + raise ValueError( + "Currently DFLASH speculative decoding only supports pp_size == 1." + ) + + if self.speculative_draft_model_path is None: + raise ValueError( + "DFLASH speculative decoding requires setting --speculative-draft-model-path." + ) + + # DFLASH does not use EAGLE-style `num_steps`/`topk`, but those fields still + # affect generic scheduler/KV-cache accounting (buffer sizing, KV freeing, + # RoPE reservation). Force them to 1 to avoid surprising memory behavior. + # + # For DFlash, the natural unit is `block_size` (verify window length). + if self.speculative_num_steps is None: + self.speculative_num_steps = 1 + elif int(self.speculative_num_steps) != 1: + logger.warning( + "DFLASH only supports speculative_num_steps == 1; overriding speculative_num_steps=%s to 1.", + self.speculative_num_steps, + ) + self.speculative_num_steps = 1 + + if self.speculative_eagle_topk is None: + self.speculative_eagle_topk = 1 + elif int(self.speculative_eagle_topk) != 1: + logger.warning( + "DFLASH only supports speculative_eagle_topk == 1; overriding speculative_eagle_topk=%s to 1.", + self.speculative_eagle_topk, + ) + self.speculative_eagle_topk = 1 + + if self.speculative_dflash_block_size is not None: + if int(self.speculative_dflash_block_size) <= 0: + raise ValueError( + "DFLASH requires --speculative-dflash-block-size to be positive, " + f"got {self.speculative_dflash_block_size}." + ) + if self.speculative_num_draft_tokens is not None and int( + self.speculative_num_draft_tokens + ) != int(self.speculative_dflash_block_size): + raise ValueError( + "Both --speculative-num-draft-tokens and --speculative-dflash-block-size are set " + "but they differ. For DFLASH they must match. " + f"speculative_num_draft_tokens={self.speculative_num_draft_tokens}, " + f"speculative_dflash_block_size={self.speculative_dflash_block_size}." + ) + self.speculative_num_draft_tokens = int( + self.speculative_dflash_block_size + ) + + window_size = None + if self.speculative_dflash_draft_window_size is not None: + window_size = int(self.speculative_dflash_draft_window_size) + if window_size <= 0: + raise ValueError( + "DFLASH requires --speculative-dflash-draft-window-size " + f"to be positive, got {window_size}." + ) + self.speculative_dflash_draft_window_size = window_size + + if self.speculative_num_draft_tokens is None: + from sglang.srt.speculative.dflash_utils import ( + parse_dflash_draft_config, + ) + + model_override_args = json.loads(self.json_model_override_args) + inferred_block_size = None + try: + from sglang.srt.utils.hf_transformers_utils import get_config + + draft_hf_config = get_config( + self.speculative_draft_model_path, + trust_remote_code=self.trust_remote_code, + revision=self.speculative_draft_model_revision, + model_override_args=model_override_args, + ) + inferred_block_size = parse_dflash_draft_config( + draft_hf_config=draft_hf_config + ).resolve_block_size(default=None) + except Exception as e: + logger.warning( + "Failed to infer DFLASH block_size from draft model config; " + "defaulting speculative_num_draft_tokens to 16. Error: %s", + e, + ) + + if inferred_block_size is None: + inferred_block_size = 16 + logger.warning( + "speculative_num_draft_tokens is not set; defaulting to %d for DFLASH.", + inferred_block_size, + ) + self.speculative_num_draft_tokens = inferred_block_size + + if window_size is not None: + draft_tokens = int(self.speculative_num_draft_tokens) + if window_size < draft_tokens: + raise ValueError( + "DFLASH --speculative-dflash-draft-window-size must be >= " + "--speculative-num-draft-tokens (block_size). " + f"window_size={window_size}, block_size={draft_tokens}." + ) + + if self.max_running_requests is None: + self.max_running_requests = 48 + logger.warning( + "Max running requests is reset to 48 for speculative decoding. You can override this by explicitly setting --max-running-requests." + ) + + self.disable_overlap_schedule = True + logger.warning( + "Overlap scheduler is disabled when using DFLASH speculative decoding (spec v2 is not supported yet)." + ) + + if self.enable_mixed_chunk: + self.enable_mixed_chunk = False + logger.warning( + "Mixed chunked prefill is disabled because of using dflash speculative decoding." + ) + if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"): if self.speculative_algorithm == "STANDALONE" and self.enable_dp_attention: # TODO: support dp attention for standalone speculative decoding @@ -3858,6 +4021,14 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Enable automatic SSL certificate hot-reloading when cert/key " "files change on disk. Requires --ssl-certfile and --ssl-keyfile.", ) + parser.add_argument( + "--enable-http2", + action="store_true", + default=ServerArgs.enable_http2, + help="Use Granian instead of Uvicorn as the ASGI server, enabling HTTP/1.1 and " + "HTTP/2 auto-negotiation. Clients may use h2c (cleartext HTTP/2) or plain HTTP/1.1. " + "Requires 'pip install sglang[http2]'.", + ) # Quantization and data type parser.add_argument( @@ -4796,7 +4967,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--speculative-algorithm", type=str, - choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"], + choices=["DFLASH", "EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"], help="Speculative algorithm.", ) parser.add_argument( @@ -4840,6 +5011,21 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The number of tokens sampled from the draft model in Speculative Decoding.", default=ServerArgs.speculative_num_draft_tokens, ) + parser.add_argument( + "--speculative-dflash-block-size", + type=int, + help="DFLASH only. Block size (verify window length). Alias of --speculative-num-draft-tokens for DFLASH.", + default=ServerArgs.speculative_dflash_block_size, + ) + parser.add_argument( + "--speculative-dflash-draft-window-size", + type=int, + help="DFLASH only. Sliding window size for the draft-model KV cache. " + "When set, the draft worker keeps a recent target-token window in its " + "local cache (paged backends may retain up to one extra page on the left " + "for alignment). Default is full context.", + default=ServerArgs.speculative_dflash_draft_window_size, + ) parser.add_argument( "--speculative-accept-threshold-single", type=float, @@ -5532,8 +5718,8 @@ def add_cli_args(parser: argparse.ArgumentParser): ) parser.add_argument( "--enable-piecewise-cuda-graph", - action=DeprecatedAction, - help="Deprecated: Piecewise cuda graph is enabled by default. Use --enforce-piecewise-cuda-graph to skip auto-disable conditions.", + action="store_true", + help="Optimize the model with piecewise cuda graph for extend/prefill only.", ) parser.add_argument( "--enforce-piecewise-cuda-graph", diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py new file mode 100644 index 000000000000..fbb06cc70ee1 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_info.py @@ -0,0 +1,501 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Tuple + +import torch + +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.sampler import apply_custom_logit_processor +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, + get_last_loc, +) +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode +from sglang.srt.speculative.dflash_utils import ( + compute_dflash_accept_len_and_bonus, + compute_dflash_sampling_accept_len_and_bonus, + is_dflash_sampling_verify_available, +) +from sglang.srt.speculative.spec_info import SpecInput, SpecInputType +from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func + + +def _compute_paged_keep_slots( + *, + prefix_lens: torch.Tensor, + commit_lens: torch.Tensor, + draft_token_num: int, + page_size: int, +) -> torch.Tensor: + """Compute how many draft slots per request must remain allocated. + + The allocator frees at page granularity for paged mode, so we can only release + full pages from the tail after verify. + """ + + if page_size <= 1: + raise ValueError(f"Expected page_size > 1, got {page_size}.") + + seq_dtype = prefix_lens.dtype + extended_lens = prefix_lens + int(draft_token_num) + new_lens = prefix_lens + commit_lens.to(seq_dtype) + aligned_new_lens = ((new_lens + page_size - 1) // page_size) * page_size + keep_lens = torch.minimum(aligned_new_lens, extended_lens) + keep_slots = (keep_lens - prefix_lens).to(torch.int64) + keep_slots.clamp_(min=0, max=int(draft_token_num)) + return keep_slots + + +@dataclass +class DFlashDraftInput(SpecInput): + """Per-batch DFlash draft state for spec-v1 (non-overlap) scheduling. + + This object is stored on `ScheduleBatch.spec_info` between decode iterations. + It is NOT sent to model attention backends; the DFlash worker uses it to run + the draft model and to track draft-side cache progress. + + When draft windowing is disabled, `draft_seq_lens` matches the committed target + prefix length already materialized in the draft KV cache. When windowing is + enabled, `draft_seq_lens` is the logical resident length in the draft worker's + compact req-to-token mapping. In paged mode this may exceed the requested + window by up to `page_size - 1` so the local page table remains valid. `ctx_lens` + tracks newly committed target tokens that still need draft KV materialization. + """ + + # Current token to start the next DFlash block (one per request). + verified_id: torch.Tensor + + # Flattened context features for tokens that need to be appended into the draft cache. + # Shape: [sum(ctx_lens), K * hidden_size], where K is the number of target-layer + # hidden-state features concatenated per token (len(dflash_config.target_layer_ids), + # or default K == draft_num_layers for existing checkpoints). + target_hidden: torch.Tensor + + # Context lengths per request, used to slice `target_hidden`. Device tensor (int32). + ctx_lens: torch.Tensor + + # How many committed tokens are visible to the draft worker per request. + draft_seq_lens: torch.Tensor + + def __post_init__(self): + super().__init__(spec_input_type=SpecInputType.DFLASH_DRAFT) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + # Draft state does not change token accounting. + return (1, 1) + + def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): + old_ctx_lens = self.ctx_lens + old_target_hidden = self.target_hidden + + self.verified_id = self.verified_id[new_indices] + self.ctx_lens = old_ctx_lens[new_indices] + self.draft_seq_lens = self.draft_seq_lens[new_indices] + + if old_target_hidden is None or old_target_hidden.numel() == 0: + self.target_hidden = old_target_hidden + return + + # Rebuild target_hidden for the filtered batch using vectorized indexing. + old_bs = int(old_ctx_lens.shape[0]) + offsets = torch.zeros( + (old_bs + 1,), dtype=torch.int64, device=old_ctx_lens.device + ) + offsets[1:].copy_(old_ctx_lens.to(torch.int64).cumsum(0)) + + start = offsets[:-1] + seg_start = start[new_indices] + seg_lens = old_ctx_lens[new_indices].to(torch.int64) + + max_len = int(seg_lens.max().item()) if seg_lens.numel() > 0 else 0 + if max_len <= 0: + self.target_hidden = old_target_hidden[:0] + return + + r = torch.arange(max_len, device=old_ctx_lens.device, dtype=torch.int64)[ + None, : + ] + pos2d = seg_start[:, None] + r + mask = r < seg_lens[:, None] + flat_pos = pos2d[mask] + self.target_hidden = ( + old_target_hidden.index_select(0, flat_pos) + if flat_pos.numel() > 0 + else old_target_hidden[:0] + ) + + def merge_batch(self, spec_info: "DFlashDraftInput"): + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], dim=0) + self.ctx_lens = torch.cat([self.ctx_lens, spec_info.ctx_lens], dim=0) + self.draft_seq_lens = torch.cat( + [self.draft_seq_lens, spec_info.draft_seq_lens], dim=0 + ) + if self.target_hidden is None or self.target_hidden.numel() == 0: + self.target_hidden = spec_info.target_hidden + elif ( + spec_info.target_hidden is not None and spec_info.target_hidden.numel() > 0 + ): + self.target_hidden = torch.cat( + [self.target_hidden, spec_info.target_hidden], dim=0 + ) + + +@dataclass +class DFlashVerifyInput(SpecInput): + """Inputs for a target-model verify forward in DFlash (spec-v1). + + The verify forward is run with `ForwardMode.TARGET_VERIFY` so that the target + model returns logits for all tokens in the block, enabling accept-length + computation. + """ + + draft_token: torch.Tensor + positions: torch.Tensor + draft_token_num: int + # Kept for compatibility with attention backends that gate tree metadata by `topk > 1`. + # DFLASH verify is linear (non-tree), so this is always 1. + topk: int = 1 + # Custom attention "allow mask" for TARGET_VERIFY in backends that require it (e.g. triton). + # Semantics follow SGLang speculative conventions: True means the (q, k) pair is allowed. + custom_mask: torch.Tensor | None = None + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL + + # Shape info for padding (e.g., DP attention / CUDA graph). + num_tokens_per_batch: int = -1 + + def __post_init__(self): + super().__init__(spec_input_type=SpecInputType.DFLASH_VERIFY) + if self.num_tokens_per_batch == -1: + self.num_tokens_per_batch = int(self.draft_token_num) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + return self.draft_token_num, self.draft_token_num + + def prepare_for_verify( + self, + batch: ScheduleBatch, + page_size: int, + *, + build_custom_mask: bool = True, + ): + if batch.forward_mode.is_idle(): + return + + batch.input_ids = self.draft_token + + if page_size == 1: + batch.out_cache_loc = alloc_token_slots( + batch.tree_cache, len(batch.input_ids) + ) + end_offset = batch.seq_lens + self.draft_token_num + else: + prefix_lens = batch.seq_lens + prefix_lens_cpu = batch.seq_lens_cpu + end_offset = prefix_lens + self.draft_token_num + end_offset_cpu = prefix_lens_cpu + self.draft_token_num + last_loc = get_last_loc( + batch.req_to_token_pool.req_to_token, + batch.req_pool_indices, + prefix_lens, + ) + batch.out_cache_loc = alloc_paged_token_slots_extend( + batch.tree_cache, + prefix_lens, + prefix_lens_cpu, + end_offset, + end_offset_cpu, + last_loc, + len(batch.input_ids), + ) + self.last_loc = last_loc + + bs = batch.batch_size() + assign_req_to_token_pool_func( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + end_offset, + batch.out_cache_loc, + bs, + ) + + if not build_custom_mask: + self.custom_mask = None + return + + if self.draft_token_num <= 0: + raise ValueError( + f"DFLASH draft_token_num must be positive, got {self.draft_token_num}." + ) + mask_chunks: List[torch.Tensor] = [] + q_len = int(self.draft_token_num) + q_idx = torch.arange(q_len, device=batch.device, dtype=torch.int32).unsqueeze(1) + for prefix_len in batch.seq_lens_cpu.tolist(): + prefix_len_i = int(prefix_len) + kv_len = prefix_len_i + q_len + k_idx = torch.arange( + kv_len, device=batch.device, dtype=torch.int32 + ).unsqueeze(0) + # Allow attending to the full prefix and to tokens up to (and including) the + # current query position within the verify block (standard causal masking). + allow = k_idx <= (prefix_len_i + q_idx) + mask_chunks.append(allow.flatten()) + self.custom_mask = ( + torch.cat(mask_chunks, dim=0) + if mask_chunks + else torch.empty((0,), dtype=torch.bool, device=batch.device) + ) + + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: int, + req_to_token: torch.Tensor, + ): + device = req_pool_indices.device + bs = len(req_pool_indices) + + qo_indptr = torch.arange( + 0, + (bs + 1) * self.draft_token_num, + step=self.draft_token_num, + dtype=torch.int32, + device=device, + ) + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=device) + paged_kernel_lens = paged_kernel_lens + self.draft_token_num + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_indices = torch.empty( + paged_kernel_lens_sum + self.draft_token_num * bs, + dtype=torch.int32, + device=device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + mask = self.custom_mask + if mask is not None: + mask_numel = ( + paged_kernel_lens_sum * self.draft_token_num + + (self.draft_token_num**2) * bs + ) + if mask.numel() < mask_numel: + # FIXME(attn): temporary fix for custom mask padding with cuda graph + mask = torch.cat( + [ + mask, + torch.full( + (mask_numel - mask.numel(),), + True, + dtype=torch.bool, + device=device, + ), + ], + dim=0, + ) + self.custom_mask = mask + return kv_indices, cum_kv_seq_len, qo_indptr, mask + + def verify( + self, + *, + batch: ScheduleBatch, + logits_output: LogitsProcessorOutput, + page_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """DFlash verification for greedy and non-greedy sampling. + + Returns: + new_verified_id: int64 tensor [bs] (the new current token per request) + commit_lens: int32 tensor [bs] (how many verify-input tokens are committed) + next_target_hidden: tensor [sum(commit_lens), feature_dim] + accept_length_per_req_cpu: list[int] (accepted draft tokens per request) + """ + if batch.forward_mode.is_idle(): + empty = torch.empty((0,), dtype=torch.int64, device=batch.device) + return empty, empty.to(torch.int32), empty, [] + + bs = batch.batch_size() + device = logits_output.next_token_logits.device + + sampling_info = batch.sampling_info + if sampling_info is not None: + if len(sampling_info) != bs: + raise RuntimeError( + "DFLASH verify sampling_info size mismatch: " + f"len(sampling_info)={len(sampling_info)}, bs={bs}." + ) + + # Keep speculative verify semantics consistent with normal sampling path. + if sampling_info.has_custom_logit_processor: + apply_custom_logit_processor( + logits_output.next_token_logits, + sampling_info, + num_tokens_in_batch=self.draft_token_num, + ) + + if ( + sampling_info.penalizer_orchestrator.is_required + or sampling_info.logit_bias is not None + ): + linear_penalty = torch.zeros( + (bs, logits_output.next_token_logits.shape[1]), + dtype=torch.float32, + device=device, + ) + sampling_info.apply_logits_bias(linear_penalty) + logits_output.next_token_logits.add_( + torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) + ) + + candidates = self.draft_token.view(bs, self.draft_token_num) + if ( + sampling_info is not None + and not sampling_info.is_all_greedy + and is_dflash_sampling_verify_available() + ): + accept_len, bonus = compute_dflash_sampling_accept_len_and_bonus( + candidates=candidates, + next_token_logits=logits_output.next_token_logits, + sampling_info=sampling_info, + ) + else: + target_predict = torch.argmax(logits_output.next_token_logits, dim=-1).view( + bs, self.draft_token_num + ) + accept_len, bonus = compute_dflash_accept_len_and_bonus( + candidates=candidates, + target_predict=target_predict, + ) + + # Single D2H transfer: candidates[1:] + accept_len + bonus + packed = torch.cat( + [candidates[:, 1:], accept_len.unsqueeze(1), bonus.unsqueeze(1)], dim=1 + ).cpu() + + max_acc = self.draft_token_num - 1 + accept_length_per_req_cpu: List[int] = [] + commit_lens_cpu: List[int] = [] + new_verified_list: List[int] = [] + + for i, req in enumerate(batch.reqs): + acc_len = int(packed[i, max_acc].item()) + proposed = packed[i, :acc_len].tolist() + [ + int(packed[i, max_acc + 1].item()) + ] + + appended = 0 + for token_id in proposed: + token_id = int(token_id) + req.output_ids.append(token_id) + appended += 1 + req.check_finished() + if req.finished(): + break + if req.grammar is not None: + req.grammar.accept_token(token_id) + + if req.output_ids: + new_verified_token = int(req.output_ids[-1]) + elif req.origin_input_ids: + # If no token was appended in this verify step, keep the current token unchanged. + new_verified_token = int(req.origin_input_ids[-1]) + else: + raise RuntimeError( + "DFLASH verify cannot determine current token: both output_ids and origin_input_ids are empty." + ) + + commit_lens_cpu.append(appended) + new_verified_list.append(new_verified_token) + accept_length_per_req_cpu.append(max(0, appended - 1)) + req.spec_verify_ct += 1 + req.spec_accepted_tokens += accept_length_per_req_cpu[-1] + + commit_lens = torch.tensor(commit_lens_cpu, dtype=torch.int32, device=device) + new_verified_id = torch.tensor( + new_verified_list, dtype=torch.int64, device=device + ) + + # Free uncommitted KV cache slots and compact out_cache_loc. + if page_size == 1: + out_cache_loc = batch.out_cache_loc.view(bs, self.draft_token_num) + keep_mask = ( + torch.arange(self.draft_token_num, device=device)[None, :] + < commit_lens[:, None] + ) + batch.token_to_kv_pool_allocator.free(out_cache_loc[~keep_mask]) + batch.out_cache_loc = out_cache_loc[keep_mask] + else: + out_cache_loc = batch.out_cache_loc.view(bs, self.draft_token_num) + row_offsets = torch.arange(self.draft_token_num, device=device)[None, :] + keep_slots = _compute_paged_keep_slots( + prefix_lens=batch.seq_lens, + commit_lens=commit_lens, + draft_token_num=self.draft_token_num, + page_size=page_size, + ) + free_mask = row_offsets >= keep_slots[:, None] + batch.token_to_kv_pool_allocator.free(out_cache_loc[free_mask]) + + keep_mask = row_offsets < commit_lens[:, None] + batch.out_cache_loc = out_cache_loc[keep_mask] + + # Update req-level KV cache accounting. + for req, commit_len in zip(batch.reqs, commit_lens_cpu, strict=True): + req.kv_committed_len += commit_len + req.kv_allocated_len = req.kv_committed_len + + # Update req_to_token pool mapping for newly committed tokens. + end_offset = batch.seq_lens + commit_lens.to(batch.seq_lens.dtype) + assign_req_to_token_pool_func( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + end_offset, + batch.out_cache_loc, + bs, + ) + + # Update batch seq lens. + batch.seq_lens.add_(commit_lens.to(batch.seq_lens.dtype)) + batch.seq_lens_cpu.add_( + torch.tensor(commit_lens_cpu, dtype=batch.seq_lens_cpu.dtype) + ) + # Keep seq_lens_sum in sync; flashinfer indices updaters rely on this for buffer sizing. + batch.seq_lens_sum += sum(commit_lens_cpu) + + # Build next-step context features from the committed verify-input tokens. + hidden = logits_output.hidden_states + if hidden is None: + raise RuntimeError( + "DFLASH verify requires target hidden states, but got None." + ) + hidden = hidden.view(bs, self.draft_token_num, -1) + segments: List[torch.Tensor] = [] + for i, ln in enumerate(commit_lens_cpu): + if ln > 0: + segments.append(hidden[i, :ln, :]) + next_target_hidden = torch.cat(segments, dim=0) if segments else hidden[:0] + + # Avoid confusing downstream consumers (spec-v1 decode doesn't use this). + logits_output.hidden_states = None + + return ( + new_verified_id, + commit_lens, + next_target_hidden, + accept_length_per_req_cpu, + ) diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py new file mode 100644 index 000000000000..ddec049e0a24 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -0,0 +1,637 @@ +from __future__ import annotations + +from dataclasses import dataclass +from numbers import Integral +from typing import Any, List, Optional, Tuple + +import torch +import torch.nn.functional as F + +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.utils import is_cuda + +DEFAULT_DFLASH_MASK_TOKEN = "<|MASK|>" + +_DFLASH_SAMPLING_VERIFY_AVAILABLE = False +_DFLASH_CHAIN_VERIFY_BUFFERS: dict[tuple[Optional[int], int], dict[str, Any]] = {} +_DFLASH_VERIFY_SKIP_CUSTOM_MASK_BACKENDS = frozenset( + { + "FlashInferAttnBackend", + "FlashInferMLAAttnBackend", + "FlashAttentionBackend", + "TRTLLMHAAttnBackend", + "TRTLLMMLABackend", + } +) + + +if is_cuda(): + try: + from sgl_kernel import ( + top_k_renorm_prob, + top_p_renorm_prob, + tree_speculative_sampling_target_only, + ) + + _DFLASH_SAMPLING_VERIFY_AVAILABLE = True + except Exception: + top_k_renorm_prob = None + top_p_renorm_prob = None + tree_speculative_sampling_target_only = None +else: + top_k_renorm_prob = None + top_p_renorm_prob = None + tree_speculative_sampling_target_only = None + + +def is_dflash_sampling_verify_available() -> bool: + return _DFLASH_SAMPLING_VERIFY_AVAILABLE + + +def scale_kv_cell_size_per_token_for_dflash( + *, + target_cell_size_per_token: int, + target_num_layers: int, + draft_num_layers: int, + draft_cell_size_per_token: Optional[int] = None, +) -> int: + """Compute bytes/token budget for combined target+draft KV pools (DFLASH). + + DFLASH runs a separate draft runner with its own KV pool. The target runner's + token capacity must fit both pools in aggregate. + + Returns: + Approximate per-token bytes for (target KV + draft KV), expressed as a + scaled version of `target_cell_size_per_token`, unless an explicit + `draft_cell_size_per_token` is provided (in which case we sum them). + """ + if target_cell_size_per_token <= 0: + raise ValueError( + "target_cell_size_per_token must be positive, " + f"got {target_cell_size_per_token}." + ) + + if draft_cell_size_per_token is not None: + draft_cell_size_per_token = int(draft_cell_size_per_token) + if draft_cell_size_per_token <= 0: + raise ValueError( + "draft_cell_size_per_token must be positive when provided, " + f"got {draft_cell_size_per_token}." + ) + return int(target_cell_size_per_token) + int(draft_cell_size_per_token) + + if target_num_layers <= 0 or draft_num_layers <= 0: + return int(target_cell_size_per_token) + + total_layers = int(target_num_layers) + int(draft_num_layers) + return ( + int(target_cell_size_per_token) * int(total_layers) + int(target_num_layers) - 1 + ) // int(target_num_layers) + + +def resolve_dflash_verify_mask_policy(attn_backend: Any) -> tuple[str, bool]: + backend = attn_backend + for _ in range(4): + full_backend = getattr(backend, "full_attn_backend", None) + if full_backend is None: + break + backend = full_backend + backend_name = type(backend).__name__ + return backend_name, (backend_name not in _DFLASH_VERIFY_SKIP_CUSTOM_MASK_BACKENDS) + + +def _get_or_create_chain_verify_buffers( + *, + bs: int, + draft_token_num: int, + device: torch.device, +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor +]: + key = (device.index, int(draft_token_num)) + cached = _DFLASH_CHAIN_VERIFY_BUFFERS.get(key) + cap_bs = 0 if cached is None else int(cached["cap_bs"]) + if cap_bs < bs: + new_cap = max(int(bs), cap_bs * 2 if cap_bs > 0 else int(bs)) + retrieve_index = torch.arange( + new_cap * draft_token_num, dtype=torch.int64, device=device + ).view(new_cap, draft_token_num) + row_next = torch.arange( + 1, draft_token_num + 1, dtype=torch.int64, device=device + ) + row_next[-1] = -1 + retrieve_next_token = row_next.unsqueeze(0).expand(new_cap, -1).clone() + retrieve_next_sibling = torch.full( + (new_cap, draft_token_num), -1, dtype=torch.int64, device=device + ) + predicts = torch.empty( + (new_cap * draft_token_num,), dtype=torch.int32, device=device + ) + accept_index = torch.empty( + (new_cap, draft_token_num), dtype=torch.int32, device=device + ) + accept_token_num = torch.empty((new_cap,), dtype=torch.int32, device=device) + cached = { + "cap_bs": int(new_cap), + "retrieve_index": retrieve_index, + "retrieve_next_token": retrieve_next_token, + "retrieve_next_sibling": retrieve_next_sibling, + "predicts": predicts, + "accept_index": accept_index, + "accept_token_num": accept_token_num, + } + _DFLASH_CHAIN_VERIFY_BUFFERS[key] = cached + + assert cached is not None + retrieve_index = cached["retrieve_index"][:bs] + retrieve_next_token = cached["retrieve_next_token"][:bs] + retrieve_next_sibling = cached["retrieve_next_sibling"][:bs] + predicts = cached["predicts"][: bs * draft_token_num] + accept_index = cached["accept_index"][:bs] + accept_token_num = cached["accept_token_num"][:bs] + return ( + retrieve_index, + retrieve_next_token, + retrieve_next_sibling, + predicts, + accept_index, + accept_token_num, + ) + + +def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]: + """Select target layer indices used to build DFlash context features. + + Args: + num_target_layers: Number of transformer layers in the runtime target model. + num_draft_layers: Number of layers in the DFlash draft model. + + Returns: + A list of 0-based target layer indices of length `num_draft_layers`. + + Notes: + - DFlash uses hidden states after each selected target layer (HF-style). + - SGLang captures "before layer i", so the model hook will typically add +1 + when mapping to capture points. + """ + if num_target_layers <= 0: + raise ValueError( + f"num_target_layers must be positive, got {num_target_layers}." + ) + if num_draft_layers <= 0: + raise ValueError(f"num_draft_layers must be positive, got {num_draft_layers}.") + + if num_draft_layers == 1: + return [num_target_layers // 2] + + start = 1 + end = num_target_layers - 3 + if end < start: + raise ValueError( + "DFlash layer selection requires num_target_layers >= 4. " + f"Got num_target_layers={num_target_layers}." + ) + + span = end - start + return [ + int(round(start + (i * span) / (num_draft_layers - 1))) + for i in range(num_draft_layers) + ] + + +def _cfg_get(config: Any, key: str, default: Any = None) -> Any: + if isinstance(config, dict): + return config.get(key, default) + return getattr(config, key, default) + + +def _get_text_config(config: Any) -> Any: + if config is None: + return None + if isinstance(config, dict): + return config.get("text_config", config) + text_config = getattr(config, "text_config", None) + if text_config is not None: + return text_config + get_text_config = getattr(config, "get_text_config", None) + if callable(get_text_config): + try: + resolved = get_text_config() + if resolved is not None: + return resolved + except TypeError: + pass + return config + + +def _get_dflash_config(config: Any) -> dict: + if isinstance(config, dict): + cfg = config.get("dflash_config", None) + else: + cfg = getattr(config, "dflash_config", None) + if cfg is None: + return {} + if isinstance(cfg, dict): + return cfg + + try: + return dict(cfg) + except Exception: + return {} + + +def _parse_optional_int( + value: Any, + *, + field_name: str, + min_value: Optional[int] = None, +) -> Optional[int]: + if value is None: + return None + try: + parsed = int(value) + except Exception as e: + raise ValueError(f"Invalid {field_name}={value!r}.") from e + if min_value is not None and parsed < int(min_value): + comparator = "positive" if int(min_value) == 1 else f">= {int(min_value)}" + raise ValueError(f"{field_name} must be {comparator}, got {parsed}.") + return parsed + + +@dataclass(frozen=True) +class DFlashDraftConfig: + num_hidden_layers: Optional[int] + num_target_layers: Optional[int] + block_size: Optional[int] + target_layer_ids: Optional[List[int]] + mask_token: str + mask_token_id: Optional[int] + + def require_num_layers(self) -> int: + if self.num_hidden_layers is None: + raise ValueError( + "DFLASH requires draft num_hidden_layers in config. " + "Got config without num_hidden_layers." + ) + return int(self.num_hidden_layers) + + def resolve_block_size(self, *, default: Optional[int] = None) -> Optional[int]: + return self.block_size if self.block_size is not None else default + + def resolve_target_layer_ids( + self, + *, + target_num_layers: int, + draft_num_layers: Optional[int] = None, + ) -> List[int]: + target_num_layers = int(target_num_layers) + if target_num_layers <= 0: + raise ValueError( + f"target_num_layers must be positive, got {target_num_layers}." + ) + + if self.target_layer_ids is None: + if draft_num_layers is None: + draft_num_layers = self.require_num_layers() + return build_target_layer_ids(target_num_layers, int(draft_num_layers)) + + resolved = list(self.target_layer_ids) + if len(resolved) <= 0: + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be non-empty. " + f"Got len(target_layer_ids)={len(resolved)}." + ) + for idx, val in enumerate(resolved): + if val < 0 or val >= target_num_layers: + raise ValueError( + "DFLASH target_layer_ids contains an out-of-range layer id. " + f"target_layer_ids[{idx}]={val}, target_num_layers={target_num_layers}." + ) + return resolved + + +def parse_dflash_draft_config(*, draft_hf_config: Any) -> DFlashDraftConfig: + """Parse and validate DFLASH draft config fields from HF config/dict.""" + dflash_cfg = _get_dflash_config(draft_hf_config) + draft_text_config = _get_text_config(draft_hf_config) + + num_hidden_layers = _parse_optional_int( + _cfg_get(draft_text_config, "num_hidden_layers", None), + field_name="DFLASH draft num_hidden_layers", + min_value=1, + ) + raw_num_target_layers = dflash_cfg.get( + "num_target_layers", + _cfg_get(draft_hf_config, "num_target_layers", None), + ) + num_target_layers = _parse_optional_int( + raw_num_target_layers, + field_name="DFLASH draft num_target_layers", + min_value=1, + ) + + # Keep support for current checkpoints where block_size is top-level. + raw_block_size = dflash_cfg.get( + "block_size", + _cfg_get(draft_hf_config, "block_size", None), + ) + block_size = _parse_optional_int( + raw_block_size, + field_name="DFLASH block_size", + min_value=1, + ) + + layer_ids = dflash_cfg.get( + "target_layer_ids", + _cfg_get(draft_hf_config, "target_layer_ids", None), + ) + parsed_target_layer_ids: Optional[List[int]] + if layer_ids is None: + parsed_target_layer_ids = None + else: + if not isinstance(layer_ids, (list, tuple)): + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be a list of ints, " + f"got type={type(layer_ids).__name__}." + ) + parsed_target_layer_ids = [int(x) for x in layer_ids] + if len(parsed_target_layer_ids) <= 0: + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be non-empty. " + f"Got len(target_layer_ids)={len(parsed_target_layer_ids)}." + ) + + mask_token = dflash_cfg.get("mask_token", None) + if mask_token is None: + mask_token = DEFAULT_DFLASH_MASK_TOKEN + if not isinstance(mask_token, str) or not mask_token: + raise ValueError( + "DFLASH dflash_config.mask_token must be a non-empty string, " + f"got {mask_token!r}." + ) + + mask_token_id = dflash_cfg.get("mask_token_id", None) + if mask_token_id is not None: + if not isinstance(mask_token_id, Integral) or isinstance(mask_token_id, bool): + raise ValueError( + "DFLASH dflash_config.mask_token_id must be an integer, " + f"got {mask_token_id!r} (type={type(mask_token_id).__name__})." + ) + mask_token_id = int(mask_token_id) + if mask_token_id < 0: + raise ValueError( + "DFLASH dflash_config.mask_token_id must be non-negative, " + f"got {mask_token_id}." + ) + + return DFlashDraftConfig( + num_hidden_layers=num_hidden_layers, + num_target_layers=num_target_layers, + block_size=block_size, + target_layer_ids=parsed_target_layer_ids, + mask_token=mask_token, + mask_token_id=mask_token_id, + ) + + +def can_dflash_slice_qkv_weight(qkv_proj: Any) -> Tuple[bool, str]: + """Validate whether DFlash can slice KV weights from a fused QKV linear layer.""" + quant_method = getattr(qkv_proj, "quant_method", None) + if not isinstance(quant_method, UnquantizedLinearMethod): + return ( + False, + "quantized qkv_proj is not supported for this path " + f"(quant_method={type(quant_method).__name__})", + ) + if not hasattr(qkv_proj, "weight"): + return False, "qkv weight tensor is missing" + return True, "" + + +def can_dflash_use_fused_qkv_proj(qkv_proj: Any) -> Tuple[bool, str]: + """Validate whether a QKV layer is eligible for DFlash fused KV materialization.""" + eligible, reason = can_dflash_slice_qkv_weight(qkv_proj) + if not eligible: + return False, reason + if getattr(qkv_proj, "bias", None) is not None: + return False, "qkv bias is not supported for fused KV path" + return True, "" + + +def compute_dflash_accept_len_and_bonus( + *, + candidates: torch.Tensor, + target_predict: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute DFlash accept lengths and bonus tokens (greedy verify rule). + + Args: + candidates: Token ids proposed by the DFlash draft, including the current token. + Shape: [bs, block_size]. candidates[:, 0] is the current token. + target_predict: Token ids predicted by the target model for each position in the block. + Shape: [bs, block_size]. target_predict[:, t] corresponds to argmax at position t. + + Returns: + accept_len: int32 tensor [bs], number of accepted *draft* tokens (excluding current token and bonus token). + bonus: int64 tensor [bs], the target-predicted token at index accept_len (the "bonus" token to append). + + Notes: + Matches the reference implementation rule: + accept while candidates[:, 1:] == target_predict[:, :-1] consecutively. + """ + if candidates.ndim != 2: + raise ValueError(f"candidates must be 2D, got shape={tuple(candidates.shape)}") + if target_predict.shape != candidates.shape: + raise ValueError( + "target_predict must have the same shape as candidates. " + f"candidates.shape={tuple(candidates.shape)}, target_predict.shape={tuple(target_predict.shape)}" + ) + + bs, block_size = candidates.shape + if bs <= 0: + raise ValueError(f"batch size must be positive, got {bs}.") + if block_size <= 0: + raise ValueError(f"block_size must be positive, got {block_size}.") + + matches = candidates[:, 1:] == target_predict[:, :-1] + accept_len = matches.to(torch.int32).cumprod(dim=1).sum(dim=1) + bonus = target_predict[torch.arange(bs, device=target_predict.device), accept_len] + return accept_len, bonus.to(torch.int64) + + +def compute_dflash_sampling_accept_len_and_bonus( + *, + candidates: torch.Tensor, + next_token_logits: torch.Tensor, + sampling_info: Any, + threshold_single: Optional[float] = None, + threshold_acc: Optional[float] = None, + uniform_samples: Optional[torch.Tensor] = None, + uniform_samples_for_final_sampling: Optional[torch.Tensor] = None, + use_sparse_topk: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute DFlash accept lengths and bonus tokens for non-greedy sampling. + + This is a chain-specialized variant of speculative target-only verification: + - DFlash proposals are linear (topk == 1), so each verify level has at most one candidate. + - When a candidate is rejected at a level, the final token is sampled from + `relu(q - p)` where `p` has only the rejected candidate mass. + """ + if not _DFLASH_SAMPLING_VERIFY_AVAILABLE: + raise RuntimeError( + "DFLASH non-greedy verification is unavailable on this build/device." + ) + if candidates.ndim != 2: + raise ValueError(f"candidates must be 2D, got shape={tuple(candidates.shape)}") + if next_token_logits.ndim != 2: + raise ValueError( + "next_token_logits must be 2D, " + f"got shape={tuple(next_token_logits.shape)}." + ) + + bs, draft_token_num = candidates.shape + if bs <= 0: + raise ValueError(f"batch size must be positive, got {bs}.") + if draft_token_num <= 0: + raise ValueError(f"draft_token_num must be positive, got {draft_token_num}.") + if next_token_logits.shape[0] != bs * draft_token_num: + raise ValueError( + "next_token_logits row count mismatch. " + f"Expected {bs * draft_token_num}, got {next_token_logits.shape[0]}." + ) + if candidates.device != next_token_logits.device: + raise ValueError( + "candidates and next_token_logits must be on the same device, " + f"got {candidates.device} and {next_token_logits.device}." + ) + + if threshold_single is None: + from sglang.srt.server_args import get_global_server_args + + threshold_single = get_global_server_args().speculative_accept_threshold_single + if threshold_acc is None: + from sglang.srt.server_args import get_global_server_args + + threshold_acc = get_global_server_args().speculative_accept_threshold_acc + threshold_single = float(threshold_single) + threshold_acc = max(float(threshold_acc), 1e-9) + + device = next_token_logits.device + + if uniform_samples is None: + uniform_samples = torch.rand( + (bs, draft_token_num), dtype=torch.float32, device=device + ) + else: + if uniform_samples.shape != (bs, draft_token_num): + raise ValueError( + "uniform_samples shape mismatch. " + f"Expected {(bs, draft_token_num)}, got {tuple(uniform_samples.shape)}." + ) + uniform_samples = uniform_samples.to(device=device, dtype=torch.float32) + + if uniform_samples_for_final_sampling is None: + uniform_samples_for_final_sampling = torch.rand( + (bs,), dtype=torch.float32, device=device + ) + else: + if uniform_samples_for_final_sampling.shape != (bs,): + raise ValueError( + "uniform_samples_for_final_sampling shape mismatch. " + f"Expected {(bs,)}, got {tuple(uniform_samples_for_final_sampling.shape)}." + ) + uniform_samples_for_final_sampling = uniform_samples_for_final_sampling.to( + device=device, + dtype=torch.float32, + ) + + need_top_k = bool(getattr(sampling_info, "need_top_k_sampling", True)) + need_top_p = bool(getattr(sampling_info, "need_top_p_sampling", False)) + # Build target distribution once over all verify rows. + expanded_temperature = torch.repeat_interleave( + sampling_info.temperatures, draft_token_num, dim=0 + ) + scaled_logits = next_token_logits / expanded_temperature + sparse_topk_applied = False + + if use_sparse_topk and need_top_k: + repeated_top_ks = torch.repeat_interleave( + sampling_info.top_ks, draft_token_num, dim=0 + ).to(dtype=torch.int64) + vocab_size = int(scaled_logits.shape[-1]) + repeated_top_ks.clamp_(min=1, max=vocab_size) + max_top_k = int(repeated_top_ks.max().item()) + + # Sparse exact path for top-k/top-p (top-k-first semantics), then scatter to dense. + if 0 < max_top_k < vocab_size: + topk_logits, topk_indices = torch.topk(scaled_logits, k=max_top_k, dim=-1) + if not torch.all(repeated_top_ks == max_top_k): + ranks = torch.arange(max_top_k, device=device, dtype=torch.int64)[ + None, : + ] + valid = ranks < repeated_top_ks.unsqueeze(1) + topk_logits = topk_logits.masked_fill(~valid, float("-inf")) + + topk_probs = F.softmax(topk_logits, dim=-1) + if need_top_p: + repeated_top_ps = torch.repeat_interleave( + sampling_info.top_ps, draft_token_num, dim=0 + ) + topk_probs = top_p_renorm_prob(topk_probs, repeated_top_ps) + + target_probs = torch.zeros_like(scaled_logits, dtype=topk_probs.dtype) + target_probs.scatter_(1, topk_indices, topk_probs) + sparse_topk_applied = True + + if not sparse_topk_applied: + target_probs = F.softmax(scaled_logits, dim=-1) + if need_top_k: + target_probs = top_k_renorm_prob( + target_probs, + torch.repeat_interleave(sampling_info.top_ks, draft_token_num, dim=0), + ) + if need_top_p: + target_probs = top_p_renorm_prob( + target_probs, + torch.repeat_interleave(sampling_info.top_ps, draft_token_num, dim=0), + ) + target_probs = target_probs.view(bs, draft_token_num, -1).contiguous() + draft_probs = torch.zeros_like(target_probs) + + ( + retrieve_index, + retrieve_next_token, + retrieve_next_sibling, + predicts, + accept_index, + accept_token_num, + ) = _get_or_create_chain_verify_buffers( + bs=bs, + draft_token_num=draft_token_num, + device=device, + ) + candidates_i64 = ( + candidates if candidates.dtype == torch.int64 else candidates.to(torch.int64) + ) + tree_speculative_sampling_target_only( + predicts=predicts, + accept_index=accept_index, + accept_token_num=accept_token_num, + candidates=candidates_i64, + retrive_index=retrieve_index, + retrive_next_token=retrieve_next_token, + retrive_next_sibling=retrieve_next_sibling, + uniform_samples=uniform_samples, + uniform_samples_for_final_sampling=uniform_samples_for_final_sampling, + target_probs=target_probs, + draft_probs=draft_probs, + threshold_single=threshold_single, + threshold_acc=threshold_acc, + deterministic=True, + ) + + accept_len = accept_token_num + row_ids = torch.arange(bs, dtype=torch.long, device=device) + accept_pos = accept_index[row_ids, accept_len.to(torch.long)].to(torch.long) + bonus = predicts[accept_pos].to(torch.int64) + return accept_len, bonus diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py new file mode 100644 index 000000000000..030aa21e5b35 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -0,0 +1,1245 @@ +import logging +import math +from copy import deepcopy +from typing import Optional, Union + +import torch + +from sglang.srt.distributed import get_tp_group +from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch +from sglang.srt.managers.scheduler import GenerationBatchResult +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.mem_cache.common import get_last_loc +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.server_args import ( + ServerArgs, + get_global_server_args, + set_global_server_args_for_scheduler, +) +from sglang.srt.speculative.dflash_info import DFlashDraftInput, DFlashVerifyInput +from sglang.srt.speculative.dflash_utils import ( + can_dflash_use_fused_qkv_proj, + is_dflash_sampling_verify_available, + parse_dflash_draft_config, + resolve_dflash_verify_mask_policy, +) +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func +from sglang.srt.utils import is_cuda + +logger = logging.getLogger(__name__) + +_FusedKVMaterializeHelper = None + + +def _get_fused_kv_materialize_helper(): + global _FusedKVMaterializeHelper + if _FusedKVMaterializeHelper is None: + from sglang.srt.speculative.triton_ops.fused_kv_materialize import ( + FusedKVMaterializeHelper, + ) + + _FusedKVMaterializeHelper = FusedKVMaterializeHelper + return _FusedKVMaterializeHelper + + +class DFlashWorker: + """DFlash speculative decoding worker (spec-v1, tp>=1/pp=1).""" + + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, + nccl_port: int, + target_worker: TpModelWorker, + ): + self.server_args = server_args + self.gpu_id = gpu_id + self.tp_rank = tp_rank + self.dp_rank = dp_rank + self.moe_ep_rank = moe_ep_rank + self.attn_cp_rank = attn_cp_rank + self.moe_dp_rank = moe_dp_rank + self.nccl_port = nccl_port + self.target_worker = target_worker + self.model_runner = target_worker.model_runner + self.page_size = server_args.page_size + self.draft_window_size: Optional[int] = ( + int(server_args.speculative_dflash_draft_window_size) + if server_args.speculative_dflash_draft_window_size is not None + else None + ) + self.use_compact_draft_cache = self.draft_window_size is not None + self.device = target_worker.device + + self._warned_sampling_fallback = False + self._logged_first_verify = False + + # Draft runner (separate KV cache + attention backend). + # Without draft windowing, the draft worker aliases the target request->token + # mapping and allocation state. With draft windowing enabled, the draft worker + # keeps a private compact req->token table over the same global KV index space, + # so radix-cache/prefix-hit KV remains reusable while draft attention sees only + # the recent window. + target_req_to_token_pool, target_token_to_kv_pool_allocator = ( + target_worker.get_memory_pool() + ) + shared_req_to_token_pool = ( + None if self.use_compact_draft_cache else target_req_to_token_pool + ) + draft_server_args = deepcopy(server_args) + draft_server_args.skip_tokenizer_init = True + draft_backend = draft_server_args.speculative_draft_attention_backend + supported_draft_backends = ("flashinfer", "fa3", "fa4") + if draft_backend is None: + draft_backend, _ = draft_server_args.get_attention_backends() + if draft_backend is None: + draft_backend = "flashinfer" + elif draft_backend == "trtllm_mha": + logger.warning( + "DFLASH draft worker does not support 'trtllm_mha' because the " + "draft path requires non-causal attention. Falling back to " + "'flashinfer'." + ) + draft_backend = "flashinfer" + elif draft_backend not in supported_draft_backends: + logger.warning( + "DFLASH draft worker only supports attention_backend in %s for now, " + "but got %r. Falling back to 'flashinfer'.", + supported_draft_backends, + draft_backend, + ) + draft_backend = "flashinfer" + # Make the draft worker backend explicit and self-contained (no further overrides). + draft_server_args.speculative_draft_attention_backend = None + draft_server_args.prefill_attention_backend = None + draft_server_args.decode_attention_backend = None + draft_server_args.attention_backend = draft_backend + # Keep draft context length aligned with the target. + draft_server_args.context_length = ( + target_worker.model_runner.model_config.context_len + ) + saved_server_args = get_global_server_args() + self.draft_worker = TpModelWorker( + server_args=draft_server_args, + gpu_id=gpu_id, + tp_rank=tp_rank, + moe_ep_rank=moe_ep_rank, + pp_rank=0, + attn_cp_rank=attn_cp_rank, + moe_dp_rank=moe_dp_rank, + dp_rank=dp_rank, + nccl_port=nccl_port, + is_draft_worker=True, + req_to_token_pool=shared_req_to_token_pool, + token_to_kv_pool_allocator=target_token_to_kv_pool_allocator, + memory_pool_config=target_worker.model_runner.memory_pool_config, + ) + set_global_server_args_for_scheduler(saved_server_args) + self.draft_model_runner = self.draft_worker.model_runner + self.draft_model = self.draft_model_runner.model + draft_config = parse_dflash_draft_config( + draft_hf_config=self.draft_model_runner.model_config.hf_config + ) + if server_args.speculative_num_draft_tokens is None: + # Should not happen (ServerArgs should have inferred it), but keep a fallback. + self.block_size = int(draft_config.resolve_block_size(default=16)) + else: + self.block_size = int(server_args.speculative_num_draft_tokens) + model_block_size = draft_config.block_size + if model_block_size is None: + model_block_size = getattr(self.draft_model, "block_size", None) + if model_block_size is not None and int(model_block_size) != int( + self.block_size + ): + logger.warning( + "DFLASH block size mismatch: using speculative_num_draft_tokens=%s but draft config block_size=%s.", + self.block_size, + model_block_size, + ) + + self._mask_token = draft_config.mask_token + self._mask_token_id_override = draft_config.mask_token_id + self._mask_token_id = self._resolve_mask_token_id( + mask_token=self._mask_token, + mask_token_id=self._mask_token_id_override, + ) + if self.tp_rank == 0: + logger.info( + "Initialized DFLASH draft runner. attention_backend=%s, model=%s, block_size=%s, draft_window_size=%s, compact_cache=%s", + getattr(draft_server_args, "attention_backend", None), + self.draft_model.__class__.__name__, + self.block_size, + self.draft_window_size, + self.use_compact_draft_cache, + ) + logger.info( + "DFLASH draft runner ready. mask_token=%s, mask_token_id=%s, mask_token_id_override=%s", + self._mask_token, + self._mask_token_id, + self._mask_token_id_override, + ) + + self._block_pos_offsets = torch.arange( + self.block_size, device=self.device, dtype=torch.int64 + ) + self._draft_block_ids_buf: Optional[torch.Tensor] = None # [cap_bs, block_size] + self._draft_block_positions_buf: Optional[torch.Tensor] = ( + None # [cap_bs, block_size] + ) + self._draft_block_tokens_buf: Optional[torch.Tensor] = ( + None # [cap_bs, block_size] + ) + self._draft_block_end_buf: Optional[torch.Tensor] = None # [cap_bs] + self._draft_seq_lens_cpu_buf: Optional[torch.Tensor] = None # [cap_bs] on CPU + self._draft_block_spec_info = DFlashVerifyInput( + draft_token=torch.empty((0,), dtype=torch.long, device=self.device), + positions=torch.empty((0,), dtype=torch.int64, device=self.device), + draft_token_num=int(self.block_size), + custom_mask=None, + capture_hidden_mode=CaptureHiddenMode.NULL, + ) + self._draft_greedy_gathered_max_buf: Optional[torch.Tensor] = None + self._draft_greedy_gathered_ids_buf: Optional[torch.Tensor] = None + self._draft_greedy_gather_cap: int = 0 + self._draft_greedy_best_rank_buf: Optional[torch.Tensor] = None + self._draft_greedy_rank_index_buf: Optional[torch.Tensor] = None + self._draft_greedy_selected_ids_buf: Optional[torch.Tensor] = None + self._draft_greedy_index_cap: int = 0 + + self._use_fused_kv_materialize = is_cuda() + self._fused_kv_helper: Optional[object] = None + if self._use_fused_kv_materialize: + self._init_fused_kv_helper() + + def _init_fused_kv_helper(self) -> None: + """Initialize the fused KV materialization helper with pre-stacked weights.""" + try: + layers = self.draft_model.layers + fused_disable_reason: Optional[str] = None + + if len(layers) == 0: + fused_disable_reason = "no layers found" + + for layer_idx, layer in enumerate(layers): + attn = layer.self_attn + eligible, reason = can_dflash_use_fused_qkv_proj(attn.qkv_proj) + if not eligible: + fused_disable_reason = f"{reason}: layer={layer_idx}" + break + + # Keep semantics aligned with set_kv_buffer scaling behavior. + k_scale = getattr(attn.attn, "k_scale", None) + v_scale = getattr(attn.attn, "v_scale", None) + if k_scale is not None and not math.isclose(float(k_scale), 1.0): + fused_disable_reason = ( + "non-unit k_scale is not supported for fused KV path: " + f"layer={layer_idx}, k_scale={k_scale}" + ) + break + if v_scale is not None and not math.isclose(float(v_scale), 1.0): + fused_disable_reason = ( + "non-unit v_scale is not supported for fused KV path: " + f"layer={layer_idx}, v_scale={v_scale}" + ) + break + + rope_is_neox_style = bool( + getattr(attn.rotary_emb, "is_neox_style", True) + ) + if not rope_is_neox_style: + fused_disable_reason = ( + "non-neox RoPE is not supported for fused KV path: " + f"layer={layer_idx}, rope_is_neox_style={rope_is_neox_style}" + ) + break + + if fused_disable_reason is not None: + if self.tp_rank == 0: + logger.info( + "DFLASH fused KV materialization disabled: %s", + fused_disable_reason, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + return + + FusedKVMaterializeHelper = _get_fused_kv_materialize_helper() + first_attn = layers[0].self_attn + rotary_emb = first_attn.rotary_emb + + self._fused_kv_helper = FusedKVMaterializeHelper( + layers=layers, + rotary_emb=rotary_emb, + num_kv_heads=first_attn.num_kv_heads, + head_dim=first_attn.head_dim, + device=self.device, + ) + if self.tp_rank == 0: + logger.info( + "DFLASH fused KV materialization enabled. " + "n_layers=%d, num_kv_heads=%d, head_dim=%d", + len(layers), + first_attn.num_kv_heads, + first_attn.head_dim, + ) + except Exception as e: + logger.warning( + "DFLASH fused KV initialization failed, falling back to sequential path: %s", + e, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + + def _ensure_draft_block_buffers(self, bs: int) -> None: + cap = ( + 0 + if self._draft_block_ids_buf is None + else int(self._draft_block_ids_buf.shape[0]) + ) + if cap >= int(bs): + return + + new_cap = max(int(bs), cap * 2 if cap > 0 else int(bs)) + device = self.device + block_size = int(self.block_size) + self._draft_block_ids_buf = torch.empty( + (new_cap, block_size), dtype=torch.long, device=device + ) + self._draft_block_positions_buf = torch.empty( + (new_cap, block_size), dtype=torch.int64, device=device + ) + self._draft_block_tokens_buf = torch.empty( + (new_cap, block_size), dtype=torch.long, device=device + ) + self._draft_block_end_buf = torch.empty( + (new_cap,), dtype=torch.int32, device=device + ) + self._draft_seq_lens_cpu_buf = torch.empty( + (new_cap,), dtype=torch.int32, device="cpu" + ) + + def __getattr__(self, name): + # Delegate anything not implemented yet to the target worker. + return getattr(self.target_worker, name) + + def clear_cache_pool(self): + # The target worker owns the shared KV allocator/cache. For the compact + # sliding-window path, the draft req->token view is rebuilt from committed + # target state before each draft forward, so there is nothing persistent + # to flush here. + pass + + def _gather_req_to_token_masked( + self, + *, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + pos2d: torch.Tensor, + mask: torch.Tensor, + context: str, + ) -> torch.Tensor: + if pos2d.ndim != 2: + raise RuntimeError( + f"{context} expected 2D positions, got shape={tuple(pos2d.shape)}." + ) + if mask.shape != pos2d.shape: + raise RuntimeError( + f"{context} mask/position shape mismatch: {tuple(mask.shape)} vs {tuple(pos2d.shape)}." + ) + + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) + if mask.dtype != torch.bool: + mask = mask.to(torch.bool) + + table_width = int(req_to_token.shape[1]) + if table_width <= 0: + if bool(mask.any().item()): + raise RuntimeError( + f"{context} req_to_token table is empty but gather mask is non-empty." + ) + return torch.empty((0,), dtype=torch.int64, device=self.device) + + # Only the masked-off rectangular padding can be out of range in the normal + # ragged-batch case. Replace those don't-care columns with a valid in-range + # position before the gather so the kernel only sees real positions. + safe_pos2d = pos2d.masked_fill(~mask, 0) + return req_to_token[req_pool_indices[:, None], safe_pos2d][mask].to(torch.int64) + + def _gather_req_to_token_segments( + self, + *, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + start: torch.Tensor | None, + lengths: torch.Tensor, + ) -> torch.Tensor: + lengths = lengths.to(torch.int64) + if lengths.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=self.device) + max_len = int(lengths.max().item()) + if max_len <= 0: + return torch.empty((0,), dtype=torch.int64, device=self.device) + + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) + offsets = torch.arange( + max_len, device=self.device, dtype=torch.int64 + ).unsqueeze(0) + if start is None: + pos2d = offsets.expand(req_pool_indices.shape[0], -1) + else: + pos2d = start.to(torch.int64).unsqueeze(1) + offsets + mask = offsets < lengths.unsqueeze(1) + return self._gather_req_to_token_masked( + req_to_token=req_to_token, + req_pool_indices=req_pool_indices, + pos2d=pos2d, + mask=mask, + context="DFLASH req_to_token segment gather", + ) + + def _compute_compact_draft_seq_lens(self, seq_lens: torch.Tensor) -> torch.Tensor: + assert self.draft_window_size is not None + visible_lens = torch.clamp( + seq_lens.to(dtype=torch.int32, device=self.device), + max=int(self.draft_window_size), + ) + if self.page_size <= 1: + return visible_lens + + # Paged FA backends derive the page table from local token positions, so the + # compact suffix must start on a page boundary. Keep up to page_size - 1 extra + # tokens on the left to preserve valid local page structure. + seq_lens_i64 = seq_lens.to(torch.int64) + visible_lens_i64 = visible_lens.to(torch.int64) + visible_start = seq_lens_i64 - visible_lens_i64 + aligned_start = visible_start - torch.remainder(visible_start, self.page_size) + return (seq_lens_i64 - aligned_start).to(torch.int32) + + def _resolve_mask_token_id( + self, *, mask_token: str, mask_token_id: Optional[int] = None + ) -> int: + if not isinstance(mask_token, str) or not mask_token: + raise ValueError( + f"DFLASH mask_token must be a non-empty string, got {mask_token!r}." + ) + + vocab_size = int(self.target_worker.model_runner.model_config.vocab_size) + if mask_token_id is not None: + resolved_id = int(mask_token_id) + if resolved_id >= vocab_size: + raise ValueError( + "DFLASH mask_token_id is outside the target vocab size. " + f"mask_token_id={resolved_id}, vocab_size={vocab_size}. " + f"This likely means mask_token={mask_token!r} requires vocab expansion beyond the model's embedding size. " + "SGLang does not support resizing target embeddings for DFLASH yet." + ) + + tokenizer = getattr(self.target_worker, "tokenizer", None) + if tokenizer is not None: + token_id_from_vocab = tokenizer.get_vocab().get(mask_token, None) + if ( + token_id_from_vocab is not None + and int(token_id_from_vocab) != resolved_id + ): + raise ValueError( + "DFLASH config mismatch: dflash_config.mask_token_id conflicts with tokenizer vocab id " + f"for dflash_config.mask_token. mask_token={mask_token!r}, " + f"mask_token_id={resolved_id}, tokenizer_vocab_id={int(token_id_from_vocab)}." + ) + return resolved_id + + tokenizer = getattr(self.target_worker, "tokenizer", None) + if tokenizer is None: + raise RuntimeError( + "DFLASH requires tokenizer initialization when dflash_config.mask_token_id is not set " + "(skip_tokenizer_init is not supported in this mode)." + ) + + resolved_id = None + if getattr(tokenizer, "mask_token", None) == mask_token: + resolved_id = getattr(tokenizer, "mask_token_id", None) + + if resolved_id is None: + # Prefer checking the explicit vocab mapping first. + vocab = tokenizer.get_vocab() + resolved_id = vocab.get(mask_token, None) + + if resolved_id is None: + # Mirror the reference DFlash HF demo by adding the mask token to the tokenizer. + # This is safe only when the resulting id stays within the target model vocab size. + added = tokenizer.add_special_tokens({"mask_token": mask_token}) + resolved_id = getattr(tokenizer, "mask_token_id", None) + if resolved_id is None: + resolved_id = tokenizer.convert_tokens_to_ids(mask_token) + + if added and self.tp_rank == 0: + logger.info( + "Added DFLASH mask token to tokenizer. token=%s, mask_token_id=%s, tokenizer_len=%s, model_vocab_size=%s", + mask_token, + resolved_id, + len(tokenizer), + vocab_size, + ) + + if resolved_id is None or int(resolved_id) < 0: + raise ValueError( + "DFLASH requires resolving a mask token id, but it could not be resolved. " + f"mask_token={mask_token!r}." + ) + + if resolved_id >= vocab_size: + raise ValueError( + "DFLASH mask_token_id is outside the target vocab size. " + f"mask_token_id={resolved_id}, vocab_size={vocab_size}. " + f"This likely means mask_token={mask_token!r} requires vocab expansion beyond the model's embedding size. " + "SGLang does not support resizing target embeddings for DFLASH yet." + ) + + return int(resolved_id) + + def _prepare_for_speculative_decoding( + self, batch: ScheduleBatch, draft_input: DFlashDraftInput + ): + if batch.forward_mode.is_extend() or batch.forward_mode.is_idle(): + return + + if batch.has_grammar: + raise RuntimeError( + "Invariant broken: DFLASH batch has grammar constraints, but scheduler should have rejected this request." + ) + if batch.sampling_info is not None and not batch.sampling_info.is_all_greedy: + if ( + not is_dflash_sampling_verify_available() + and not self._warned_sampling_fallback + and self.tp_rank == 0 + ): + logger.warning( + "DFLASH non-greedy verification is unavailable on this build/device; " + "falling back to greedy argmax verification." + ) + self._warned_sampling_fallback = True + + bs = batch.batch_size() + + # --- 1) Append any newly committed tokens into the draft KV cache. + self._append_target_hidden_to_draft_kv(batch, draft_input) + + target_model = self.target_worker.model_runner.model + embed_module = target_model.get_input_embeddings() + lm_head = getattr(target_model, "lm_head", None) + if ( + lm_head is None + or not hasattr(lm_head, "weight") + or not hasattr(lm_head, "shard_indices") + ): + raise RuntimeError( + "DFLASH requires the target model to expose a vocab-parallel `lm_head` with `weight` and " + "`shard_indices` attributes." + ) + + # --- 2) Draft a non-causal block with the draft model. + self._ensure_draft_block_buffers(bs) + assert self._draft_block_ids_buf is not None + assert self._draft_block_positions_buf is not None + assert self._draft_block_tokens_buf is not None + assert self._draft_block_end_buf is not None + assert self._draft_seq_lens_cpu_buf is not None + + block_ids = self._draft_block_ids_buf[:bs] + block_ids.fill_(int(self._mask_token_id)) + block_ids[:, 0].copy_(draft_input.verified_id.to(torch.long)) + + noise_embedding = embed_module(block_ids) + input_embeds = noise_embedding.view(-1, noise_embedding.shape[-1]) + + # For spec-v1, the draft KV cache is always materialized before drafting the + # next block. `target_prefix_lens` stay absolute for RoPE; `draft_prefix_lens` + # are the logical resident lengths in the draft-local cache. + target_prefix_lens = batch.seq_lens # int32, device + draft_prefix_lens = draft_input.draft_seq_lens + if draft_prefix_lens.dtype != torch.int32: + draft_prefix_lens = draft_prefix_lens.to(torch.int32) + if draft_prefix_lens.device != self.device: + draft_prefix_lens = draft_prefix_lens.to(self.device, non_blocking=True) + + positions_2d = self._draft_block_positions_buf[:bs] + torch.add( + target_prefix_lens.unsqueeze(1), self._block_pos_offsets, out=positions_2d + ) + positions = positions_2d.reshape(-1) + + block_start = draft_prefix_lens + block_end = self._draft_block_end_buf[:bs] + torch.add(block_start, int(self.block_size), out=block_end) + + seq_lens_cpu = self._draft_seq_lens_cpu_buf[:bs] + seq_lens_cpu.copy_(draft_prefix_lens.to(device="cpu", dtype=torch.int32)) + allocator = self.draft_model_runner.token_to_kv_pool_allocator + token_to_kv_pool_state_backup = allocator.backup_state() + try: + if self.page_size == 1: + block_cache_loc = allocator.alloc(bs * self.block_size) + else: + block_end_cpu = seq_lens_cpu + int(self.block_size) + last_loc = get_last_loc( + self.draft_model_runner.req_to_token_pool.req_to_token, + batch.req_pool_indices, + block_start, + ) + block_cache_loc = allocator.alloc_extend( + block_start, + seq_lens_cpu, + block_end, + block_end_cpu, + last_loc, + bs * self.block_size, + ) + if block_cache_loc is None: + raise RuntimeError( + f"DFLASH draft OOM when allocating {bs * self.block_size} block tokens." + ) + + assign_req_to_token_pool_func( + batch.req_pool_indices, + self.draft_model_runner.req_to_token_pool.req_to_token, + block_start, + block_end, + block_cache_loc, + bs, + ) + + # Use TARGET_VERIFY mode (cuda-graphable) to run a fixed-size draft block. + # In this mode, `seq_lens` stores the prefix lengths; attention backends + # derive kv_len by adding `draft_token_num`. + draft_spec_info = self._draft_block_spec_info + seq_lens = draft_prefix_lens + seq_lens_sum = int(draft_prefix_lens.sum().item()) + forward_batch = ForwardBatch( + forward_mode=ForwardMode.TARGET_VERIFY, + batch_size=bs, + input_ids=block_ids.flatten(), + req_pool_indices=batch.req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=block_cache_loc, + seq_lens_sum=seq_lens_sum, + seq_lens_cpu=seq_lens_cpu, + positions=positions, + req_to_token_pool=self.draft_model_runner.req_to_token_pool, + token_to_kv_pool=self.draft_model_runner.token_to_kv_pool, + attn_backend=self.draft_model_runner.attn_backend, + input_embeds=input_embeds, + spec_algorithm=SpeculativeAlgorithm.DFLASH, + spec_info=draft_spec_info, + capture_hidden_mode=CaptureHiddenMode.NULL, + ) + + with torch.inference_mode(): + draft_logits_output = self.draft_model_runner.forward( + forward_batch + ).logits_output + finally: + # Drop the speculative block from the shared allocator (EAGLE3-style). + allocator.restore_state(token_to_kv_pool_state_backup) + + draft_hidden = draft_logits_output.hidden_states + if draft_hidden is None: + raise RuntimeError("DFLASH draft model returned no hidden states.") + draft_hidden = draft_hidden.view(bs, self.block_size, -1) + draft_next = self._greedy_sample_from_vocab_parallel_head( + hidden_states=draft_hidden[:, 1:, :].reshape(-1, draft_hidden.shape[-1]), + lm_head=lm_head, + ).view(bs, self.block_size - 1) + draft_tokens = self._draft_block_tokens_buf[:bs] + draft_tokens[:, 0].copy_(block_ids[:, 0]) + draft_tokens[:, 1:].copy_(draft_next) + positions = positions_2d.reshape(-1) + + verify_input = DFlashVerifyInput( + draft_token=draft_tokens.reshape(-1), + positions=positions, + draft_token_num=self.block_size, + ) + _, build_custom_mask = resolve_dflash_verify_mask_policy( + self.model_runner.attn_backend + ) + verify_input.prepare_for_verify( + batch, + self.page_size, + build_custom_mask=build_custom_mask, + ) + + batch.forward_mode = ( + ForwardMode.TARGET_VERIFY + if not batch.forward_mode.is_idle() + else ForwardMode.IDLE + ) + batch.spec_info = verify_input + batch.return_hidden_states = False + + def _greedy_sample_from_vocab_parallel_head( + self, + *, + hidden_states: torch.Tensor, + lm_head, + chunk_size: int = 256, + ) -> torch.Tensor: + """Greedy argmax over the target LM head in a TP-safe way. + + We cannot materialize full logits for large vocabularies efficiently, and with + TP>1 each rank only owns a shard of the LM head weight. This computes the + per-rank max, gathers candidates across TP ranks, and selects the global max. + """ + + if hidden_states.numel() == 0: + return torch.empty((0,), dtype=torch.long, device=hidden_states.device) + + tp_group = get_tp_group() + tp_size = int(tp_group.world_size) + + if not hasattr(lm_head, "weight") or not hasattr(lm_head, "shard_indices"): + raise RuntimeError( + "DFLASH greedy sampling requires a vocab-parallel head with `weight` and `shard_indices`." + ) + + shard = lm_head.shard_indices + weight = lm_head.weight # [local_vocab_padded, hidden] + weight_dtype = weight.dtype + + # Valid ranges in the local shard (excluding padding): + # base vocab: [0, num_org) + # added vocab: [num_org_padded, num_org_padded + num_added) + num_org = int(shard.num_org_elements) + num_org_padded = int(shard.num_org_elements_padded) + num_added = int(shard.num_added_elements) + org_vocab_start = int(shard.org_vocab_start_index) + added_vocab_start = int(shard.added_vocab_start_index) + + num_tokens = int(hidden_states.shape[0]) + out_token_ids = torch.empty( + (num_tokens,), dtype=torch.long, device=hidden_states.device + ) + + def _cast_hs(x: torch.Tensor) -> torch.Tensor: + return x if x.dtype == weight_dtype else x.to(weight_dtype) + + # Fast path (common): single-rank greedy sampling over the base vocab shard. + # Avoids extra max/id bookkeeping that is only needed for TP sync or added vocab. + if tp_size == 1 and num_added == 0: + for start in range(0, num_tokens, int(chunk_size)): + end = min(num_tokens, start + int(chunk_size)) + hs = _cast_hs(hidden_states[start:end]) + if num_org > 0: + base_logits = torch.matmul(hs, weight[:num_org].T) + out_token_ids[start:end] = ( + torch.argmax(base_logits, dim=-1).to(torch.long) + + org_vocab_start + ) + else: + out_token_ids[start:end] = 0 + return out_token_ids + + for start in range(0, num_tokens, int(chunk_size)): + end = min(num_tokens, start + int(chunk_size)) + hs = _cast_hs(hidden_states[start:end]) + chunk_len = int(hs.shape[0]) + + # Base vocab logits. + if num_org > 0: + base_logits = torch.matmul(hs, weight[:num_org].T) + local_max, local_arg = torch.max(base_logits, dim=-1) + else: + local_max = torch.full( + (chunk_len,), + torch.finfo(weight_dtype).min, + dtype=weight_dtype, + device=hs.device, + ) + local_arg = torch.zeros( + (chunk_len,), dtype=torch.int64, device=hs.device + ) + + # Added vocab logits (e.g., LoRA-added embeddings), if present. + if num_added > 0: + added_slice_start = num_org_padded + added_slice_end = num_org_padded + num_added + added_logits = torch.matmul( + hs, weight[added_slice_start:added_slice_end].T + ) + added_max, added_arg = torch.max(added_logits, dim=-1) + use_added = added_max > local_max + local_max = torch.where(use_added, added_max, local_max) + # For base/added conversion below, keep local_arg expressed in the full local + # weight index space (base + padding + added), matching `lm_head.weight`. + local_arg = torch.where( + use_added, added_arg.to(local_arg.dtype) + num_org_padded, local_arg + ) + + # Convert local argmax indices to global token ids. + if num_added == 0: + local_arg.add_(org_vocab_start) + global_ids = local_arg + else: + global_ids = torch.empty( + (chunk_len,), dtype=torch.int64, device=hs.device + ) + is_base = local_arg < num_org + global_ids[is_base] = org_vocab_start + local_arg[is_base] + global_ids[~is_base] = added_vocab_start + ( + local_arg[~is_base] - num_org_padded + ) + + if tp_size == 1: + out_token_ids[start:end] = global_ids.to(torch.long) + continue + + # Gather per-rank maxima and associated global ids, then select the global max. + needed = tp_size * chunk_len + chunk_cap = int(chunk_size) + if ( + self._draft_greedy_gather_cap < needed + or self._draft_greedy_gathered_max_buf is None + or self._draft_greedy_gathered_ids_buf is None + or self._draft_greedy_gathered_max_buf.dtype != local_max.dtype + or self._draft_greedy_gathered_max_buf.device != hs.device + ): + # Allocate enough space for the max chunk size to avoid reallocations. + cap = tp_size * chunk_cap + self._draft_greedy_gathered_max_buf = torch.empty( + (cap,), dtype=local_max.dtype, device=hs.device + ) + self._draft_greedy_gathered_ids_buf = torch.empty( + (cap,), dtype=global_ids.dtype, device=hs.device + ) + self._draft_greedy_gather_cap = cap + + if ( + self._draft_greedy_index_cap < chunk_len + or self._draft_greedy_best_rank_buf is None + or self._draft_greedy_rank_index_buf is None + or self._draft_greedy_selected_ids_buf is None + or self._draft_greedy_best_rank_buf.device != hs.device + or self._draft_greedy_selected_ids_buf.device != hs.device + ): + self._draft_greedy_best_rank_buf = torch.empty( + (chunk_cap,), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_rank_index_buf = torch.empty( + (1, chunk_cap), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_selected_ids_buf = torch.empty( + (1, chunk_cap), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_index_cap = chunk_cap + + gathered_max = self._draft_greedy_gathered_max_buf[:needed] + gathered_ids = self._draft_greedy_gathered_ids_buf[:needed] + + tp_group.all_gather_into_tensor(gathered_max, local_max.contiguous()) + tp_group.all_gather_into_tensor(gathered_ids, global_ids.contiguous()) + gathered_max = gathered_max.view(tp_size, chunk_len) + gathered_ids = gathered_ids.view(tp_size, chunk_len) + + best_rank = self._draft_greedy_best_rank_buf[:chunk_len] + torch.argmax(gathered_max, dim=0, out=best_rank) + + rank_index = self._draft_greedy_rank_index_buf[:, :chunk_len] + rank_index[0].copy_(best_rank) + selected_ids = self._draft_greedy_selected_ids_buf[:, :chunk_len] + torch.gather(gathered_ids, 0, rank_index, out=selected_ids) + out_token_ids[start:end].copy_(selected_ids.view(-1)) + + return out_token_ids + + def _append_target_hidden_to_draft_kv( + self, + batch: ScheduleBatch, + draft_input: DFlashDraftInput, + ) -> None: + """Materialize the target hidden-state features into the draft KV cache. + + This must be run before exposing new tokens to radix cache (prefix hits), otherwise + another request could reuse target KV indices without having draft KV values. + """ + + bs = batch.batch_size() + device = self.model_runner.device + + if draft_input.target_hidden is None: + raise RuntimeError( + "DFLASH draft state missing target_hidden context features." + ) + if draft_input.ctx_lens.numel() != bs: + raise RuntimeError( + f"DFLASH ctx_lens length mismatch: got {draft_input.ctx_lens.numel()} for bs={bs}." + ) + if draft_input.draft_seq_lens.numel() != bs: + raise RuntimeError( + f"DFLASH draft_seq_lens length mismatch: got {draft_input.draft_seq_lens.numel()} for bs={bs}." + ) + + total_ctx = int(draft_input.target_hidden.shape[0]) + if total_ctx <= 0: + draft_input.ctx_lens = torch.zeros_like(draft_input.ctx_lens) + draft_input.target_hidden = draft_input.target_hidden[:0] + return + + target_req_to_token = batch.req_to_token_pool.req_to_token + draft_req_to_token = self.draft_model_runner.req_to_token_pool.req_to_token + + req_pool_indices = batch.req_pool_indices + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) + + ctx_lens = draft_input.ctx_lens + if ctx_lens.dtype != torch.int32: + ctx_lens = ctx_lens.to(torch.int32) + if ctx_lens.device != device: + ctx_lens = ctx_lens.to(device, non_blocking=True) + ctx_start = batch.seq_lens.to(torch.int64) - ctx_lens.to(torch.int64) + + if bs == 1: + # Fast path for single request. + max_ctx = int(total_ctx) + if max_ctx <= self._block_pos_offsets.numel(): + r = self._block_pos_offsets[:max_ctx] + else: + r = torch.arange(max_ctx, device=device, dtype=torch.int64) + pos2d = ctx_start[:, None] + r[None, :] # [1, ctx] + cache2d = target_req_to_token[req_pool_indices[:, None], pos2d] # [1, ctx] + ctx_cache_loc = cache2d.reshape(-1).to(torch.int64) # [ctx] + ctx_positions = pos2d.reshape(-1) # [ctx] + else: + # In decode mode, ctx_lens <= block_size so we can skip the .item() sync. + if batch.forward_mode.is_extend() or batch.is_extend_in_batch: + max_ctx = int(ctx_lens.max().item()) + else: + max_ctx = int(self.block_size) + if max_ctx <= 0: + raise RuntimeError(f"DFLASH invalid max_ctx={max_ctx} for KV append.") + + if max_ctx <= self._block_pos_offsets.numel(): + r = self._block_pos_offsets[:max_ctx] + else: + r = torch.arange(max_ctx, device=device, dtype=torch.int64) + r = r[None, :] # [1, max_ctx] + pos2d = ctx_start[:, None] + r # [bs, max_ctx] + mask = r < ctx_lens[:, None] + + # Batched gather of cache locations and positions. + ctx_cache_loc = self._gather_req_to_token_masked( + req_to_token=target_req_to_token, + req_pool_indices=req_pool_indices, + pos2d=pos2d, + mask=mask, + context="DFLASH target hidden KV append", + ) # [sum(ctx_lens)] + ctx_positions = pos2d[mask] # [sum(ctx_lens)] + + with torch.inference_mode(): + ctx_hidden = self.draft_model.project_target_hidden( + draft_input.target_hidden + ) # [sum(ctx), hidden] + if ctx_hidden.shape[0] != ctx_cache_loc.numel(): + raise RuntimeError( + f"DFLASH ctx_hidden/cache_loc mismatch: {ctx_hidden.shape[0]} vs {ctx_cache_loc.numel()}." + ) + + if self._use_fused_kv_materialize and self._fused_kv_helper is not None: + try: + self._append_target_hidden_fused( + ctx_hidden, ctx_positions, ctx_cache_loc + ) + except Exception as e: + logger.warning( + "DFLASH fused KV append failed; falling back to sequential path: %s", + e, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + self._append_target_hidden_sequential( + ctx_hidden, ctx_positions, ctx_cache_loc + ) + else: + self._append_target_hidden_sequential( + ctx_hidden, ctx_positions, ctx_cache_loc + ) + + if self.use_compact_draft_cache: + new_draft_seq_lens = self._compute_compact_draft_seq_lens(batch.seq_lens) + suffix_start = batch.seq_lens.to(torch.int64) - new_draft_seq_lens.to( + torch.int64 + ) + suffix_cache_loc = self._gather_req_to_token_segments( + req_to_token=target_req_to_token, + req_pool_indices=req_pool_indices, + start=suffix_start, + lengths=new_draft_seq_lens, + ) + assign_req_to_token_pool_func( + batch.req_pool_indices, + draft_req_to_token, + torch.zeros_like(new_draft_seq_lens), + new_draft_seq_lens, + suffix_cache_loc, + bs, + ) + draft_input.draft_seq_lens = new_draft_seq_lens + else: + draft_input.draft_seq_lens = batch.seq_lens.to(dtype=torch.int32) + draft_input.ctx_lens = torch.zeros_like(ctx_lens) + draft_input.target_hidden = draft_input.target_hidden[:0] + + def _append_target_hidden_sequential( + self, + ctx_hidden: torch.Tensor, + ctx_positions: torch.Tensor, + ctx_cache_loc: torch.Tensor, + ) -> None: + for layer in self.draft_model.layers: + attn = layer.self_attn + k, v = attn.kv_proj_only(ctx_hidden) + k = attn.apply_k_norm(k) + k = attn.apply_k_rope(ctx_positions, k) + k = k.view(-1, attn.num_kv_heads, attn.head_dim) + v = v.view(-1, attn.num_kv_heads, attn.head_dim) + self.draft_model_runner.token_to_kv_pool.set_kv_buffer( + attn.attn, + ctx_cache_loc, + k, + v, + attn.attn.k_scale, + attn.attn.v_scale, + ) + + def _append_target_hidden_fused( + self, + ctx_hidden: torch.Tensor, + ctx_positions: torch.Tensor, + ctx_cache_loc: torch.Tensor, + ) -> None: + """Fused KV materialization using batched projection + Triton kernel.""" + token_to_kv_pool = self.draft_model_runner.token_to_kv_pool + layers = self.draft_model.layers + + def _write_layer_kv( + layer_idx: int, cache_k: torch.Tensor, cache_v: torch.Tensor + ) -> None: + attn = layers[layer_idx].self_attn.attn + token_to_kv_pool.set_kv_buffer( + attn, + ctx_cache_loc, + cache_k, + cache_v, + attn.k_scale, + attn.v_scale, + ) + + self._fused_kv_helper.materialize( + ctx_hidden=ctx_hidden, + positions=ctx_positions, + write_layer_kv=_write_layer_kv, + ) + + def _update_target_mamba_state_after_verify( + self, + *, + batch: ScheduleBatch, + seq_lens_pre_verify: torch.Tensor, + commit_lens: torch.Tensor, + ) -> None: + """Commit Mamba intermediate states for accepted verify steps. + + During TARGET_VERIFY, Mamba kernels run with `disable_state_update=True` and + cache per-step intermediate states. After acceptance, we need to commit the + state corresponding to each request's last accepted step. + """ + attn_backend = self.target_worker.model_runner.attn_backend + if not hasattr(attn_backend, "update_mamba_state_after_mtp_verify"): + return + + accepted_steps = commit_lens.to(torch.int64) - 1 + mamba_steps_to_track = None + + if batch.mamba_track_indices is not None: + mamba_track_interval = self.server_args.mamba_track_interval + to_track_mask = ( + seq_lens_pre_verify // mamba_track_interval + != batch.seq_lens // mamba_track_interval + ) + tracking_point = ( + batch.seq_lens // mamba_track_interval * mamba_track_interval + ) + to_track_ith = torch.clamp(tracking_point - seq_lens_pre_verify - 1, min=0) + can_track_mask = to_track_mask & ( + to_track_ith < commit_lens.to(to_track_ith.dtype) + ) + mamba_steps_to_track = torch.where( + can_track_mask, + to_track_ith.to(torch.int64), + torch.full_like(to_track_ith, -1, dtype=torch.int64), + ) + + attn_backend.update_mamba_state_after_mtp_verify( + accepted_steps=accepted_steps, + mamba_track_indices=batch.mamba_track_indices, + mamba_steps_to_track=mamba_steps_to_track, + model=self.target_worker.model_runner.model, + ) + + def forward_batch_generation( + self, + batch: Union[ScheduleBatch, ModelWorkerBatch], + **kwargs, + ) -> GenerationBatchResult: + if getattr(batch, "return_logprob", False): + raise RuntimeError( + "Invariant broken: DFLASH batch requested return_logprob, but scheduler should have rejected this request." + ) + + if isinstance(batch, ModelWorkerBatch): + # Should not happen for spec-v1 (non-overlap) scheduling, but keep a sane fallback. + return self.target_worker.forward_batch_generation(batch, **kwargs) + + if batch.forward_mode.is_extend() or batch.is_extend_in_batch: + model_worker_batch = batch.get_model_worker_batch() + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + + batch_result = self.target_worker.forward_batch_generation( + model_worker_batch, **kwargs + ) + logits_output, next_token_ids = ( + batch_result.logits_output, + batch_result.next_token_ids, + ) + if logits_output.hidden_states is None: + raise RuntimeError( + "DFLASH requires target aux hidden capture for prefill, but got None. " + "Make sure the target model has DFlash layers-to-capture configured." + ) + + if ( + model_worker_batch.extend_seq_lens is None + or model_worker_batch.extend_prefix_lens is None + ): + raise RuntimeError( + "DFLASH expected extend_seq_lens / extend_prefix_lens to be populated in extend mode, but got None." + ) + + # Materialize the prompt tokens into the draft KV cache immediately. This is required + # for radix cache support, since the scheduler may update radix after prefill returns. + device = next_token_ids.device + + def _to_int32_device_tensor(x, *, device=device): + if isinstance(x, torch.Tensor): + if x.device != device: + x = x.to(device, non_blocking=True) + return x if x.dtype == torch.int32 else x.to(torch.int32) + return torch.tensor(x, dtype=torch.int32, device=device) + + extend_seq_lens = _to_int32_device_tensor( + model_worker_batch.extend_seq_lens + ) + draft_input = DFlashDraftInput( + verified_id=next_token_ids.to(torch.int64), + target_hidden=logits_output.hidden_states, + ctx_lens=extend_seq_lens, + draft_seq_lens=( + torch.zeros_like(extend_seq_lens) + if self.use_compact_draft_cache + else _to_int32_device_tensor(model_worker_batch.extend_prefix_lens) + ), + ) + self._append_target_hidden_to_draft_kv(batch, draft_input) + batch.spec_info = draft_input + + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=next_token_ids, + num_accepted_tokens=0, + can_run_cuda_graph=batch_result.can_run_cuda_graph, + ) + + # Decode / target-verify stage. + draft_input = batch.spec_info + if not isinstance(draft_input, DFlashDraftInput): + raise RuntimeError( + "DFLASH decode requires DFlashDraftInput state on the running batch. " + "This usually means the request did not complete the prefill stage." + ) + + self._prepare_for_speculative_decoding(batch, draft_input) + + model_worker_batch = batch.get_model_worker_batch() + assert model_worker_batch.forward_mode.is_target_verify() + verify_input = model_worker_batch.spec_info + assert isinstance(verify_input, DFlashVerifyInput) + need_mamba_verify_commit = hasattr( + self.target_worker.model_runner.attn_backend, + "update_mamba_state_after_mtp_verify", + ) + seq_lens_pre_verify = ( + batch.seq_lens.clone() if need_mamba_verify_commit else None + ) + + batch_result = self.target_worker.forward_batch_generation( + model_worker_batch, is_verify=True, **kwargs + ) + logits_output, can_run_cuda_graph = ( + batch_result.logits_output, + batch_result.can_run_cuda_graph, + ) + + ( + new_verified_id, + commit_lens, + next_target_hidden, + accept_length_per_req_cpu, + ) = verify_input.verify( + batch=batch, + logits_output=logits_output, + page_size=self.page_size, + ) + if need_mamba_verify_commit: + assert seq_lens_pre_verify is not None + self._update_target_mamba_state_after_verify( + batch=batch, + seq_lens_pre_verify=seq_lens_pre_verify, + commit_lens=commit_lens, + ) + + # Update draft state for the next iteration. Also materialize the committed verify tokens + # into the draft KV cache immediately so radix cache entries are safe to reuse. + draft_input.verified_id = new_verified_id + draft_input.target_hidden = next_target_hidden + draft_input.ctx_lens = commit_lens + self._append_target_hidden_to_draft_kv(batch, draft_input) + batch.spec_info = draft_input + batch.forward_mode = ForwardMode.DECODE + + num_accepted_tokens = sum(accept_length_per_req_cpu) + if not self._logged_first_verify and self.tp_rank == 0: + logger.info( + "DFLASH verify completed. accept_length_per_req=%s", + accept_length_per_req_cpu, + ) + self._logged_first_verify = True + + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=new_verified_id, + num_accepted_tokens=num_accepted_tokens, + accept_length_per_req_cpu=accept_length_per_req_cpu, + can_run_cuda_graph=can_run_cuda_graph, + ) diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index a40a8aa0dc33..3e5727187572 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -15,6 +15,7 @@ class SpeculativeAlgorithm(Enum): """Enumeration of speculative decoding algorithms.""" + DFLASH = auto() EAGLE = auto() EAGLE3 = auto() STANDALONE = auto() @@ -33,6 +34,9 @@ def from_string(cls, name: Optional[str]) -> SpeculativeAlgorithm: def is_none(self) -> bool: return self == SpeculativeAlgorithm.NONE + def is_speculative(self) -> bool: + return self != SpeculativeAlgorithm.NONE + def is_eagle(self) -> bool: # NOTE: EAGLE3 is a variant of EAGLE return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.EAGLE3 @@ -40,6 +44,9 @@ def is_eagle(self) -> bool: def is_eagle3(self) -> bool: return self == SpeculativeAlgorithm.EAGLE3 + def is_dflash(self) -> bool: + return self == SpeculativeAlgorithm.DFLASH + def is_standalone(self) -> bool: return self == SpeculativeAlgorithm.STANDALONE @@ -57,6 +64,16 @@ def create_worker( ), "Cannot create worker for NONE speculative algorithm." enable_overlap = not server_args.disable_overlap_schedule + + if self.is_dflash(): + if enable_overlap: + raise ValueError( + "DFLASH does not support overlap scheduling (spec v2)." + ) + from sglang.srt.speculative.dflash_worker import DFlashWorker + + return DFlashWorker + if self.is_eagle() and server_args.enable_multi_layer_eagle: # FIXME: migrate to EagleWorker if enable_overlap: @@ -110,6 +127,8 @@ class SpecInputType(IntEnum): # If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it EAGLE_DRAFT = auto() EAGLE_VERIFY = auto() + DFLASH_DRAFT = auto() + DFLASH_VERIFY = auto() NGRAM_VERIFY = auto() @@ -120,11 +139,15 @@ def __init__(self, spec_input_type: SpecInputType): def is_draft_input(self) -> bool: # FIXME: remove this function which is only used for assertion # or use another variable name like `draft_input` to substitute `spec_info` - return self.spec_input_type == SpecInputType.EAGLE_DRAFT + return self.spec_input_type in { + SpecInputType.EAGLE_DRAFT, + SpecInputType.DFLASH_DRAFT, + } def is_verify_input(self) -> bool: return self.spec_input_type in { SpecInputType.EAGLE_VERIFY, + SpecInputType.DFLASH_VERIFY, SpecInputType.NGRAM_VERIFY, } diff --git a/python/sglang/srt/speculative/triton_ops/__init__.py b/python/sglang/srt/speculative/triton_ops/__init__.py new file mode 100644 index 000000000000..a8ea8f4c704b --- /dev/null +++ b/python/sglang/srt/speculative/triton_ops/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Triton kernels for speculative decoding.""" + +from sglang.srt.speculative.triton_ops.fused_kv_materialize import ( + FusedKVMaterializeHelper, +) + +__all__ = ["FusedKVMaterializeHelper"] diff --git a/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py b/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py new file mode 100644 index 000000000000..e7dc4c05ddfc --- /dev/null +++ b/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py @@ -0,0 +1,303 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fused Triton kernel for DFlash KV materialization. + +Combines: KV projection (cuBLAS) + RMSNorm + RoPE (Triton), then pool-managed KV writes. +""" + +from typing import Callable, List + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fused_norm_rope_kernel( + kv_ptr, # [total_ctx, kv_size * 2] + k_norm_weight_ptr, # [head_dim] + cos_sin_cache_ptr, # [max_pos, rotary_dim] + positions_ptr, # [total_ctx] + k_out_ptr, # [total_ctx, num_kv_heads, head_dim] + v_out_ptr, # [total_ctx, num_kv_heads, head_dim] + kv_stride_ctx, + cos_sin_stride_pos, + k_out_stride_ctx, + k_out_stride_head, + v_out_stride_ctx, + v_out_stride_head, + total_ctx, + num_kv_heads: tl.constexpr, + head_dim: tl.constexpr, + kv_size: tl.constexpr, + rotary_dim: tl.constexpr, + half_rotary_dim: tl.constexpr, + eps: tl.constexpr, + BLOCK_HD: tl.constexpr, +): + """Fused RMSNorm(K) + RoPE(K) materialization. Grid: (total_ctx, num_kv_heads).""" + ctx_id = tl.program_id(0) + head_id = tl.program_id(1) + if ctx_id >= total_ctx: + return + + # Load metadata + position = tl.load(positions_ptr + ctx_id) + + # Compute base pointers + kv_base = kv_ptr + ctx_id * kv_stride_ctx + k_base = kv_base + head_id * head_dim + v_base = kv_base + kv_size + head_id * head_dim + k_write = k_out_ptr + ctx_id * k_out_stride_ctx + head_id * k_out_stride_head + v_write = v_out_ptr + ctx_id * v_out_stride_ctx + head_id * v_out_stride_head + + # Load K and V + offs = tl.arange(0, BLOCK_HD) + mask_hd = offs < head_dim + mask_half = offs < half_rotary_dim + + k_raw = tl.load(k_base + offs, mask=mask_hd, other=0.0).to(tl.float32) + v_raw = tl.load(v_base + offs, mask=mask_hd, other=0.0) + + # RMSNorm on K + inv_rms = tl.rsqrt(tl.sum(k_raw * k_raw) / head_dim + eps) + norm_w = tl.load(k_norm_weight_ptr + offs, mask=mask_hd, other=1.0).to(tl.float32) + k_normed = k_raw * inv_rms * norm_w + + # RoPE (neox style): k_first, k_second -> rotated + cos_sin_base = cos_sin_cache_ptr + position * cos_sin_stride_pos + cos_v = tl.load(cos_sin_base + offs, mask=mask_half, other=1.0).to(tl.float32) + sin_v = tl.load( + cos_sin_base + half_rotary_dim + offs, mask=mask_half, other=0.0 + ).to(tl.float32) + + # Extract first/second halves of K for rotation + k_first = tl.where(mask_half, k_normed, 0.0) + k_second_raw = tl.load( + k_base + half_rotary_dim + offs, mask=mask_half, other=0.0 + ).to(tl.float32) + norm_w_second = tl.load( + k_norm_weight_ptr + half_rotary_dim + offs, mask=mask_half, other=1.0 + ).to(tl.float32) + k_second = k_second_raw * inv_rms * norm_w_second + + # Apply rotation + k_rot_first = k_first * cos_v - k_second * sin_v + k_rot_second = k_second * cos_v + k_first * sin_v + + # Store V (no transform) + tl.store(v_write + offs, v_raw, mask=mask_hd) + + # Store K: rotated halves + pass-through + tl.store(k_write + offs, k_rot_first.to(v_raw.dtype), mask=mask_half) + tl.store( + k_write + half_rotary_dim + offs, k_rot_second.to(v_raw.dtype), mask=mask_half + ) + mask_pass = (offs >= rotary_dim) & (offs < head_dim) + tl.store(k_write + offs, k_normed.to(v_raw.dtype), mask=mask_pass) + + +def _fused_norm_rope( + kv: torch.Tensor, # [total_ctx, kv_size*2] + k_norm_weight: torch.Tensor, # [head_dim] + cos_sin_cache: torch.Tensor, # [max_pos, rotary_dim] + positions: torch.Tensor, # [total_ctx] + num_kv_heads: int, + head_dim: int, + rotary_dim: int, + eps: float = 1e-6, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused RMSNorm + RoPE materialization for a single layer.""" + total_ctx = kv.shape[0] + if total_ctx == 0: + empty = torch.empty( + (0, num_kv_heads, head_dim), dtype=kv.dtype, device=kv.device + ) + return empty, empty + + kv_size = num_kv_heads * head_dim + if kv.shape[1] != kv_size * 2: + raise ValueError( + "Invalid fused KV projection shape: " + f"got {tuple(kv.shape)}, expected second dim {kv_size * 2}." + ) + if rotary_dim <= 0 or rotary_dim > head_dim or rotary_dim % 2 != 0: + raise ValueError( + "Invalid fused KV rotary/head dim pair: " + f"rotary_dim={rotary_dim}, head_dim={head_dim}." + ) + + half_rotary_dim = rotary_dim // 2 + BLOCK_HD = triton.next_power_of_2(head_dim) + + # Ensure int64 for indexing + if positions.device != kv.device: + positions = positions.to(device=kv.device, dtype=torch.int64) + elif positions.dtype != torch.int64: + positions = positions.to(torch.int64) + + k_out = torch.empty( + (total_ctx, num_kv_heads, head_dim), dtype=kv.dtype, device=kv.device + ) + v_out = torch.empty_like(k_out) + + _fused_norm_rope_kernel[(total_ctx, num_kv_heads)]( + kv, + k_norm_weight, + cos_sin_cache, + positions, + k_out, + v_out, + kv.stride(0), + cos_sin_cache.stride(0), + k_out.stride(0), + k_out.stride(1), + v_out.stride(0), + v_out.stride(1), + total_ctx, + num_kv_heads, + head_dim, + kv_size, + rotary_dim, + half_rotary_dim, + eps, + BLOCK_HD, + ) + return k_out, v_out + + +class FusedKVMaterializeHelper: + """Fused KV materialization helper using batched projection. + + Uses torch.einsum for batched KV projection across all layers, + then a Triton kernel for fused RMSNorm + RoPE materialization per layer. + """ + + def __init__( + self, + layers: List, + rotary_emb, + num_kv_heads: int, + head_dim: int, + device: torch.device, + ): + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.rotary_emb = rotary_emb + self.n_layers = len(layers) + self.device = device + + self.rotary_dim = int(getattr(rotary_emb, "rotary_dim", head_dim)) + self.is_neox_style = bool(getattr(rotary_emb, "is_neox_style", True)) + + if not self.is_neox_style: + raise NotImplementedError("Only neox-style RoPE is supported.") + if self.rotary_dim <= 0 or self.rotary_dim > self.head_dim: + raise ValueError( + "Invalid fused KV rotary/head dim pair: " + f"rotary_dim={self.rotary_dim}, head_dim={self.head_dim}." + ) + + # Pre-extract and stack weights for batched projection. + kv_weights = [] + self.k_norm_weights = [] + self.eps_values = [] + + for layer_id, layer in enumerate(layers): + attn = layer.self_attn + if int(attn.num_kv_heads) != self.num_kv_heads: + raise ValueError( + "num_kv_heads mismatch across layers for fused KV path: " + f"expected {self.num_kv_heads}, got {int(attn.num_kv_heads)} at layer {layer_id}." + ) + if int(attn.head_dim) != self.head_dim: + raise ValueError( + "head_dim mismatch across layers for fused KV path: " + f"expected {self.head_dim}, got {int(attn.head_dim)} at layer {layer_id}." + ) + layer_rotary_dim = int( + getattr(attn.rotary_emb, "rotary_dim", self.head_dim) + ) + layer_is_neox = bool(getattr(attn.rotary_emb, "is_neox_style", True)) + if ( + layer_rotary_dim != self.rotary_dim + or layer_is_neox != self.is_neox_style + ): + raise ValueError( + "RoPE config mismatch across layers for fused KV path: " + f"expected (rotary_dim={self.rotary_dim}, neox={self.is_neox_style}), " + f"got (rotary_dim={layer_rotary_dim}, neox={layer_is_neox}) at layer {layer_id}." + ) + + # Extract KV portion of QKV weight + qkv_w = attn.qkv_proj.weight + kv_weight = qkv_w[attn.q_size : attn.q_size + 2 * attn.kv_size] + kv_weights.append(kv_weight) + self.k_norm_weights.append(attn.k_norm.weight) + self.eps_values.append(attn.k_norm.variance_epsilon) + + # Stack for batched einsum: [n_layers, kv_size*2, hidden_size] + self.batched_kv_weight = torch.stack(kv_weights) + + def materialize( + self, + ctx_hidden: torch.Tensor, + positions: torch.Tensor, + write_layer_kv: Callable[[int, torch.Tensor, torch.Tensor], None], + ) -> None: + """Materialize KV cache for all layers using batched projection.""" + total_ctx = ctx_hidden.shape[0] + if total_ctx == 0: + return + + if positions.ndim != 1: + positions = positions.reshape(-1) + if positions.numel() != total_ctx: + raise ValueError( + "positions must match ctx_hidden token count for fused KV materialization: " + f"positions={positions.numel()}, total_ctx={total_ctx}." + ) + + max_position = int(positions.max().item()) + ensure_cos_sin_cache_length = getattr( + self.rotary_emb, "_ensure_cos_sin_cache_length", None + ) + if callable(ensure_cos_sin_cache_length): + ensure_cos_sin_cache_length(max_position) + + cos_sin_cache = self.rotary_emb.cos_sin_cache + if max_position >= int(cos_sin_cache.shape[0]): + raise RuntimeError( + "RoPE cos/sin cache is too short for fused KV materialization: " + f"max_position={max_position}, cache_len={int(cos_sin_cache.shape[0])}." + ) + if cos_sin_cache.device != ctx_hidden.device: + cos_sin_cache = cos_sin_cache.to(ctx_hidden.device) + + # Batched KV projection: [n_layers, total_ctx, kv_size*2] + kv_all = torch.einsum("th,loh->lto", ctx_hidden, self.batched_kv_weight) + + # Per-layer fused norm/RoPE/materialize, then delegate writes to the KV pool. + for layer_id in range(self.n_layers): + cache_k, cache_v = _fused_norm_rope( + kv_all[layer_id], + self.k_norm_weights[layer_id], + cos_sin_cache, + positions, + self.num_kv_heads, + self.head_dim, + self.rotary_dim, + self.eps_values[layer_id], + ) + write_layer_kv(layer_id, cache_k, cache_v) diff --git a/python/sglang/srt/utils/runai_utils.py b/python/sglang/srt/utils/runai_utils.py index 0424a6371bde..dd74efb6626d 100644 --- a/python/sglang/srt/utils/runai_utils.py +++ b/python/sglang/srt/utils/runai_utils.py @@ -5,6 +5,8 @@ import os from pathlib import Path +from sglang.srt.environ import envs + logger = logging.getLogger(__name__) SUPPORTED_SCHEMES = ["s3://", "gs://", "az://"] @@ -26,12 +28,6 @@ # This avoids file locks, race conditions, and duplicate downloads -def get_cache_dir() -> str: - # Expand user path (~) to ensure absolute paths for locking - path = os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/") - return os.path.expanduser(path) - - def list_safetensors(path: str = "") -> list[str]: """ List full file names from object path and filter by allow pattern. @@ -122,7 +118,7 @@ def get_path(cls, model_path: str) -> str: Returns the local directory path. """ model_hash = hashlib.sha256(str(model_path).encode()).hexdigest()[:16] - base_dir = get_cache_dir() + base_dir = envs.SGLANG_CACHE_DIR.get() # Ensure base cache dir exists os.makedirs(os.path.join(base_dir, "model_streamer"), exist_ok=True) diff --git a/python/sglang/test/kits/eval_accuracy_kit.py b/python/sglang/test/kits/eval_accuracy_kit.py index 25bf58151c9f..9757dc01523e 100644 --- a/python/sglang/test/kits/eval_accuracy_kit.py +++ b/python/sglang/test/kits/eval_accuracy_kit.py @@ -9,12 +9,16 @@ _THRESHOLD_NOT_SET = float("nan") -def _check_accept_length(test_case, base_url, threshold): - """Check speculative decoding accept length from server info.""" - server_info = requests.get(base_url + "/get_server_info").json() - avg_spec_accept_length = server_info["internal_states"][0]["avg_spec_accept_length"] - print(f"{avg_spec_accept_length=}") - test_case.assertGreater(avg_spec_accept_length, threshold) +def _check_accept_length(test_case, base_url, threshold=None): + """Print accept length; optionally assert it exceeds threshold.""" + try: + server_info = requests.get(base_url + "/server_info").json() + val = server_info["internal_states"][0]["avg_spec_accept_length"] + except (KeyError, IndexError, requests.RequestException): + return + print(f"avg_spec_accept_length={val:.4f}") + if threshold is not None: + test_case.assertGreater(val, threshold) class GSM8KMixin: @@ -57,8 +61,7 @@ def test_gsm8k(self): self.assertGreaterEqual(metrics["score"], self.gsm8k_accuracy_thres) - if self.gsm8k_accept_length_thres is not None: - _check_accept_length(self, self.base_url, self.gsm8k_accept_length_thres) + _check_accept_length(self, self.base_url, self.gsm8k_accept_length_thres) class MMLUMixin: @@ -95,8 +98,7 @@ def test_mmlu(self): self.assertGreaterEqual(metrics["score"], self.mmlu_score_threshold) - if self.mmlu_accept_length_thres is not None: - _check_accept_length(self, self.base_url, self.mmlu_accept_length_thres) + _check_accept_length(self, self.base_url, self.mmlu_accept_length_thres) class HumanEvalMixin: @@ -136,6 +138,8 @@ def test_human_eval(self): self.assertGreaterEqual(metrics["score"], threshold) + _check_accept_length(self, self.base_url) + class MGSMEnMixin: """Mixin for MGSM English evaluation. @@ -169,3 +173,5 @@ def test_mgsm_en(self): write_github_step_summary(f"### test_mgsm_en\n{metrics['score']=:.4f}\n") self.assertGreaterEqual(metrics["score"], self.mgsm_en_score_threshold) + + _check_accept_length(self, self.base_url) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index c2d84ff2f85b..adbfcaf41d72 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -235,7 +235,7 @@ def _forward_gme_qwen2_vl( **kwargs, ) -> torch.Tensor: if inputs_embeds is None: - inputs_embeds = self.model.model.embed_tokens(input_ids) + inputs_embeds = self.model.model.get_input_embeddings()(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.model.visual.get_dtype()) image_embeds = self.model.visual( diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 9cbd2e59dc90..6022f602c3f4 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -107,6 +107,10 @@ DEFAULT_TARGET_MODEL_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct" DEFAULT_DRAFT_MODEL_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B" +# DFLASH model +DEFAULT_TARGET_MODEL_DFLASH = "meta-llama/Llama-3.1-8B-Instruct" +DEFAULT_DRAFT_MODEL_DFLASH = "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat" + # EAGLE2 with DP-Attention models DEFAULT_TARGET_MODEL_EAGLE_DP_ATTN = "Qwen/Qwen3-30B-A3B" DEFAULT_DRAFT_MODEL_EAGLE_DP_ATTN = "Tengyunw/qwen3_30b_moe_eagle3" diff --git a/scripts/ci/cuda/ci_install_dependency.sh b/scripts/ci/cuda/ci_install_dependency.sh index 5bfbea04ffeb..c10a79e62222 100755 --- a/scripts/ci/cuda/ci_install_dependency.sh +++ b/scripts/ci/cuda/ci_install_dependency.sh @@ -358,6 +358,10 @@ mark_step_done "Fix other dependencies" # can delete the .pth file without reliably recreating it (pip race condition). $PIP_CMD install "nvidia-cutlass-dsl>=4.4.1" "nvidia-cutlass-dsl-libs-base>=4.4.1" --no-deps --force-reinstall $PIP_INSTALL_SUFFIX || true +# Download kernels from kernels community +kernels download python || true +kernels lock python || true +mv python/kernels.lock ${HOME}/.cache/sglang || true # Install human-eval pip install "setuptools==70.0.0" diff --git a/sgl-kernel/csrc/gemm/marlin/marlin_template.h b/sgl-kernel/csrc/gemm/marlin/marlin_template.h index 01eb338782c4..19f5d5477c4e 100644 --- a/sgl-kernel/csrc/gemm/marlin/marlin_template.h +++ b/sgl-kernel/csrc/gemm/marlin/marlin_template.h @@ -487,11 +487,11 @@ __global__ void Marlin( constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == sglang::kFE2M1f ? 2 : 1) - : 1; + // FP4 (kFE2M1f) uses FP8 scales (1 byte/element), others use FP16 (2 bytes) + int s_gl_stride = prob_n / (w_type == sglang::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (w_type == sglang::kFE2M1f ? 16 : 8); + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -543,8 +543,7 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == sglang::kFE2M1f ? 2 : 1) + - s_sh_stride * slice_col + threadIdx.x; + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; } } auto s_sh_wr = threadIdx.x; @@ -566,15 +565,7 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == sglang::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + warp_row % 2; - - } else if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; @@ -879,7 +870,7 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == sglang::kFE2M1f ? 2 : 1)); + int cur_group_id = k_blocks / group_blocks; int4* sh_s_stage = sh_s + s_sh_stage * pipe; diff --git a/test/registered/8-gpu-models/test_ring_2_5_1t.py b/test/registered/8-gpu-models/test_ring_2_5_1t.py index 71b2a4f2609e..29c64160cf96 100644 --- a/test/registered/8-gpu-models/test_ring_2_5_1t.py +++ b/test/registered/8-gpu-models/test_ring_2_5_1t.py @@ -5,8 +5,7 @@ from sglang.test.run_combined_tests import run_combined_tests from sglang.test.test_utils import ModelLaunchSettings -# register_cuda_ci(est_time=1000, suite="nightly-8-gpu-common", nightly=True) -register_cuda_ci(est_time=1000, suite="stage-c-test-8-gpu-h200") +register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True) RING_2_5_1T_MODEL_PATH = "inclusionAI/Ring-2.5-1T" @@ -25,6 +24,8 @@ def test_ring_2_5_1t(self): '{"enable_multithread_load": true, "num_threads": 64}', "--watchdog-timeout", "1800", + "--soft-watchdog-timeout", + "1800", ] variants = [ diff --git a/test/registered/amd/accuracy/mi30x/test_gsm8k_eval_amd.py b/test/registered/amd/accuracy/mi30x/test_gsm8k_eval_amd.py index aa7813ee543a..9a37ed6d5315 100644 --- a/test/registered/amd/accuracy/mi30x/test_gsm8k_eval_amd.py +++ b/test/registered/amd/accuracy/mi30x/test_gsm8k_eval_amd.py @@ -1,7 +1,7 @@ """ AMD GSM8K Evaluation Test (Migrated from test/srt/nightly/) -This test evaluates instruction-tuned models on the mgsm_en benchmark using chat completions. +This test evaluates instruction-tuned models on the gsm8k benchmark using chat completions. Models are tested with various TP configurations on AMD GPUs. Registry: nightly-amd suite (2-GPU tests) @@ -35,34 +35,35 @@ register_amd_ci(est_time=3600, suite="nightly-amd", nightly=True) MODEL_SCORE_THRESHOLDS = { + # Thresholds set at 5% below reported GSM8K (5-shot/CoT) scores # Llama 3.1 series - "meta-llama/Llama-3.1-8B-Instruct": 0.82, - "meta-llama/Llama-3.1-70B-Instruct": 0.95, + "meta-llama/Llama-3.1-8B-Instruct": 0.80, # 84.5% - 5% + "meta-llama/Llama-3.1-70B-Instruct": 0.89, # 94.1% - 5% # Llama 3.2 series (smaller models) - "meta-llama/Llama-3.2-3B-Instruct": 0.55, + "meta-llama/Llama-3.2-3B-Instruct": 0.43, # 48.2% - 5% # Mistral series - "mistralai/Mistral-7B-Instruct-v0.3": 0.55, - "mistralai/Mixtral-8x7B-Instruct-v0.1": 0.58, + "mistralai/Mistral-7B-Instruct-v0.3": 0.47, # 52.1% - 5% + "mistralai/Mixtral-8x7B-Instruct-v0.1": 0.69, # 74.4% - 5% (lower if AMD scores differently) # DeepSeek series - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": 0.85, + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": 0.81, # 86.4% - 5% # Qwen2 series - "Qwen/Qwen2-57B-A14B-Instruct": 0.86, - "Qwen/Qwen2.5-7B-Instruct": 0.85, + "Qwen/Qwen2-57B-A14B-Instruct": 0.76, # 80.7% - 5% (official A14B score; 88.2% was the 72B) + "Qwen/Qwen2.5-7B-Instruct": 0.82, # 86.3% - 5% # Qwen3 series - "Qwen/Qwen3-30B-A3B-Thinking-2507": 0.84, # MoE model verified on MI300X - "Qwen/Qwen3-8B": 0.77, + "Qwen/Qwen3-30B-A3B-Thinking-2507": 0.86, # 91.4% - 5% (full attention mode; ensure sufficient max_tokens) + "Qwen/Qwen3-8B": 0.76, # ~81% - 5% # Google Gemma - "google/gemma-2-27b-it": 0.91, - "google/gemma-2-9b-it": 0.72, + "google/gemma-2-27b-it": 0.86, # 90.7% - 5% + "google/gemma-2-9b-it": 0.74, # 78.5% - 5% # "neuralmagic/gemma-2-2b-it-FP8": 0.4, # Small 2B model - OOM on single GPU # FP8 quantized models - "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8": 0.8, - "neuralmagic/Mistral-7B-Instruct-v0.3-FP8": 0.54, - "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8": 0.94, - "neuralmagic/Qwen2-72B-Instruct-FP8": 0.92, - "neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.81, - "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.57, - "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8": 0.84, + "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8": 0.80, # 84.5% - 5% + "neuralmagic/Mistral-7B-Instruct-v0.3-FP8": 0.46, # ~51% - 5% + "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8": 0.89, # 94.1% - 5% + "neuralmagic/Qwen2-72B-Instruct-FP8": 0.86, # 91.1% - 5% + "neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.76, # 80.7% - 5% (official A14B score) + "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.69, # 74.4% - 5% + "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8": 0.81, # 86.4% - 5% } failing_models = { @@ -185,7 +186,7 @@ def check_model_scores(results): summary += line print(f"\n{'='*60}") - print("SUMMARY - TP=2 Instruction Models (mgsm_en)") + print("SUMMARY - TP=2 Instruction Models (gsm8k)") print(f"{'='*60}") print(summary) print(f"\n📊 Final Statistics:") @@ -200,7 +201,7 @@ def check_model_scores(results): raise AssertionError(f"The following models failed:\n{failure_msg}") -# Do not use `CustomTestCase` since `test_mgsm_en_all_models` does not want retry +# Do not use `CustomTestCase` since `test_gsm8k_all_models` does not want retry class TestNightlyGsm8KEval(unittest.TestCase): @classmethod def setUpClass(cls): @@ -215,7 +216,7 @@ def setUpClass(cls): ] cls.base_url = DEFAULT_URL_FOR_TEST - def test_mgsm_en_all_models(self): + def test_gsm8k_all_models(self): warnings.filterwarnings( "ignore", category=ResourceWarning, message="unclosed.*socket" ) @@ -226,7 +227,7 @@ def test_mgsm_en_all_models(self): print(f"\n{'='*60}") print("AMD GSM8K Evaluation Test (TP=2 Instruction Models)") print(f"{'='*60}") - print(f"Benchmark: mgsm_en (chat completions)") + print(f"Benchmark: gsm8k (chat completions)") print(f"{'='*60}\n") for model_group, is_fp8, is_tp2 in self.model_groups: @@ -261,13 +262,13 @@ def test_mgsm_en_all_models(self): args = SimpleNamespace( base_url=self.base_url, model=model, - eval_name="mgsm_en", + eval_name="gsm8k", num_examples=None, num_threads=1024, ) # Run eval with timing and retries - print(f"📊 Running mgsm_en evaluation...") + print(f"📊 Running gsm8k evaluation...") eval_start = time.time() threshold = MODEL_SCORE_THRESHOLDS.get(model) metrics = None diff --git a/test/registered/amd/test_kimi_k25_mxfp4.py b/test/registered/amd/test_kimi_k25_mxfp4.py index a4ef774304c8..1ce83f8eb928 100644 --- a/test/registered/amd/test_kimi_k25_mxfp4.py +++ b/test/registered/amd/test_kimi_k25_mxfp4.py @@ -27,6 +27,7 @@ register_amd_ci(est_time=3600, suite="stage-c-test-large-8-gpu-amd-mi35x") KIMI_K25_MXFP4_MODEL_PATH = "amd/Kimi-K2.5-MXFP4" +KIMI_K25_MXFP4_REVISION = "b071bc6f8eb042e093e14f3b8bdbad71c18e09d3" SERVER_LAUNCH_TIMEOUT = 3600 @@ -36,6 +37,8 @@ def setUpClass(cls): cls.model = KIMI_K25_MXFP4_MODEL_PATH cls.base_url = DEFAULT_URL_FOR_TEST other_args = [ + "--revision", + KIMI_K25_MXFP4_REVISION, "--tp", "8", "--attention-backend", diff --git a/test/registered/distributed/test_dp_attention_large.py b/test/registered/distributed/test_dp_attention_large.py index 48cdee862f8a..3e1d65f747e3 100644 --- a/test/registered/distributed/test_dp_attention_large.py +++ b/test/registered/distributed/test_dp_attention_large.py @@ -56,11 +56,11 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_mgsm_en(self): + def test_gsm8k(self): args = SimpleNamespace( base_url=self.base_url, model=self.model, - eval_name="mgsm_en", + eval_name="gsm8k", num_examples=None, num_threads=1024, ) diff --git a/test/registered/distributed/test_pp_single_node.py b/test/registered/distributed/test_pp_single_node.py index 76e1c068d7f1..0dd5d4fe8277 100644 --- a/test/registered/distributed/test_pp_single_node.py +++ b/test/registered/distributed/test_pp_single_node.py @@ -128,11 +128,11 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_mgsm_en(self): + def test_gsm8k(self): args = SimpleNamespace( base_url=self.base_url, model=self.model, - eval_name="mgsm_en", + eval_name="gsm8k", num_examples=None, num_threads=1024, ) diff --git a/test/registered/eval/test_text_models_gsm8k_eval.py b/test/registered/eval/test_text_models_gsm8k_eval.py index 9436895422b7..c2974439c797 100644 --- a/test/registered/eval/test_text_models_gsm8k_eval.py +++ b/test/registered/eval/test_text_models_gsm8k_eval.py @@ -26,28 +26,29 @@ register_cuda_ci(est_time=3600, suite="nightly-eval-text-2-gpu", nightly=True) MODEL_SCORE_THRESHOLDS = { - "meta-llama/Llama-3.1-8B-Instruct": 0.82, - "mistralai/Mistral-7B-Instruct-v0.3": 0.58, - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": 0.85, - "google/gemma-2-27b-it": 0.91, - "meta-llama/Llama-3.1-70B-Instruct": 0.95, - "mistralai/Mixtral-8x7B-Instruct-v0.1": 0.616, - "Qwen/Qwen2-57B-A14B-Instruct": 0.86, - "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8": 0.83, - "neuralmagic/Mistral-7B-Instruct-v0.3-FP8": 0.54, - "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8": 0.835, - "zai-org/GLM-4.5-Air-FP8": 0.75, - # The threshold of neuralmagic/gemma-2-2b-it-FP8 should be 0.6, but this model has some accuracy regression. - # The fix is tracked at https://github.com/sgl-project/sglang/issues/4324, we set it to 0.50, for now, to make CI green. - "neuralmagic/gemma-2-2b-it-FP8": 0.50, - "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8": 0.94, - "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.65, - "neuralmagic/Qwen2-72B-Instruct-FP8": 0.94, - "neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.82, + # Thresholds set at 5% below reported GSM8K (5-shot/CoT) scores + "meta-llama/Llama-3.1-8B-Instruct": 0.80, # 84.5% - 5% + "mistralai/Mistral-7B-Instruct-v0.3": 0.47, # 52.1% - 5% + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": 0.81, # 86.4% - 5% + "google/gemma-2-27b-it": 0.86, # 90.7% - 5% + "meta-llama/Llama-3.1-70B-Instruct": 0.89, # 94.1% - 5% + "mistralai/Mixtral-8x7B-Instruct-v0.1": 0.69, # 74.4% - 5% + "Qwen/Qwen2-57B-A14B-Instruct": 0.76, # 80.7% - 5% (official A14B score; 88.2% was the 72B) + "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8": 0.80, # 84.5% - 5% + "neuralmagic/Mistral-7B-Instruct-v0.3-FP8": 0.47, # 52.1% - 5% + "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8": 0.81, # 86.4% - 5% + "zai-org/GLM-4.5-Air-FP8": 0.80, # ~85% - 5% + # GSM8K baseline for gemma-2-2b is ~40-45%; threshold set at 5% below. + # (Previously 0.50 based on MGSM-EN; tracked regression: https://github.com/sgl-project/sglang/issues/4324) + "neuralmagic/gemma-2-2b-it-FP8": 0.38, # ~43% - 5% + "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8": 0.89, # 94.1% - 5% + "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.69, # 74.4% - 5% + "neuralmagic/Qwen2-72B-Instruct-FP8": 0.86, # 91.1% - 5% + "neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.76, # 80.7% - 5% (official A14B score) } -# Do not use `CustomTestCase` since `test_mgsm_en_all_models` does not want retry +# Do not use `CustomTestCase` since `test_gsm8k_all_models` does not want retry class TestNightlyGsm8KEval(unittest.TestCase): @classmethod def setUpClass(cls): @@ -66,7 +67,7 @@ def setUpClass(cls): cls.base_url = DEFAULT_URL_FOR_TEST - def test_mgsm_en_all_models(self): + def test_gsm8k_all_models(self): warnings.filterwarnings( "ignore", category=ResourceWarning, message="unclosed.*socket" ) @@ -91,7 +92,7 @@ def test_mgsm_en_all_models(self): args = SimpleNamespace( base_url=self.base_url, model=model_setup.model_path, - eval_name="mgsm_en", + eval_name="gsm8k", num_examples=None, num_threads=1024, ) diff --git a/test/registered/openai_server/basic/test_http2_server.py b/test/registered/openai_server/basic/test_http2_server.py new file mode 100644 index 000000000000..6cfc3ee7e7e0 --- /dev/null +++ b/test/registered/openai_server/basic/test_http2_server.py @@ -0,0 +1,112 @@ +""" +Test HTTP/2 server (Granian) with basic OpenAI-compatible endpoints. + +Verifies that --enable-http2 launches successfully and serves requests +via both HTTP/1.1 and HTTP/2 (h2c). +""" + +import subprocess +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +try: + import granian # noqa: F401 + + _HAS_GRANIAN = True +except ImportError: + _HAS_GRANIAN = False + +register_cuda_ci(est_time=120, suite="stage-b-test-1-gpu-small") + + +@unittest.skipUnless(_HAS_GRANIAN, "granian not installed (pip install sglang[http2])") +class TestHTTP2Server(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--enable-http2"], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_health(self): + resp = requests.get(f"{self.base_url}/health") + self.assertEqual(resp.status_code, 200) + + def test_get_model_info(self): + resp = requests.get(f"{self.base_url}/get_model_info") + self.assertEqual(resp.status_code, 200) + self.assertIn("model_path", resp.json()) + + def test_completion(self): + resp = requests.post( + f"{self.base_url}/v1/completions", + json={ + "model": self.model, + "prompt": "The capital of France is", + "max_tokens": 8, + "temperature": 0, + }, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertIn("choices", data) + self.assertGreater(len(data["choices"][0]["text"]), 0) + + def test_chat_completion(self): + resp = requests.post( + f"{self.base_url}/v1/chat/completions", + json={ + "model": self.model, + "messages": [{"role": "user", "content": "Say hello"}], + "max_tokens": 16, + "temperature": 0, + }, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertIn("choices", data) + self.assertGreater(len(data["choices"][0]["message"]["content"]), 0) + + def test_h2c_with_curl(self): + """Verify the server actually speaks HTTP/2 via h2c.""" + result = subprocess.run( + [ + "curl", + "--http2-prior-knowledge", + "-s", + "-o", + "/dev/null", + "-w", + "%{http_version}", + f"{self.base_url}/health", + ], + capture_output=True, + text=True, + timeout=10, + ) + self.assertEqual( + result.stdout.strip(), "2", "Server should respond with HTTP/2" + ) + + +if __name__ == "__main__": + unittest.main(verbosity=3) diff --git a/test/registered/piecewise_cuda_graph/test_piecewise_cuda_graph_support_1_gpu.py b/test/registered/piecewise_cuda_graph/test_piecewise_cuda_graph_support_1_gpu.py index e38b59f5b86b..ce6fe2291828 100644 --- a/test/registered/piecewise_cuda_graph/test_piecewise_cuda_graph_support_1_gpu.py +++ b/test/registered/piecewise_cuda_graph/test_piecewise_cuda_graph_support_1_gpu.py @@ -41,21 +41,19 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_mgsm_accuracy(self): - num_examples = 2000 - + def test_gsm8k_accuracy(self): args = SimpleNamespace( base_url=self.base_url, model=self.model, - eval_name="mgsm_en", - num_examples=num_examples, - num_threads=min(num_examples, 1024), + eval_name="gsm8k", + num_examples=None, + num_threads=1024, ) metrics = run_eval(args) - print(f"MGSM Accuracy: {metrics['score']:.3f}") + print(f"GSM8K Accuracy: {metrics['score']:.3f}") - self.assertGreaterEqual(metrics["score"], 0.70) + self.assertGreaterEqual(metrics["score"], 0.82) class TestPiecewiseCudaGraphInternVL25(CustomTestCase): @@ -79,21 +77,23 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_mgsm_accuracy(self): - num_examples = 2000 - + def test_gsm8k_accuracy(self): args = SimpleNamespace( base_url=self.base_url, model=self.model, - eval_name="mgsm_en", - num_examples=num_examples, - num_threads=min(num_examples, 1024), + eval_name="gsm8k", + num_examples=None, + num_threads=1024, ) metrics = run_eval(args) - print(f"MGSM Accuracy: {metrics['score']:.3f}") + print(f"GSM8K Accuracy: {metrics['score']:.3f}") - self.assertGreaterEqual(metrics["score"], 0.70) + # Baseline (no piecewise CUDA graph): 0.571 — this eval uses 5-shot + # concatenated text via chat API, which scores lower than reported + # benchmarks (~77.8%) that use proper CoT chat format. The threshold + # is set 5% below observed to catch catastrophic regressions. + self.assertGreaterEqual(metrics["score"], 0.54) class TestPiecewiseCudaGraphQwen25VLEmbedding(CustomTestCase): diff --git a/test/registered/quant/test_nvfp4_marlin_fallback.py b/test/registered/quant/test_nvfp4_marlin_fallback.py new file mode 100644 index 000000000000..348294d5e565 --- /dev/null +++ b/test/registered/quant/test_nvfp4_marlin_fallback.py @@ -0,0 +1,788 @@ +"""Tests for NVFP4 Marlin fallback on non-Blackwell GPUs (SM75+).""" + +import unittest + +import torch + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci(est_time=480, suite="stage-b-test-1-gpu-large") + +_FP4_MARLIN_GROUP_SIZE = 16 + +_FP4_E2M1_LUT_VALUES = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + 0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def _check_requirements(): + from sglang.srt.utils import is_cuda + + if not is_cuda(): + return False + from sglang.srt.layers.quantization.marlin_utils_fp4 import is_fp4_marlin_supported + + if not is_fp4_marlin_supported(): + return False + return True + + +def _dequant_fp4_weights( + raw_weight: torch.Tensor, device: torch.device +) -> torch.Tensor: + """Dequantize uint8-packed FP4 E2M1 weights to float32 via lookup table.""" + lut = torch.tensor(_FP4_E2M1_LUT_VALUES, dtype=torch.float32, device=device) + lo = (raw_weight.int() & 0x0F).long() + hi = ((raw_weight.int() >> 4) & 0x0F).long() + return torch.stack([lut[lo], lut[hi]], dim=-1).reshape( + raw_weight.shape[0], raw_weight.shape[1] * 2 + ) + + +class _FakeLayer(torch.nn.Module): + """Minimal stand-in for a quantized layer in unit tests.""" + + pass + + +# --------------------------------------------------------------------------- +# Linear (non-MoE) tests +# --------------------------------------------------------------------------- +class TestNvfp4MarlinLinear(CustomTestCase): + """Test the FP4 Marlin linear layer fallback (non-MoE).""" + + def setUp(self): + if not _check_requirements(): + self.skipTest("Requirements not met (CUDA unavailable or SM < 75)") + self.device = torch.device("cuda") + self.dtype = torch.bfloat16 + + # -- helpers ------------------------------------------------------------- + + def _make_fake_fp4_layer(self, N, K): + layer = _FakeLayer() + layer.params_dtype = self.dtype + layer.input_size_per_partition = K + layer.output_size_per_partition = N + + layer.weight = torch.nn.Parameter( + torch.randint(0, 256, (N, K // 2), dtype=torch.uint8, device=self.device), + requires_grad=False, + ) + layer.weight_scale = torch.nn.Parameter( + torch.ones( + N, + K // _FP4_MARLIN_GROUP_SIZE, + dtype=torch.float8_e4m3fn, + device=self.device, + ), + requires_grad=False, + ) + layer.weight_scale_2_marlin = torch.nn.Parameter( + torch.tensor(1.0, dtype=torch.float32, device=self.device), + requires_grad=False, + ) + return layer + + def _run_fp4_marlin_vs_reference(self, M, N, K): + """Prepare a layer, run the Marlin kernel, return (kernel_out, ref_out).""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + raw_weight = torch.randint( + 0, 256, (N, K // 2), dtype=torch.uint8, device=self.device + ) + dq_weight = _dequant_fp4_weights(raw_weight, self.device) + + raw_scale = torch.full( + (N, K // _FP4_MARLIN_GROUP_SIZE), + 1.0, + dtype=torch.float8_e4m3fn, + device=self.device, + ) + global_scale_val = torch.tensor(1.0, dtype=torch.float32, device=self.device) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) + ref_output = (x.float() @ dq_weight.T).to(self.dtype) + + layer = self._make_fake_fp4_layer(N, K) + layer.weight = torch.nn.Parameter(raw_weight, requires_grad=False) + layer.weight_scale = torch.nn.Parameter(raw_scale, requires_grad=False) + layer.weight_scale_2_marlin = torch.nn.Parameter( + global_scale_val.to(self.dtype), requires_grad=False + ) + + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + marlin_output = apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + return marlin_output, ref_output + + # -- tests --------------------------------------------------------------- + + def test_prepare_and_apply_fp4_marlin_linear(self): + """Smoke test: shape and dtype are correct after prepare + apply.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + N, K, M = 256, 128, 16 + layer = self._make_fake_fp4_layer(N, K) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + self.assertTrue(hasattr(layer, "marlin_workspace")) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) + output = apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + self.assertEqual(output.shape, (M, N)) + self.assertEqual(output.dtype, self.dtype) + + def test_fp4_marlin_numerical_correctness(self): + """Kernel output vs BF16 dequant reference (cosine sim, MAE, assert_close).""" + N, K, M = 256, 256, 32 + marlin_output, ref_output = self._run_fp4_marlin_vs_reference(M, N, K) + + self.assertEqual(marlin_output.shape, ref_output.shape) + self.assertEqual(marlin_output.dtype, ref_output.dtype) + + cos_sim = torch.nn.functional.cosine_similarity( + marlin_output.float().flatten(), ref_output.float().flatten(), dim=0 + ) + self.assertGreater( + cos_sim.item(), + 0.99, + f"Cosine similarity {cos_sim.item():.6f} too low", + ) + + rel_mae = torch.mean( + torch.abs(marlin_output.float() - ref_output.float()) + ) / torch.mean(torch.abs(ref_output.float())) + self.assertLess( + rel_mae.item(), + 0.04, + f"Relative MAE {rel_mae.item():.6f} >= 0.04", + ) + + torch.testing.assert_close(marlin_output, ref_output, atol=1e-1, rtol=1e-1) + + def test_fp4_marlin_multiple_shapes(self): + """Numerical correctness across various (M, N, K) dimensions.""" + shapes = [ + (1, 256, 256), + (16, 512, 128), + (64, 128, 512), + (32, 256, 256), + ] + for M, N, K in shapes: + with self.subTest(M=M, N=N, K=K): + marlin_out, ref_out = self._run_fp4_marlin_vs_reference(M, N, K) + rel_mae = torch.mean( + torch.abs(marlin_out.float() - ref_out.float()) + ) / torch.mean(torch.abs(ref_out.float())) + self.assertLess( + rel_mae.item(), + 0.04, + f"Shape ({M},{N},{K}): relative MAE {rel_mae.item():.6f} >= 0.04", + ) + + def test_fp4_marlin_linear_with_bias(self): + """Verify output_with_bias == output_no_bias + bias.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + N, K, M = 256, 128, 16 + layer = self._make_fake_fp4_layer(N, K) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) + bias = torch.randn(N, dtype=self.dtype, device=self.device) + + common = dict( + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + output_no_bias = apply_fp4_marlin_linear(input=x, **common) + output_with_bias = apply_fp4_marlin_linear(input=x, bias=bias, **common) + + torch.testing.assert_close( + output_with_bias, output_no_bias + bias, atol=1e-5, rtol=1e-5 + ) + + def test_fp4_marlin_registered_op_numerical(self): + """torch.ops.sglang.apply_fp4_marlin_linear matches the direct Python call.""" + import sglang.srt.layers.quantization.marlin_utils_fp4 # noqa: F401 + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + N, K, M = 256, 128, 16 + layer = self._make_fake_fp4_layer(N, K) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) + + common = dict( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + + direct_out = apply_fp4_marlin_linear(**common) + op_out = torch.ops.sglang.apply_fp4_marlin_linear(**common) + + self.assertEqual(op_out.shape, direct_out.shape) + self.assertEqual(op_out.dtype, direct_out.dtype) + torch.testing.assert_close(op_out, direct_out, atol=0, rtol=0) + + def test_fp4_marlin_3d_input(self): + """Verify correct reshape for 3-D input (batch, seq_len, K).""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + N, K = 256, 128 + batch, seq_len = 2, 8 + layer = self._make_fake_fp4_layer(N, K) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + x_3d = torch.randn(batch, seq_len, K, dtype=self.dtype, device=self.device) + x_2d = x_3d.reshape(-1, K) + + common = dict( + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + + out_3d = apply_fp4_marlin_linear(input=x_3d, **common) + out_2d = apply_fp4_marlin_linear(input=x_2d, **common) + + self.assertEqual(out_3d.shape, (batch, seq_len, N)) + self.assertEqual(out_3d.dtype, self.dtype) + torch.testing.assert_close(out_3d.reshape(-1, N), out_2d, atol=0, rtol=0) + + def test_fake_apply_fp4_marlin_linear(self): + """Fake impl for PCG tracing must return the correct shape and dtype.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + fake_apply_fp4_marlin_linear, + ) + + N, K = 256, 128 + + for input_shape in [(16, K), (2, 8, K)]: + with self.subTest(input_shape=input_shape): + x = torch.randn(*input_shape, dtype=self.dtype, device=self.device) + out = fake_apply_fp4_marlin_linear( + input=x, + weight=torch.empty(0, device=self.device), + weight_scale=torch.empty(0, device=self.device), + weight_global_scale=torch.empty(0, device=self.device), + workspace=torch.empty(0, device=self.device), + size_n=N, + size_k=K, + ) + expected_shape = input_shape[:-1] + (N,) + self.assertEqual(out.shape, expected_shape) + self.assertEqual(out.dtype, self.dtype) + + def test_prepare_rejects_bad_weight_shape(self): + """prepare_fp4_layer_for_marlin must raise on mismatched weight shape.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + ) + + N, K = 256, 128 + layer = _FakeLayer() + layer.params_dtype = self.dtype + layer.input_size_per_partition = K + layer.output_size_per_partition = N + + layer.weight = torch.nn.Parameter( + torch.randint( + 0, 256, (N + 1, K // 2), dtype=torch.uint8, device=self.device + ), + requires_grad=False, + ) + layer.weight_scale = torch.nn.Parameter( + torch.ones( + N, + K // _FP4_MARLIN_GROUP_SIZE, + dtype=torch.float8_e4m3fn, + device=self.device, + ), + requires_grad=False, + ) + layer.weight_scale_2_marlin = torch.nn.Parameter( + torch.tensor(1.0, dtype=torch.float32, device=self.device), + requires_grad=False, + ) + + with self.assertRaises(AssertionError): + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + def test_prepare_fp4_layer_permutes_bias(self): + """prepare_fp4_layer_for_marlin must permute layer.bias when present.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + ) + + N, K = 256, 128 + layer = self._make_fake_fp4_layer(N, K) + original_bias = torch.randn(N, dtype=self.dtype, device=self.device) + layer.bias = torch.nn.Parameter(original_bias.clone(), requires_grad=False) + + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + self.assertEqual(layer.bias.shape, (N,)) + self.assertEqual(layer.bias.dtype, self.dtype) + self.assertFalse( + torch.equal(layer.bias.data, original_bias), + "Bias should be permuted by prepare_fp4_layer_for_marlin", + ) + + def test_fp4_marlin_custom_op_registration(self): + """apply_fp4_marlin_linear must be registered as torch.ops.sglang for PCG.""" + import sglang.srt.layers.quantization.marlin_utils_fp4 # noqa: F401 + + self.assertTrue( + hasattr(torch.ops.sglang, "apply_fp4_marlin_linear"), + "apply_fp4_marlin_linear not registered as a custom op", + ) + + def test_nvfp4_marlin_scale_values_correctness(self): + """Verify scale conversion produces analytically correct values.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + nvfp4_marlin_process_global_scale, + nvfp4_marlin_process_scales, + ) + + # -- global scale: BF16 -- + # fp4_exp=2, target_exp=8 => bias = 2^7 - 2^1 = 126 + # result = 1.0 * 2^(126-7) = 2^119 + gs_bf16 = torch.tensor(1.0, dtype=torch.bfloat16, device=self.device) + result_bf16 = nvfp4_marlin_process_global_scale(gs_bf16) + expected_bf16 = torch.tensor(2.0**119, dtype=torch.bfloat16, device=self.device) + self.assertEqual( + result_bf16.item(), + expected_bf16.item(), + f"BF16 global_scale(1.0): expected 2^119, got {result_bf16.item()}", + ) + self.assertEqual(result_bf16.dtype, torch.bfloat16) + + # -- global scale: FP16 -- + # fp4_exp=2, target_exp=5 => bias = 2^4 - 2^1 = 14 + # result = 1.0 * 2^(14-7) = 128 + gs_fp16 = torch.tensor(1.0, dtype=torch.float16, device=self.device) + result_fp16 = nvfp4_marlin_process_global_scale(gs_fp16) + self.assertEqual( + result_fp16.item(), + 128.0, + f"FP16 global_scale(1.0): expected 128.0, got {result_fp16.item()}", + ) + self.assertEqual(result_fp16.dtype, torch.float16) + + # -- global scale: linearity -- + gs_2 = torch.tensor(2.0, dtype=torch.bfloat16, device=self.device) + result_2 = nvfp4_marlin_process_global_scale(gs_2) + self.assertAlmostEqual( + result_2.item(), + 2.0 * result_bf16.item(), + places=0, + msg="Global scale processing should be linear", + ) + + # -- per-group scales: structural properties -- + N, K_div_group = 64, 16 + raw_scale = torch.ones( + N, K_div_group, dtype=torch.float8_e4m3fn, device=self.device + ).to(self.dtype) + processed = nvfp4_marlin_process_scales(raw_scale) + + self.assertEqual(processed.dtype, torch.float8_e4m3fn) + self.assertEqual(processed.shape, (N, K_div_group)) + self.assertFalse(torch.isnan(processed.to(self.dtype)).any()) + + # Deterministic + self.assertTrue(torch.equal(processed, nvfp4_marlin_process_scales(raw_scale))) + + # Large scales (448 = FP8 E4M3 max) must not produce NaN + large_scale = torch.full( + (N, K_div_group), 448.0, dtype=self.dtype, device=self.device + ) + proc_large = nvfp4_marlin_process_scales(large_scale) + self.assertFalse(torch.isnan(proc_large.to(self.dtype)).any()) + self.assertEqual(proc_large.shape, (N, K_div_group)) + + +# --------------------------------------------------------------------------- +# MoE tests +# --------------------------------------------------------------------------- +class TestNvfp4MarlinMoe(CustomTestCase): + """Test the FP4 Marlin MoE fallback.""" + + def setUp(self): + if not _check_requirements(): + self.skipTest("Requirements not met (CUDA unavailable or SM < 75)") + self.device = torch.device("cuda") + self.dtype = torch.bfloat16 + try: + from sglang.jit_kernel.gptq_marlin_repack import gptq_marlin_repack + + self._gptq_marlin_repack = gptq_marlin_repack + except ImportError: + self.skipTest("gptq_marlin_repack JIT compilation not available") + self._perm = torch.empty(0, dtype=torch.int, device=self.device) + + # -- helpers ------------------------------------------------------------- + + def _repack_fp4_weight(self, raw_fp4, size_k, size_n): + """Repack raw uint8 FP4 weights into Marlin tile layout.""" + qw = raw_fp4.view(torch.int32).T.contiguous() + return self._gptq_marlin_repack(qw, self._perm, size_k, size_n, num_bits=4) + + def _make_marlin_scale(self, size_k, size_n): + from sglang.srt.layers.quantization.marlin_utils import marlin_permute_scales + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + nvfp4_marlin_process_scales, + ) + + raw = torch.ones( + size_k // _FP4_MARLIN_GROUP_SIZE, + size_n, + dtype=self.dtype, + device=self.device, + ) + permuted = marlin_permute_scales(raw, size_k, size_n, _FP4_MARLIN_GROUP_SIZE) + return nvfp4_marlin_process_scales(permuted) + + def _make_processed_global_scale(self): + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + nvfp4_marlin_process_global_scale, + ) + + return nvfp4_marlin_process_global_scale( + torch.tensor(1.0, dtype=self.dtype, device=self.device) + ) + + # -- tests --------------------------------------------------------------- + + def test_fused_marlin_moe_fp4(self): + """Smoke test: shape, dtype, no NaN for multi-expert MoE.""" + from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import ( + fused_marlin_moe, + ) + + E, K, N, topk, M = 4, 128, 64, 2, 8 + + def _rand_weight(size_k, size_n): + raw = torch.randint( + 0, 256, (size_n, size_k // 2), dtype=torch.uint8, device=self.device + ) + return self._repack_fp4_weight(raw, size_k, size_n) + + w1 = torch.stack([_rand_weight(K, 2 * N) for _ in range(E)]) + w2 = torch.stack([_rand_weight(N, K) for _ in range(E)]) + w1_scale = torch.stack([self._make_marlin_scale(K, 2 * N) for _ in range(E)]) + w2_scale = torch.stack([self._make_marlin_scale(N, K) for _ in range(E)]) + + gs = self._make_processed_global_scale() + w1_gs = gs.expand(E) + w2_gs = gs.expand(E) + + hidden = torch.randn(M, K, dtype=self.dtype, device=self.device) + gating = torch.randn(M, E, dtype=self.dtype, device=self.device) + topk_weights, topk_ids = torch.topk(torch.softmax(gating, dim=-1), topk, dim=-1) + + output = fused_marlin_moe( + hidden_states=hidden, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + gating_output=gating, + topk_weights=topk_weights, + topk_ids=topk_ids, + num_bits=4, + w1_global_scale=w1_gs, + w2_global_scale=w2_gs, + ) + + self.assertEqual(output.shape, (M, K)) + self.assertEqual(output.dtype, self.dtype) + self.assertFalse(torch.isnan(output).any(), "Output contains NaN!") + + def test_fused_marlin_moe_fp4_numerical(self): + """E=1, topk=1 MoE output vs dequant reference (SiLU-gated).""" + from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import ( + fused_marlin_moe, + ) + + E, K, N, topk, M = 1, 128, 64, 1, 8 + + raw_w1 = torch.randint( + 0, 256, (2 * N, K // 2), dtype=torch.uint8, device=self.device + ) + raw_w2 = torch.randint( + 0, 256, (K, N // 2), dtype=torch.uint8, device=self.device + ) + dq_w1 = _dequant_fp4_weights(raw_w1, self.device) + dq_w2 = _dequant_fp4_weights(raw_w2, self.device) + + w1 = self._repack_fp4_weight(raw_w1, K, 2 * N).unsqueeze(0) + w2 = self._repack_fp4_weight(raw_w2, N, K).unsqueeze(0) + w1_scale = self._make_marlin_scale(K, 2 * N).unsqueeze(0) + w2_scale = self._make_marlin_scale(N, K).unsqueeze(0) + + gs = self._make_processed_global_scale() + w1_gs = gs.unsqueeze(0) + w2_gs = gs.unsqueeze(0) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) * 0.1 + gating = torch.ones(M, E, dtype=self.dtype, device=self.device) + topk_weights = torch.ones(M, topk, dtype=self.dtype, device=self.device) + topk_ids = torch.zeros(M, topk, dtype=torch.int64, device=self.device) + + output = fused_marlin_moe( + hidden_states=x, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + gating_output=gating, + topk_weights=topk_weights, + topk_ids=topk_ids, + num_bits=4, + w1_global_scale=w1_gs, + w2_global_scale=w2_gs, + ) + + gate_up = x.float() @ dq_w1.T + gate, up = gate_up[:, :N], gate_up[:, N:] + ref_output = ((torch.nn.functional.silu(gate) * up) @ dq_w2.T).to(self.dtype) + + self.assertEqual(output.shape, ref_output.shape) + self.assertFalse(torch.isinf(output).any(), "MoE output contains Inf") + self.assertFalse(torch.isnan(output).any(), "MoE output contains NaN") + + finite = torch.isfinite(ref_output) & torch.isfinite(output) + if finite.any(): + cos_sim = torch.nn.functional.cosine_similarity( + output[finite].float().flatten(), + ref_output[finite].float().flatten(), + dim=0, + ) + self.assertGreater( + cos_sim.item(), + 0.90, + f"MoE cosine similarity {cos_sim.item():.4f} too low", + ) + + def test_prepare_moe_fp4_layer_for_marlin(self): + """Weight repacking produces correct shapes for all expert tensors.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin, + ) + + E, K, N = 4, 128, 64 + + class _FakeMoeRunnerConfig: + is_gated = True + + layer = _FakeLayer() + layer.num_local_experts = E + layer.intermediate_size_per_partition = N + layer.params_dtype = self.dtype + layer.moe_runner_config = _FakeMoeRunnerConfig() + + layer.w13_weight = torch.nn.Parameter( + torch.randint( + 0, 256, (E, 2 * N, K // 2), dtype=torch.uint8, device=self.device + ), + requires_grad=False, + ) + layer.w2_weight = torch.nn.Parameter( + torch.randint( + 0, 256, (E, K, N // 2), dtype=torch.uint8, device=self.device + ), + requires_grad=False, + ) + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + E, + 2 * N, + K // _FP4_MARLIN_GROUP_SIZE, + dtype=torch.float8_e4m3fn, + device=self.device, + ), + requires_grad=False, + ) + layer.w2_weight_scale = torch.nn.Parameter( + torch.ones( + E, + K, + N // _FP4_MARLIN_GROUP_SIZE, + dtype=torch.float8_e4m3fn, + device=self.device, + ), + requires_grad=False, + ) + layer.w13_weight_scale_2 = torch.nn.Parameter( + torch.ones(E, 2, dtype=torch.float32, device=self.device), + requires_grad=False, + ) + layer.w2_weight_scale_2 = torch.nn.Parameter( + torch.ones(E, dtype=torch.float32, device=self.device), + requires_grad=False, + ) + + prepare_moe_fp4_layer_for_marlin(layer) + + self.assertEqual(layer.w13_weight.shape[0], E) + self.assertEqual(layer.w2_weight.shape[0], E) + self.assertEqual(layer.w13_weight_scale_2.shape, (E,)) + self.assertEqual(layer.w2_weight_scale_2.shape, (E,)) + + +# --------------------------------------------------------------------------- +# Support / capability tests +# --------------------------------------------------------------------------- +class TestFp4MarlinSupport(CustomTestCase): + """Test the capability detection functions.""" + + def test_is_fp4_marlin_supported(self): + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + is_fp4_marlin_supported, + ) + + result = is_fp4_marlin_supported() + if torch.cuda.is_available() and torch.version.hip is None: + cap = torch.cuda.get_device_capability() + sm = cap[0] * 10 + cap[1] + expected = sm >= 75 + self.assertEqual(result, expected) + elif torch.version.hip is not None: + self.assertFalse(result, "FP4 Marlin should not be supported on ROCm/HIP") + + def test_min_capability_changed(self): + """get_min_capability() must return 75 (not 100).""" + from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config + + cap = ModelOptFp4Config.get_min_capability() + self.assertEqual(cap, 75, f"Expected 75, got {cap}") + + def test_should_use_fp4_marlin_fallback(self): + """should_use_fp4_marlin_fallback returns True on non-Blackwell SM>=75.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + should_use_fp4_marlin_fallback, + ) + + result = should_use_fp4_marlin_fallback() + self.assertIsInstance(result, bool) + + if torch.cuda.is_available() and torch.version.hip is None: + cap = torch.cuda.get_device_capability() + sm = cap[0] * 10 + cap[1] + is_blackwell = sm >= 100 + if is_blackwell: + self.assertFalse( + result, + "Blackwell GPUs should NOT use Marlin fallback (native FP4)", + ) + elif sm >= 75: + self.assertTrue( + result, + f"SM{sm} should use Marlin fallback, but got False", + ) + else: + self.assertFalse( + result, + f"SM{sm} should not support FP4 Marlin at all", + ) + + +if __name__ == "__main__": + unittest.main(verbosity=3) diff --git a/test/registered/quant/test_quantization.py b/test/registered/quant/test_quantization.py index cdb1f0970ef4..cdf0b0e619bb 100644 --- a/test/registered/quant/test_quantization.py +++ b/test/registered/quant/test_quantization.py @@ -19,9 +19,12 @@ register_cuda_ci(est_time=370, suite="stage-b-test-1-gpu-large") MODEL_SCORE_THRESHOLDS = { - "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": 0.825, - "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4": 0.825, - "hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4": 0.615, + # Baselines observed with gsm8k 5-shot concatenated format via chat API, + # which scores lower than reported benchmarks using proper CoT format. + # Thresholds set 5% below observed to catch catastrophic regressions. + "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": 0.74, # observed: 0.781 + "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4": 0.74, # observed: 0.785 + "hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4": 0.36, # observed: 0.380 } @@ -93,7 +96,7 @@ def setUpClass(cls): ] cls.base_url = DEFAULT_URL_FOR_TEST - def test_mgsm_en_all_models(self): + def test_gsm8k_all_models(self): warnings.filterwarnings( "ignore", category=ResourceWarning, message="unclosed.*socket" ) @@ -110,7 +113,7 @@ def test_mgsm_en_all_models(self): args = SimpleNamespace( base_url=self.base_url, model=model, - eval_name="mgsm_en", + eval_name="gsm8k", num_examples=None, num_threads=1024, ) diff --git a/test/registered/sampling/test_fused_temperature_softmax.py b/test/registered/sampling/test_fused_temperature_softmax.py new file mode 100644 index 000000000000..956d67c2cefe --- /dev/null +++ b/test/registered/sampling/test_fused_temperature_softmax.py @@ -0,0 +1,268 @@ +"""Correctness tests for fused_temperature_softmax Triton kernel.""" + +import unittest + +import torch +from flashinfer.sampling import softmax as flashinfer_softmax + +from sglang.srt.layers.fused_sampling import ( + fused_temperature_softmax, + fused_temperature_softmax_inplace, +) +from sglang.srt.utils import get_device +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci(est_time=15, suite="stage-b-test-small-1-gpu") + + +def reference_temperature_softmax(logits, temperatures): + """Reference implementation: div + softmax (separate kernels).""" + logits = logits.clone() + logits.div_(temperatures) + return torch.softmax(logits, dim=-1).float() + + +class TestFusedTemperatureSoftmax(CustomTestCase): + @classmethod + def setUpClass(cls): + torch.set_default_device(get_device()) + torch.manual_seed(42) + + def _check_close(self, fused, ref, atol=1e-5, rtol=1e-5): + """Assert outputs are close and both are valid probability distributions.""" + self.assertEqual(fused.shape, ref.shape) + # Valid probabilities: non-negative, sum to ~1 + self.assertTrue((fused >= 0).all(), f"Negative probabilities in fused output") + row_sums = fused.sum(dim=-1) + torch.testing.assert_close( + row_sums, + torch.ones_like(row_sums), + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close(fused, ref, atol=atol, rtol=rtol) + + # --- out-of-place kernel --- + + def test_basic(self): + logits = torch.randn(4, 1024, dtype=torch.bfloat16) + temps = torch.tensor([0.7, 1.0, 1.5, 2.0], dtype=torch.float32).view(-1, 1) + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_large_vocab(self): + logits = torch.randn(8, 128256, dtype=torch.bfloat16) + temps = torch.full((8, 1), 0.6, dtype=torch.float32) + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_batch_sizes(self): + for bs in [1, 2, 16, 64, 128, 512]: + logits = torch.randn(bs, 32000, dtype=torch.bfloat16) + temps = torch.rand(bs, 1, dtype=torch.float32) * 1.5 + 0.1 + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_temperature_one(self): + """Temperature=1.0 should be equivalent to plain softmax.""" + logits = torch.randn(16, 32000, dtype=torch.bfloat16) + temps = torch.ones(16, 1, dtype=torch.float32) + ref = torch.softmax(logits.float(), dim=-1) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_very_low_temperature(self): + """Very low temperature should produce near-one-hot distribution.""" + logits = torch.randn(4, 1024, dtype=torch.bfloat16) + temps = torch.full((4, 1), 0.01, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + # Max probability should be very close to 1.0 + max_probs = fused.max(dim=-1).values + self.assertTrue((max_probs > 0.99).all()) + + def test_very_high_temperature(self): + """Very high temperature should produce near-uniform distribution.""" + logits = torch.randn(4, 1024, dtype=torch.bfloat16) + temps = torch.full((4, 1), 100.0, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + uniform = 1.0 / 1024 + self.assertTrue( + (fused - uniform).abs().max() < 0.01, + "High temperature should produce near-uniform distribution", + ) + + def test_fp16_input(self): + logits = torch.randn(8, 32000, dtype=torch.float16) + temps = torch.rand(8, 1, dtype=torch.float32) * 1.5 + 0.1 + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-3, rtol=1e-2) + + def test_fp32_input(self): + logits = torch.randn(8, 32000, dtype=torch.float32) + temps = torch.rand(8, 1, dtype=torch.float32) + 0.5 + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-5, rtol=1e-5) + + def test_mixed_temperatures(self): + """Each row has a different temperature.""" + logits = torch.randn(8, 32000, dtype=torch.bfloat16) + temps = torch.tensor( + [0.1, 0.5, 0.7, 1.0, 1.2, 1.5, 2.0, 5.0], dtype=torch.float32 + ).view(-1, 1) + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_empty_batch(self): + logits = torch.randn(0, 32000, dtype=torch.bfloat16) + temps = torch.ones(0, 1, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + self.assertEqual(fused.shape, (0, 32000)) + + # --- in-place kernel --- + + def test_inplace_basic(self): + logits = torch.randn(8, 32000, dtype=torch.float32) + temps = torch.rand(8, 1, dtype=torch.float32) * 1.5 + 0.1 + ref = reference_temperature_softmax(logits, temps) + fused_temperature_softmax_inplace(logits, temps) + # In-place writes back to logits in the original dtype + self._check_close(logits.float(), ref, atol=1e-5, rtol=1e-5) + + def test_inplace_bf16(self): + logits = torch.randn(8, 32000, dtype=torch.bfloat16) + temps = torch.rand(8, 1, dtype=torch.float32) + 0.5 + ref = reference_temperature_softmax(logits, temps) + fused_temperature_softmax_inplace(logits, temps) + self._check_close(logits.float(), ref, atol=2e-3, rtol=2e-3) + + def test_inplace_large_vocab(self): + logits = torch.randn(4, 128256, dtype=torch.bfloat16) + temps = torch.full((4, 1), 0.8, dtype=torch.float32) + ref = reference_temperature_softmax(logits, temps) + fused_temperature_softmax_inplace(logits, temps) + self._check_close(logits.float(), ref, atol=2e-3, rtol=2e-3) + + # --- exact known-value correctness --- + + def test_known_uniform_logits(self): + """Identical logits must produce uniform distribution regardless of temperature.""" + logits = torch.zeros(2, 5, dtype=torch.float32) + temps = torch.tensor([0.5, 2.0], dtype=torch.float32).view(-1, 1) + fused = fused_temperature_softmax(logits, temps) + expected = torch.full((2, 5), 0.2, dtype=torch.float32, device=fused.device) + torch.testing.assert_close(fused, expected, atol=1e-6, rtol=1e-6) + + def test_known_softmax_values(self): + """Verify against hand-computed softmax(logits / T).""" + logits = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32) + temps = torch.tensor([[1.0]], dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + # softmax([1,2,3]) = exp([1,2,3]) / sum(exp([1,2,3])) + e = torch.exp(logits) + expected = (e / e.sum(dim=-1, keepdim=True)).to(fused.device) + torch.testing.assert_close(fused, expected, atol=1e-6, rtol=1e-6) + + def test_known_softmax_with_temperature(self): + """Verify softmax([1,2,3] / 0.5) against hand computation.""" + logits = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32) + temps = torch.tensor([[0.5]], dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + scaled = logits / 0.5 + e = torch.exp(scaled) + expected = (e / e.sum(dim=-1, keepdim=True)).to(fused.device) + torch.testing.assert_close(fused, expected, atol=1e-6, rtol=1e-6) + + # --- argmax preservation --- + + def test_argmax_preserved(self): + """argmax must be invariant to temperature for finite T > 0.""" + logits = torch.randn(64, 32000, dtype=torch.bfloat16) + original_argmax = logits.float().argmax(dim=-1) + for t_val in [0.1, 0.5, 1.0, 2.0, 10.0]: + temps = torch.full((64, 1), t_val, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + fused_argmax = fused.argmax(dim=-1) + self.assertTrue( + (original_argmax == fused_argmax).all(), + f"argmax changed at temperature={t_val}", + ) + + # --- numerical stability --- + + def test_large_logits_no_nan(self): + """Extreme logit magnitudes must not produce NaN or Inf.""" + logits = torch.tensor( + [[1e6, -1e6, 0.0], [1e4, 1e4 + 1, 1e4 - 1]], dtype=torch.float32 + ) + temps = torch.tensor([[1.0], [0.01]], dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + self.assertFalse(torch.isnan(fused).any(), "NaN in output") + self.assertFalse(torch.isinf(fused).any(), "Inf in output") + row_sums = fused.sum(dim=-1) + torch.testing.assert_close( + row_sums, + torch.ones_like(row_sums), + atol=1e-4, + rtol=1e-4, + ) + + def test_large_logits_inplace_no_nan(self): + """In-place variant: extreme logits must not produce NaN or Inf.""" + logits = torch.tensor( + [[1e6, -1e6, 0.0], [1e4, 1e4 + 1, 1e4 - 1]], dtype=torch.float32 + ) + temps = torch.tensor([[1.0], [0.01]], dtype=torch.float32) + fused_temperature_softmax_inplace(logits, temps) + self.assertFalse(torch.isnan(logits).any(), "NaN in output") + self.assertFalse(torch.isinf(logits).any(), "Inf in output") + + # --- comparison with flashinfer.sampling.softmax --- + + def test_vs_flashinfer_basic(self): + logits = torch.randn(4, 1024, dtype=torch.bfloat16) + temps = torch.tensor([0.7, 1.0, 1.5, 2.0], dtype=torch.float32).view(-1, 1) + fused = fused_temperature_softmax(logits, temps) + fi = flashinfer_softmax(logits, temperature=temps.view(-1)) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + def test_vs_flashinfer_large_vocab(self): + logits = torch.randn(8, 128256, dtype=torch.bfloat16) + temps = torch.full((8, 1), 0.6, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + fi = flashinfer_softmax(logits, temperature=temps.view(-1)) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + def test_vs_flashinfer_batch_sizes(self): + for bs in [1, 16, 64, 128, 512]: + logits = torch.randn(bs, 32000, dtype=torch.bfloat16) + temps = torch.rand(bs, 1, dtype=torch.float32) * 1.5 + 0.1 + fused = fused_temperature_softmax(logits, temps) + fi = flashinfer_softmax(logits, temperature=temps.view(-1)) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + def test_vs_flashinfer_scalar_temperature(self): + logits = torch.randn(16, 32000, dtype=torch.bfloat16) + temps_2d = torch.full((16, 1), 0.8, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps_2d) + fi = flashinfer_softmax(logits, temperature=0.8) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + def test_vs_flashinfer_mixed_temperatures(self): + logits = torch.randn(8, 32000, dtype=torch.bfloat16) + temps = torch.tensor( + [0.1, 0.5, 0.7, 1.0, 1.2, 1.5, 2.0, 5.0], dtype=torch.float32 + ).view(-1, 1) + fused = fused_temperature_softmax(logits, temps) + fi = flashinfer_softmax(logits, temperature=temps.view(-1)) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/scheduler/test_prefill_delayer.py b/test/registered/scheduler/test_prefill_delayer.py index 66ea497bf3c0..493346fda930 100644 --- a/test/registered/scheduler/test_prefill_delayer.py +++ b/test/registered/scheduler/test_prefill_delayer.py @@ -428,10 +428,10 @@ async def send_normal_request(dp_rank, req_idx): class TestPrefillDelayerAccuracy(CustomTestCase): - def test_1_mgsm_en_has_prefill_delayer(self): + def test_1_gsm8k_has_prefill_delayer(self): self._run_accuracy_test(prefill_delayer=True) - def test_2_mgsm_en_no_prefill_delayer(self): + def test_2_gsm8k_no_prefill_delayer(self): self._run_accuracy_test(prefill_delayer=False) def _run_accuracy_test(self, prefill_delayer: bool): @@ -454,14 +454,14 @@ def _run_accuracy_test(self, prefill_delayer: bool): args = SimpleNamespace( base_url=base_url, model=model, - eval_name="mgsm_en", + eval_name="gsm8k", num_examples=None, num_threads=1024, ) metrics = run_eval(args) - print(f"=== mgsm_en ({prefill_delayer=}) ===") + print(f"=== gsm8k ({prefill_delayer=}) ===") print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.87) + self.assertGreater(metrics["score"], 0.57) finally: kill_process_tree(process.pid) diff --git a/test/registered/spec/dflash/test_dflash.py b/test/registered/spec/dflash/test_dflash.py new file mode 100644 index 000000000000..aa9ee2327d21 --- /dev/null +++ b/test/registered/spec/dflash/test_dflash.py @@ -0,0 +1,152 @@ +import os +import unittest + +import openai + +from sglang.srt.environ import envs +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.kits.eval_accuracy_kit import GSM8KMixin +from sglang.test.kits.matched_stop_kit import MatchedStopMixin +from sglang.test.kits.radix_cache_server_kit import gen_radix_tree +from sglang.test.test_utils import ( + DEFAULT_DRAFT_MODEL_DFLASH, + DEFAULT_TARGET_MODEL_DFLASH, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_cuda_ci(est_time=300, suite="stage-b-test-1-gpu-small") + + +class TestDFlashServerBase(CustomTestCase, MatchedStopMixin, GSM8KMixin): + max_running_requests = 64 + attention_backend = "flashinfer" + page_size = 1 + other_launch_args = [] + model = DEFAULT_TARGET_MODEL_DFLASH + draft_model = DEFAULT_DRAFT_MODEL_DFLASH + gsm8k_accuracy_thres = 0.75 + gsm8k_accept_length_thres = 2.8 + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + launch_args = [ + "--trust-remote-code", + "--attention-backend", + cls.attention_backend, + "--speculative-algorithm", + "DFLASH", + "--speculative-draft-model-path", + cls.draft_model, + "--page-size", + str(cls.page_size), + "--max-running-requests", + str(cls.max_running_requests), + "--cuda-graph-bs", + *[str(i) for i in range(1, cls.max_running_requests + 1)], + ] + launch_args.extend(cls.other_launch_args) + old_value = os.environ.get("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" + try: + with envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override( + 1 + ), envs.SGLANG_SPEC_NAN_DETECTION.override( + True + ), envs.SGLANG_SPEC_OOB_DETECTION.override( + True + ): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=launch_args, + ) + finally: + if old_value is None: + del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] + else: + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = old_value + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_early_stop(self): + client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") + for i in range(8): + max_tokens = (i % 3) + 1 + response = client.completions.create( + model=self.model, + prompt=f"There are {i} apples on the table. How to divide them equally?", + max_tokens=max_tokens, + temperature=0, + ) + text = response.choices[0].text + print(f"early_stop: max_tokens={max_tokens}, text={text!r}") + assert self.process.poll() is None + + def test_eos_handling(self): + client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") + response = client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": "Today is a sunny day and I like"}], + max_tokens=256, + temperature=0.1, + ) + text = response.choices[0].message.content + print(f"eos_handling: text={text!r}") + self.assertNotIn("<|eot_id|>", text) + self.assertNotIn("<|end_of_text|>", text) + assert self.process.poll() is None + + def test_greedy_determinism(self): + client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") + prompt = "The capital of France is" + outputs = [] + for _ in range(2): + response = client.completions.create( + model=self.model, + prompt=prompt, + max_tokens=32, + temperature=0, + ) + outputs.append(response.choices[0].text) + print(f"determinism: {outputs=}") + self.assertEqual(outputs[0], outputs[1]) + assert self.process.poll() is None + + +class TestDFlashServerPage256(TestDFlashServerBase): + page_size = 256 + + def test_radix_attention(self): + import requests + + nodes = gen_radix_tree(num_nodes=50) + data = { + "input_ids": [node["input_ids"] for node in nodes], + "sampling_params": [ + {"max_new_tokens": node["decode_len"], "temperature": 0} + for node in nodes + ], + } + res = requests.post(self.base_url + "/generate", json=data) + assert res.status_code == 200 + assert self.process.poll() is None + + +class TestDFlashServerChunkedPrefill(TestDFlashServerBase): + other_launch_args = ["--chunked-prefill-size", "4"] + + +class TestDFlashServerNoCudaGraph(TestDFlashServerBase): + other_launch_args = ["--disable-cuda-graph"] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/spec/test_ngram_speculative_decoding.py b/test/registered/spec/test_ngram_speculative_decoding.py index f80b1e646dea..d8e0c467b6b4 100644 --- a/test/registered/spec/test_ngram_speculative_decoding.py +++ b/test/registered/spec/test_ngram_speculative_decoding.py @@ -111,7 +111,7 @@ def generate_batch(): return outputs def get_accept_length(): - info = requests.get(self.base_url + "/get_server_info").json() + info = requests.get(self.base_url + "/server_info").json() return info["internal_states"][0]["avg_spec_accept_length"] # Phase 1: baseline — no SAM corpus loaded, only trie diff --git a/test/registered/spec/utils/test_ngram_corpus.py b/test/registered/spec/utils/test_ngram_corpus.py new file mode 100644 index 000000000000..e8d9fc026beb --- /dev/null +++ b/test/registered/spec/utils/test_ngram_corpus.py @@ -0,0 +1,578 @@ +import unittest + +import numpy as np + +from sglang.srt.speculative.cpp_ngram.ngram_corpus import NgramCorpus +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci(est_time=8, suite="stage-b-test-1-gpu-small") + + +def _make_corpus(match_type="BFS", **kwargs): + defaults = dict( + max_trie_depth=12, + min_bfs_breadth=1, + max_bfs_breadth=8, + draft_token_num=8, + capacity=100000, + ) + defaults.update(kwargs) + defaults["match_type"] = match_type + return NgramCorpus(**defaults) + + +SEED_SEQUENCES = [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [1, 2, 3, 44, 55, 66, 77, 88, 99, 100], +] + +QUERY_SEQUENCES = [[1, 2, 3], [3, 44], [3, 6, 999]] + +EXPECTED_BFS_IDS = [ + [3, 4, 44, 5, 55, 6, 66, 77], + [44, 55, 66, 77, 88, 99, 100, 0], + [999, 0, 0, 0, 0, 0, 0, 0], +] + +EXPECTED_PROB_IDS = [ + [3, 44, 4, 55, 5, 66, 6, 7], + [44, 55, 66, 77, 88, 99, 100, 0], + [999, 0, 0, 0, 0, 0, 0, 0], +] + +EXPECTED_BFS_MASKS = [ + [ + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [1, 0, 1, 0, 0, 0, 0, 0], + [1, 1, 0, 1, 0, 0, 0, 0], + [1, 0, 1, 0, 1, 0, 0, 0], + [1, 1, 0, 1, 0, 1, 0, 0], + [1, 0, 1, 0, 1, 0, 1, 0], + [1, 0, 1, 0, 1, 0, 1, 1], + ], + [ + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0], + [1, 0, 0, 0, 0, 0, 0, 1], + ], + [ + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [1, 0, 1, 0, 0, 0, 0, 0], + [1, 0, 0, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 1, 0, 0], + [1, 0, 0, 0, 0, 0, 1, 0], + [1, 0, 0, 0, 0, 0, 0, 1], + ], +] + +EXPECTED_PROB_MASKS = [ + [ + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [1, 0, 1, 0, 0, 0, 0, 0], + [1, 1, 0, 1, 0, 0, 0, 0], + [1, 0, 1, 0, 1, 0, 0, 0], + [1, 1, 0, 1, 0, 1, 0, 0], + [1, 0, 1, 0, 1, 0, 1, 0], + [1, 0, 1, 0, 1, 0, 1, 1], + ], + [ + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0], + [1, 0, 0, 0, 0, 0, 0, 1], + ], + [ + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [1, 0, 1, 0, 0, 0, 0, 0], + [1, 0, 0, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 1, 0, 0], + [1, 0, 0, 0, 0, 0, 1, 0], + [1, 0, 0, 0, 0, 0, 0, 1], + ], +] + + +class TestNgramCorpusBFS(CustomTestCase): + """Golden-output tests for BFS matching mode.""" + + @classmethod + def setUpClass(cls): + cls.corpus = _make_corpus("BFS") + cls.corpus.batch_put(SEED_SEQUENCES) + cls.corpus.synchronize() + ids, masks = cls.corpus.batch_get(QUERY_SEQUENCES) + draft = 8 + cls.ids = ids.reshape(-1, draft) + cls.masks = masks.reshape(-1, draft, draft) + + def test_token_ids(self): + np.testing.assert_array_equal(self.ids.tolist(), EXPECTED_BFS_IDS) + + def test_masks(self): + np.testing.assert_array_equal(self.masks.tolist(), EXPECTED_BFS_MASKS) + + def test_output_shapes(self): + n_queries = len(QUERY_SEQUENCES) + draft = 8 + self.assertEqual(self.ids.shape, (n_queries, draft)) + self.assertEqual(self.masks.shape, (n_queries, draft, draft)) + + +class TestNgramCorpusProb(CustomTestCase): + """Golden-output tests for Prob matching mode.""" + + @classmethod + def setUpClass(cls): + cls.corpus = _make_corpus("PROB") + cls.corpus.batch_put(SEED_SEQUENCES) + cls.corpus.synchronize() + ids, masks = cls.corpus.batch_get(QUERY_SEQUENCES) + cls.ids = ids.reshape(-1, 8) + cls.masks = masks.reshape(-1, 8, 8) + + def test_token_ids(self): + np.testing.assert_array_equal(self.ids.tolist(), EXPECTED_PROB_IDS) + + def test_masks(self): + np.testing.assert_array_equal(self.masks.tolist(), EXPECTED_PROB_MASKS) + + def test_output_shapes(self): + n_queries = len(QUERY_SEQUENCES) + self.assertEqual(self.ids.shape, (n_queries, 8)) + self.assertEqual(self.masks.shape, (n_queries, 8, 8)) + + +class TestNgramCorpusReset(CustomTestCase): + """Verify reset clears all cached state.""" + + def test_reset_produces_empty_results(self): + corpus = _make_corpus("BFS") + corpus.batch_put(SEED_SEQUENCES) + corpus.synchronize() + + ids_before, _ = corpus.batch_get([[1, 2, 3]]) + self.assertTrue( + any(t != 0 for t in ids_before.tolist()[1:]), + "Expected non-trivial draft tokens before reset", + ) + + corpus.reset() + + ids_after, _ = corpus.batch_get([[1, 2, 3]]) + self.assertEqual( + ids_after.tolist(), + [3, 0, 0, 0, 0, 0, 0, 0], + "After reset, only last_token should be present (rest zero-padded)", + ) + + +class TestNgramCorpusNoMatch(CustomTestCase): + """Verify behavior when query has no match in the corpus.""" + + def test_unmatched_query(self): + corpus = _make_corpus("BFS") + corpus.batch_put([[10, 20, 30, 40, 50]]) + corpus.synchronize() + + ids, masks = corpus.batch_get([[999, 888, 777]]) + ids_list = ids.tolist() + self.assertEqual(ids_list[0], 777, "First token should be last context token") + self.assertTrue( + all(t == 0 for t in ids_list[1:]), + "No draft tokens expected when nothing matches", + ) + + def test_empty_corpus(self): + corpus = _make_corpus("BFS") + ids, masks = corpus.batch_get([[1, 2, 3]]) + ids_list = ids.tolist() + self.assertEqual(ids_list[0], 3) + self.assertTrue(all(t == 0 for t in ids_list[1:])) + + +class TestNgramCorpusMultipleInserts(CustomTestCase): + """Verify that multiple inserts accumulate correctly.""" + + def test_incremental_inserts(self): + corpus = _make_corpus("BFS") + corpus.batch_put([[1, 2, 3, 4, 5]]) + corpus.synchronize() + + corpus.batch_put([[1, 2, 3, 44, 55]]) + corpus.synchronize() + + ids, _ = corpus.batch_get([[1, 2, 3]]) + ids_list = ids.tolist() + + self.assertIn(4, ids_list, "Token 4 from first insert should still match") + self.assertIn(44, ids_list, "Token 44 from second insert should also match") + + +class TestNgramCorpusSqueeze(CustomTestCase): + """Verify cache eviction under memory pressure.""" + + def test_small_capacity_does_not_crash(self): + corpus = _make_corpus("BFS", capacity=200) + long_seq = list(range(1, 101)) + corpus.batch_put([long_seq]) + corpus.synchronize() + + ids, masks = corpus.batch_get([[50, 51, 52]]) + self.assertEqual(len(ids), 8, "Should still produce draft_token_num outputs") + + def test_eviction_preserves_recent(self): + corpus = _make_corpus("BFS", capacity=500, max_trie_depth=6) + + old_seq = list(range(1000, 1050)) + corpus.batch_put([old_seq]) + corpus.synchronize() + + recent_seq = list(range(2000, 2050)) + corpus.batch_put([recent_seq]) + corpus.synchronize() + + ids, _ = corpus.batch_get([[2000, 2001, 2002]]) + ids_list = ids.tolist() + self.assertEqual(ids_list[0], 2002, "Last context token should be first") + self.assertIn(2003, ids_list, "Recent sequence should still be matchable") + + +class TestNgramCorpusLeafPaths(CustomTestCase): + """Verify the leaf_paths_from_mask utility.""" + + def test_simple_tree(self): + corpus = _make_corpus("BFS") + tokens = [3, 4, 44, 5, 55] + mask = [ + [1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 0, 1, 0, 0], + [1, 1, 0, 1, 0], + [1, 0, 1, 0, 1], + ] + paths = corpus.leaf_paths_from_mask(tokens, mask) + + for path in paths: + self.assertIn(3, path, "Root token should be in every path") + + self.assertEqual(len(paths), 2, "Two leaf paths expected for a binary tree") + + def test_single_chain(self): + corpus = _make_corpus("BFS") + tokens = [10, 20, 30] + mask = [ + [1, 0, 0], + [1, 1, 0], + [1, 1, 1], + ] + paths = corpus.leaf_paths_from_mask(tokens, mask) + self.assertEqual(len(paths), 1) + self.assertEqual(paths[0], [10, 20, 30]) + + +class TestNgramCorpusBatchConsistency(CustomTestCase): + """Verify batch queries produce same results as individual queries.""" + + def test_batch_vs_individual(self): + corpus = _make_corpus("BFS") + corpus.batch_put(SEED_SEQUENCES) + corpus.synchronize() + + batch_ids, batch_masks = corpus.batch_get(QUERY_SEQUENCES) + draft = 8 + batch_ids = batch_ids.reshape(-1, draft) + batch_masks = batch_masks.reshape(-1, draft, draft) + + for i, query in enumerate(QUERY_SEQUENCES): + single_ids, single_masks = corpus.batch_get([query]) + single_ids = single_ids.reshape(-1, draft) + single_masks = single_masks.reshape(-1, draft, draft) + + np.testing.assert_array_equal( + batch_ids[i], + single_ids[0], + err_msg=f"Token mismatch for query {i}", + ) + np.testing.assert_array_equal( + batch_masks[i], + single_masks[0], + err_msg=f"Mask mismatch for query {i}", + ) + + +class TestMaskValidity(CustomTestCase): + """Verify structural invariants of the output mask for any draft tree.""" + + def _check_mask(self, masks_2d): + n = len(masks_2d) + for i in range(n): + self.assertEqual(masks_2d[i][i], 1, f"Diagonal must be 1 at row {i}") + self.assertEqual(masks_2d[0], [1] + [0] * (n - 1)) + + def test_bfs_mask_invariants(self): + corpus = _make_corpus("BFS") + corpus.batch_put(SEED_SEQUENCES) + corpus.synchronize() + _, masks = corpus.batch_get(QUERY_SEQUENCES) + masks = masks.reshape(-1, 8, 8) + for i in range(masks.shape[0]): + self._check_mask(masks[i].tolist()) + + def test_prob_mask_invariants(self): + corpus = _make_corpus("PROB") + corpus.batch_put(SEED_SEQUENCES) + corpus.synchronize() + _, masks = corpus.batch_get(QUERY_SEQUENCES) + masks = masks.reshape(-1, 8, 8) + for i in range(masks.shape[0]): + self._check_mask(masks[i].tolist()) + + +class TestFrequencyBoosting(CustomTestCase): + """Verify that repeated insertions change Prob-mode selection.""" + + def test_repeated_insert_promotes_token(self): + corpus = _make_corpus( + "PROB", + draft_token_num=2, + max_bfs_breadth=1, + min_bfs_breadth=1, + max_trie_depth=5, + ) + corpus.batch_put([[1, 2, 3, 10, 11]]) + corpus.synchronize() + + for _ in range(10): + corpus.batch_put([[1, 2, 3, 20, 21]]) + corpus.synchronize() + + ids, _ = corpus.batch_get([[1, 2, 3]]) + ids_list = ids.tolist() + + self.assertEqual( + ids_list[1], + 20, + f"Token 20 should be selected over 10 after frequency boost, got {ids_list}", + ) + + +class TestRecencyOrdering(CustomTestCase): + """Verify that BFS mode respects LRU recency.""" + + def test_most_recent_insert_selected(self): + corpus = _make_corpus( + "BFS", + draft_token_num=2, + max_bfs_breadth=1, + min_bfs_breadth=1, + max_trie_depth=5, + ) + corpus.batch_put([[1, 2, 3, 10, 11]]) + corpus.synchronize() + corpus.batch_put([[1, 2, 3, 20, 21]]) + corpus.synchronize() + + ids, _ = corpus.batch_get([[1, 2, 3]]) + ids_list = ids.tolist() + self.assertEqual( + ids_list[1], + 20, + f"Token 20 (recent) should be selected over 10 (old), got {ids_list}", + ) + + +class TestOverlappingSuffixes(CustomTestCase): + """Verify correct matching when sequences share suffixes.""" + + def test_shared_suffix_both_match(self): + corpus = _make_corpus("BFS") + corpus.batch_put([[100, 200, 7, 8, 9, 50, 51]]) + corpus.batch_put([[300, 400, 7, 8, 9, 60, 61]]) + corpus.synchronize() + + ids, _ = corpus.batch_get([[7, 8, 9]]) + ids_list = ids.tolist() + self.assertIn(50, ids_list, "Continuation from first sequence missing") + self.assertIn(60, ids_list, "Continuation from second sequence missing") + + +class TestSingleTokenContext(CustomTestCase): + """Verify behavior with minimum-length context.""" + + def test_single_token_query(self): + corpus = _make_corpus("BFS") + corpus.batch_put([[5, 10, 20, 30]]) + corpus.synchronize() + + ids, masks = corpus.batch_get([[5]]) + ids_list = ids.tolist() + self.assertEqual(ids_list[0], 5, "First token should be last context token") + self.assertIn(10, ids_list, "Should match continuation after single token 5") + + +class TestLongContext(CustomTestCase): + """Verify behavior when query context exceeds max_trie_depth.""" + + def test_context_longer_than_max_trie_depth(self): + corpus = _make_corpus("BFS", max_trie_depth=6) + seq = list(range(1, 20)) + corpus.batch_put([seq]) + corpus.synchronize() + + long_query = list(range(1, 16)) + ids, masks = corpus.batch_get([long_query]) + ids_list = ids.tolist() + self.assertEqual(ids_list[0], 15, "First token should be last context token") + self.assertIn(16, ids_list, "Should match via suffix despite long context") + + def test_matches_longest_stored_suffix(self): + corpus = _make_corpus("BFS", max_trie_depth=6, draft_token_num=4) + corpus.batch_put([[1, 2, 3, 4, 5, 6, 7]]) + corpus.batch_put([[99, 3, 4, 5, 6, 8]]) + corpus.synchronize() + + ids, _ = corpus.batch_get([[2, 3, 4, 5, 6]]) + ids_list = ids.tolist() + self.assertIn( + 7, ids_list, "Longest stored suffix should contribute a continuation" + ) + self.assertIn( + 8, + ids_list, + "Shorter matching suffixes should still contribute continuations", + ) + + +class TestDraftBudgetSaturation(CustomTestCase): + """Verify the draft tree uses exactly draft_token_num slots.""" + + def test_full_budget_used(self): + corpus = _make_corpus("BFS", draft_token_num=8) + seq = list(range(1, 30)) + corpus.batch_put([seq]) + corpus.synchronize() + + ids, _ = corpus.batch_get([[1, 2, 3]]) + ids_list = ids.tolist() + self.assertEqual(len(ids_list), 8) + non_zero = [t for t in ids_list[1:] if t != 0] + self.assertGreater( + len(non_zero), + 0, + "Draft budget should have non-zero tokens when cache has long chains", + ) + + +class TestTruncate(CustomTestCase): + """Verify truncation logic on batch_get output.""" + + def test_truncate_reduces_output(self): + corpus = _make_corpus("BFS", draft_token_num=8) + corpus.batch_put(SEED_SEQUENCES) + corpus.synchronize() + + ids, masks = corpus.batch_get([[1, 2, 3]]) + ids = ids.reshape(8) + self.assertEqual(len(ids), 8) + + # Simulate truncate to 4 + trunc_n = 4 + trunc_ids = ids[:trunc_n] + self.assertEqual(len(trunc_ids), trunc_n) + + def test_truncate_preserves_mask_structure(self): + corpus = _make_corpus("BFS", draft_token_num=8) + corpus.batch_put(SEED_SEQUENCES) + corpus.synchronize() + + ids, masks = corpus.batch_get([[1, 2, 3]]) + n = 8 + full_mask = masks.reshape(n, n) + + trunc_n = 4 + trunc_mask = full_mask[:trunc_n, :trunc_n] + + for i in range(trunc_n): + for j in range(trunc_n): + self.assertEqual( + trunc_mask[i, j], + full_mask[i, j], + f"Mask mismatch at ({i},{j})", + ) + + +class TestResetAndReinsert(CustomTestCase): + """Verify that reset followed by new inserts works correctly.""" + + def test_reset_then_reinsert(self): + corpus = _make_corpus("BFS") + corpus.batch_put([[1, 2, 3, 4, 5]]) + corpus.synchronize() + + corpus.reset() + + corpus.batch_put([[10, 20, 30, 40, 50]]) + corpus.synchronize() + + ids_old, _ = corpus.batch_get([[1, 2, 3]]) + ids_old_list = ids_old.tolist() + self.assertTrue( + all(t == 0 for t in ids_old_list[1:]), + f"Old data should not match after reset+reinsert, got {ids_old_list}", + ) + + ids_new, _ = corpus.batch_get([[10, 20, 30]]) + ids_new_list = ids_new.tolist() + self.assertEqual(ids_new_list[0], 30) + self.assertIn(40, ids_new_list, "New data should match after reset+reinsert") + + +class TestSqueezeEvictsOld(CustomTestCase): + """Verify that squeeze actually evicts old data, not just preserves recent.""" + + def test_old_data_evicted(self): + corpus = _make_corpus("BFS", capacity=150, max_trie_depth=6) + + old_seq = list(range(5000, 5030)) + corpus.batch_put([old_seq]) + corpus.synchronize() + + ids_before, _ = corpus.batch_get([[5000, 5001, 5002]]) + self.assertIn( + 5003, + ids_before.tolist(), + "Old data should match before eviction", + ) + + for i in range(5): + new_seq = list(range(6000 + i * 30, 6000 + i * 30 + 30)) + corpus.batch_put([new_seq]) + corpus.synchronize() + + ids_after, _ = corpus.batch_get([[5000, 5001, 5002]]) + ids_after_list = ids_after.tolist() + self.assertNotIn( + 5003, + ids_after_list, + f"Old data should be evicted after pressure, got {ids_after_list}", + ) + + +if __name__ == "__main__": + unittest.main(verbosity=3) diff --git a/test/registered/unit/model_loader/test_modelopt_loader.py b/test/registered/unit/model_loader/test_modelopt_loader.py index 7f9652c0e5db..9ad6183a0b0a 100644 --- a/test/registered/unit/model_loader/test_modelopt_loader.py +++ b/test/registered/unit/model_loader/test_modelopt_loader.py @@ -646,7 +646,11 @@ def test_mixed_precision_override_does_not_hijack_w4afp8(self): ) def test_mixed_precision_uses_nvfp4_min_capability(self): - self.assertEqual(ModelOptMixedPrecisionConfig.get_min_capability(), 100) + """NVFP4 supports SM75+ (Turing) via Marlin fallback; min_capability must be >= 75.""" + cap = ModelOptMixedPrecisionConfig.get_min_capability() + self.assertGreaterEqual( + cap, 75, f"NVFP4 requires SM75+ (Marlin fallback); got min_capability={cap}" + ) def test_mixed_precision_quant_layer_resolution_after_mapping(self): quant_config = ModelOptMixedPrecisionConfig.from_config( diff --git a/test/srt/cpu/test_flash_attn.py b/test/srt/cpu/test_flash_attn.py index 8b1faa98b5cb..4e1968fa06e7 100644 --- a/test/srt/cpu/test_flash_attn.py +++ b/test/srt/cpu/test_flash_attn.py @@ -1,15 +1,12 @@ import unittest -import sgl_kernel # noqa: F401 import torch import torch.nn.functional as F from utils import parametrize, precision +from sglang.jit_kernel.flash_attention import flash_attn_varlen_func from sglang.test.test_utils import CustomTestCase -flash_attn_varlen_func = torch.ops.sgl_kernel.flash_attn_varlen_func - - torch.manual_seed(1234)