Skip to content

[perf] enable SwapAB for bf16 moe triton kernel#20861

Open
ZelinMa557 wants to merge 2 commits intosgl-project:mainfrom
ZelinMa557:bf16_swap
Open

[perf] enable SwapAB for bf16 moe triton kernel#20861
ZelinMa557 wants to merge 2 commits intosgl-project:mainfrom
ZelinMa557:bf16_swap

Conversation

@ZelinMa557
Copy link
Copy Markdown

Motivation

Inspired by #15712, sglang have enabled SwapAB trick for fp8 moe triton kernel, actually we can also use this trick for bf16 on hopper

Modifications

Modified python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py, enable SwapAB when input type is bf16.

Accuracy Tests

start command:

MODEL=/model_load/Qwen3_5
python -m sglang.launch_server --model-path $MODEL --port 8000 --tp-size 2 --mem-fraction-static 0.8 --context-length 262144 #--reasoning-parser qwen3

result:
img_v3_02vt_038c32e9-b078-4963-b1fe-7e1ee70c03dg

Benchmarking and Profiling

I use qwen3.5-35b-a3b to do benchmark
kernel benchmark:
It shows significant improvement when 256 <= M <= 2048
image

End to end benchmark command:

MODEL=/model_load/Qwen3_5
python3 -m sglang.bench_serving \
  --backend sglang-oai \
  --host 127.0.0.1 --port 8000 \
  --num-prompts 200 \
  --model $MODEL --max-concurrency 16

result before this pr:
img_v3_02vt_61c5e417-2cb4-47b9-8973-8af157518b2g
result after this pr: about 2.5% speedup for input/output token throughput
img_v3_02vt_deb45613-9c11-4ca8-a5f6-41820e94645g

Appendix

I use this command to tune the moe kernel:

python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
    --model /model_load/Qwen3_5 \
    --tp-size 2 \
    --tune

config file for swap ab kernel:

{
   "1": {
       "BLOCK_SIZE_M": 16,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 1,
       "num_warps": 4,
       "num_stages": 5
   },
   "2": {
       "BLOCK_SIZE_M": 16,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 1,
       "num_warps": 4,
       "num_stages": 5
   },
   "4": {
       "BLOCK_SIZE_M": 16,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 1,
       "num_warps": 4,
       "num_stages": 5
   },
   "8": {
       "BLOCK_SIZE_M": 16,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 1,
       "num_warps": 4,
       "num_stages": 3
   },
   "16": {
       "BLOCK_SIZE_M": 16,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 16,
       "num_warps": 4,
       "num_stages": 3
   },
   "24": {
       "BLOCK_SIZE_M": 16,
       "BLOCK_SIZE_N": 128,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 32,
       "num_warps": 8,
       "num_stages": 3
   },
   "32": {
       "BLOCK_SIZE_M": 16,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 1,
       "num_warps": 4,
       "num_stages": 3
   },
   "48": {
       "BLOCK_SIZE_M": 16,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 16,
       "num_warps": 4,
       "num_stages": 3
   },
   "64": {
       "BLOCK_SIZE_M": 16,
       "BLOCK_SIZE_N": 128,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 1,
       "num_warps": 8,
       "num_stages": 3
   },
   "96": {
       "BLOCK_SIZE_M": 16,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 1,
       "num_warps": 4,
       "num_stages": 3
   },
   "128": {
       "BLOCK_SIZE_M": 16,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 64,
       "num_warps": 4,
       "num_stages": 3
   },
   "256": {
       "BLOCK_SIZE_M": 16,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 16,
       "num_warps": 4,
       "num_stages": 3
   },
   "512": {
       "BLOCK_SIZE_M": 32,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 16,
       "num_warps": 4,
       "num_stages": 2
   },
   "1024": {
       "BLOCK_SIZE_M": 32,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 64,
       "num_warps": 4,
       "num_stages": 2
   },
   "1536": {
       "BLOCK_SIZE_M": 64,
       "BLOCK_SIZE_N": 128,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 1,
       "num_warps": 8,
       "num_stages": 3
   },
   "2048": {
       "BLOCK_SIZE_M": 32,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 1,
       "num_warps": 4,
       "num_stages": 2
   },
   "3072": {
       "BLOCK_SIZE_M": 32,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 1,
       "num_warps": 4,
       "num_stages": 2
   },
   "4096": {
       "BLOCK_SIZE_M": 32,
       "BLOCK_SIZE_N": 64,
       "BLOCK_SIZE_K": 64,
       "GROUP_SIZE_M": 16,
       "num_warps": 4,
       "num_stages": 2
   }
}

config file for normal kernel:

{
    "1": {
        "BLOCK_SIZE_M": 16,
        "BLOCK_SIZE_N": 64,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 16,
        "num_warps": 4,
        "num_stages": 4
    },
    "2": {
        "BLOCK_SIZE_M": 16,
        "BLOCK_SIZE_N": 64,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 32,
        "num_warps": 4,
        "num_stages": 4
    },
    "4": {
        "BLOCK_SIZE_M": 16,
        "BLOCK_SIZE_N": 64,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 1,
        "num_warps": 4,
        "num_stages": 5
    },
    "8": {
        "BLOCK_SIZE_M": 16,
        "BLOCK_SIZE_N": 64,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 1,
        "num_warps": 4,
        "num_stages": 3
    },
    "16": {
        "BLOCK_SIZE_M": 16,
        "BLOCK_SIZE_N": 64,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 32,
        "num_warps": 4,
        "num_stages": 5
    },
    "24": {
        "BLOCK_SIZE_M": 16,
        "BLOCK_SIZE_N": 128,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 1,
        "num_warps": 8,
        "num_stages": 3
    },
    "32": {
        "BLOCK_SIZE_M": 16,
        "BLOCK_SIZE_N": 64,
        "BLOCK_SIZE_K": 128,
        "GROUP_SIZE_M": 16,
        "num_warps": 4,
        "num_stages": 2
    },
    "48": {
        "BLOCK_SIZE_M": 16,
        "BLOCK_SIZE_N": 128,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 1,
        "num_warps": 4,
        "num_stages": 3
    },
    "64": {
        "BLOCK_SIZE_M": 16,
        "BLOCK_SIZE_N": 64,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 16,
        "num_warps": 4,
        "num_stages": 4
    },
    "96": {
        "BLOCK_SIZE_M": 16,
        "BLOCK_SIZE_N": 64,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 1,
        "num_warps": 4,
        "num_stages": 3
    },
    "128": {
        "BLOCK_SIZE_M": 16,
        "BLOCK_SIZE_N": 64,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 16,
        "num_warps": 4,
        "num_stages": 3
    },
    "256": {
        "BLOCK_SIZE_M": 16,
        "BLOCK_SIZE_N": 64,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 16,
        "num_warps": 4,
        "num_stages": 4
    },
    "512": {
        "BLOCK_SIZE_M": 16,
        "BLOCK_SIZE_N": 128,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 64,
        "num_warps": 4,
        "num_stages": 4
    },
    "1024": {
        "BLOCK_SIZE_M": 64,
        "BLOCK_SIZE_N": 64,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 64,
        "num_warps": 4,
        "num_stages": 4
    },
    "1536": {
        "BLOCK_SIZE_M": 64,
        "BLOCK_SIZE_N": 128,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 1,
        "num_warps": 8,
        "num_stages": 3
    },
    "2048": {
        "BLOCK_SIZE_M": 64,
        "BLOCK_SIZE_N": 128,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 16,
        "num_warps": 8,
        "num_stages": 3
    },
    "3072": {
        "BLOCK_SIZE_M": 64,
        "BLOCK_SIZE_N": 128,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 1,
        "num_warps": 4,
        "num_stages": 2
    },
    "4096": {
        "BLOCK_SIZE_M": 64,
        "BLOCK_SIZE_N": 128,
        "BLOCK_SIZE_K": 64,
        "GROUP_SIZE_M": 16,
        "num_warps": 4,
        "num_stages": 2
    }
}

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

Signed-off-by: ZelinMa557 <3388706467@qq.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a performance optimization for the Mixture-of-Experts (MoE) Triton kernel by enabling the "SwapAB" trick for bfloat16 data types on Hopper GPUs. This change aims to improve the efficiency of MoE computations, resulting in faster kernel execution for specific matrix sizes and an overall increase in token throughput for the serving backend.

Highlights

  • Performance Optimization: Enabled the SwapAB optimization for bf16 Mixture-of-Experts (MoE) Triton kernels on Hopper architectures.
  • Kernel Benchmark Improvement: Achieved significant performance improvements for kernel benchmarks when the matrix dimension M is between 256 and 2048.
  • End-to-End Speedup: Demonstrated an approximate 2.5% speedup in end-to-end input/output token throughput.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request successfully introduces the logic to enable the SwapAB trick for bf16 tensors, which is a valuable performance optimization. The invoke_fused_moe_kernel function correctly identifies when swap_ab should be enabled for bf16 inputs. However, a potential issue exists in the fused_moe_kernel where the actual matrix swap (tl.trans) is only applied in a very specific execution path, which might prevent the SwapAB optimization from being fully utilized or correctly applied across all relevant bf16 scenarios.

Comment on lines +725 to +726
use_bf16 = A.dtype == torch.bfloat16
if use_fp8_w8a8 or use_bf16:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The introduction of use_bf16 and its inclusion in the if condition correctly extends the swap_ab enablement logic to bfloat16 tensors, aligning with the pull request's goal. However, please refer to the comment on lines 579-580 in fused_moe_kernel regarding the limited application scope of swap_ab. While this change correctly sets the swap_ab flag, the kernel might not always apply the actual matrix swap due to its current placement.

Signed-off-by: ZelinMa557 <3388706467@qq.com>
@ZelinMa557
Copy link
Copy Markdown
Author

@BBuf Hi, sorry for the ping, this pr enabled SwapAB for bf16 moe triton kernel with only a few lines of change. Would you mind taking a quick look when you have a moment? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant