diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 4920b1f7aef..d0d8984616a 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -108,6 +108,35 @@ # remove this patch once upstream no longer requires these global symbols or # provides a backend-safe initialization path. # +# ** 7. File: platform/patch_minimax_m2_config.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.config.model.ModelConfig._verify_quantization` +# Why: +# MiniMax-M2 fp8 checkpoints on NPU may fail upstream quantization validation. +# vllm-ascend needs to disable fp8 quantization and load bf16 dequantized +# weights in worker-side patches instead. +# How: +# Monkey-patch `_verify_quantization` and intercept platform quantization +# verification to force `cfg.quantization=None` for MiniMax-M2 fp8 on NPU. +# Related PR (if no, explain why): +# No, upstream behavior differs across versions and needs discussion. +# Future Plan: +# Remove this patch once upstream supports MiniMax-M2 fp8 on NPU or provides +# a backend-safe validation / override mechanism. +# +# 2. `vllm.config.model.ModelConfig._verify_cuda_graph` +# Why: +# For MiniMax-M2 on NPU with ACL graph capture enabled, HCCL op expansion +# mode affects graph shape coverage. Users may forget to set it. +# How: +# If user doesn't set it, set `HCCL_OP_EXPANSION_MODE=AIV` for this model +# and log a warning when a different value is detected. +# Related PR (if no, explain why): +# No, this is an environment-specific tuning knob. +# Future Plan: +# Remove this patch if upstream provides an official NPU graph-capture +# guidance / auto-configuration path for HCCL. +# # * Worker Patch: # =============== # @@ -333,7 +362,73 @@ # Future Plan: # Remove this patch when vLLM merges the PR. # -# ** 17. File: worker/patch_qwen3_5.py** +# ** 17. File: worker/patch_minimax_m2.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.minimax_m2.MiniMaxM2MoE.forward` +# Why: +# In TP mode, MiniMax-M2 MoE needs a backend-aware reduction path to avoid +# unnecessary communication / maintain correctness on NPU. +# How: +# Replace the forward to call `experts.maybe_all_reduce_tensor_model_parallel` +# when `tp_size > 1`. +# Related PR (if no, explain why): +# No, model-specific behavior. +# Future Plan: +# Move this behavior upstream once a generic MoE reduce hook exists. +# +# 2. `vllm.model_executor.models.minimax_m2.MiniMaxM2Attention.__init__` +# Why: +# When total kv heads < TP world size, kv head replication happens and k_norm +# weights should be sharded to match the replication layout. +# How: +# Add `num_kv_head_replicas` and create sharded `k_norm` via +# `MiniMaxText01RMSNormTP(..., weight_shard_world_size=total_num_kv_heads, ...)`. +# Related PR (if no, explain why): +# No, depends on Ascend kernel behavior and TP layout. +# Future Plan: +# Remove this patch if upstream implements kv-head-aware norm sharding. +# +# 3. `vllm.model_executor.models.minimax_m2.MiniMaxM2Model.load_weights` +# Why: +# MiniMax-M2 fp8 checkpoints may store fp8 weights with per-block inverse +# scales. On NPU we load bf16 weights by dequantizing at load time. +# How: +# Inject fp8 dequant helpers and wrap `load_weights` to convert fp8 weight + +# `weight_scale_inv` pairs into bf16 blocks before delegating to upstream. +# Related PR (if no, explain why): +# No, fp8 load format and backend constraints are model/backend specific. +# Future Plan: +# Remove this patch when upstream supports MiniMax-M2 fp8 loading on NPU. +# +# ** 18. File: worker/patch_minimax_m2_linear_attn.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.__init__` +# `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.weight_loader` +# Why: +# MiniMax-M2 linear attention RMSNorm needs weight sharding that can follow +# TP layout (and sometimes kv-head replication) on NPU. +# How: +# Override `__init__` to parameterize weight shard world/rank and install a +# sharded `weight_loader` implementation. +# Related PR (if no, explain why): +# No, upstream API surface differs across versions. +# Future Plan: +# Remove this patch when upstream exposes stable sharding hooks for this layer. +# +# 2. `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.forward_qk` +# (or older `_normalize_qk`) +# Why: +# q/k norm for linear attention is performance-sensitive. On NPU, a fused +# rms_norm kernel is faster and TP needs a global rstd correction. +# How: +# Replace q/k normalization with NPU rms_norm fast path and TP-global rstd +# correction; fall back to upstream implementation on non-NPU. +# Related PR (if no, explain why): +# No, backend-specific optimization. +# Future Plan: +# Remove this patch when upstream adds a backend dispatch path for q/k norm. +# +# ** 19. File: worker/patch_qwen3_5.py** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.model_executor.models.qwen3_5.Qwen3_5GatedDeltaNet._forward_core` # Why: diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index d6397eace9d..4b8fc9d256f 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -19,6 +19,7 @@ import vllm_ascend.patch.platform.patch_distributed # noqa import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa import vllm_ascend.patch.platform.patch_mamba_config # noqa +import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true": diff --git a/vllm_ascend/patch/platform/patch_minimax_m2_config.py b/vllm_ascend/patch/platform/patch_minimax_m2_config.py new file mode 100644 index 00000000000..5c4ff99b30e --- /dev/null +++ b/vllm_ascend/patch/platform/patch_minimax_m2_config.py @@ -0,0 +1,138 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Patch target: vllm/config/model.py +# - MiniMax-M2 fp8 checkpoint on NPU: disable fp8 quantization (load bf16 +# dequantized weights in worker patch) instead of failing validation. +# - For ACL graph capture, set HCCL_OP_EXPANSION_MODE=AIV if user didn't set it. +# + +import os + +from vllm.config.model import ModelConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +_original_verify_quantization = getattr(ModelConfig, "_verify_quantization", None) +_original_verify_cuda_graph = getattr(ModelConfig, "_verify_cuda_graph", None) + +_DISABLE_FP8_LOG = ( + "Detected fp8 MiniMax-M2 checkpoint on NPU. " + "Disabling fp8 quantization and loading dequantized bf16 " + "weights instead." +) + + +def _get_model_type(cfg: ModelConfig) -> str | None: + # vLLM config fields have changed across versions; try multiple sources. + model_arch_cfg = getattr(cfg, "model_arch_config", None) + if model_arch_cfg is not None: + mt = getattr(model_arch_cfg, "model_type", None) + if mt: + return mt + + hf_text_cfg = getattr(cfg, "hf_text_config", None) + if hf_text_cfg is not None: + mt = getattr(hf_text_cfg, "model_type", None) + if mt: + return mt + + hf_cfg = getattr(cfg, "hf_config", None) + if hf_cfg is not None: + mt = getattr(hf_cfg, "model_type", None) + if mt: + return mt + + return getattr(cfg, "model_type", None) + + +def _should_disable_fp8(cfg: ModelConfig, quant_method: str | None) -> bool: + return current_platform.device_name == "npu" and _get_model_type(cfg) == "minimax_m2" and quant_method == "fp8" + + +def _disable_fp8(cfg: ModelConfig, *, log: bool) -> bool: + if not _should_disable_fp8(cfg, getattr(cfg, "quantization", None)): + return False + if log: + logger.warning(_DISABLE_FP8_LOG) + cfg.quantization = None + return True + + +def _patched_verify_quantization(self: ModelConfig) -> None: + """Inject mid-function behavior for ModelConfig._verify_quantization. + + Upstream validates quantization inside this method via: + current_platform.verify_quantization(self.quantization) + + We emulate a mid-function patch without copying upstream code by temporarily + overriding current_platform.verify_quantization while the original verifier + executes. + """ + assert _original_verify_quantization is not None + + orig_platform_verify = getattr(current_platform, "verify_quantization", None) + + def _platform_verify_hook(quant_method: str | None) -> None: + if _should_disable_fp8(self, quant_method): + # This is the effective "middle of _verify_quantization" interception. + _disable_fp8(self, log=True) + return + assert orig_platform_verify is not None + return orig_platform_verify(quant_method) + + # Some versions may read self.quantization before calling platform verifier. + _disable_fp8(self, log=True) + + try: + if orig_platform_verify is not None: + current_platform.verify_quantization = _platform_verify_hook + return _original_verify_quantization(self) + finally: + if orig_platform_verify is not None: + current_platform.verify_quantization = orig_platform_verify + # Ensure fp8 isn't restored by upstream logic. + _disable_fp8(self, log=False) + + +def _patched_verify_cuda_graph(self: ModelConfig) -> None: + assert _original_verify_cuda_graph is not None + + if ( + current_platform.device_name == "npu" + and _get_model_type(self) == "minimax_m2" + and not getattr(self, "enforce_eager", True) + ): + expansion_mode = os.environ.get("HCCL_OP_EXPANSION_MODE") + if expansion_mode is None: + os.environ["HCCL_OP_EXPANSION_MODE"] = "AIV" + logger.info("Set HCCL_OP_EXPANSION_MODE=AIV for MiniMax-M2 ACL graph capture on NPU.") + elif expansion_mode != "AIV": + logger.warning( + "HCCL_OP_EXPANSION_MODE=%s may reduce ACL graph shape " + "coverage for MiniMax-M2 on NPU. Recommended value: AIV.", + expansion_mode, + ) + + return _original_verify_cuda_graph(self) + + +if _original_verify_quantization is not None: + ModelConfig._verify_quantization = _patched_verify_quantization + +if _original_verify_cuda_graph is not None: + ModelConfig._verify_cuda_graph = _patched_verify_cuda_graph diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 5f2d9a2afbd..2493ecb2238 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -30,6 +30,8 @@ import vllm_ascend.patch.worker.patch_unquantized_gemm # noqa import vllm_ascend.patch.worker.patch_bert # noqa import vllm_ascend.patch.worker.patch_distributed # noqa +import vllm_ascend.patch.worker.patch_minimax_m2 # noqa +import vllm_ascend.patch.worker.patch_minimax_m2_linear_attn # noqa import vllm_ascend.patch.worker.patch_multimodal_merge # noqa import vllm_ascend.patch.worker.patch_qwen3_next # noqa import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa diff --git a/vllm_ascend/patch/worker/patch_minimax_m2.py b/vllm_ascend/patch/worker/patch_minimax_m2.py new file mode 100644 index 00000000000..bff94bedb74 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_minimax_m2.py @@ -0,0 +1,174 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MiniMax-M2 on Ascend: MoE all_reduce, k_norm weight sharding, fp8 load dequant. +# + +from collections.abc import Iterable + +import torch +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP +from vllm.model_executor.models.minimax_m2 import MiniMaxM2Attention, MiniMaxM2Model, MiniMaxM2MoE +from vllm.platforms import current_platform + +FP8_DTYPES = tuple( + getattr(torch, dtype_name) + for dtype_name in ( + "float8_e4m3fn", + "float8_e4m3fnuz", + "float8_e5m2", + "float8_e5m2fnuz", + "float8_e8m0fnu", + ) + if hasattr(torch, dtype_name) +) + + +# --------------------------------------------------------------------------- +# MiniMaxM2MoE.forward: use maybe_all_reduce_tensor_model_parallel +# --------------------------------------------------------------------------- +def _patched_moe_forward( + self, + hidden_states: torch.Tensor, +) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states.to(torch.float32)) + final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_dim) + + +MiniMaxM2MoE.forward = _patched_moe_forward + + +# --------------------------------------------------------------------------- +# MiniMaxM2Attention: num_kv_head_replicas and k_norm weight sharding +# --------------------------------------------------------------------------- +_original_attention_init = MiniMaxM2Attention.__init__ + + +def _patched_attention_init(self, *args, **kwargs) -> None: + _original_attention_init(self, *args, **kwargs) + tp_size = get_tensor_model_parallel_world_size() + self.num_kv_head_replicas = max(1, tp_size // self.total_num_kv_heads) + if self.total_num_kv_heads < tp_size: + rms_norm_eps = getattr(getattr(self, "q_norm", None), "variance_epsilon", 1e-6) + self.k_norm = MiniMaxText01RMSNormTP( + self.head_dim * self.total_num_kv_heads, + eps=rms_norm_eps, + weight_shard_world_size=self.total_num_kv_heads, + weight_shard_rank=get_tensor_model_parallel_rank() // self.num_kv_head_replicas, + ) + + +MiniMaxM2Attention.__init__ = _patched_attention_init + + +# --------------------------------------------------------------------------- +# MiniMaxM2Model: fp8 dequant helpers and load_weights wrapper +# --------------------------------------------------------------------------- +def _need_dequantize_fp8_weights(self) -> bool: + quant_cfg = getattr(self.config, "quantization_config", None) + return ( + isinstance(quant_cfg, dict) and quant_cfg.get("quant_method") == "fp8" and current_platform.device_name == "npu" + ) + + +def _dequantize_fp8_block_weight( + fp8_weight: torch.Tensor, + weight_scale_inv: torch.Tensor, + block_size: tuple[int, int], +) -> torch.Tensor: + block_n, block_k = block_size + n, k = fp8_weight.shape + n_tiles = (n + block_n - 1) // block_n + k_tiles = (k + block_k - 1) // block_k + if tuple(weight_scale_inv.shape) != (n_tiles, k_tiles): + raise ValueError( + "Unexpected fp8 scale shape: " + f"weight={tuple(fp8_weight.shape)}, " + f"scale={tuple(weight_scale_inv.shape)}, " + f"block_size={block_size}" + ) + expanded_scale = weight_scale_inv.repeat_interleave(block_n, dim=0).repeat_interleave(block_k, dim=1) + expanded_scale = expanded_scale[:n, :k].to(dtype=torch.bfloat16) + return fp8_weight.to(dtype=torch.bfloat16) * expanded_scale + + +def _fp8_dequant_weight_iter( + self: "MiniMaxM2Model", + weights: Iterable[tuple[str, torch.Tensor]], +) -> Iterable[tuple[str, torch.Tensor]]: + quant_cfg = getattr(self.config, "quantization_config", {}) + block_cfg = quant_cfg.get("weight_block_size", [128, 128]) + weight_block_size: tuple[int, int] = (128, 128) + if isinstance(block_cfg, list) and len(block_cfg) == 2: + weight_block_size = (int(block_cfg[0]), int(block_cfg[1])) + + pending_fp8_weights: dict[str, torch.Tensor] = {} + pending_fp8_scales: dict[str, torch.Tensor] = {} + + for name, loaded_weight in weights: + if name.endswith(".weight_scale_inv"): + paired_weight_name = name[: -len("_scale_inv")] + pending_weight = pending_fp8_weights.pop(paired_weight_name, None) + if pending_weight is None: + pending_fp8_scales[name] = loaded_weight + continue + loaded_weight = self._dequantize_fp8_block_weight(pending_weight, loaded_weight, weight_block_size) + name = paired_weight_name + elif loaded_weight.dtype in FP8_DTYPES and name.endswith(".weight"): + scale_name = f"{name}_scale_inv" + pending_scale = pending_fp8_scales.pop(scale_name, None) + if pending_scale is None: + pending_fp8_weights[name] = loaded_weight + continue + loaded_weight = self._dequantize_fp8_block_weight(loaded_weight, pending_scale, weight_block_size) + yield name, loaded_weight + + if pending_fp8_weights or pending_fp8_scales: + raise ValueError( + "Unpaired fp8 MiniMax-M2 weight/scale tensors detected: " + f"pending_weights={len(pending_fp8_weights)}, " + f"pending_scales={len(pending_fp8_scales)}" + ) + + +MiniMaxM2Model._need_dequantize_fp8_weights = _need_dequantize_fp8_weights +MiniMaxM2Model._dequantize_fp8_block_weight = staticmethod(_dequantize_fp8_block_weight) +MiniMaxM2Model._fp8_dequant_weight_iter = _fp8_dequant_weight_iter + +_original_load_weights = MiniMaxM2Model.load_weights + + +def _patched_load_weights( + self: "MiniMaxM2Model", + weights: Iterable[tuple[str, torch.Tensor]], +) -> set[str]: + if self._need_dequantize_fp8_weights(): + weights = self._fp8_dequant_weight_iter(weights) + return _original_load_weights(self, weights) + + +MiniMaxM2Model.load_weights = _patched_load_weights diff --git a/vllm_ascend/patch/worker/patch_minimax_m2_linear_attn.py b/vllm_ascend/patch/worker/patch_minimax_m2_linear_attn.py new file mode 100644 index 00000000000..71c55af5397 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_minimax_m2_linear_attn.py @@ -0,0 +1,145 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MiniMax-M2 linear attention: MiniMaxText01RMSNormTP weight sharding and NPU q/k norm path. +# + +from functools import partial + +import torch +import torch.nn as nn +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.mamba.linear_attn import ( + CustomOp, + MiniMaxText01RMSNormTP, +) +from vllm.platforms import current_platform + +_ORIG_QK_METHOD_NAME: str | None = None +_original_qk_method = None +_qk_is_staticmethod = False + +if hasattr(MiniMaxText01RMSNormTP, "forward_qk"): + _ORIG_QK_METHOD_NAME = "forward_qk" + _original_qk_method = getattr(MiniMaxText01RMSNormTP, _ORIG_QK_METHOD_NAME) +elif hasattr(MiniMaxText01RMSNormTP, "_normalize_qk"): + # Older vLLM versions + _ORIG_QK_METHOD_NAME = "_normalize_qk" + _original_qk_method = getattr(MiniMaxText01RMSNormTP, _ORIG_QK_METHOD_NAME) + +if _ORIG_QK_METHOD_NAME is not None: + # Detect whether upstream defined it as a staticmethod (some versions do). + _orig_desc = MiniMaxText01RMSNormTP.__dict__.get(_ORIG_QK_METHOD_NAME) + _qk_is_staticmethod = isinstance(_orig_desc, staticmethod) + + +def _patched_qk( + q_norm: "MiniMaxText01RMSNormTP", + k_norm: "MiniMaxText01RMSNormTP", + q: torch.Tensor, + k: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + # NPU fast path: kernelized local RMSNorm for q/k, then TP-global rstd correction. + if current_platform.device_name == "npu": + q, q_inv_rms = torch.ops.npu.npu_rms_norm(q, q_norm.weight, q_norm.variance_epsilon) + k, k_inv_rms = torch.ops.npu.npu_rms_norm(k, k_norm.weight, k_norm.variance_epsilon) + + if q_norm.tp_world > 1: + q_local_inv_rms = q_inv_rms.to(torch.float32) + if q_local_inv_rms.shape[-1] != 1: + q_local_inv_rms = q_local_inv_rms.mean(dim=-1, keepdim=True) + q_local_var = (q_local_inv_rms.reciprocal().pow(2) - q_norm.variance_epsilon).clamp_min_(0.0) + + k_local_inv_rms = k_inv_rms.to(torch.float32) + if k_local_inv_rms.shape[-1] != 1: + k_local_inv_rms = k_local_inv_rms.mean(dim=-1, keepdim=True) + k_local_var = (k_local_inv_rms.reciprocal().pow(2) - k_norm.variance_epsilon).clamp_min_(0.0) + + qk_var = torch.cat([q_local_var, k_local_var], dim=-1) + qk_var = tensor_model_parallel_all_reduce(qk_var) / q_norm.tp_world + q_global_var, k_global_var = qk_var.chunk(2, dim=-1) + + q_local_rstd = torch.rsqrt(q_local_var + q_norm.variance_epsilon) + k_local_rstd = torch.rsqrt(k_local_var + k_norm.variance_epsilon) + q_global_rstd = torch.rsqrt(q_global_var + q_norm.variance_epsilon) + k_global_rstd = torch.rsqrt(k_global_var + k_norm.variance_epsilon) + + q = q * (q_global_rstd / q_local_rstd).to(q.dtype) + k = k * (k_global_rstd / k_local_rstd).to(k.dtype) + + return q, k + + assert _original_qk_method is not None + # We install the patch as a staticmethod below, so prefer the static calling + # convention for the original as well. + return _original_qk_method(q_norm, k_norm, q, k) + + +def _patched_weight_loader( + param: nn.Parameter, + loaded_weight: torch.Tensor, + shard_world_size: int | None = None, + shard_rank: int | None = None, +) -> None: + if shard_world_size is None: + shard_world_size = get_tensor_model_parallel_world_size() + if shard_rank is None: + shard_rank = get_tensor_model_parallel_rank() + shard_size = loaded_weight.shape[0] // shard_world_size + shard = slice(shard_rank * shard_size, (shard_rank + 1) * shard_size) + param.data.copy_(loaded_weight[shard]) + + +def _patched_init( + self: "MiniMaxText01RMSNormTP", + hidden_size: int, + eps: float = 1e-6, + *, + weight_shard_world_size: int | None = None, + weight_shard_rank: int | None = None, +) -> None: + CustomOp.__init__(self) + self.tp_world = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.weight_shard_world = weight_shard_world_size or self.tp_world + self.weight_shard_rank = self.tp_rank if weight_shard_rank is None else weight_shard_rank + + if hidden_size % self.weight_shard_world != 0: + raise ValueError( + "MiniMaxText01RMSNormTP hidden_size must be divisible by " + f"weight_shard_world_size, got hidden_size={hidden_size}, " + f"weight_shard_world_size={self.weight_shard_world}" + ) + + self.weight = nn.Parameter(torch.ones(int(hidden_size / self.weight_shard_world))) + self.weight.weight_loader = partial( + _patched_weight_loader, + shard_world_size=self.weight_shard_world, + shard_rank=self.weight_shard_rank, + ) + self.variance_epsilon = eps + + +MiniMaxText01RMSNormTP.__init__ = _patched_init +MiniMaxText01RMSNormTP.weight_loader = staticmethod(_patched_weight_loader) + +if _ORIG_QK_METHOD_NAME is not None: + # Force staticmethod style, as requested. + setattr(MiniMaxText01RMSNormTP, _ORIG_QK_METHOD_NAME, staticmethod(_patched_qk))