Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 96 additions & 1 deletion vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,35 @@
# remove this patch once upstream no longer requires these global symbols or
# provides a backend-safe initialization path.
#
# ** 7. File: platform/patch_minimax_m2_config.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.config.model.ModelConfig._verify_quantization`
# Why:
# MiniMax-M2 fp8 checkpoints on NPU may fail upstream quantization validation.
# vllm-ascend needs to disable fp8 quantization and load bf16 dequantized
# weights in worker-side patches instead.
# How:
# Monkey-patch `_verify_quantization` and intercept platform quantization
# verification to force `cfg.quantization=None` for MiniMax-M2 fp8 on NPU.
# Related PR (if no, explain why):
# No, upstream behavior differs across versions and needs discussion.
# Future Plan:
# Remove this patch once upstream supports MiniMax-M2 fp8 on NPU or provides
# a backend-safe validation / override mechanism.
#
# 2. `vllm.config.model.ModelConfig._verify_cuda_graph`
# Why:
# For MiniMax-M2 on NPU with ACL graph capture enabled, HCCL op expansion
# mode affects graph shape coverage. Users may forget to set it.
# How:
# If user doesn't set it, set `HCCL_OP_EXPANSION_MODE=AIV` for this model
# and log a warning when a different value is detected.
# Related PR (if no, explain why):
# No, this is an environment-specific tuning knob.
# Future Plan:
# Remove this patch if upstream provides an official NPU graph-capture
# guidance / auto-configuration path for HCCL.
#
# * Worker Patch:
# ===============
#
Expand Down Expand Up @@ -333,7 +362,73 @@
# Future Plan:
# Remove this patch when vLLM merges the PR.
#
# ** 17. File: worker/patch_qwen3_5.py**
# ** 17. File: worker/patch_minimax_m2.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.models.minimax_m2.MiniMaxM2MoE.forward`
# Why:
# In TP mode, MiniMax-M2 MoE needs a backend-aware reduction path to avoid
# unnecessary communication / maintain correctness on NPU.
# How:
# Replace the forward to call `experts.maybe_all_reduce_tensor_model_parallel`
# when `tp_size > 1`.
# Related PR (if no, explain why):
# No, model-specific behavior.
# Future Plan:
# Move this behavior upstream once a generic MoE reduce hook exists.
#
# 2. `vllm.model_executor.models.minimax_m2.MiniMaxM2Attention.__init__`
# Why:
# When total kv heads < TP world size, kv head replication happens and k_norm
# weights should be sharded to match the replication layout.
# How:
# Add `num_kv_head_replicas` and create sharded `k_norm` via
# `MiniMaxText01RMSNormTP(..., weight_shard_world_size=total_num_kv_heads, ...)`.
# Related PR (if no, explain why):
# No, depends on Ascend kernel behavior and TP layout.
# Future Plan:
# Remove this patch if upstream implements kv-head-aware norm sharding.
#
# 3. `vllm.model_executor.models.minimax_m2.MiniMaxM2Model.load_weights`
# Why:
# MiniMax-M2 fp8 checkpoints may store fp8 weights with per-block inverse
# scales. On NPU we load bf16 weights by dequantizing at load time.
# How:
# Inject fp8 dequant helpers and wrap `load_weights` to convert fp8 weight +
# `weight_scale_inv` pairs into bf16 blocks before delegating to upstream.
# Related PR (if no, explain why):
# No, fp8 load format and backend constraints are model/backend specific.
# Future Plan:
# Remove this patch when upstream supports MiniMax-M2 fp8 loading on NPU.
#
# ** 18. File: worker/patch_minimax_m2_linear_attn.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.__init__`
# `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.weight_loader`
# Why:
# MiniMax-M2 linear attention RMSNorm needs weight sharding that can follow
# TP layout (and sometimes kv-head replication) on NPU.
# How:
# Override `__init__` to parameterize weight shard world/rank and install a
# sharded `weight_loader` implementation.
# Related PR (if no, explain why):
# No, upstream API surface differs across versions.
# Future Plan:
# Remove this patch when upstream exposes stable sharding hooks for this layer.
#
# 2. `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.forward_qk`
# (or older `_normalize_qk`)
# Why:
# q/k norm for linear attention is performance-sensitive. On NPU, a fused
# rms_norm kernel is faster and TP needs a global rstd correction.
# How:
# Replace q/k normalization with NPU rms_norm fast path and TP-global rstd
# correction; fall back to upstream implementation on non-NPU.
# Related PR (if no, explain why):
# No, backend-specific optimization.
# Future Plan:
# Remove this patch when upstream adds a backend dispatch path for q/k norm.
#
# ** 19. File: worker/patch_qwen3_5.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.models.qwen3_5.Qwen3_5GatedDeltaNet._forward_core`
# Why:
Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/patch/platform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
import vllm_ascend.patch.platform.patch_mamba_config # noqa
import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa
import vllm_ascend.patch.platform.patch_sched_yield # noqa

if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true":
Expand Down
138 changes: 138 additions & 0 deletions vllm_ascend/patch/platform/patch_minimax_m2_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Patch target: vllm/config/model.py
# - MiniMax-M2 fp8 checkpoint on NPU: disable fp8 quantization (load bf16
# dequantized weights in worker patch) instead of failing validation.
# - For ACL graph capture, set HCCL_OP_EXPANSION_MODE=AIV if user didn't set it.
#

