Skip to content

[ROCm][Perf] Add Fused Shared Expert (FSE) support for Qwen3-Next#39280

Merged
robertgshaw2-redhat merged 15 commits into
vllm-project:mainfrom
nholmber:pr/fse-qwen3next-v2
May 8, 2026
Merged

[ROCm][Perf] Add Fused Shared Expert (FSE) support for Qwen3-Next#39280
robertgshaw2-redhat merged 15 commits into
vllm-project:mainfrom
nholmber:pr/fse-qwen3next-v2

Conversation

@nholmber
Copy link
Copy Markdown
Contributor

@nholmber nholmber commented Apr 8, 2026

Purpose

Fuse shared expert into the AITER MoE kernel as an extra expert slot when VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=1, eliminating the separate shared expert MLP forward pass and greatly improving decode throughput.

The router gate [num_experts, hidden] and shared expert gate [num_shared, hidden] weight matrices are fused into a single [num_experts + num_shared, hidden] matrix at init. One F.linear call produces combined logits, and the topk_softmax kernel applies routing softmax and shared expert activation (sigmoid) in a single launch; no extra kernel launches for the shared expert gate projection, activation, or buffer copy.

Changes:

  • qwen3_next.py: Model-level FSE wiring (init, weight loading, expert mapping, forward tuple unpack for SharedFusedMoE compatibility)
  • qwen3_next_mtp.py: MTP weight loading for fused expert slot
  • moe_runner_base.py: Lazy gate weight fusion in forward_dispatch(); thread num_fused_shared_experts through routing
  • _aiter_ops.py: Extend topk_softmax with num_shared_experts and shared_expert_scoring_func params; add runtime version check for graceful fallback with older AITER
  • fused_topk_router.py: Fused kernel dispatch path + non-fused fallback (separate softmax, sigmoid, inject)
  • base_router.py + router subclasses: Add num_fused_shared_experts param to _compute_routing() interface
  • rocm_aiter_fused_moe.py: inject_shared_expert_weights() for merging routed topk results with the shared expert buffer

Test Plan

  • Model: Qwen/Qwen3-Next-80B-A3B-Instruct-FP8
  • Container: vllm/vllm-openai-rocm:v0.19.0
  • Hardware: MI355X, ROCm 7.2.1
  • AITER: 02d8af55e (with 7-arg topk_softmax support), stock version that uses non-fused topk + sigmoid also tested
  • Accuracy: GSM8K 8-shot flexible-extract (FSE=0 baseline vs FSE=1)
  • Throughput: vllm bench serve, random 1k input / 1k output at c4/c8/c16/c32

Sample commands

docker run --name fse-test -d \
  --device /dev/dri --device /dev/kfd \
  --group-add video --ipc host --network host \
  --security-opt seccomp=unconfined --shm-size 64G \
  --entrypoint "" \
  -e HIP_VISIBLE_DEVICES=0 \
  -v $HOME/.cache/huggingface:/root/.cache/huggingface \
  vllm/vllm-openai-rocm:v0.19.0 sleep infinity
# Install this branch
docker exec -it fse-test bash
pip install git+https://github.com/nholmber/vllm.git@pr/fse-qwen3next-v2 --no-build-isolation
pip install lm_eval[api]

# Start the server
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=<0|1> \
vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 \
  --gpu-memory-utilization 0.95 \
  --max-model-len 16384 \
  --max-num-seqs 256 \
  --attention-backend ROCM_AITER_FA \
  --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE"}'

Test Result

Accuracy (lm_eval GSM8K 8-shot, flexible-extract)

Config Score Stderr
TP1 FSE=0 0.8537 ±0.0097
TP1 FSE=1 0.8567 ±0.0097
TP2 FSE=0 0.8484 ±0.0099
TP2 FSE=1 0.8522 ±0.0098

Verdict: All deltas within standard error. No accuracy regression.

Throughput (output tok/s, 1k input / 1k output)

TP1:

Concurrency FSE=0 (tok/s) FSE=1 (tok/s) Speedup
4 458.5 557.0 +21.5%
8 854.9 1037.4 +21.3%
16 1535.8 1855.6 +20.8%
32 2632.4 3060.3 +16.3%

TP2:

Concurrency FSE=0 (tok/s) FSE=1 (tok/s) Speedup
4 463.5 568.0 +22.5%
8 855.6 1048.8 +22.6%
16 1638.0 2026.8 +23.7%
32 2930.0 3557.1 +21.4%

Verdict: FSE provides +16–24% output throughput improvement across
concurrency levels and TP configurations.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify Bot added qwen Related to Qwen models rocm Related to AMD ROCm labels Apr 8, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 8, 2026
@vadiklyutiy vadiklyutiy removed their request for review April 8, 2026 08:33
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Fused Shared Expert (FSE) support for ROCm AITER MoE kernels, specifically targeting Qwen3Next models. The changes include a new weight injection mechanism in the MoE runner and logic to remap shared expert weights to fused expert slots during model loading. Review feedback highlights two critical issues in qwen3_next.py: a crash-inducing logic error when unpacking the output of SharedFusedMoE in the default case, and a potential TypeError caused by passing None instead of 0 for the number of shared experts.

Comment on lines +196 to +199
if self.shared_expert is not None:
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
elif self.is_fse_enabled:
_, final_hidden_states = final_hidden_states
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current logic for unpacking the result from SharedFusedMoE is broken for the default case where both shared_expert and is_fse_enabled are False. Since SharedFusedMoE.forward always returns a tuple (either (shared, fused) or (None, fused)), final_hidden_states will remain a tuple if both conditions are False, causing a crash in the subsequent .view() call or all_gather operation. The logic should be simplified to always unpack the second element when shared_expert is None.

Suggested change
if self.shared_expert is not None:
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
elif self.is_fse_enabled:
_, final_hidden_states = final_hidden_states
if self.shared_expert is not None:
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
else:
_, final_hidden_states = final_hidden_states

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it's fixed

enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
n_shared_experts=1 if self.is_fse_enabled else None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Passing None for n_shared_experts when FSE is disabled will cause a TypeError in moe_runner_base.py during the comparison if num_fused_shared > 0:. It should default to 0 instead of None to ensure compatibility with the runner's logic and the AITER metadata initialization.

Suggested change
n_shared_experts=1 if self.is_fse_enabled else None,
n_shared_experts=1 if self.is_fse_enabled else 0,

assert shared_experts_input is not None
self._shared_experts.apply(shared_experts_input, order)

def _inject_fse_weights(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite prefer injection. And this fused expert is not a new feature, it was first introduced in DeepSeekV3

Can you try to implement following the approach taken by DeepSeek
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
?

Another thing, I would also like @robertgshaw2-redhat feedback regarding to this PR.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback. Looking into refactoring this to use the same approach taken in Deepseek

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite prefer injection. And this fused expert is not a new feature, it was first introduced in DeepSeekV3

Can you try to implement following the approach taken by DeepSeek https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py ?

Another thing, I would also like @robertgshaw2-redhat feedback regarding to this PR.

I'd vote for the DeepSeek approach as well.

@nholmber nholmber force-pushed the pr/fse-qwen3next-v2 branch 2 times, most recently from 6493060 to 554600d Compare April 13, 2026 10:56
@nholmber
Copy link
Copy Markdown
Contributor Author

@tjtanaa the PR has been revised and description updated. Could you review it again? @ChuanLi1101 @dllehr-amd could you also take a look?

The PR now covers:

  • Fuse gate projection for shared and routed experts
  • Fuse shared expert scoring function into routed expert topk-softmax (new AITER kernel with fallback)
  • Fuse shared expert into routed experts for MoE

Coming back to your question about re-using the DeepSeekV3.2 shared expert fusion, the main difference is that Qwen3-Next has a learned shared_expert_gate (a per-token sigmoid gate on the shared expert output), whereas DeepSeek always includes the shared expert with weight 1.0. This gate is why we need the first two optimizations: fusing the gate projection into the router matmul and fusing the sigmoid activation into the topk kernel.

Note on code placement. The changes follow the existing runner/router separation rather than living in FusedMoE.apply():

  • Gate fusion → runner (moe_runner_base.py): the runner already owns the gate modules
  • Fused scoring → router (fused_topk_router.py): routing/expert selection is the router's responsibility
  • Expert computation (apply()) is untouched: it receives the same (topk_weights, topk_ids) interface regardless of whether they came from fused or separate kernels

@nholmber nholmber requested a review from tjtanaa April 13, 2026 11:10
@nholmber nholmber force-pushed the pr/fse-qwen3next-v2 branch from 554600d to c49851e Compare April 23, 2026 14:41
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @nholmber.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 23, 2026
@nholmber nholmber force-pushed the pr/fse-qwen3next-v2 branch from c49851e to 9110fd9 Compare April 23, 2026 21:50
@mergify mergify Bot removed the needs-rebase label Apr 23, 2026
@nholmber nholmber force-pushed the pr/fse-qwen3next-v2 branch from 9110fd9 to 3344962 Compare April 23, 2026 21:52
@tjtanaa
Copy link
Copy Markdown
Member

tjtanaa commented May 4, 2026

@nholmber can you help to rebase the PR. Thanks.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 4, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @nholmber.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 4, 2026
@nholmber nholmber force-pushed the pr/fse-qwen3next-v2 branch from 3344962 to 5e26cf4 Compare May 4, 2026 14:18
@mergify mergify Bot removed the needs-rebase label May 4, 2026
@nholmber
Copy link
Copy Markdown
Contributor Author

nholmber commented May 4, 2026

Rebased

@nholmber nholmber force-pushed the pr/fse-qwen3next-v2 branch from 2293eae to 66fe572 Compare May 4, 2026 22:01
Signed-off-by: Doug Lehr <douglehr@amd.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 8, 2026

Hi @nholmber, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

…outer

Move the import of aiter_topK_meta_data from module level into the
_compute_routing method body. The module-level import captured the
initial None value and never saw the reassignment by
init_aiter_topK_meta_data, causing shared expert weights to be silently
dropped and a ~33 point accuracy regression on gsm8k.

Also remove unused fse_fuse_gate variable in layer.py and fix E501 line
length in router_factory.py.

Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 8, 2026

Hi @nholmber, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
n_shared_experts=1,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tpopp We are binding the n_shared_experts and shared_expert_gate to FusedMoE now without any checks. This may be the cause of the fail?

When FSE is disabled (non-ROCm or env var off), the shared expert
is handled by the model's own MLP. Passing shared_expert_gate to
FusedMoE in that case caused _fse_fuse_gate to activate, fusing
gate weights into [num_experts+1, hidden] and corrupting routing.

Set shared_expert_gate=None and n_shared_experts=None in the
non-FSE path so FusedMoE does not attempt gate fusion.

Fixes test_hybrid[tiny-random/qwen3-next-moe] regression.

Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled
else 0
)
self.shared_expert_gate = shared_expert_gate
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems unnessrary to have this attribute?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to wait for CI to finish before pushing anything else. I'm happy to remove it. This is consistent with some other attributes that aren't used elsewhere and that was the reason for this. I thought there might be debugging or other reasons that most construction args are saved as attributes.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i know, i hate all those old attrs since it makes it hard to tell what "owns" the object

)

shared_weights = torch.sigmoid(shared_logits)
topk_weights, topk_ids = inject_shared_expert_weights(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems ot me this inject_shared_experts_weight function should be defined in this file

)

if (
num_fused_shared_experts > 0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if num_fused_shared_experts > 0 and either scoring_func != softmax or is not aiter?

should we just reject?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently we take FusedTopKRouter. Which is what happened prior as well. So I think we're okay on that front. It's not a change in behavior in the router unless the specific 3 conditions here are met

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please open a github issue to audit and guard this for future so we have a clear view of what does and does not work

@tpopp
Copy link
Copy Markdown
Contributor

tpopp commented May 8, 2026

@robertgshaw2-redhat I've created #42088. Can you or @dllehr-amd assign it to me?

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

test failures unreleated. passes all key moe tests

@robertgshaw2-redhat robertgshaw2-redhat merged commit 2c6b59b into vllm-project:main May 8, 2026
28 of 80 checks passed
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD May 8, 2026
weifang231 pushed a commit to weifang231/eb-vllm that referenced this pull request May 13, 2026
…lm-project#39280)

Signed-off-by: nholmber <nholmber@users.noreply.github.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Signed-off-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: nholmber <nholmber@users.noreply.github.com>
Co-authored-by: Tres <tpopp@users.noreply.github.com>
Co-authored-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
…lm-project#39280)

Signed-off-by: nholmber <nholmber@users.noreply.github.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Signed-off-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: nholmber <nholmber@users.noreply.github.com>
Co-authored-by: Tres <tpopp@users.noreply.github.com>
Co-authored-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
…lm-project#39280)

Signed-off-by: nholmber <nholmber@users.noreply.github.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Signed-off-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: nholmber <nholmber@users.noreply.github.com>
Co-authored-by: Tres <tpopp@users.noreply.github.com>
Co-authored-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
…lm-project#39280)

Signed-off-by: nholmber <nholmber@users.noreply.github.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Signed-off-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: nholmber <nholmber@users.noreply.github.com>
Co-authored-by: Tres <tpopp@users.noreply.github.com>
Co-authored-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
omirosh added a commit to omirosh/vllm that referenced this pull request Jun 5, 2026
## Purpose

Extend the AITER Fused Shared Expert (FSE) path - originally added for
DeepSeek-V2/V3 (vllm-project#28540) and Qwen3-Next (vllm-project#39280) - to the GLM-4 MoE family
(GLM-4.5, GLM-4.6, GLM-4.7). When `VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=1`
the shared expert is folded into the AITER FusedMoE kernel as
`n_shared_experts` extra expert slots, eliminating the separate shared-expert
MLP forward pass at low/medium concurrency.

## Changes

Single-file model wiring in `vllm/model_executor/models/glm4_moe.py`, mirroring
the canonical `deepseek_v2.py` FSE pattern:

* `Glm4MoE.__init__`
  - Cache `is_rocm_aiter_moe_enabled` and `is_fusion_moe_shared_experts_enabled`
    from `rocm_aiter_ops`.
  - When FSE is enabled, skip building the separate `shared_experts` MLP and
    pass `n_shared_experts=config.n_shared_experts` to `FusedMoE` so the
    AITER kernel routes the shared expert(s) as extra slots in the routed
    tensor.
  - Switch `apply_routed_scale_to_output` to
    `not self.is_rocm_aiter_moe_enabled`. AITER applies `routed_scaling_factor`
    internally, per routed slot; applying it again post-fusion would also
    scale the FSE shared-expert slot (which the kernel inserts with unit
    weight), producing a structural magnitude error in every MoE layer.
    This matches `deepseek_v2.py`. (`routed_scaling_factor=2.5` for GLM-4.7,
    so the unfixed path showed a ~48 pp gsm8k regression.)

* `Glm4MoeModel.get_expert_mapping`
  - Widen `num_experts` by `config.n_shared_experts` when FSE is on so the
    weight loader enumerates the appended slots.

* `Glm4MoeModel.load_weights`
  - Treat `mlp.shared_experts.{gate,up,down}_proj.*` as expert-style tensors
    when FSE is on (skip the stacked QKV/gate_up linear path).
  - Split each widened shared-expert tensor into `n_shared_experts` chunks
    along the intermediate-size axis (dim 0 for ColumnParallel
    gate/up_proj, dim 1 for RowParallel down_proj) and route each chunk to
    `mlp.experts.{n_routed_experts + j}.*` via the FusedMoE expert-aware
    weight loader.

No changes to FusedMoE / AITER plumbing - all of that landed earlier with
vllm-project#39280 (Qwen3-Next FSE).

## Test Plan

* Model: `zai-org/GLM-4.7-FP8`
* Hardware: 1x MI355X node, TP=4
* Container: ROCm vLLM image (AITER >= v0.1.13.post1, PR vllm-project#44265)
* Accuracy: `lm_eval --tasks gsm8k --num_fewshot 5`
* Throughput: `vllm bench serve --dataset-name random` sweep over
  (ISL, OSL, MC) in {1000/100, 5000/500, 10000/1000} x {4, 16, 64}

Server launch:

```
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=<0|1> \
vllm serve zai-org/GLM-4.7-FP8 \
  --tensor-parallel-size 4 \
  --gpu-memory-utilization 0.92 \
  --max-model-len 32768 \
  --max-num-seqs 256
```

## Test Result

### Accuracy (gsm8k, 5-shot, exact_match)

| Config              | flexible-extract | strict-match     |
|---------------------|-----------------:|-----------------:|
| FSE=0 (baseline)    | 0.9469 ± 0.0062  | 0.9439 ± 0.0063  |
| FSE=1               | 0.9439 ± 0.0063  | 0.9416 ± 0.0065  |

All deltas within standard error. No accuracy regression.

### Throughput (`vllm bench serve`, random)

| ISL  | OSL  | MC | TPOT mean (ms) FSE=0 -> FSE=1 (Δ) | TPOT p99 (ms) FSE=0 -> FSE=1 (Δ) | Output tok/s FSE=0 -> FSE=1 (Δ) | Total tok/s FSE=0 -> FSE=1 (Δ) |
|-----:|-----:|---:|----------------------------------:|---------------------------------:|--------------------------------:|-------------------------------:|
|  1000|   100|   4| 17.76 -> 14.36  (**-19.2%**)      | 19.43 -> 15.93 (**-18.0%**)      | 199.4 -> 243.6  (**+22.1%**)    | 2193.7 -> 2679.1 (**+22.1%**)  |
|  1000|   100|  16| 20.96 -> 18.48  (**-11.9%**)      | 24.29 -> 22.77 (-6.3%)           | 631.0 -> 673.4  (**+6.7%**)     | 6940.6 -> 7407.9 (**+6.7%**)   |
|  1000|   100|  64| 30.74 -> 30.23  (-1.7%)           | 42.85 -> 43.44 (+1.4%)           | 1452.7 -> 1424.3 (-2.0%)        | 15980.1 -> 15667.6 (-2.0%)     |
|  5000|   500|   4| 17.82 -> 14.50  (**-18.7%**)      | 18.63 -> 15.50 (**-16.8%**)      | 211.5 -> 253.5  (**+19.9%**)    | 2326.1 -> 2788.7 (**+19.9%**)  |
|  5000|   500|  16| 22.73 -> 20.76  (**-8.7%**)       | 25.38 -> 23.07 (**-9.1%**)       | 619.1 -> 657.7  (**+6.2%**)     | 6810.4 -> 7234.6 (**+6.2%**)   |
|  5000|   500|  64| 39.79 -> 40.15  (+0.9%)           | 46.15 -> 46.78 (+1.4%)           | 1363.8 -> 1339.1 (-1.8%)        | 15001.9 -> 14730.4 (-1.8%)     |
| 10000|  1000|   4| 18.00 -> 14.70  (**-18.3%**)      | 18.68 -> 15.50 (**-17.0%**)      | 210.3 -> 251.8  (**+19.7%**)    | 2313.5 -> 2769.4 (**+19.7%**)  |
| 10000|  1000|  16| 24.47 -> 22.87  (-6.5%)           | 26.66 -> 25.56 (-4.1%)           | 589.6 -> 615.1  (**+4.3%**)     | 6485.6 -> 6766.2 (**+4.3%**)   |
| 10000|  1000|  64| 46.37 -> 46.33  (-0.1%)           | 51.14 -> 51.78 (+1.3%)           | 1233.6 -> 1211.9 (-1.8%)        | 13570.0 -> 13330.7 (-1.8%)     |

Verdict: FSE delivers +20-22% output throughput and -18-19% TPOT at low
concurrency (MC=4), modest gains at MC=16, and is roughly break-even
(<2% regression) at MC=64. No accuracy regression.

Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Olga Miroshnichenko <olga.miroshnichenko@amd.com>
omirosh added a commit to omirosh/vllm that referenced this pull request Jun 5, 2026
## Purpose

Extend the AITER Fused Shared Expert (FSE) path - originally added for
DeepSeek-V2/V3 (vllm-project#28540) and Qwen3-Next (vllm-project#39280) - to the GLM-4 MoE family
(GLM-4.5, GLM-4.6, GLM-4.7). When `VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=1`
the shared expert is folded into the AITER FusedMoE kernel as
`n_shared_experts` extra expert slots, eliminating the separate shared-expert
MLP forward pass at low/medium concurrency.

## Changes

Single-file model wiring in `vllm/model_executor/models/glm4_moe.py`, mirroring
the canonical `deepseek_v2.py` FSE pattern:

* `Glm4MoE.__init__`
  - Cache `is_rocm_aiter_moe_enabled` and `is_fusion_moe_shared_experts_enabled`
    from `rocm_aiter_ops`.
  - When FSE is enabled, skip building the separate `shared_experts` MLP and
    pass `n_shared_experts=config.n_shared_experts` to `FusedMoE` so the
    AITER kernel routes the shared expert(s) as extra slots in the routed
    tensor.
  - Switch `apply_routed_scale_to_output` to
    `not self.is_rocm_aiter_moe_enabled`. AITER applies `routed_scaling_factor`
    internally, per routed slot; applying it again post-fusion would also
    scale the FSE shared-expert slot (which the kernel inserts with unit
    weight), producing a structural magnitude error in every MoE layer.
    This matches `deepseek_v2.py`. (`routed_scaling_factor=2.5` for GLM-4.7,
    so the unfixed path showed a ~48 pp gsm8k regression.)

* `Glm4MoeModel.get_expert_mapping`
  - Widen `num_experts` by `config.n_shared_experts` when FSE is on so the
    weight loader enumerates the appended slots.

* `Glm4MoeModel.load_weights`
  - Treat `mlp.shared_experts.{gate,up,down}_proj.*` as expert-style tensors
    when FSE is on (skip the stacked QKV/gate_up linear path).
  - Split each widened shared-expert tensor into `n_shared_experts` chunks
    along the intermediate-size axis (dim 0 for ColumnParallel
    gate/up_proj, dim 1 for RowParallel down_proj) and route each chunk to
    `mlp.experts.{n_routed_experts + j}.*` via the FusedMoE expert-aware
    weight loader.

No changes to FusedMoE / AITER plumbing - all of that landed earlier with
vllm-project#39280 (Qwen3-Next FSE).

## Test Plan

* Model: `zai-org/GLM-4.7-FP8`
* Hardware: 1x MI355X node, TP=4
* Container: ROCm vLLM image (AITER >= v0.1.13.post1, PR vllm-project#44265)
* Accuracy: `lm_eval --tasks gsm8k --num_fewshot 5`
* Throughput: `vllm bench serve --dataset-name random` sweep over
  (ISL, OSL, MC) in {1000/100, 5000/500, 10000/1000} x {4, 16, 64}

Server launch:

```
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=<0|1> \
vllm serve zai-org/GLM-4.7-FP8 \
  --tensor-parallel-size 4 \
  --gpu-memory-utilization 0.92 \
  --max-model-len 32768 \
  --max-num-seqs 256
```

## Test Result

### Accuracy (gsm8k, 5-shot, exact_match)

| Config              | flexible-extract | strict-match     |
|---------------------|-----------------:|-----------------:|
| FSE=0 (baseline)    | 0.9469 ± 0.0062  | 0.9439 ± 0.0063  |
| FSE=1               | 0.9439 ± 0.0063  | 0.9416 ± 0.0065  |

All deltas within standard error. No accuracy regression.

### Throughput (`vllm bench serve`, random)

| ISL  | OSL  | MC | TPOT mean (ms) FSE=0 -> FSE=1 (Δ) | TPOT p99 (ms) FSE=0 -> FSE=1 (Δ) | Output tok/s FSE=0 -> FSE=1 (Δ) | Total tok/s FSE=0 -> FSE=1 (Δ) |
|-----:|-----:|---:|----------------------------------:|---------------------------------:|--------------------------------:|-------------------------------:|
|  1000|   100|   4| 17.76 -> 14.36  (**-19.2%**)      | 19.43 -> 15.93 (**-18.0%**)      | 199.4 -> 243.6  (**+22.1%**)    | 2193.7 -> 2679.1 (**+22.1%**)  |
|  1000|   100|  16| 20.96 -> 18.48  (**-11.9%**)      | 24.29 -> 22.77 (-6.3%)           | 631.0 -> 673.4  (**+6.7%**)     | 6940.6 -> 7407.9 (**+6.7%**)   |
|  1000|   100|  64| 30.74 -> 30.23  (-1.7%)           | 42.85 -> 43.44 (+1.4%)           | 1452.7 -> 1424.3 (-2.0%)        | 15980.1 -> 15667.6 (-2.0%)     |
|  5000|   500|   4| 17.82 -> 14.50  (**-18.7%**)      | 18.63 -> 15.50 (**-16.8%**)      | 211.5 -> 253.5  (**+19.9%**)    | 2326.1 -> 2788.7 (**+19.9%**)  |
|  5000|   500|  16| 22.73 -> 20.76  (**-8.7%**)       | 25.38 -> 23.07 (**-9.1%**)       | 619.1 -> 657.7  (**+6.2%**)     | 6810.4 -> 7234.6 (**+6.2%**)   |
|  5000|   500|  64| 39.79 -> 40.15  (+0.9%)           | 46.15 -> 46.78 (+1.4%)           | 1363.8 -> 1339.1 (-1.8%)        | 15001.9 -> 14730.4 (-1.8%)     |
| 10000|  1000|   4| 18.00 -> 14.70  (**-18.3%**)      | 18.68 -> 15.50 (**-17.0%**)      | 210.3 -> 251.8  (**+19.7%**)    | 2313.5 -> 2769.4 (**+19.7%**)  |
| 10000|  1000|  16| 24.47 -> 22.87  (-6.5%)           | 26.66 -> 25.56 (-4.1%)           | 589.6 -> 615.1  (**+4.3%**)     | 6485.6 -> 6766.2 (**+4.3%**)   |
| 10000|  1000|  64| 46.37 -> 46.33  (-0.1%)           | 51.14 -> 51.78 (+1.3%)           | 1233.6 -> 1211.9 (-1.8%)        | 13570.0 -> 13330.7 (-1.8%)     |

Verdict: FSE delivers +20-22% output throughput and -18-19% TPOT at low
concurrency (MC=4), modest gains at MC=16, and is roughly break-even
(<2% regression) at MC=64. No accuracy regression.

Signed-off-by: Olga Miroshnichenko <olga.miroshnichenko@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants