Skip to content

[Kernel] Add Sonic MoE integration for Hopper GPUs#31548

Open
clocksmith wants to merge 10 commits intovllm-project:mainfrom
clocksmith:feature/sonic-moe-integration
Open

[Kernel] Add Sonic MoE integration for Hopper GPUs#31548
clocksmith wants to merge 10 commits intovllm-project:mainfrom
clocksmith:feature/sonic-moe-integration

Conversation

@clocksmith
Copy link

@clocksmith clocksmith commented Dec 30, 2025

Purpose

Integrate Sonic MoE for Hopper GPUs. (paper) 🎸 🚀

sonic_moe_h100

Weights are permuted from vLLM's concatenated format to Sonic's interleaved format during model loading (no inference runtime cost).

Addresses #31039

Key Changes

Usage

VLLM_USE_SONIC_MOE=1 vllm serve <model>

Requirements (per Dao AI Labs): Hopper GPU (H100/H200), CUDA 12.9+, Python 3.12+

### Compatible Models

Most MoE models using SwiGLU activation should work out of the box:

- Mixtral: mistralai/Mixtral-8x7B-v0.1
- DeepSeek: deepseek-ai/DeepSeek-V2
- Qwen-MoE: Qwen/Qwen2-57B-A14B, Qwen/Qwen3-MoE
- Others: DBRX, Jamba, OLMoE, PhiMoE, Grok1, GLM4-MoE, GraniteMoE, EXAONE-MoE, Llama4-MoE, and https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models

Not compatible:
- GPT-OSS: biased MoE layers (has_bias=True)
- Nemotron-H: non-gated MoE (is_act_and_mul=False)

### Gating Logic

Sonic MoE auto-disables when:
- Expert parallelism (EP) enabled
- MoE biases present
- FlashInfer CUTLASS MoE active
- Non-Hopper GPU

## Test Plan

pytest tests/kernels/moe/test_sonic_moe.py -v

- Platform detection, weight permutation, shape validation
- E2E comparison vs TritonExperts (48 cases, <1% diff)
- apply_router_weight_on_input parity test
- Skips gracefully on non-Hopper hardware

### Test results

#### Correctness

(venv312) ubuntu@192-222-54-80:~/vllm$ pytest -v -s tests/kernels/moe/test_sonic_moe.py
================================================================================================================================= test session starts ==================================================================================================================================
platform linux -- Python 3.12.12, pytest-9.0.2, pluggy-1.6.0 -- /home/ubuntu/venv312/bin/python3
cachedir: .pytest_cache
rootdir: /home/ubuntu/vllm
configfile: pyproject.toml
plugins: anyio-4.12.1
collecting ... INFO 02-11 19:41:33 [sonic_moe.py:58] Sonic MoE is available
WARNING 02-11 19:41:33 [interface.py:586] Current platform cuda does not have '__test__' attribute.
WARNING 02-11 19:41:33 [interface.py:586] Current platform cuda does not have '__bases__' attribute.
WARNING 02-11 19:41:33 [interface.py:586] Current platform cuda does not have '__test__' attribute.
collected 26 items                                                                                                                                                                                                                                                                     

tests/kernels/moe/test_sonic_moe.py::test_check_sonicmoe_available PASSED
tests/kernels/moe/test_sonic_moe.py::test_is_hopper_gpu PASSED
tests/kernels/moe/test_sonic_moe.py::test_is_sonic_moe_supported PASSED
tests/kernels/moe/test_sonic_moe.py::test_permute_weights_for_sonic PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_experts_init PASSED
tests/kernels/moe/test_sonic_moe.py::test_is_valid_sonic_moe_basic PASSED
tests/kernels/moe/test_sonic_moe.py::test_is_valid_sonic_moe_large_topk PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_forward_unsupported SKIPPED (Sonic MoE is supported on this system)
tests/kernels/moe/test_sonic_moe.py::test_import_from_fused_moe PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype0-8-2-256-2048-512] WARNING 02-11 19:41:34 [fused_moe.py:1086] Using default MoE config. Performance might be sub-optimal! Config file not found at /home/ubuntu/vllm/vllm/model_executor/layers/fused_moe/configs/E=8,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json
PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype0-8-2-512-4096-1024] INFO 02-11 19:41:40 [fused_moe.py:1073] Using configuration from /home/ubuntu/vllm/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json for MoE layer.
PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype0-8-4-256-2048-512] PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype0-8-4-512-4096-1024] PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype0-16-2-256-2048-512] WARNING 02-11 19:41:47 [fused_moe.py:1086] Using default MoE config. Performance might be sub-optimal! Config file not found at /home/ubuntu/vllm/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json
PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype0-16-2-512-4096-1024] WARNING 02-11 19:41:52 [fused_moe.py:1086] Using default MoE config. Performance might be sub-optimal! Config file not found at /home/ubuntu/vllm/vllm/model_executor/layers/fused_moe/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype0-16-4-256-2048-512] PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype0-16-4-512-4096-1024] PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype1-8-2-256-2048-512] PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype1-8-2-512-4096-1024] PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype1-8-4-256-2048-512] PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype1-8-4-512-4096-1024] PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype1-16-2-256-2048-512] PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype1-16-2-512-4096-1024] PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype1-16-4-256-2048-512] PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype1-16-4-512-4096-1024] PASSED
tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_apply_router_weight_on_input PASSED

=================================================================================================================================== warnings summary ===================================================================================================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

../venv312/lib/python3.12/site-packages/torch/jit/_script.py:362: 14 warnings
  /home/ubuntu/venv312/lib/python3.12/site-packages/torch/jit/_script.py:362: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

tests/kernels/moe/test_sonic_moe.py: 48 warnings
  /home/ubuntu/venv312/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: `make_fragment` is deprecated, use `make_rmem_tensor` instead
    res_or_list = opFunc(*args, **kwargs, loc=loc)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
===================================================================================================================== 25 passed, 1 skipped, 64 warnings in 46.44s ======================================================================================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

#### Benchmarks

(venv312) ubuntu@192-222-54-80:~/vllm$ python benchmarks/kernels/benchmark_sonic_moe.py --warmup 20 --iters 100 --output-json bench_stable.json
INFO 02-11 19:37:31 [sonic_moe.py:58] Sonic MoE is available
INFO 02-11 19:37:31 [fused_moe.py:1073] Using configuration from /home/ubuntu/vllm/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json for MoE layer.
INFO 02-11 19:37:41 [fused_moe.py:1073] Using configuration from /home/ubuntu/vllm/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json for MoE layer.
^[[Cdtype,e,topk,m,k,n,triton_us,sonic_us,speedup,mode,rel_err
bf16,8,2,256,512,4096,310.45,709.39,0.438,eager/eager,0.0045
fp16,8,2,256,512,4096,312.93,714.18,0.438,eager/eager,0.0005
bf16,8,4,512,1024,8192,4496.62,713.64,6.301,eager/eager,0.0121
fp16,8,4,512,1024,8192,4433.30,714.53,6.204,eager/eager,0.0014
{
  "device": "NVIDIA H100 80GB HBM3",
  "results": [
    {
      "dtype": "bf16",
      "m": 256,
      "k": 512,
      "n": 4096,
      "e": 8,
      "topk": 2,
      "triton_us": 310.45312881469727,
      "sonic_us": 709.3862152099609,
      "speedup": 0.43763625815989493,
      "mode": "eager/eager",
      "rel_err": 0.00445556640625
    },
    {
      "dtype": "fp16",
      "m": 256,
      "k": 512,
      "n": 4096,
      "e": 8,
      "topk": 2,
      "triton_us": 312.92640686035156,
      "sonic_us": 714.1766357421875,
      "speedup": 0.4381638815937346,
      "mode": "eager/eager",
      "rel_err": 0.0005464553833007812
    },
    {
      "dtype": "bf16",
      "m": 512,
      "k": 1024,
      "n": 8192,
      "e": 8,
      "topk": 4,
      "triton_us": 4496.620178222656,
      "sonic_us": 713.6374664306641,
      "speedup": 6.300986691061772,
      "mode": "eager/eager",
      "rel_err": 0.01214599609375
    },
    {
      "dtype": "fp16",
      "m": 512,
      "k": 1024,
      "n": 8192,
      "e": 8,
      "topk": 4,
      "triton_us": 4433.2989501953125,
      "sonic_us": 714.5299530029297,
      "speedup": 6.20449700052971,
      "mode": "eager/eager",
      "rel_err": 0.0013608932495117188
    }
  ]
}

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates Sonic MoE for Hopper GPUs, which is a great performance enhancement. The implementation is well-structured, following existing patterns in vLLM for MoE layers, and includes necessary components like weight permutation and support-checking utilities. The accompanying tests cover the new logic, including platform detection and weight permutation correctness. I've identified a couple of areas for improvement: one is a performance optimization to cache the created MoE kernel, and the other is a potential logic issue in one of the new tests. Overall, this is a solid contribution.

@clocksmith clocksmith mentioned this pull request Dec 30, 2025
1 task
@clocksmith clocksmith force-pushed the feature/sonic-moe-integration branch from 4022ab9 to cdc35ec Compare December 30, 2025 22:19
@clocksmith clocksmith marked this pull request as ready for review December 30, 2025 23:08
@clocksmith
Copy link
Author

Requested review to run sonicmoe related tests in buildkite/ci/pr CI item

@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented Dec 31, 2025

In this PR, the kernel is not hooked up to any MoeMethod, so it cannot be run by any users other than in the unit test

@clocksmith
Copy link
Author

clocksmith commented Dec 31, 2025

In this PR, the kernel is not hooked up to any MoeMethod, so it cannot be run by any users other than in the unit test

EDIT:

Correct, this is intentional: kernel and unit tests first to get implementation reviewed.

You are right tho, I can add the wiring into UnquantizedFusedMoEMethod (behind VLLM_USE_SONIC_MOE) to this PR now. The blocker is that sonicmoe isn't in CI deps yet, so the wiring can't be e2e tested until that's added.

Happy to add the wiring now, just need sonicmoe added to CI deps before this should merge so we can validate properly.

Can a maintainer add sonicmoe to CI, or should I open a separate PR for that?

UPDATEL: @robertgshaw2-redhat added the wiring, PTAL. Still pending the sonicmoe dep for testing

Copy link
Member

@zyongye zyongye left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution. I think the biggest issue is our current framework can't take into router logits directly and needs to do additional bookkeeping to generate the metadata (and that process is not optimized). I think we can leave that for now and wait for MoE refactor to complete and change this PR accordingly.

The moe refactor issue tracker is here

_up_projection_forward(
x=hidden_states,
w1=w1_sonic,
z=z,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need this since they fuse swiglu with up proj.

Copy link
Author

@clocksmith clocksmith Dec 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is needed: _up_projection_forward appears to perform both the up projection and swiglu in the fused kernel (which is usually more efficient, and core to Sonic MoE on Hopper).

swiglu is not being applied separately; is this correct?

@clocksmith
Copy link
Author

Thank you for your contribution. I think the biggest issue is our current framework can't take into router logits directly and needs to do additional bookkeeping to generate the metadata (and that process is not optimized). I think we can leave that for now and wait for MoE refactor to complete and change this PR accordingly.

The moe refactor issue tracker is here

Thanks! Got it, is it common within vllm to use comments like TODO(#31578): Fix this. in all the additional bookkeeping areas in the current impl?

@zyongye
Copy link
Member

zyongye commented Dec 31, 2025

Thank you for your contribution. I think the biggest issue is our current framework can't take into router logits directly and needs to do additional bookkeeping to generate the metadata (and that process is not optimized). I think we can leave that for now and wait for MoE refactor to complete and change this PR accordingly.
The moe refactor issue tracker is here

Thanks! Got it, is it common within vllm to use comments like TODO(#31578): Fix this. in all the additional bookkeeping areas in the current impl?

Yeah I think it's fine for now

@clocksmith clocksmith force-pushed the feature/sonic-moe-integration branch from 8c5f5af to 1479564 Compare December 31, 2025 19:15
@mergify
Copy link

mergify bot commented Dec 31, 2025

Hi @clocksmith, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@clocksmith clocksmith force-pushed the feature/sonic-moe-integration branch from 1479564 to f4d8a42 Compare December 31, 2025 19:24
@mergify
Copy link

mergify bot commented Dec 31, 2025

Hi @clocksmith, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@clocksmith clocksmith force-pushed the feature/sonic-moe-integration branch 2 times, most recently from cf15b99 to b665cb3 Compare January 1, 2026 02:05
@clocksmith clocksmith changed the title [Kernel] Add Sonic MoE integration for Hopper GPUs with swiglu support [Kernel] Add Sonic MoE integration for Hopper GPUs Jan 1, 2026
@clocksmith
Copy link
Author

This should simplify the testing issue so this PR can be tested in place (kernel), potentially with e2e following: #31606

@mergify
Copy link

mergify bot commented Jan 3, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @clocksmith.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@clocksmith clocksmith force-pushed the feature/sonic-moe-integration branch from 27c689b to 00d9a35 Compare January 3, 2026 16:06
@mergify mergify bot removed the needs-rebase label Jan 3, 2026
@clocksmith
Copy link
Author

clocksmith commented Jan 3, 2026

Merged and update with CI build step added to this PR directly. Closed #31606 (review)

Can't wait to see how this benchmarks!

@mergify
Copy link

mergify bot commented Feb 2, 2026

Hi @clocksmith, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

1 similar comment
@mergify
Copy link

mergify bot commented Feb 3, 2026

Hi @clocksmith, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link

mergify bot commented Feb 4, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @clocksmith.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link

mergify bot commented Feb 9, 2026

Hi @clocksmith, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

4 similar comments
@mergify
Copy link

mergify bot commented Feb 10, 2026

Hi @clocksmith, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link

mergify bot commented Feb 10, 2026

Hi @clocksmith, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link

mergify bot commented Feb 10, 2026

Hi @clocksmith, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link

mergify bot commented Feb 10, 2026

Hi @clocksmith, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@clocksmith
Copy link
Author

clocksmith commented Feb 10, 2026

Still a mismatch that I am debugging, probably in activiation, but getting closer:

FAILED tests/kernels/moe/test_sonic_moe.py::test_sonic_moe_vs_triton[dtype0-8-2-256-2048-512] - AssertionError: Diff exceeded 1%: 0.03598732530068127

@mergify
Copy link

mergify bot commented Feb 10, 2026

Hi @clocksmith, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

2 similar comments
@mergify
Copy link

mergify bot commented Feb 10, 2026

Hi @clocksmith, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link

mergify bot commented Feb 11, 2026

Hi @clocksmith, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link

mergify bot commented Feb 11, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @clocksmith.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@clocksmith
Copy link
Author

WOOHOO!

dtype,e,topk,m,k,n,triton_us,sonic_us,speedup,mode,rel_err
  bf16,8,2,256,512,4096,313.71,710.66,0.441,eager/eager,0.0045
  fp16,8,2,256,512,4096,313.56,709.20,0.442,eager/eager,0.0005
  bf16,8,4,512,1024,8192,4510.59,716.38,6.296,eager/eager,0.0121
  fp16,8,4,512,1024,8192,4447.49,727.15,6.116,eager/eager,0.0014
  {
   "device": "NVIDIA H100 80GB HBM3",
   "results": [
    {
     "dtype": "bf16",
     "m": 256,
     "k": 512,
     "n": 4096,
     "e": 8,
     "topk": 2,
     "triton_us": 313.7094306945801,
     "sonic_us": 710.6578826904297,
     "speedup": 0.4414352367512334,
     "mode": "eager/eager",
     "rel_err": 0.00445556640625
    },
    {
     "dtype": "fp16",
     "m": 256,
     "k": 512,
     "n": 4096,
     "e": 8,
     "topk": 2,
     "triton_us": 313.56224060058594,
     "sonic_us": 709.2038726806641,
     "speedup": 0.4421327247063339,
     "mode": "eager/eager",
     "rel_err": 0.0005464553833007812
    },
    {
     "dtype": "bf16",
     "m": 512,
     "k": 1024,
     "n": 8192,
     "e": 8,
     "topk": 4,
     "triton_us": 4510.5926513671875,
     "sonic_us": 716.3763427734375,
     "speedup": 6.2964009027776005,
     "mode": "eager/eager",
     "rel_err": 0.01214599609375
    },
    {
     "dtype": "fp16",
     "m": 512,
     "k": 1024,
     "n": 8192,
     "e": 8,
     "topk": 4,
     "triton_us": 4447.4859619140625,
     "sonic_us": 727.1507263183594,
     "speedup": 6.116319218206866,
     "mode": "eager/eager",
     "rel_err": 0.0013608932495117188
    }
   ]
  }

H100 benchmark update (tuned Triton configs, eager/eager):

 - e=8, topk=2, m=256, k=512, n=4096
   - bf16: Triton 313.71us, Sonic 710.66us (0.44x)
   - fp16: Triton 313.56us, Sonic 709.20us (0.44x)
 - e=8, topk=4, m=512, k=1024, n=8192
   - bf16: Triton 4510.59us, Sonic 716.38us (6.30x)
   - fp16: Triton 4447.49us, Sonic 727.15us (6.12x)

X added 10 commits February 15, 2026 16:35
Signed-off-by: X <x@simulatte.world>
Signed-off-by: X <x@simulatte.world>
Signed-off-by: X <x@simulatte.world>
Signed-off-by: X <x@simulatte.world>
Signed-off-by: X <x@simulatte.world>
Signed-off-by: X <x@simulatte.world>
Signed-off-by: X <x@simulatte.world>
Signed-off-by: X <x@simulatte.world>
@clocksmith
Copy link
Author

I cannot get into vllm slack; what shall be done with this PR? It can be safely merged and not used, then decided on how to route to prefill-only next?

@mergify
Copy link

mergify bot commented Mar 16, 2026

Hi @clocksmith, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Copy link
Member

@zyongye zyongye left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few minor comment but the structure is LTGM. cc @yzong-rh who is working on unquantized oracle refactor #36732 Can we try merge this first?

Comment on lines +87 to +130
sonic_requested = envs.VLLM_USE_SONIC_MOE
sonic_supported = False
if sonic_requested:
from vllm.model_executor.layers.fused_moe.sonic_moe import (
is_sonic_moe_supported,
)

sonic_supported = is_sonic_moe_supported()
sonic_enabled = (
sonic_supported
and sonic_requested
and is_act_and_mul
and not has_bias
and not use_ep
and not moe_config.moe_parallel_config.is_sequence_parallel
and moe_config.experts_per_token <= 16
and moe_config.in_dtype in (torch.float16, torch.bfloat16)
and moe_config.activation in ("silu", "silu_and_mul")
)
if sonic_requested and sonic_supported and not sonic_enabled:
if use_ep:
logger.debug_once(
"Sonic MoE disabled because expert parallelism is enabled."
)
elif has_bias:
logger.debug_once("Sonic MoE disabled because MoE biases are enabled.")
elif not is_act_and_mul:
logger.debug_once("Sonic MoE disabled because is_act_and_mul is False.")
elif moe_config.moe_parallel_config.is_sequence_parallel:
logger.debug_once(
"Sonic MoE disabled because sequence parallelism is enabled."
)
elif moe_config.experts_per_token > 16:
logger.debug_once("Sonic MoE disabled because topk > 16.")
elif moe_config.in_dtype not in (torch.float16, torch.bfloat16):
logger.debug_once(
"Sonic MoE disabled because input dtype is unsupported: %s",
moe_config.in_dtype,
)
elif moe_config.activation not in ("silu", "silu_and_mul"):
logger.debug_once(
"Sonic MoE disabled because activation is unsupported: %s",
moe_config.activation,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put all of this in sonic_moe.py?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it would be great if you could move this into SonicMoeExperts::_supports_current_device, _supports_no_act_and_mul, etc. and rely on FusedMoEExperts::is_supported_config check for Sonic support

Comment on lines +95 to +101
- pip install -r /tmp/sonic-moe/requirements.txt
# Override SonicMoE pins: older CUTLASS DSL can generate invalid LLVM IR on H100;
# newer CUTLASS DSL triggers import-time annotation evaluation errors in SonicMoE.
- pip install -U nvidia-cutlass-dsl==4.3.5
- pip install -U --force-reinstall --no-deps git+https://github.com/Dao-AILab/quack.git@4210c0abcb20a7126775661640e79de425a55206
- pip install -e /tmp/sonic-moe --no-deps
- pytest -v -s kernels/moe/test_sonic_moe.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this installation will mess up with other package. Since it's not default enabled maybe we don't need it in the CI?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f"got {activation}"
)

w1_sonic, w2_sonic = self._ensure_weights_ready(w1, w2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move it out from the critical path? It can be done completely at weight loading time

) from None


def sonic_moe_forward(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function solely used for test?

@zyongye
Copy link
Member

zyongye commented Mar 18, 2026

Also please fix the pre-commit

Copy link

@yzong-rh yzong-rh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution.

There has been some refactoring which caused those pre-commit complaints.

  • FusedMoEModularKernel -> FusedMoEKernel (now supports both modular and monolithic)
  • MoEPrepareAndFinalizeNoEP -> MoEPrepareAndFinalizeNoDPEP (more accurate naming)
  • FusedMoEPermuteExpertsUnpermute -> FusedMoEExpertsModular or FusedMoEExpertsMonolithic

I think a lot of it was done in #32564
FlashInferExperts is a good example of the new flow.

Comment on lines +87 to +130
sonic_requested = envs.VLLM_USE_SONIC_MOE
sonic_supported = False
if sonic_requested:
from vllm.model_executor.layers.fused_moe.sonic_moe import (
is_sonic_moe_supported,
)

sonic_supported = is_sonic_moe_supported()
sonic_enabled = (
sonic_supported
and sonic_requested
and is_act_and_mul
and not has_bias
and not use_ep
and not moe_config.moe_parallel_config.is_sequence_parallel
and moe_config.experts_per_token <= 16
and moe_config.in_dtype in (torch.float16, torch.bfloat16)
and moe_config.activation in ("silu", "silu_and_mul")
)
if sonic_requested and sonic_supported and not sonic_enabled:
if use_ep:
logger.debug_once(
"Sonic MoE disabled because expert parallelism is enabled."
)
elif has_bias:
logger.debug_once("Sonic MoE disabled because MoE biases are enabled.")
elif not is_act_and_mul:
logger.debug_once("Sonic MoE disabled because is_act_and_mul is False.")
elif moe_config.moe_parallel_config.is_sequence_parallel:
logger.debug_once(
"Sonic MoE disabled because sequence parallelism is enabled."
)
elif moe_config.experts_per_token > 16:
logger.debug_once("Sonic MoE disabled because topk > 16.")
elif moe_config.in_dtype not in (torch.float16, torch.bfloat16):
logger.debug_once(
"Sonic MoE disabled because input dtype is unsupported: %s",
moe_config.in_dtype,
)
elif moe_config.activation not in ("silu", "silu_and_mul"):
logger.debug_once(
"Sonic MoE disabled because activation is unsupported: %s",
moe_config.activation,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it would be great if you could move this into SonicMoeExperts::_supports_current_device, _supports_no_act_and_mul, etc. and rely on FusedMoEExperts::is_supported_config check for Sonic support

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build performance Performance-related issues

Projects

Status: No status
Status: No status
Status: No status
Status: To Triage

Development

Successfully merging this pull request may close these issues.

7 participants