[ROCm] Enable dual-stream MoE shared experts, AITER sparse MLA workaround, and GLM-5-FP8 weight loading fix#38665
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the MoE runner to use is_cuda_alike() for platform compatibility checks and extends Quark quantization support to include the glm_moe_dsa model type, which belongs to the DSA-MoE architecture family. I have no feedback to provide as there are no review comments to evaluate.
Benchmark Status UpdateBenchmarking on MI355X (TP=8) is currently blocked by an upstream AITER bug:
What we verified
Theoretical performance impact
Will update with benchmark numbers once the upstream AITER fix lands. |
|
Hi @ChuanLi1101, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Benchmark Update (follow-up)After implementing a workaround for the The dense MLA backend ( Both bugs are tracked in ROCm/aiter#2563. What was verified
Blocking issueAITER sparse MLA kernels do not support Workaround committedAdded a third commit to the PR: guards |
|
Hi @ChuanLi1101, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
This pull request has merge conflicts that must be resolved before it can be |
b00e823 to
1379888
Compare
Integrate AITER's mla_prefill_fwd assembly kernel into vLLM's dense MLA backend (AiterMLAImpl), replacing the base class's chunk-by-chunk prefill context path. The base class computes context attention by: gathering cached KV per chunk, expanding through kv_b_proj, running flash_attn, then merging results. This is expensive because kv_b_proj expansion materializes full K/V heads for every cached token. The new path instead: 1. Absorbs Q through W_UK^T (from kv_b_proj weights) to produce compressed queries in the latent space 2. Runs mla_prefill_fwd directly against the paged KV cache in a single kernel call (no KV expansion needed) 3. Projects the compressed output through W_UV to recover full V dims This eliminates the O(context_len * num_heads * head_dim) KV expansion and replaces multiple chunk iterations with a single kernel dispatch. Falls back to the base class implementation for FP8 prefill (AITER mla_prefill_fwd does not yet accept q_scale/kv_scale) and when AITER is not available. Part of the GLM-5 on ROCm performance series: - PR vllm-project#38665: dual-stream shared experts + Quark MXFP4 + FP8 weight fix - PR vllm-project#36855: sparse MLA head repeat + FP8 KV cache support - This PR: MLA prefill kernel integration Co-authored-by: Claude Signed-off-by: Chuan Li <Chuan.Li2@amd.com> Made-with: Cursor
|
cc @gshtras @tjtanaa @BowenBao @wufann — Requesting expedited review. This PR is part of a series enabling GLM-5 MXFP4/FP8 inference on MI355X (customer-facing, Tencent is asking). The changes are small and targeted (4 files, 31 lines diff). Specifically fixes a GLM-5-FP8 weight loading KeyError reported during integration testing. See also #36855 and #38947 for the full series. |
| # OCP MX fp4 Quark checkpoints | ||
| _DEEPSEEK_V3_FAMILY_MODEL_TYPES = frozenset({"deepseek_v3"}) | ||
| # OCP MX fp4 Quark checkpoints (DSA-MoE architecture family) | ||
| _DEEPSEEK_V3_FAMILY_MODEL_TYPES = frozenset({"deepseek_v3", "glm_moe_dsa"}) |
There was a problem hiding this comment.
could this change be moved to separate PR? Needs standalone evaluation for effect on accuracy and performance.
There was a problem hiding this comment.
Good point — removed glm_moe_dsa from this PR. I will open a separate PR for the Quark MXFP4 model-type addition with standalone accuracy/performance evaluation.
| self.wk_weights_proj = MergedColumnParallelLinear( | ||
| hidden_size, | ||
| [self.head_dim, self.n_head], | ||
| bias=False, | ||
| quant_config=None, | ||
| quant_config=quant_config, |
There was a problem hiding this comment.
I'm surprised, GLM-5-FP8 doesn't work with vLLM at all without this change?
There was a problem hiding this comment.
Yes — GLM-5-FP8 checkpoints (e.g. zai-org/GLM-5-FP8) quantize the fused wk_weights_proj tensor with FP8 block quantization, producing both weight and weight_scale_inv tensors in the checkpoint. With quant_config=None, MergedColumnParallelLinear does not create the weight_scale_inv parameter in the layer state dict, so weight loading fails with: KeyError: 'model.layers.X.indexer.wk_weights_proj.weight_scale_inv'. This only affects GLM-5-FP8 (and potentially future FP8-quantized models that quantize the indexer fused projection). The original quant_config=None was correct for the initial DeepSeek-V3 use case where this projection was always unquantized, but GLM-5-FP8 checkpoints chose to quantize it.
|
This pull request has merge conflicts that must be resolved before it can be |
|
@BowenBao Re: verifying with #36855 — the changes in this PR (dual-stream shared experts, AITER sparse MLA head<16 workaround, and FP8 weight loading fix) are orthogonal to #36855 (sparse MLA head repeat + FP8 KV cache). They can be verified independently:
That said, for GLM-5 end-to-end inference on MI355X, both PRs are needed together. Once #36855 is also ready, I will run a combined verification and post benchmark results on both PRs. Also updated this PR per your feedback: removed the glm_moe_dsa Quark MXFP4 change — will submit that as a standalone PR with accuracy/performance evaluation. |
…pport Enable dual-stream shared expert overlap on ROCm by using is_cuda_alike() instead of is_cuda() in the MoE forward path. This allows shared experts and routed experts to execute concurrently on separate HIP streams, matching the optimization already available on CUDA. Also add GLM-5 (glm_moe_dsa) to the Quark dynamic MXFP4 model types so that its attention projections use the same dynamic re-quantization path as DeepSeek-V3 family models. Co-authored-by: Claude Signed-off-by: Chuan Li <Chuan.Li2@amd.com> Made-with: Cursor
AITER's deepgemm_fp8_paged_mqa_logits_stage1 kernel computes TileQCount from num_heads; when heads < 16 (e.g. GLM-5 with TP=8 giving 8 heads per GPU), TileQCount becomes 0, causing ZeroDivisionError. Guard both rocm_fp8_paged_mqa_logits and rocm_fp8_mqa_logits to fall back to the PyTorch reference implementation when num_heads < 16, with a one-time warning log. Tracked upstream: ROCm/aiter#2563 Co-authored-by: Claude Made-with: Cursor
Signed-off-by: Li <chuali@amd.com> Co-authored-by: Claude Made-with: Cursor
GLM-5-FP8 checkpoints quantize the fused wk_weights_proj tensor with FP8 block quantization (weight + weight_scale_inv). Resolve merge conflict with upstream indexer refactor (vllm-project#38684/vllm-project#38870) by always using fused MergedColumnParallelLinear with quant_config: - FP4: quant_config=None (weights_proj should not be quantized) - Non-FP4: quant_config=quant_config (supports FP8 weight_scale_inv) Add fallback in load_weights to handle both fused and separate checkpoint formats gracefully via stacked_params_mapping. Also reverts glm_moe_dsa from _DEEPSEEK_V3_FAMILY_MODEL_TYPES per review feedback (will be submitted as a standalone PR). Co-authored-by: Claude Signed-off-by: Chuan Li <Chuan.Li2@amd.com> Made-with: Cursor
1379888 to
e40e395
Compare
| ("wk_weights_proj", "weights_proj", 1), | ||
| ] | ||
| stacked_params_mapping.extend(indexer_fused_mapping) | ||
| # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj). |
There was a problem hiding this comment.
I think this will break the main checkpoint which has different dtypes for these
| ] | ||
| if self.is_fp4_ckpt: | ||
| # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) | ||
| if self.is_v32: |
There was a problem hiding this comment.
Seems like self.is_v32 needs to be defined in the class __init__() function
| "gate_up_proj": ["gate_proj", "up_proj"], | ||
| } |
There was a problem hiding this comment.
Since FP4 checkpoints explicitly exclude quantization using names weights_proj and wk, I believe wk_weights_proj also needs to be added to packed_modules_mapping.
| "wk_weights_proj": ["wk", "weights_proj"], | |
| "gate_up_proj": ["gate_proj", "up_proj"], | |
| } |
Summary
Targeted changes to improve GLM-5 MXFP4/FP8 inference on ROCm (AMD MI355X):
Enable dual-stream MoE shared expert overlap on ROCm: The shared experts stream guard in
SharedExpertsLogicusedcurrent_platform.is_cuda(), restricting dual-stream execution to NVIDIA only. Changed tois_cuda_alike()so ROCm/HIP streams are used as well.Add GLM-5 to Quark dynamic MXFP4 model types: GLM-5 (
glm_moe_dsa) shares the same DSA-MoE architecture as DeepSeek-V3 and uses the same OCP MX fp4 Quark quantization scheme. Added it to_DEEPSEEK_V3_FAMILY_MODEL_TYPESso its attention projections use dynamic MXFP4 re-quantization.Work around AITER sparse MLA ZeroDivisionError for < 16 heads: AITER's
deepgemm_fp8_paged_mqa_logits_stage1kernel computes TileQCount from num_heads; when heads < 16 (e.g. GLM-5 with TP=8 giving 8 heads per GPU), TileQCount becomes 0. Guard bothrocm_fp8_paged_mqa_logitsandrocm_fp8_mqa_logitsto fall back to PyTorch reference when heads < 16. Tracked upstream: ZeroDivisionError in deepgemm_fp8_paged_mqa_logits_stage1 with 8 attention heads (GLM-5 TP=8) ROCm/aiter#2563Fix GLM-5-FP8 weight loading for fused indexer wk_weights_proj: GLM-5-FP8 checkpoints quantize the fused
wk_weights_projtensor with FP8 block quantization (weight+weight_scale_inv). TheIndexer'sMergedColumnParallelLinearwas created withquant_config=None, causingKeyError: '...indexer.wk_weights_proj.weight_scale_inv'at load time. Passquant_configthrough so the quantized parameter structure is created. For unquantized checkpointsquant_configis alreadyNone, so this is a no-op.Context
Reference: amd/GLM-5-MXFP4, zai-org/GLM-5-FP8
The ATOM project (ROCm/atom) achieves high performance on GLM-5 MXFP4 on MI355X partly through dual-stream shared expert execution. This PR ports that optimization to vLLM and fixes GLM-5-FP8 weight loading.
AI assistance (Claude) was used. The submitting human has reviewed all changed lines.
Not duplicating existing PRs: PR #35968 (DeepSeek V3.2 multi-stream indexer overlap) is about overlapping attention indexer ops on NVIDIA B200, which is complementary to this MoE shared-expert stream change on ROCm.
Test plan
--enforce-eagerand verify server startsvllm bench servewith baseline vs this PR and compare output throughputis_cuda_alikeis a superset ofis_cuda)