import os

from vllm.config.model import ModelConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

_original_verify_quantization = getattr(ModelConfig, "_verify_quantization", None)
_original_verify_cuda_graph = getattr(ModelConfig, "_verify_cuda_graph", None)

_DISABLE_FP8_LOG = (
"Detected fp8 MiniMax-M2 checkpoint on NPU. "
"Disabling fp8 quantization and loading dequantized bf16 "
"weights instead."
)


def _get_model_type(cfg: ModelConfig) -> str | None:
# vLLM config fields have changed across versions; try multiple sources.
model_arch_cfg = getattr(cfg, "model_arch_config", None)
if model_arch_cfg is not None:
mt = getattr(model_arch_cfg, "model_type", None)
if mt:
return mt

hf_text_cfg = getattr(cfg, "hf_text_config", None)
if hf_text_cfg is not None:
mt = getattr(hf_text_cfg, "model_type", None)
if mt:
return mt

hf_cfg = getattr(cfg, "hf_config", None)
if hf_cfg is not None:
mt = getattr(hf_cfg, "model_type", None)
if mt:
return mt

return getattr(cfg, "model_type", None)


def _should_disable_fp8(cfg: ModelConfig, quant_method: str | None) -> bool:
return current_platform.device_name == "npu" and _get_model_type(cfg) == "minimax_m2" and quant_method == "fp8"


def _disable_fp8(cfg: ModelConfig, *, log: bool) -> bool:
if not _should_disable_fp8(cfg, getattr(cfg, "quantization", None)):
return False
if log:
logger.warning(_DISABLE_FP8_LOG)
cfg.quantization = None
return True


def _patched_verify_quantization(self: ModelConfig) -> None:
"""Inject mid-function behavior for ModelConfig._verify_quantization.

Upstream validates quantization inside this method via:
current_platform.verify_quantization(self.quantization)

We emulate a mid-function patch without copying upstream code by temporarily
overriding current_platform.verify_quantization while the original verifier
executes.
"""
assert _original_verify_quantization is not None

orig_platform_verify = getattr(current_platform, "verify_quantization", None)

def _platform_verify_hook(quant_method: str | None) -> None:
if _should_disable_fp8(self, quant_method):
# This is the effective "middle of _verify_quantization" interception.
_disable_fp8(self, log=True)
return
assert orig_platform_verify is not None
return orig_platform_verify(quant_method)

# Some versions may read self.quantization before calling platform verifier.
_disable_fp8(self, log=True)
Comment on lines +98 to +99
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a logical conflict in _patched_verify_quantization that makes the patch's behavior unclear and potentially incorrect. The call to _disable_fp8 on line 103 nullifies self.quantization before _original_verify_quantization is called. This is justified by the comment on line 102 for some vLLM versions.

However, this directly conflicts with the _platform_verify_hook mechanism. The hook will receive None as quant_method, causing the check _should_disable_fp8(self, quant_method) to always fail. This makes the core interception logic within the hook (lines 95-98) unreachable.

This implementation cannot simultaneously handle versions that require pre-emptive nullification of self.quantization and versions that rely on the hook to intercept the quantization method. Please clarify the intended logic and resolve this conflict. One approach might be to remove line 103 if the hook-based interception is the primary mechanism desired.


try:
if orig_platform_verify is not None:
current_platform.verify_quantization = _platform_verify_hook
return _original_verify_quantization(self)
finally:
if orig_platform_verify is not None:
current_platform.verify_quantization = orig_platform_verify
# Ensure fp8 isn't restored by upstream logic.
_disable_fp8(self, log=False)


def _patched_verify_cuda_graph(self: ModelConfig) -> None:
assert _original_verify_cuda_graph is not None

if (
current_platform.device_name == "npu"
and _get_model_type(self) == "minimax_m2"
and not getattr(self, "enforce_eager", True)
):
expansion_mode = os.environ.get("HCCL_OP_EXPANSION_MODE")
if expansion_mode is None:
os.environ["HCCL_OP_EXPANSION_MODE"] = "AIV"
logger.info("Set HCCL_OP_EXPANSION_MODE=AIV for MiniMax-M2 ACL graph capture on NPU.")
elif expansion_mode != "AIV":
logger.warning(
"HCCL_OP_EXPANSION_MODE=%s may reduce ACL graph shape "
"coverage for MiniMax-M2 on NPU. Recommended value: AIV.",
expansion_mode,
)

return _original_verify_cuda_graph(self)


if _original_verify_quantization is not None:
ModelConfig._verify_quantization = _patched_verify_quantization

if _original_verify_cuda_graph is not None:
ModelConfig._verify_cuda_graph = _patched_verify_cuda_graph
2 changes: 2 additions & 0 deletions vllm_ascend/patch/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import vllm_ascend.patch.worker.patch_unquantized_gemm # noqa
import vllm_ascend.patch.worker.patch_bert # noqa
import vllm_ascend.patch.worker.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_minimax_m2 # noqa
import vllm_ascend.patch.worker.patch_minimax_m2_linear_attn # noqa
import vllm_ascend.patch.worker.patch_multimodal_merge # noqa
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
Expand Down
Loading
Loading