Skip to content

Integrate MixedComm from FlashInfer to fuse communication kernels when attention TP and DP are both used#23713

Open
jinyangyuan-nvidia wants to merge 2 commits into
sgl-project:mainfrom
jinyangyuan-nvidia:dev/mixed_comm
Open

Integrate MixedComm from FlashInfer to fuse communication kernels when attention TP and DP are both used#23713
jinyangyuan-nvidia wants to merge 2 commits into
sgl-project:mainfrom
jinyangyuan-nvidia:dev/mixed_comm

Conversation

@jinyangyuan-nvidia
Copy link
Copy Markdown
Contributor

@jinyangyuan-nvidia jinyangyuan-nvidia commented Apr 25, 2026

Motivation

When attention TP and DP are both used, two kernels are used for allreduce + allgather / reducescatter + allreduce currently (implemented via a reducescatter kernel + an allgather kernel). Using fused kernels for these two communication patterns can improve the performance in the low-latency scenarios.

A small bug is found when testing the following case (a single reducescatter is not sufficient for this case). Therefore, this PR also fixes this bug.

  • attention_dp_size = moe_ep_size > 1
  • attention_tp_size = moe_tp_size > 1

Modifications

python/sglang/srt/layers/dp_attention.py is modified to use MixedComm from FlashInfer when the environment variable SGLANG_USE_MIXED_COMM is set to 1 or true.

python/sglang/srt/layers/communicator.py is modified to prefer reduce_scatter over reduce_scatterv when the numbers of tokens are identical on all ranks.

python/sglang/srt/layers/moe/utils.py is modified to fix the abovementioned bug.

docs/advanced_features/dp_dpa_smg_guide.md is modified to update documentation.

test/registered/distributed/test_dp_attention.py and test/registered/distributed/test_dp_attention_large.py are modified to add unit tests.

Accuracy Tests

The accuracy has been verified by testing GSM8K and GPQA scores of Qwen3.5 on H200. The tested cases are:

  • Number of GPUs: 8
  • MoE EP size: 4
  • Attention DP size: 2, 4, and 8

Scores:

  • Attention DP size = 2
    • GSM8K: 0.980
    • GPQA: 0.865
  • Attention DP size = 4
    • GSM8K: 0.980
    • GPQA: 0.873
  • Attention DP size = 8
    • GSM8K: 0.990
    • GPQA: 0.869

Scripts to reproduce the results:

#!/bin/bash

export SGLANG_USE_MIXED_COMM=1

host=0.0.0.0
port=8000

model_path=../Qwen3.5-397B-A17B-FP8
quantization=fp8
kv_cache_dtype=fp8_e4m3

mem_fraction_static=0.83
max_running_requests=128
prefill_max_requests=1
max_prefill_tokens=$(( 128 * 1024 ))
chunked_prefill_size=$(( 64 * 1024 ))
cuda_graph_bs=(1 2 4 8 12 16 24 32 48 64 96 128)

prefill_attention_backend=fa3
decode_attention_backend=trtllm_mha
linear_attn_prefill_backend=triton
linear_attn_decode_backend=triton
fp8_gemm_backend=deep_gemm
moe_runner_backend=triton

function run_eval() {
    pkill -9 -f "sglang.launch_server" 2>/dev/null || true
    pkill -9 -f "sglang.srt.managers" 2>/dev/null || true
    sleep 10

    setsid python3 -m sglang.launch_server "${command_args[@]}" \
        > >(tee server_gpqa_${mode}.log) 2>&1 &
    server_pid=$!

    while ! curl -s http://${host}:${port}/health > /dev/null 2>&1; do
        sleep 10
    done

    python3 -m sglang.test.run_eval \
        --port 8000 \
        --eval-name gsm8k \
        --num-examples 100 \
        --max-tokens 10240 \
        --repeat 1 \
        --num-threads 100 \
        --num-shots 8 \
        --temperature 0.6 \
        --top-p 0.95 \
        --top-k 20 \
    2>&1 | tee out_gsm8k_${mode}.log

    python3 -m sglang.test.run_eval \
        --port 8000 \
        --eval-name gpqa \
        --num-examples 198 \
        --max-tokens 81920 \
        --repeat 8 \
        --temperature 0.6 \
        --top-p 0.95 \
        --top-k 20 \
    2>&1 | tee out_gpqa_${mode}.log

    if kill -0 ${server_pid} 2>/dev/null; then
        kill -9 -- -${server_pid} 2>/dev/null || kill -9 ${server_pid} 2>/dev/null
    fi
    wait ${server_pid} 2>/dev/null
    sleep 10
    echo "Evaluation completed for ${mode}"
}

num_gpus=8
ep_size=4
for dp_size in 2 4 8; do
    if (( dp_size == 1)); then
        dp_args="--dp-size ${dp_size}"
    else
        dp_args="--dp-size ${dp_size} --enable-dp-attention"
    fi
    a2a_args="--moe-a2a-backend none"

    command_args=(
        --host ${host}
        --port ${port}
        --model-path ${model_path}
        --quantization ${quantization}
        --kv-cache-dtype ${kv_cache_dtype}
        --mem-fraction-static ${mem_fraction_static}
        --max-running-requests ${max_running_requests}
        --prefill-max-requests ${prefill_max_requests}
        --max-prefill-tokens ${max_prefill_tokens}
        --chunked-prefill-size ${chunked_prefill_size}
        --cuda-graph-bs ${cuda_graph_bs[@]}
        --tp-size ${num_gpus}
        --ep-size ${ep_size}
        ${dp_args}
        ${a2a_args}
        --prefill-attention-backend ${prefill_attention_backend}
        --decode-attention-backend ${decode_attention_backend}
        --linear-attn-prefill-backend ${linear_attn_prefill_backend}
        --linear-attn-decode-backend ${linear_attn_decode_backend}
        --fp8-gemm-backend ${fp8_gemm_backend}
        --moe-runner-backend ${moe_runner_backend}
        --mamba-ssm-dtype bfloat16
        --random-seed 0
        --watchdog-timeout 3600
        --skip-server-warmup
        --model-loader-extra-config '{"enable_multithread_load": true}'
        --enable-multimodal
        --disable-radix-cache
        --incremental-streaming-output
        --enable-layerwise-nvtx-marker
    )
    mode=dp${dp_size}
    run_eval
done

Speed Tests and Profiling

The performance of Qwen3.5 is tested on H200 using the following cases:

  • Number of GPUs: 8
  • MoE EP size: 4
  • Attention DP size: 2, 4, and 8
  • Input length: 8192
  • Output length: 24
  • Concurrency: 16

Results of median ITL are obtained from the outputs of sglang.bench_serving. Results of kernel execution time are obtained by analyzing nsys timelines.


Attn DP Size = 2 NCCL Baseline NCCL Symmetric MixedComm Perf Gain w.r.t NCCL Symmetric
Median ITL (ms) 14.889 14.386 12.002 1.20x
Comm Kern Time w/ load imbalance (ms) 3.817 3.631 2.145 1.69x
Comm Kern Time w/o load imbalance (ms) 1.582 1.323 0.399 3.32x


Attn DP Size = 4 NCCL Baseline NCCL Symmetric MixedComm Perf Gain w.r.t NCCL Symmetric
Median ITL (ms) 15.155 14.824 12.924 1.15x
Comm Kern Time w/ load imbalance (ms) 3.595 3.261 2.286 1.43x
Comm Kern Time w/o load imbalance (ms) 1.274 1.267 0.405 3.13x


Attn DP Size = 8 NCCL Baseline NCCL Symmetric MixedComm Perf Gain w.r.t NCCL Symmetric
Median ITL (ms) 15.043 14.262 13.560 1.05x
Comm Kern Time w/ load imbalance (ms) 2.492 1.776 1.435 1.24x
Comm Kern Time w/o load imbalance (ms) 1.168 0.673 0.340 1.98x

The performance is tested using the following commands:

#!/bin/bash

# Comment the following line to disable MixedComm
export SGLANG_USE_MIXED_COMM=1

host=0.0.0.0
port=8000

model_path=../Qwen3.5-397B-A17B-FP8
quantization=fp8
kv_cache_dtype=fp8_e4m3

mem_fraction_static=0.84
max_running_requests=128
prefill_max_requests=1
max_prefill_tokens=$(( 128 * 1024 ))
chunked_prefill_size=$(( 64 * 1024 ))
cuda_graph_bs=(1 2 4 8 12 16 24 32 48 64 96 128)

prefill_attention_backend=fa3
decode_attention_backend=trtllm_mha
linear_attn_prefill_backend=triton
linear_attn_decode_backend=triton
fp8_gemm_backend=deep_gemm
moe_runner_backend=triton

folder_out=timelines
if [[ ! -d ${folder_out} ]]; then mkdir ${folder_out}; fi

num_gpus=8
ep_size=4
for dp_size in 2 4 8; do
    if (( dp_size == 1)); then
        dp_args="--dp-size ${dp_size}"
    else
        dp_args="--dp-size ${dp_size} --enable-dp-attention"
    fi
    a2a_args="--moe-a2a-backend none"
    # Add --enable-symm-mem and --enable-nccl-nvls to use NCCL symmetric kernels
    command_args=(
        --host ${host}
        --port ${port}
        --model-path ${model_path}
        --quantization ${quantization}
        --kv-cache-dtype ${kv_cache_dtype}
        --mem-fraction-static ${mem_fraction_static}
        --max-running-requests ${max_running_requests}
        --prefill-max-requests ${prefill_max_requests}
        --max-prefill-tokens ${max_prefill_tokens}
        --chunked-prefill-size ${chunked_prefill_size}
        --cuda-graph-bs ${cuda_graph_bs[@]}
        --tp-size ${num_gpus}
        --ep-size ${ep_size}
        ${dp_args}
        ${a2a_args}
        --prefill-attention-backend ${prefill_attention_backend}
        --decode-attention-backend ${decode_attention_backend}
        --linear-attn-prefill-backend ${linear_attn_prefill_backend}
        --linear-attn-decode-backend ${linear_attn_decode_backend}
        --fp8-gemm-backend ${fp8_gemm_backend}
        --moe-runner-backend ${moe_runner_backend}
        --mamba-ssm-dtype bfloat16
        --random-seed 0
        --watchdog-timeout 3600
        --skip-server-warmup
        --model-loader-extra-config '{"enable_multithread_load": true}'
        --enable-multimodal
        --disable-radix-cache
        --incremental-streaming-output
        --enable-layerwise-nvtx-marker
    )

    pkill -9 -f "sglang.launch_server" 2>/dev/null || true
    pkill -9 -f "sglang.srt.managers" 2>/dev/null || true
    sleep 10

    setsid nsys launch \
        --session-new profile_session \
        --trace-fork-before-exec=true \
        --cuda-graph-trace=node \
        --trace cuda,nvtx,mpi \
        python3 -m sglang.launch_server "${command_args[@]}" \
        > >(tee server_dp${dp_size}_ep${ep_size}.log) 2>&1 &
    server_pid=$!

    while ! curl -s http://${host}:${port}/health > /dev/null 2>&1; do
        sleep 10
    done

    for input_len_base in 8; do
        input_len=$(( input_len_base * 1024 ))
        for concurrency in 16; do
            output_len=$(( concurrency + 8 ))
            name=${folder_out}/qwen_isl${input_len}_bs${concurrency}_dp${dp_size}_ep${ep_size}
            command_args=(
                --backend sglang
                --dataset-name random
                --host ${host}
                --port ${port}
                --model ${model_path}
                --tokenizer ${model_path}
                --num-prompts ${concurrency}
                --random-input ${input_len}
                --random-output ${output_len}
                --max-concurrency ${concurrency}
                --output-file ${name}.jsonl
                --random-range-ratio 1.0
                --warmup-requests 1
                --seed 0
            )
            nsys start \
                --session profile_session \
                --backtrace none \
                --sample none \
                --cpuctxsw none \
                --force-overwrite true \
                --output ${name} || echo "ERROR: nsys start failed for ${name}"
            python3 -m sglang.bench_serving "${command_args[@]}"
            nsys stop --session profile_session || echo "ERROR: nsys stop failed for ${name}"
        done
    done

    nsys shutdown --session profile_session || true
    # If nsys shutdown failed to terminate the server, kill it explicitly
    if kill -0 ${server_pid} 2>/dev/null; then
        kill -9 -- -${server_pid} 2>/dev/null || kill -9 ${server_pid} 2>/dev/null || true
    fi
    wait ${server_pid} 2>/dev/null || true
    sleep 10
done

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@github-actions github-actions Bot added the documentation Improvements or additions to documentation label Apr 25, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for fused communication using FlashInfer's MixedComm kernels, which optimizes performance by fusing reduce-scatter and all-gather operations when data parallelism is combined with attention tensor parallelism. The changes include the implementation of a MixedCommHandler for initialization, integration of fused kernels into the communication layers, and updated documentation. Additionally, the logic for MoE reduce-scatter operations was refined to ensure correctness when both tensor and data parallelism are active. I have no feedback to provide as the review comments were either explanatory or did not identify actionable issues in the current implementation.

…n attention TP and DP are both used

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
@jinyangyuan-nvidia
Copy link
Copy Markdown
Contributor Author

Hi @ch-wan, it seems that the bot failed to assign this PR to a Merge Oncall. Since this PR is mainly related to attention DP, can you help review what modifications are needed and ping other relevant experts if necessary?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant