Support Pipeline and Data Parallelism for MiniMax-M2#33303
Support Pipeline and Data Parallelism for MiniMax-M2#33303rogeryoungh wants to merge 1 commit intovllm-project:mainfrom
Conversation
Signed-off-by: rogeryoungh <rogeryoungh@foxmail.com>
There was a problem hiding this comment.
Code Review
This pull request adds support for Pipeline and Data Parallelism to the MiniMax-M2 model. The changes are logical and well-structured, including updates to the attention mechanism, quantization support, and weight loading for expert parallelism. The model initialization is also correctly adapted for pipeline parallelism. I have one suggestion to improve the robustness of the code by ensuring an attribute is initialized on all pipeline parallel ranks to prevent potential AttributeError.
| else: | ||
| self.lm_head = PPMissingLayer() |
There was a problem hiding this comment.
On non-last pipeline parallel ranks, self.logits_processor is not initialized. While compute_logits is guarded in ModelRunner, it's good practice to initialize all attributes in __init__ to avoid potential AttributeError in other code paths. Other models in vLLM initialize this to PPMissingLayer on non-last ranks.
else:
self.lm_head = PPMissingLayer()
self.logits_processor = PPMissingLayer()Add comprehensive performance analysis for MiniMax-M2.5-REAP-139B-A10B-NVFP4-GB10: Architecture confirmed: - Attention IS NVFP4 in this model (ignore list = only lm_head + MoE gates) - 3 MTP modules present (layers 62-64) — biggest performance lever available - Per-step weight load: ~6.15 GB → 36–44 tok/s theoretical ceiling on GB10 Performance gap analysis: - Current: 24 tok/s on Strix Halo (AMD); GB10 expected similar baseline - vLLM is 1.78x slower than SGLang at BS=1 for NVFP4 MoE (documented gap) - Gap sources: activation quant overhead, kernel launch overhead, no fused shuffle+reduce in MoE, generic CUTLASS configs Key new PRs to integrate: - vllm-project#35041 (OPEN): MTP+NVFP4 weight shape mismatch — required for MTP+NVFP4 - vllm-project#35442 (OPEN): Non-blocking MTP token copy — 6ms→200µs CPU-GPU sync - vllm-project#33303 (OPEN): MiniMax PP+DP for multi-Spark scaling Already-merged PRs confirmed in HEAD: - vllm-project#34718 (act_quant_fusion.py): SiLU+FP4 fusion - vllm-project#34899 (allreduce_rms_fusion.py): NVFP4 AR+Norm fusion - vllm-project#30885: 8x4 SF tiling (not yet effective on GB10 — TRTLLM backend blocked)
Manual apply of PR vllm-project#33303 by rogeryoungh. Enables pipeline parallelism and data/expert parallelism for MiniMax-M2: - Remove unused sliding-window attention params - Propagate quant_config to embed_tokens and lm_head (+ prefix) - Guard LogitsProcessor creation to last PP rank only - Add params_dict existence checks for PP weight loading - Add expert-weight tracking to skip non-local experts during EP/DP - Fix return type annotation on MiniMaxM2DecoderLayer.forward This enables dual-Spark deployment (2x GB10 via InfiniBand) with PP=2, where memory bandwidth doubles for ~1.8x throughput.
Mark PR vllm-project#33303 as applied. Add additional MiniMax-specific PRs: - vllm-project#34863: compressed-tensors FP8 scale propagation - vllm-project#32232: structural_tag support - vllm-project#35358: reasoning-end detection fix
Add comprehensive performance analysis for MiniMax-M2.5-REAP-139B-A10B-NVFP4-GB10: Architecture confirmed: - Attention IS NVFP4 in this model (ignore list = only lm_head + MoE gates) - 3 MTP modules present (layers 62-64) — biggest performance lever available - Per-step weight load: ~6.15 GB → 36–44 tok/s theoretical ceiling on GB10 Performance gap analysis: - Current: 24 tok/s on Strix Halo (AMD); GB10 expected similar baseline - vLLM is 1.78x slower than SGLang at BS=1 for NVFP4 MoE (documented gap) - Gap sources: activation quant overhead, kernel launch overhead, no fused shuffle+reduce in MoE, generic CUTLASS configs Key new PRs to integrate: - vllm-project#35041 (OPEN): MTP+NVFP4 weight shape mismatch — required for MTP+NVFP4 - vllm-project#35442 (OPEN): Non-blocking MTP token copy — 6ms→200µs CPU-GPU sync - vllm-project#33303 (OPEN): MiniMax PP+DP for multi-Spark scaling Already-merged PRs confirmed in HEAD: - vllm-project#34718 (act_quant_fusion.py): SiLU+FP4 fusion - vllm-project#34899 (allreduce_rms_fusion.py): NVFP4 AR+Norm fusion - vllm-project#30885: 8x4 SF tiling (not yet effective on GB10 — TRTLLM backend blocked)
Manual apply of PR vllm-project#33303 by rogeryoungh. Enables pipeline parallelism and data/expert parallelism for MiniMax-M2: - Remove unused sliding-window attention params - Propagate quant_config to embed_tokens and lm_head (+ prefix) - Guard LogitsProcessor creation to last PP rank only - Add params_dict existence checks for PP weight loading - Add expert-weight tracking to skip non-local experts during EP/DP - Fix return type annotation on MiniMaxM2DecoderLayer.forward This enables dual-Spark deployment (2x GB10 via InfiniBand) with PP=2, where memory bandwidth doubles for ~1.8x throughput.
Mark PR vllm-project#33303 as applied. Add additional MiniMax-specific PRs: - vllm-project#34863: compressed-tensors FP8 scale propagation - vllm-project#32232: structural_tag support - vllm-project#35358: reasoning-end detection fix
Manual apply of PR vllm-project#33303 by rogeryoungh. Enables pipeline parallelism and data/expert parallelism for MiniMax-M2: - Remove unused sliding-window attention params - Propagate quant_config to embed_tokens and lm_head (+ prefix) - Guard LogitsProcessor creation to last PP rank only - Add params_dict existence checks for PP weight loading - Add expert-weight tracking to skip non-local experts during EP/DP - Fix return type annotation on MiniMaxM2DecoderLayer.forward This enables dual-Spark deployment (2x GB10 via InfiniBand) with PP=2, where memory bandwidth doubles for ~1.8x throughput.
Manual apply of PR vllm-project#33303 by rogeryoungh. Enables pipeline parallelism and data/expert parallelism for MiniMax-M2: - Remove unused sliding-window attention params - Propagate quant_config to embed_tokens and lm_head (+ prefix) - Guard LogitsProcessor creation to last PP rank only - Add params_dict existence checks for PP weight loading - Add expert-weight tracking to skip non-local experts during EP/DP - Fix return type annotation on MiniMaxM2DecoderLayer.forward This enables dual-Spark deployment (2x GB10 via InfiniBand) with PP=2, where memory bandwidth doubles for ~1.8x throughput.
Purpose
Adds Pipeline Parallelism (PP) and Data Parallelism (DP) for
minimax_m2. Currently, enabling both PP+DP simultaneously results in character encoding issues. However, both PP and DP work correctly when used individually.Test Plan
Test Result
We have validated the correctness of this change on MiniMax-M2.1 with TP2+EP2+DP2, achieving an accuracy of 0.810 on AIME2025, and TP2+EP2+PP2 achieving 0.803 on AIME2025. Offical AIME2025 score was 0.83.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.