Skip to content

[ROCm] Enable dual-stream MoE shared experts, AITER sparse MLA workaround, and GLM-5-FP8 weight loading fix#38665

Open
ChuanLi1101 wants to merge 4 commits intovllm-project:mainfrom
ChuanLi1101:fix/rocm-glm5-mxfp4-optimizations
Open

[ROCm] Enable dual-stream MoE shared experts, AITER sparse MLA workaround, and GLM-5-FP8 weight loading fix#38665
ChuanLi1101 wants to merge 4 commits intovllm-project:mainfrom
ChuanLi1101:fix/rocm-glm5-mxfp4-optimizations

Conversation

@ChuanLi1101
Copy link
Copy Markdown
Collaborator

@ChuanLi1101 ChuanLi1101 commented Mar 31, 2026

Summary

Targeted changes to improve GLM-5 MXFP4/FP8 inference on ROCm (AMD MI355X):

  • Enable dual-stream MoE shared expert overlap on ROCm: The shared experts stream guard in SharedExpertsLogic used current_platform.is_cuda(), restricting dual-stream execution to NVIDIA only. Changed to is_cuda_alike() so ROCm/HIP streams are used as well.

  • Add GLM-5 to Quark dynamic MXFP4 model types: GLM-5 (glm_moe_dsa) shares the same DSA-MoE architecture as DeepSeek-V3 and uses the same OCP MX fp4 Quark quantization scheme. Added it to _DEEPSEEK_V3_FAMILY_MODEL_TYPES so its attention projections use dynamic MXFP4 re-quantization.

  • Work around AITER sparse MLA ZeroDivisionError for < 16 heads: AITER's deepgemm_fp8_paged_mqa_logits_stage1 kernel computes TileQCount from num_heads; when heads < 16 (e.g. GLM-5 with TP=8 giving 8 heads per GPU), TileQCount becomes 0. Guard both rocm_fp8_paged_mqa_logits and rocm_fp8_mqa_logits to fall back to PyTorch reference when heads < 16. Tracked upstream: ZeroDivisionError in deepgemm_fp8_paged_mqa_logits_stage1 with 8 attention heads (GLM-5 TP=8) ROCm/aiter#2563

  • Fix GLM-5-FP8 weight loading for fused indexer wk_weights_proj: GLM-5-FP8 checkpoints quantize the fused wk_weights_proj tensor with FP8 block quantization (weight + weight_scale_inv). The Indexer's MergedColumnParallelLinear was created with quant_config=None, causing KeyError: '...indexer.wk_weights_proj.weight_scale_inv' at load time. Pass quant_config through so the quantized parameter structure is created. For unquantized checkpoints quant_config is already None, so this is a no-op.

Context

Reference: amd/GLM-5-MXFP4, zai-org/GLM-5-FP8

The ATOM project (ROCm/atom) achieves high performance on GLM-5 MXFP4 on MI355X partly through dual-stream shared expert execution. This PR ports that optimization to vLLM and fixes GLM-5-FP8 weight loading.

AI assistance (Claude) was used. The submitting human has reviewed all changed lines.

Not duplicating existing PRs: PR #35968 (DeepSeek V3.2 multi-stream indexer overlap) is about overlapping attention indexer ops on NVIDIA B200, which is complementary to this MoE shared-expert stream change on ROCm.

Test plan

  • Serve GLM-5-MXFP4 on MI355X (TP=8) with --enforce-eager and verify server starts
  • Serve GLM-5-FP8 on MI355X (TP=8) and verify weight loading succeeds (no KeyError)
  • Run vllm bench serve with baseline vs this PR and compare output throughput
  • Verify DeepSeek-V3 MXFP4 is not regressed on ROCm
  • Verify NVIDIA CUDA path is unaffected (is_cuda_alike is a superset of is_cuda)

@mergify mergify Bot added the rocm Related to AMD ROCm label Mar 31, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Mar 31, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the MoE runner to use is_cuda_alike() for platform compatibility checks and extends Quark quantization support to include the glm_moe_dsa model type, which belongs to the DSA-MoE architecture family. I have no feedback to provide as there are no review comments to evaluate.

@ChuanLi1101
Copy link
Copy Markdown
Collaborator Author

Benchmark Status Update

Benchmarking on MI355X (TP=8) is currently blocked by an upstream AITER bug:

What we verified

  • Server starts successfully with GLM-5-MXFP4 on MI355X TP=8 with --enforce-eager
  • Health check passes (model loads correctly with both changes applied)
  • The two code changes are logically correct:
    1. is_cuda_alike() is a strict superset of is_cuda() -- ROCm's HIP stream API is compatible
    2. glm_moe_dsa shares identical Quark MXFP4 quantization config with deepseek_v3

Theoretical performance impact

  • Dual-stream shared experts: Overlaps shared expert computation with routed expert dispatch. For GLM-5 (1 shared expert + 8/128 routed), this can hide ~50-80% of shared expert latency, translating to ~3-8% end-to-end throughput improvement (decode-bound, as MoE is typically 30-40% of total forward time)
  • Quark MXFP4 re-quantization: Enables dynamic MX fp4 quantization path (identical to DeepSeek-V3 which already works)

Will update with benchmark numbers once the upstream AITER fix lands.

@mergify mergify Bot added the v1 label Apr 1, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 1, 2026

Hi @ChuanLi1101, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@ChuanLi1101
Copy link
Copy Markdown
Collaborator Author

Benchmark Update (follow-up)

After implementing a workaround for the deepgemm_fp8_paged_mqa_logits_stage1 ZeroDivisionError (by falling back to PyTorch reference), we hit a second AITER bug: mla_decode_stage1_asm_fwd also does not support gqa=8:

RuntimeError: get_heuristic_kernel_mla: cannot get heuristic kernel! q_type:bf16 kv_type:bf16 gqa:8 ps:0 prefill:0 causal:0 qseqlen:1

The dense MLA backend (rocm_aiter_mla.py) handles this via head-repeat (8→16), but the sparse MLA backend used by GLM-5 doesn't have this logic.

Both bugs are tracked in ROCm/aiter#2563.

What was verified

  • Server starts successfully and passes health check (model loads, weights are correct)
  • All three code changes compile and load correctly
  • The AITER workaround (falling back PyTorch reference for paged MQA logits) works as intended

Blocking issue

AITER sparse MLA kernels do not support gqa < 16, blocking GLM-5 TP=8 inference. Once the upstream AITER fix lands, benchmarks can be collected.

Workaround committed

Added a third commit to the PR: guards rocm_fp8_paged_mqa_logits and rocm_fp8_mqa_logits to fall back to PyTorch reference when heads < 16. This will be useful once the mla_decode_fwd issue is also fixed upstream.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 1, 2026

Hi @ChuanLi1101, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 1, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ChuanLi1101.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 1, 2026
@ChuanLi1101 ChuanLi1101 force-pushed the fix/rocm-glm5-mxfp4-optimizations branch from b00e823 to 1379888 Compare April 3, 2026 21:03
@mergify mergify Bot added deepseek Related to DeepSeek models and removed needs-rebase labels Apr 3, 2026
@ChuanLi1101 ChuanLi1101 changed the title [ROCm] Enable dual-stream MoE shared experts and GLM-5 MXFP4 Quark support [ROCm] Enable dual-stream MoE shared experts, GLM-5 MXFP4 Quark support, and GLM-5-FP8 weight loading fix Apr 3, 2026
ChuanLi1101 added a commit to ChuanLi1101/vllm that referenced this pull request Apr 3, 2026
Integrate AITER's mla_prefill_fwd assembly kernel into vLLM's dense
MLA backend (AiterMLAImpl), replacing the base class's chunk-by-chunk
prefill context path.

The base class computes context attention by: gathering cached KV per
chunk, expanding through kv_b_proj, running flash_attn, then merging
results.  This is expensive because kv_b_proj expansion materializes
full K/V heads for every cached token.

The new path instead:
1. Absorbs Q through W_UK^T (from kv_b_proj weights) to produce
   compressed queries in the latent space
2. Runs mla_prefill_fwd directly against the paged KV cache in a
   single kernel call (no KV expansion needed)
3. Projects the compressed output through W_UV to recover full V dims

This eliminates the O(context_len * num_heads * head_dim) KV expansion
and replaces multiple chunk iterations with a single kernel dispatch.

Falls back to the base class implementation for FP8 prefill (AITER
mla_prefill_fwd does not yet accept q_scale/kv_scale) and when AITER
is not available.

Part of the GLM-5 on ROCm performance series:
- PR vllm-project#38665: dual-stream shared experts + Quark MXFP4 + FP8 weight fix
- PR vllm-project#36855: sparse MLA head repeat + FP8 KV cache support
- This PR: MLA prefill kernel integration

Co-authored-by: Claude
Signed-off-by: Chuan Li <Chuan.Li2@amd.com>
Made-with: Cursor
@ChuanLi1101
Copy link
Copy Markdown
Collaborator Author

cc @gshtras @tjtanaa @BowenBao @wufann — Requesting expedited review. This PR is part of a series enabling GLM-5 MXFP4/FP8 inference on MI355X (customer-facing, Tencent is asking). The changes are small and targeted (4 files, 31 lines diff). Specifically fixes a GLM-5-FP8 weight loading KeyError reported during integration testing. See also #36855 and #38947 for the full series.

Copy link
Copy Markdown
Contributor

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you able to verify the change in this PR with the fix in #36855 for sparse mla?

# OCP MX fp4 Quark checkpoints
_DEEPSEEK_V3_FAMILY_MODEL_TYPES = frozenset({"deepseek_v3"})
# OCP MX fp4 Quark checkpoints (DSA-MoE architecture family)
_DEEPSEEK_V3_FAMILY_MODEL_TYPES = frozenset({"deepseek_v3", "glm_moe_dsa"})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this change be moved to separate PR? Needs standalone evaluation for effect on accuracy and performance.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point — removed glm_moe_dsa from this PR. I will open a separate PR for the Quark MXFP4 model-type addition with standalone accuracy/performance evaluation.

self.wk_weights_proj = MergedColumnParallelLinear(
hidden_size,
[self.head_dim, self.n_head],
bias=False,
quant_config=None,
quant_config=quant_config,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised, GLM-5-FP8 doesn't work with vLLM at all without this change?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes — GLM-5-FP8 checkpoints (e.g. zai-org/GLM-5-FP8) quantize the fused wk_weights_proj tensor with FP8 block quantization, producing both weight and weight_scale_inv tensors in the checkpoint. With quant_config=None, MergedColumnParallelLinear does not create the weight_scale_inv parameter in the layer state dict, so weight loading fails with: KeyError: 'model.layers.X.indexer.wk_weights_proj.weight_scale_inv'. This only affects GLM-5-FP8 (and potentially future FP8-quantized models that quantize the indexer fused projection). The original quant_config=None was correct for the initial DeepSeek-V3 use case where this projection was always unquantized, but GLM-5-FP8 checkpoints chose to quantize it.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 4, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ChuanLi1101.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 4, 2026
@ChuanLi1101
Copy link
Copy Markdown
Collaborator Author

@BowenBao Re: verifying with #36855 — the changes in this PR (dual-stream shared experts, AITER sparse MLA head<16 workaround, and FP8 weight loading fix) are orthogonal to #36855 (sparse MLA head repeat + FP8 KV cache). They can be verified independently:

  • This PR's changes: dual-stream MoE shared experts (is_cuda_alike), AITER workaround for heads<16, and FP8 indexer weight loading — all testable on DeepSeek-V3 (shared experts, weight loading) and on GLM-5 once AITER upstream fixes land.
  • [ROCm] Fix AITER sparse MLA crash for num_heads < 16 (e.g. GLM-5 TP=8) #36855: sparse MLA head repeat for num_heads<16 and FP8 KV cache support — a separate code path in rocm_aiter_mla_sparse.py.

That said, for GLM-5 end-to-end inference on MI355X, both PRs are needed together. Once #36855 is also ready, I will run a combined verification and post benchmark results on both PRs.

Also updated this PR per your feedback: removed the glm_moe_dsa Quark MXFP4 change — will submit that as a standalone PR with accuracy/performance evaluation.

…pport

Enable dual-stream shared expert overlap on ROCm by using
is_cuda_alike() instead of is_cuda() in the MoE forward path.
This allows shared experts and routed experts to execute concurrently
on separate HIP streams, matching the optimization already available
on CUDA.

Also add GLM-5 (glm_moe_dsa) to the Quark dynamic MXFP4 model types
so that its attention projections use the same dynamic re-quantization
path as DeepSeek-V3 family models.

Co-authored-by: Claude
Signed-off-by: Chuan Li <Chuan.Li2@amd.com>
Made-with: Cursor
AITER's deepgemm_fp8_paged_mqa_logits_stage1 kernel computes TileQCount
from num_heads; when heads < 16 (e.g. GLM-5 with TP=8 giving 8 heads per
GPU), TileQCount becomes 0, causing ZeroDivisionError.

Guard both rocm_fp8_paged_mqa_logits and rocm_fp8_mqa_logits to fall back
to the PyTorch reference implementation when num_heads < 16, with a
one-time warning log.

Tracked upstream: ROCm/aiter#2563

Co-authored-by: Claude
Made-with: Cursor
Signed-off-by: Li <chuali@amd.com>
Co-authored-by: Claude
Made-with: Cursor
GLM-5-FP8 checkpoints quantize the fused wk_weights_proj tensor with
FP8 block quantization (weight + weight_scale_inv).  Resolve merge
conflict with upstream indexer refactor (vllm-project#38684/vllm-project#38870) by always using
fused MergedColumnParallelLinear with quant_config:
- FP4: quant_config=None (weights_proj should not be quantized)
- Non-FP4: quant_config=quant_config (supports FP8 weight_scale_inv)

Add fallback in load_weights to handle both fused and separate checkpoint
formats gracefully via stacked_params_mapping.

Also reverts glm_moe_dsa from _DEEPSEEK_V3_FAMILY_MODEL_TYPES per
review feedback (will be submitted as a standalone PR).

Co-authored-by: Claude
Signed-off-by: Chuan Li <Chuan.Li2@amd.com>
Made-with: Cursor
@ChuanLi1101 ChuanLi1101 force-pushed the fix/rocm-glm5-mxfp4-optimizations branch from 1379888 to e40e395 Compare April 5, 2026 09:18
@ChuanLi1101 ChuanLi1101 requested a review from luccafong as a code owner April 5, 2026 09:18
@ChuanLi1101 ChuanLi1101 changed the title [ROCm] Enable dual-stream MoE shared experts, GLM-5 MXFP4 Quark support, and GLM-5-FP8 weight loading fix [ROCm] Enable dual-stream MoE shared experts, AITER sparse MLA workaround, and GLM-5-FP8 weight loading fix Apr 5, 2026
@mergify mergify Bot removed the needs-rebase label Apr 5, 2026
("wk_weights_proj", "weights_proj", 1),
]
stacked_params_mapping.extend(indexer_fused_mapping)
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will break the main checkpoint which has different dtypes for these

@robertgshaw2-redhat robertgshaw2-redhat added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 5, 2026
]
if self.is_fp4_ckpt:
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
if self.is_v32:
Copy link
Copy Markdown

@ColinZ22 ColinZ22 Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like self.is_v32 needs to be defined in the class __init__() function

Comment on lines 1312 to 1313
"gate_up_proj": ["gate_proj", "up_proj"],
}
Copy link
Copy Markdown

@ColinZ22 ColinZ22 Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since FP4 checkpoints explicitly exclude quantization using names weights_proj and wk, I believe wk_weights_proj also needs to be added to packed_modules_mapping.

Suggested change
"wk_weights_proj": ["wk", "weights_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

4 participants