From 85cf592f9b305a5cb6a9a725af963e83711ffef2 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 19 Apr 2026 15:44:47 +0000 Subject: [PATCH 01/24] Init Signed-off-by: Jee Jee Li --- vllm/lora/layers/fused_moe.py | 332 ++---------------- .../layers/fused_moe/fused_marlin_moe.py | 114 +++++- .../layers/fused_moe/fused_moe.py | 59 +++- .../layers/fused_moe/fused_moe_method_base.py | 5 + .../fused_moe/fused_moe_modular_method.py | 7 + .../fused_moe/gpt_oss_triton_kernels_moe.py | 61 +++- vllm/model_executor/layers/fused_moe/layer.py | 11 +- .../layers/fused_moe/lora_context.py | 70 ++++ .../layers/fused_moe/modular_kernel.py | 281 ++++++++++++++- .../layers/fused_moe/oracle/unquantized.py | 36 +- .../fused_moe/runner/moe_runner_base.py | 1 + .../fused_moe/unquantized_fused_moe_method.py | 18 +- .../layers/quantization/awq_marlin.py | 2 + .../layers/quantization/bitsandbytes.py | 6 +- .../layers/quantization/experts_int8.py | 6 +- .../model_executor/layers/quantization/fp8.py | 2 + .../layers/quantization/gguf.py | 2 + .../layers/quantization/gptq_marlin.py | 6 +- .../layers/quantization/modelopt.py | 4 + .../layers/quantization/moe_wna16.py | 6 +- .../layers/quantization/mxfp4.py | 7 + .../layers/quantization/online/fp8.py | 2 + .../layers/quantization/quark/quark_moe.py | 9 +- 23 files changed, 735 insertions(+), 312 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/lora_context.py diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 01efe3e47310..dc969c71d02a 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools import torch import torch.nn as nn @@ -14,31 +13,20 @@ ) from vllm.distributed.utils import divide from vllm.lora.layers.base import BaseLayerWithLoRA -from vllm.lora.ops.triton_ops.utils import get_lora_op_configs from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.fused_moe.config import ( - _get_config_dtype_str, -) -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - MarlinExperts, -) -from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, -) from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( FusedMoEModularMethod, ) -from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( - UnfusedOAITritonExperts, -) +from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEExpertsModular, FusedMoEKernel, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoDPEPModular, ) -from .utils import _get_lora_device, try_get_optimal_moe_lora_config +from .utils import _get_lora_device class FusedMoEWithLoRA(BaseLayerWithLoRA): @@ -58,299 +46,51 @@ def __init__(self, base_layer: FusedMoE) -> None: # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3) self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1 - self._inject_lora_into_fused_moe() - - def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]: - normalized_config = {} - for key, value in config.items(): - if key.islower(): - if key.startswith("block_"): - normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper() - else: - normalized_key = key.upper() - else: - normalized_key = key - normalized_config[normalized_key] = value - return normalized_config - - def _get_lora_moe_configs( - self, - op_prefix: str, - num_loras: int, - rank: int, - num_slices: int, - M: int, - layer: FusedMoE, - top_k: int, - config_dtype: str, - ): - if envs.VLLM_TUNED_CONFIG_FOLDER: - hidden_size = layer.hidden_size - intermediate_size = ( - self.w2_lora_a_stacked[0].shape[-1] - if op_prefix == "w2" - else self.w13_lora_b_stacked[0].shape[-2] - ) - shrink_config = get_lora_op_configs( - op_type=f"fused_moe_lora_{op_prefix}_shrink", - max_loras=num_loras, - batch=M, - hidden_size=hidden_size, - rank=rank, - num_slices=num_slices, - moe_intermediate_size=intermediate_size, - ) - expand_config = get_lora_op_configs( - op_type=f"fused_moe_lora_{op_prefix}_expand", - max_loras=num_loras, - batch=M, - hidden_size=hidden_size, # lora_a_stacked.shape[-1], - rank=rank, - num_slices=num_slices, - moe_intermediate_size=intermediate_size, # lora_b_stacked.shape[-2], - ) - else: # fall back to the default config - get_config_func = functools.partial( - try_get_optimal_moe_lora_config, - w1_shape=layer.w13_weight.shape, - w2_shape=layer.w2_weight.shape, - rank=rank, - top_k=top_k, - dtype=config_dtype, - M=M, - block_shape=layer.quant_method.moe_quant_config.block_shape, - ) - shrink_config = get_config_func( - op_type=f"fused_moe_lora_{op_prefix}_shrink" - ) - expand_config = get_config_func( - op_type=f"fused_moe_lora_{op_prefix}_expand" - ) - shrink_config = self._normalize_keys(shrink_config) - expand_config = self._normalize_keys(expand_config) - return shrink_config, expand_config - - def _inject_lora_into_fused_moe(self): - moe_state_dict = {} - top_k = self.base_layer.top_k self.base_layer.ensure_moe_quant_config_init() - quant_config = self.base_layer.quant_method.moe_quant_config - if getattr(self.base_layer.quant_method, "supports_internal_mk", False): - # Use the existing modular kernel from the quant method - m_fused_moe_fn = self.base_layer.quant_method.moe_kernel + moe_kernel = self.base_layer.quant_method.moe_kernel # Don't let the kernel own shared experts so the runner can # overlap them with routed experts via a separate CUDA stream. - m_fused_moe_fn.shared_experts = None + moe_kernel.shared_experts = None else: - # Create a new modular kernel via select_gemm_impl. - # Don't pass shared_experts to the kernel so the runner can - # overlap them with routed experts via a separate CUDA stream. prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular() - m_fused_moe_fn = FusedMoEKernel( + moe_kernel = FusedMoEKernel( prepare_finalize, self.base_layer.quant_method.select_gemm_impl( prepare_finalize, self.base_layer ), ) - - if quant_config.use_mxfp4_w4a16: - assert isinstance( - m_fused_moe_fn.impl.fused_experts, - (MarlinExperts, UnfusedOAITritonExperts), - ) - else: - assert isinstance(m_fused_moe_fn.impl.fused_experts, TritonExperts) - - def fwd_decorator(layer, func): - def wrapper(*args, **kwargs): - moe_state_dict["hidden_states"] = kwargs["hidden_states"] - moe_state_dict["topk_ids"] = kwargs["topk_ids"] - moe_state_dict["topk_weights"] = kwargs["topk_weights"] - moe_state_dict["expert_map"] = kwargs["expert_map"] - moe_state_dict["apply_router_weight_on_input"] = kwargs[ - "apply_router_weight_on_input" - ] - result = func(*args, **kwargs) - return result - - return wrapper - - def act_decorator(layer, func): - def wrapper(*args, **kwargs): - _, output, input = args - - hidden_states = moe_state_dict["hidden_states"] - topk_weights = moe_state_dict["topk_weights"] - curr_topk_ids = moe_state_dict["topk_ids"] - - expert_map = moe_state_dict["expert_map"] - - config_dtype = _get_config_dtype_str( - dtype=hidden_states.dtype, - use_fp8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - ) - num_tokens = hidden_states.size(0) - M = num_tokens - max_lora_rank = self.w13_lora_a_stacked[0].shape[-2] - shrink_config, expand_config = self._get_lora_moe_configs( - op_prefix="w13", - num_loras=self.max_loras, - rank=max_lora_rank, - num_slices=self._w13_slices, - M=M, - layer=layer, - top_k=top_k, - config_dtype=config_dtype, - ) - - # SPARSITY_FACTOR is a heuristic margin ensuring tokens * top_k - # activates only a small fraction of total experts * loras. - SPARSITY_FACTOR = 8 - naive_block_assignment = ( - expert_map is None - and num_tokens * top_k * SPARSITY_FACTOR - <= self.base_layer.local_num_experts * self.max_loras - ) - - # get the block size of m from customized config or default config - ( - token_lora_mapping, - sorted_token_ids_lora, - expert_ids_lora, - num_tokens_post_padded_lora, - ) = self.punica_wrapper.moe_lora_align_block_size( - curr_topk_ids, - num_tokens, - shrink_config["BLOCK_SIZE_M"], - self.base_layer.local_num_experts, - self.max_loras, - self.adapter_enabled, - expert_map, - naive_block_assignment=naive_block_assignment, - ) - - moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora - moe_state_dict["expert_ids_lora"] = expert_ids_lora - moe_state_dict["num_tokens_post_padded_lora"] = ( - num_tokens_post_padded_lora - ) - moe_state_dict["token_lora_mapping"] = token_lora_mapping - - if sorted_token_ids_lora is not None: - expert_ids_lora = expert_ids_lora.view(self.max_loras, -1) - sorted_token_ids_lora = sorted_token_ids_lora.view( - self.max_loras, -1 - ) - # - - self.punica_wrapper.add_lora_fused_moe( - input.view(-1, top_k, input.shape[-1]), - hidden_states, - self.w13_lora_a_stacked, - self.w13_lora_b_stacked, - topk_weights, - sorted_token_ids_lora, - expert_ids_lora, - num_tokens_post_padded_lora, - max_lora_rank, - top_k, - shrink_config, ## pass the shrink config - expand_config, ## pass the expand config - self.adapter_enabled, - fully_sharded=self.fully_sharded, - token_lora_mapping=token_lora_mapping, - ) - - result = func(*args, **kwargs) - - moe_state_dict["intermediate_cache2"] = output - return result - - return wrapper - - def moe_sum_decorator(layer, func): - def wrapper(*args, **kwargs): - hidden_states = moe_state_dict["hidden_states"] - topk_weights = moe_state_dict["topk_weights"] - - config_dtype = _get_config_dtype_str( - dtype=hidden_states.dtype, - use_fp8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - ) - num_tokens = hidden_states.size(0) - M = num_tokens - max_lora_rank = self.w2_lora_a_stacked[0].shape[-2] - shrink_config, expand_config = self._get_lora_moe_configs( - op_prefix="w2", - num_loras=self.max_loras, - rank=max_lora_rank, - num_slices=1, - M=M, - layer=layer, - top_k=top_k, - config_dtype=config_dtype, - ) - - sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"] - expert_ids_lora = moe_state_dict["expert_ids_lora"] - num_tokens_post_padded_lora = moe_state_dict[ - "num_tokens_post_padded_lora" - ] - token_lora_mapping = moe_state_dict.get("token_lora_mapping") - - if sorted_token_ids_lora is not None: - expert_ids_lora = expert_ids_lora.view(self.max_loras, -1) - sorted_token_ids_lora = sorted_token_ids_lora.view( - self.max_loras, -1 - ) - intermediate_cache2 = moe_state_dict["intermediate_cache2"] - intermediate_cache3 = args[0] - - shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size) - - self.punica_wrapper.add_lora_fused_moe( - intermediate_cache3, - intermediate_cache2, - self.w2_lora_a_stacked, - self.w2_lora_b_stacked, - topk_weights, - sorted_token_ids_lora, - expert_ids_lora, - num_tokens_post_padded_lora, - max_lora_rank, - top_k, - shrink_config, ## pass the shrink config - expand_config, ## pass the expand config - self.adapter_enabled, - True, - fully_sharded=self.fully_sharded, - offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0, - token_lora_mapping=token_lora_mapping, - ) - - result = func(*args, **kwargs) - return result - - return wrapper - - fused_experts = m_fused_moe_fn.impl.fused_experts - - m_fused_moe_fn.apply = fwd_decorator(self.base_layer, m_fused_moe_fn.apply) - fused_experts.activation = act_decorator( - self.base_layer, fused_experts.activation - ) - fused_experts.moe_sum = moe_sum_decorator( - self.base_layer, fused_experts.moe_sum + assert ( + isinstance(moe_kernel.fused_experts, FusedMoEExpertsModular) + and moe_kernel.fused_experts.supports_lora() + ), ( + f"{type(moe_kernel.fused_experts).__name__} does not support LoRA. " + "For unquantized MoE, set moe_backend='triton' or moe_backend='auto' " + "(auto selects Triton automatically when LoRA is enabled). " + "For quantized MoE, implement supports_lora() -> True and handle " + "lora_context in apply()." ) - # TODO(bnell): find a less intrusive way to handle this. self.base_layer._replace_quant_method( - FusedMoEModularMethod(self.base_layer.quant_method, m_fused_moe_fn) + FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel) + ) + + def _build_lora_context(self): + return MoELoRAContext( + w13_lora_a_stacked=self.w13_lora_a_stacked, + w13_lora_b_stacked=self.w13_lora_b_stacked, + w2_lora_a_stacked=self.w2_lora_a_stacked, + w2_lora_b_stacked=self.w2_lora_b_stacked, + adapter_enabled=self.adapter_enabled, + max_loras=self.max_loras, + top_k=self.base_layer.top_k, + w13_num_slices=self._w13_slices, + fully_sharded=self.fully_sharded, + tp_rank=self.tp_rank, + tp_size=self.tp_size, + local_num_experts=self.base_layer.local_num_experts, + punica_wrapper=self.punica_wrapper, + use_tuned_config=bool(envs.VLLM_TUNED_CONFIG_FOLDER), ) def _create_lora_a_weights( @@ -589,6 +329,10 @@ def set_lora( index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2] ].copy_(sliced_w2_lora_b, non_blocking=True) + def set_mapping(self, punica_wrapper): + super().set_mapping(punica_wrapper) + self.base_layer._lora_context = self._build_lora_context() + def forward(self, *args, **kwargs): return self.base_layer.forward(*args, **kwargs) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 6c916cf3cb66..1ab906d3ca39 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -3,9 +3,13 @@ """Fused MoE utilities for GPTQ.""" from collections.abc import Callable +from typing import TYPE_CHECKING import torch +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import ( @@ -654,6 +658,10 @@ def moe_problem_size( class MarlinExperts(MarlinExpertsBase): """Marlin-based fused MoE expert implementation.""" + @staticmethod + def supports_lora() -> bool: + return True + def supports_expert_map(self) -> bool: return True @@ -713,9 +721,109 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): assert self.w1_scale is not None assert self.w2_scale is not None + + if lora_context is None: + return fused_marlin_moe( + hidden_states=hidden_states, + w1=w1, + w2=w2, + bias1=self.w1_bias, + bias2=self.w2_bias, + w1_scale=self.w1_scale, + w2_scale=self.w2_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + global_scale1=self.g1_alphas, + global_scale2=self.g2_alphas, + quant_type_id=self.quant_type_id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + activation=activation, + activation_func=self.activation, + moe_sum=self.moe_sum, + expert_map=expert_map, + output=output, + # Workspaces are swapped in workspace_shapes() to account for proper + # output buffer allocation. Please refer to workspace_shapes(). + intermediate_cache13=workspace2, + intermediate_cache2=workspace13, + g_idx1=self.w13_g_idx, + g_idx2=self.w2_g_idx, + sort_indices1=self.w13_g_idx_sort_indices, + sort_indices2=self.w2_g_idx_sort_indices, + is_k_full=self.is_k_full, + input_dtype=self.input_dtype, + ) + + # LoRA path: wrap activation_func and moe_sum to inject LoRA at the + # two natural injection points. + # + # Marlin uses moe_align_block_size (same as TritonExperts) so + # intermediate_cache1 is indexed by flat (token, expert) pair index, + # which is compatible with add_lora_fused_moe's scatter mechanism. + ctx = lora_context + M = hidden_states.size(0) + top_k_num = topk_ids.size(1) + lora_state: dict = {} + + def activation_with_lora( + act_enum: MoEActivation, + act_output: torch.Tensor, + act_input: torch.Tensor, + ) -> None: + # act_input = intermediate_cache1 (M*topk, 2N for gated) + # act_output = intermediate_cache2 (M*topk, N) + ( + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + token_lora_mapping, + *_, + ) = self._apply_w13_lora( + ctx, + hidden_states, + act_input, + topk_ids, + topk_weights, + expert_map, + w1, + w2, + M, + top_k_num, + ) + lora_state.update( + { + "sorted": sorted_token_ids_lora, + "eids": expert_ids_lora, + "npad": num_tokens_post_padded_lora, + "tlm": token_lora_mapping, + } + ) + self.activation(act_enum, act_output, act_input) + lora_state["cache2"] = act_output + + def moe_sum_with_lora(moe_out: torch.Tensor, out: torch.Tensor) -> None: + # moe_out shape: (M, topk, K) — _apply_w2_lora expects (M, topk, K) + self._apply_w2_lora( + ctx, + lora_state["cache2"], + moe_out, + topk_weights, + lora_state["sorted"], + lora_state["eids"], + lora_state["npad"], + lora_state["tlm"], + M, + w1, + w2, + top_k_num, + ) + self.moe_sum(moe_out, out) + return fused_marlin_moe( hidden_states=hidden_states, w1=w1, @@ -732,12 +840,10 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, activation=activation, - activation_func=self.activation, - moe_sum=self.moe_sum, + activation_func=activation_with_lora, + moe_sum=moe_sum_with_lora, expert_map=expert_map, output=output, - # Workspaces are swapped in workspace_shapes() to account for proper - # output buffer allocation. Please refer to workspace_shapes(). intermediate_cache13=workspace2, intermediate_cache2=workspace13, g_idx1=self.w13_g_idx, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6de25da051ad..f7303f3dc45e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -6,7 +6,10 @@ import json import os from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch @@ -1974,6 +1977,10 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo def _supports_batch_invariance(): return True + @staticmethod + def supports_lora() -> bool: + return True + def supports_expert_map(self) -> bool: return True @@ -2014,6 +2021,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): # Check constraints. if self.quant_config.use_int4_w4a16: @@ -2100,6 +2108,35 @@ def apply( B_bias=self.w1_bias, ) + # LoRA w13: applied to intermediate_cache1 before activation, using + # hidden_states as the lora_a input. moe_lora_align_block_size is + # called once here and results reused for the w2 LoRA below. + sorted_token_ids_lora = None + expert_ids_lora = None + num_tokens_post_padded_lora = None + token_lora_mapping = None + if lora_context is not None: + ( + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + token_lora_mapping, + shrink_config_w13, + expand_config_w13, + ) = self._apply_w13_lora( + lora_context, + hidden_states, + intermediate_cache1, + topk_ids, + topk_weights, + expert_map, + w1, + w2, + num_tokens, + top_k_num, + block_shape=self.block_shape, + ) + self.activation( activation, intermediate_cache2, intermediate_cache1.view(-1, N) ) @@ -2137,6 +2174,26 @@ def apply( B_bias=self.w2_bias, ) + # LoRA w2: applied to intermediate_cache3 before moe_sum, using the + # unquantized intermediate_cache2 as the lora_a input. Reuses the + # sorted_token_ids_lora computed above. + if lora_context is not None: + self._apply_w2_lora( + lora_context, + intermediate_cache2, + intermediate_cache3, + topk_weights, + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + token_lora_mapping, + num_tokens, + w1, + w2, + top_k_num, + block_shape=self.block_shape, + ) + # separate function is required for MoE + LoRA self.moe_sum(intermediate_cache3, output) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index a239dfea92e4..af71ca6bf254 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -2,10 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import abstractmethod +from typing import TYPE_CHECKING import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, @@ -161,6 +165,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 142e180786c6..0f4f7bdab316 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -2,9 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, @@ -91,6 +96,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply( @@ -104,4 +110,5 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=None if self.disable_expert_map else layer.expert_map, shared_experts_input=shared_experts_input, + lora_context=lora_context, ) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index a21ddaba0755..058b47401439 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -1,9 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING import torch +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops @@ -664,6 +668,10 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): One use case for it is to inject LoRA modules on the activation and moe_sum. """ + @staticmethod + def supports_lora() -> bool: + return True + @staticmethod def _supports_activation(activation: MoEActivation) -> bool: return activation in [ @@ -715,6 +723,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): # Use local variable to help mypy narrow the type after None check quant_config = self.quant_config @@ -775,10 +784,41 @@ def apply( y=intermediate_cache1, ) + # w13 LoRA: gather the activation input from expert-sorted + # intermediate_cache1, then add the LoRA delta in-place on that copy + # before passing it to activation — exactly mirroring the old + # decorator approach which modified the gathered tensor in-place. + act_input = intermediate_cache1.view(-1, N)[gather_indx.dst_indx] + + sorted_token_ids_lora = None + expert_ids_lora = None + num_tokens_post_padded_lora = None + token_lora_mapping = None + if lora_context is not None: + ( + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + token_lora_mapping, + _, + _, + ) = self._apply_w13_lora( + lora_context, + hidden_states, + act_input, + topk_ids, + topk_weights, + None, # expert_map already applied above + w1, + w2, + M, + topk, + ) + self.activation( activation, intermediate_cache2, - intermediate_cache1.view(-1, N)[gather_indx.dst_indx], + act_input, ) # matmul_ogs grouped reduction fuse sum across multiple experts: @@ -797,6 +837,25 @@ def apply( y=intermediate_cache3, ) + # w2 LoRA: after matmul_ogs with scatter_indx, intermediate_cache3 is + # in token-topk order, matching the standard (M, topk, K) layout that + # _apply_w2_lora expects. + if lora_context is not None: + self._apply_w2_lora( + lora_context, + intermediate_cache2, + intermediate_cache3.view(-1, topk, K), + topk_weights, + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + token_lora_mapping, + M, + w1, + w2, + topk, + ) + self.moe_sum(intermediate_cache3.view(-1, topk, K), output) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 422f9e427620..bbda944ad1bb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -3,7 +3,10 @@ from collections.abc import Callable, Iterable from enum import Enum -from typing import Literal, cast, get_args, overload +from typing import TYPE_CHECKING, Literal, cast, get_args, overload + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch from torch.nn.parameter import UninitializedParameter @@ -584,6 +587,12 @@ def _get_quant_method() -> FusedMoEMethodBase: enable_dbo=self.vllm_config.parallel_config.enable_dbo, ) + # Set permanently by FusedMoEWithLoRA.set_mapping() when the active + # kernel supports native LoRA (supports_lora() is True). + # FusedMoEModularMethod.apply() reads this and passes it down to + # FusedMoEKernel → FusedMoEExpertsModular.apply(). + self._lora_context: MoELoRAContext | None = None + # TODO(bnell): This method is provided as a hook so vllm/lora/layers/fused_moe.py # can safely swap out the quant_method. We should figure out a less # intrusive way to do this. diff --git a/vllm/model_executor/layers/fused_moe/lora_context.py b/vllm/model_executor/layers/fused_moe/lora_context.py new file mode 100644 index 000000000000..18402123ade2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/lora_context.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Any + +import torch + + +def _normalize_lora_config_keys( + config: dict[str, int | None], +) -> dict[str, int | None]: + """Normalize Triton config dict keys to uppercase BLOCK_SIZE_* format.""" + out: dict[str, int | None] = {} + for key, val in config.items(): + if key.islower(): + if key.startswith("block_"): + nk = "BLOCK_SIZE_" + key.split("_")[-1].upper() + else: + nk = key.upper() + else: + nk = key + out[nk] = val + return out + + +@dataclass +class MoELoRAContext: + """ + Carries all LoRA state for MoE forward passes. + + Built once by FusedMoEWithLoRA.set_mapping() and stored permanently on the + base FusedMoE layer as _lora_context. Propagated explicitly through the + modular kernel path (FusedMoEKernel -> FusedMoEExpertsModular.apply) so + that TritonExperts.apply() can compute the LoRA contribution inline. + + All tensor fields are stored by reference so in-place updates (set_lora, + reset_lora, adapter_enabled) and punica_wrapper state updates are visible + without rebuilding the context. + + Typed as Any for punica_wrapper to avoid a circular import at module load + time: vllm.lora imports vllm.model_executor.layers.fused_moe, so the + reverse at module level would be circular. The actual type is + PunicaWrapperBase from vllm.lora.punica_wrapper. + """ + + # LoRA weight tensors (same shapes as FusedMoEWithLoRA attributes) + w13_lora_a_stacked: tuple[torch.Tensor, ...] + w13_lora_b_stacked: tuple[torch.Tensor, ...] + w2_lora_a_stacked: tuple[torch.Tensor, ...] + w2_lora_b_stacked: tuple[torch.Tensor, ...] + + # (max_loras + 1,) int32; slot 0 is the "no-adapter" sentinel + adapter_enabled: torch.Tensor + + # Metadata + max_loras: int + top_k: int + w13_num_slices: int # 2 = gated (gate + up), 1 = non-gated or 3D-fused + fully_sharded: bool + tp_rank: int + tp_size: int + local_num_experts: int + + # PunicaWrapperBase instance (typed Any to avoid circular import) + punica_wrapper: Any + + # Whether VLLM_TUNED_CONFIG_FOLDER is set; selects get_lora_op_configs vs + # try_get_optimal_moe_lora_config for Triton kernel tile configs. + use_tuned_config: bool diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index f2e6e2560e70..0738d4861323 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -5,10 +5,13 @@ from dataclasses import dataclass from enum import Enum from math import prod -from typing import final +from typing import TYPE_CHECKING, final import torch +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import ( @@ -254,6 +257,16 @@ class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize): described above for the Modular case. """ + @staticmethod + def supports_lora() -> bool: + """Return True if this prepare/finalize impl can propagate LoRA context. + + Non-EP implementations support LoRA by default. EP-aware + implementations must override this to True once they can dispatch + lora_ids alongside hidden states in their all2all call (Task 2). + """ + return True + @abstractmethod def prepare( self, @@ -734,6 +747,17 @@ def g1_alphas(self) -> torch.Tensor | None: def g2_alphas(self) -> torch.Tensor | None: return self.quant_config.g2_alphas + @staticmethod + def supports_lora() -> bool: + """Return True if this expert impl natively handles MoELoRAContext. + + When True, FusedMoEWithLoRA will propagate a MoELoRAContext through + FusedMoEKernel.apply() instead of using the legacy decorator injection. + Subclasses that inline the LoRA computation inside apply() must override + this to return True. + """ + return False + @abstractmethod def supports_expert_map(self) -> bool: """ @@ -896,6 +920,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ) -> None: """ This function computes the intermediate result of a Mixture of Experts @@ -934,6 +959,252 @@ def apply( """ raise NotImplementedError + # ------------------------------------------------------------------------- + # Shared LoRA helpers + # + # The LoRA operators (add_lora_fused_moe / moe_lora_align_block_size) are + # Triton kernels that operate on fp16/bf16 activation tensors and are + # independent of the main GEMM quantization backend. Any modular expert + # backend can call these helpers from its apply() at the two natural + # injection points: + # • after the w13 GEMM, before activation (_apply_w13_lora) + # • after the w2 GEMM, before moe_sum (_apply_w2_lora) + # + # block_shape is the only backend-specific knob: pass self.block_shape for + # quantized backends; omit (None) for unquantized ones. + # ------------------------------------------------------------------------- + + def _apply_w13_lora( + self, + lora_context: "MoELoRAContext", + hidden_states: torch.Tensor, + intermediate_cache1: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_map: torch.Tensor | None, + w1: torch.Tensor, + w2: torch.Tensor, + num_tokens: int, + top_k_num: int, + *, + block_shape: list[int] | None = None, + ): + import functools + + from vllm.lora.layers.utils import try_get_optimal_moe_lora_config + from vllm.lora.ops.triton_ops.utils import get_lora_op_configs + from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str + from vllm.model_executor.layers.fused_moe.lora_context import ( + _normalize_lora_config_keys, + ) + + ctx = lora_context + config_dtype = _get_config_dtype_str( + dtype=hidden_states.dtype, + use_fp8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + ) + max_lora_rank = ctx.w13_lora_a_stacked[0].shape[-2] + + if ctx.use_tuned_config: + hidden_size = hidden_states.shape[-1] + intermediate_size = ctx.w13_lora_b_stacked[0].shape[-2] + shrink_config = get_lora_op_configs( + op_type="fused_moe_lora_w13_shrink", + max_loras=ctx.max_loras, + batch=num_tokens, + hidden_size=hidden_size, + rank=max_lora_rank, + num_slices=ctx.w13_num_slices, + moe_intermediate_size=intermediate_size, + ) + expand_config = get_lora_op_configs( + op_type="fused_moe_lora_w13_expand", + max_loras=ctx.max_loras, + batch=num_tokens, + hidden_size=hidden_size, + rank=max_lora_rank, + num_slices=ctx.w13_num_slices, + moe_intermediate_size=intermediate_size, + ) + else: + get_config_func = functools.partial( + try_get_optimal_moe_lora_config, + w1_shape=w1.shape, + w2_shape=w2.shape, + rank=max_lora_rank, + top_k=ctx.top_k, + dtype=config_dtype, + M=num_tokens, + block_shape=block_shape, + ) + shrink_config = get_config_func(op_type="fused_moe_lora_w13_shrink") + expand_config = get_config_func(op_type="fused_moe_lora_w13_expand") + + shrink_config = _normalize_lora_config_keys(shrink_config) + expand_config = _normalize_lora_config_keys(expand_config) + + SPARSITY_FACTOR = 8 + naive_block_assignment = ( + expert_map is None + and num_tokens * ctx.top_k * SPARSITY_FACTOR + <= ctx.local_num_experts * ctx.max_loras + ) + + ( + token_lora_mapping, + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + ) = ctx.punica_wrapper.moe_lora_align_block_size( + topk_ids, + num_tokens, + shrink_config["BLOCK_SIZE_M"], + ctx.local_num_experts, + ctx.max_loras, + ctx.adapter_enabled, + expert_map, + naive_block_assignment=naive_block_assignment, + ) + + _sorted_token_ids_lora = sorted_token_ids_lora + _expert_ids_lora = expert_ids_lora + if _sorted_token_ids_lora is not None: + _expert_ids_lora = _expert_ids_lora.view(ctx.max_loras, -1) + _sorted_token_ids_lora = _sorted_token_ids_lora.view(ctx.max_loras, -1) + + ctx.punica_wrapper.add_lora_fused_moe( + intermediate_cache1.view(-1, top_k_num, intermediate_cache1.shape[-1]), + hidden_states, + ctx.w13_lora_a_stacked, + ctx.w13_lora_b_stacked, + topk_weights, + _sorted_token_ids_lora, + _expert_ids_lora, + num_tokens_post_padded_lora, + max_lora_rank, + ctx.top_k, + shrink_config, + expand_config, + ctx.adapter_enabled, + fully_sharded=ctx.fully_sharded, + token_lora_mapping=token_lora_mapping, + ) + + return ( + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + token_lora_mapping, + shrink_config, + expand_config, + ) + + def _apply_w2_lora( + self, + lora_context: "MoELoRAContext", + intermediate_cache2: torch.Tensor, + intermediate_cache3: torch.Tensor, + topk_weights: torch.Tensor, + sorted_token_ids_lora: torch.Tensor | None, + expert_ids_lora: torch.Tensor | None, + num_tokens_post_padded_lora: torch.Tensor | None, + token_lora_mapping: torch.Tensor | None, + num_tokens: int, + w1: torch.Tensor, + w2: torch.Tensor, + top_k_num: int, + *, + block_shape: list[int] | None = None, + ): + import functools + + from vllm.lora.layers.utils import try_get_optimal_moe_lora_config + from vllm.lora.ops.triton_ops.utils import get_lora_op_configs + from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str + from vllm.model_executor.layers.fused_moe.lora_context import ( + _normalize_lora_config_keys, + ) + + ctx = lora_context + config_dtype = _get_config_dtype_str( + dtype=intermediate_cache2.dtype, + use_fp8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + ) + max_lora_rank = ctx.w2_lora_a_stacked[0].shape[-2] + + if ctx.use_tuned_config: + hidden_size = intermediate_cache3.shape[-1] + intermediate_size = ctx.w2_lora_a_stacked[0].shape[-1] + shrink_config = get_lora_op_configs( + op_type="fused_moe_lora_w2_shrink", + max_loras=ctx.max_loras, + batch=num_tokens, + hidden_size=hidden_size, + rank=max_lora_rank, + num_slices=1, + moe_intermediate_size=intermediate_size, + ) + expand_config = get_lora_op_configs( + op_type="fused_moe_lora_w2_expand", + max_loras=ctx.max_loras, + batch=num_tokens, + hidden_size=hidden_size, + rank=max_lora_rank, + num_slices=1, + moe_intermediate_size=intermediate_size, + ) + else: + get_config_func = functools.partial( + try_get_optimal_moe_lora_config, + w1_shape=w1.shape, + w2_shape=w2.shape, + rank=max_lora_rank, + top_k=ctx.top_k, + dtype=config_dtype, + M=num_tokens, + block_shape=block_shape, + ) + shrink_config = get_config_func(op_type="fused_moe_lora_w2_shrink") + expand_config = get_config_func(op_type="fused_moe_lora_w2_expand") + + shrink_config = _normalize_lora_config_keys(shrink_config) + expand_config = _normalize_lora_config_keys(expand_config) + + _sorted_token_ids_lora = sorted_token_ids_lora + _expert_ids_lora = expert_ids_lora + if _sorted_token_ids_lora is not None: + _expert_ids_lora = _expert_ids_lora.view(ctx.max_loras, -1) + _sorted_token_ids_lora = _sorted_token_ids_lora.view(ctx.max_loras, -1) + + # w2_lora_b shape[-2] is hidden_size // tp_size when fully_sharded, + # matching divide(base_layer.hidden_size, tp_size). + shard_size_w2 = ctx.w2_lora_b_stacked[0].shape[-2] + offset = shard_size_w2 * ctx.tp_rank if ctx.fully_sharded else 0 + + ctx.punica_wrapper.add_lora_fused_moe( + intermediate_cache3, + intermediate_cache2, + ctx.w2_lora_a_stacked, + ctx.w2_lora_b_stacked, + topk_weights, + _sorted_token_ids_lora, + _expert_ids_lora, + num_tokens_post_padded_lora, + max_lora_rank, + ctx.top_k, + shrink_config, + expand_config, + ctx.adapter_enabled, + True, # is_w2 + fully_sharded=ctx.fully_sharded, + offset=offset, + token_lora_mapping=token_lora_mapping, + ) + class FusedMoEExpertsMonolithic(FusedMoEExperts): """ @@ -1204,6 +1475,7 @@ def _fused_experts( expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, expert_tokens_meta: ExpertTokensMetadata | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: _, M_full, N, K, top_k = self.fused_experts.moe_problem_size( a1q, w1, w2, topk_ids @@ -1248,6 +1520,7 @@ def _fused_experts( workspace2=workspace2, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, + lora_context=lora_context, ) return fused_out @@ -1330,6 +1603,7 @@ def apply( expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, shared_experts_input: torch.Tensor | None = None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets @@ -1354,6 +1628,8 @@ def apply( - shared_experts_input (Optional[torch.Tensor]): Optional separate input for shared experts. For latent MoE, this is the original hidden_states before latent projection. + - lora_context (Optional[MoELoRAContext]): LoRA context to propagate to + fused_experts.apply() when the expert backend supports native LoRA. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -1392,6 +1668,7 @@ def apply( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, expert_tokens_meta=expert_tokens_meta, + lora_context=lora_context, ) return self._finalize( @@ -1595,6 +1872,7 @@ def apply( expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, shared_experts_input: torch.Tensor | None = None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert isinstance(self.impl, FusedMoEKernelModularImpl) return self.impl.apply( @@ -1608,4 +1886,5 @@ def apply( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, shared_experts_input=shared_experts_input, + lora_context=lora_context, ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py index 33bf7a0c75f6..ddc49a7e3bcd 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py +++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py @@ -209,6 +209,8 @@ def _return_or_raise( return backend, k_cls raise ValueError(_make_log_unsupported(backend, reason)) + require_lora = moe_config.is_lora_enabled + runner_backend = moe_config.moe_backend if runner_backend != "auto": requested_backend = map_unquantized_backend(runner_backend) @@ -218,10 +220,20 @@ def _return_or_raise( ): requested_backend = UnquantizedMoeBackend.BATCHED_TRITON + if require_lora: + k_cls = backend_to_kernel_cls(requested_backend) + if not k_cls.supports_lora(): + raise ValueError( + f"moe_backend='{runner_backend}' does not support LoRA. " + "Use moe_backend='triton' or moe_backend='auto'." + ) return _return_or_raise(requested_backend, moe_config, activation_format) # Handle explicit FlashInfer FP16 configuration. - if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP16"): + # When LoRA is enabled, FlashInfer backends don't support it — skip this + # block entirely and fall through to the generic auto-selection loop which + # will filter by supports_lora(). + if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP16") and not require_lora: if not envs.VLLM_USE_FLASHINFER_MOE_FP16: if UnquantizedMoeBackend.FLASHINFER_TRTLLM in AVAILABLE_BACKENDS: AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.FLASHINFER_TRTLLM) @@ -272,10 +284,27 @@ def _return_or_raise( AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.AITER) else: backend = UnquantizedMoeBackend.AITER - return _return_or_raise(backend, moe_config, activation_format) + if require_lora: + k_cls = backend_to_kernel_cls(backend) + if not k_cls.supports_lora(): + logger.warning_once( + "VLLM_ROCM_USE_AITER_MOE=1 but AiterExperts does not " + "support LoRA; falling back to generic backend selection.", + scope="local", + ) + else: + return _return_or_raise(backend, moe_config, activation_format) + else: + return _return_or_raise(backend, moe_config, activation_format) for backend in AVAILABLE_BACKENDS: k_cls = backend_to_kernel_cls(backend) + if require_lora and not k_cls.supports_lora(): + logger.debug_once( + f"Skipping MoE backend {backend.value}: does not support LoRA.", + scope="local", + ) + continue supported, reason = k_cls.is_supported_config( k_cls, moe_config, None, None, activation_format ) @@ -286,7 +315,8 @@ def _return_or_raise( logger.debug_once(_make_log_unsupported(backend, reason), scope="local") raise NotImplementedError( - "No Unquantized MoE backend supports the deployment configuration." + "No Unquantized MoE backend supports the deployment configuration" + + (" with LoRA enabled." if require_lora else ".") ) diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py index a881c27d542c..c848e1aea4e5 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py @@ -412,6 +412,7 @@ def _apply_quant_method( topk_weights=topk_weights, topk_ids=topk_ids, shared_experts_input=shared_experts_input, + lora_context=getattr(layer, "_lora_context", None), ) self._maybe_apply_shared_experts( diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 190821562130..2c18a7527fe5 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -2,8 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable +from typing import TYPE_CHECKING import torch + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch.nn.functional as F from torch.nn import Module @@ -257,6 +261,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: return self.forward( layer=layer, @@ -264,7 +269,8 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, shared_experts_input=shared_experts_input, - ) + lora_context=lora_context, + ) # CustomOp.forward uses *args/**kwargs, lora_context passes through def forward_native( self, @@ -273,6 +279,7 @@ def forward_native( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply( @@ -286,6 +293,7 @@ def forward_native( global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, shared_experts_input=shared_experts_input, + lora_context=lora_context, ) def forward_cuda( @@ -295,9 +303,15 @@ def forward_cuda( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: return self.forward_native( - layer, x, topk_weights, topk_ids, shared_experts_input + layer, + x, + topk_weights, + topk_ids, + shared_experts_input, + lora_context=lora_context, ) def apply_monolithic( diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index cfad1f86faa2..f533a031a18f 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -59,6 +59,7 @@ from vllm.transformers_utils.config import get_safetensors_params_metadata if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.models.utils import WeightsMapper @@ -817,6 +818,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: return fused_marlin_moe( x, diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 729924663646..5e86fd8a1c06 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch from packaging import version @@ -483,6 +486,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 301441ff019d..63a1d5a00cd1 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch @@ -141,6 +144,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d7920462e613..cbcbbce0eb2b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -86,6 +86,7 @@ ) if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.models.utils import WeightsMapper ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -899,6 +900,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 61eb6c912a11..3dbbfabd8ad9 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.layers.quantization import QuantizationMethods import gguf @@ -650,6 +651,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: if layer.apply_router_weight_on_input: raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 1ca551d6351b..196437afcfd1 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -2,7 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE @@ -907,6 +910,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: return fused_marlin_moe( x, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 0b8ad0cbc1ed..2cbc180144fc 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -96,6 +96,7 @@ from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) @@ -973,6 +974,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None @@ -1465,6 +1467,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None @@ -2003,6 +2006,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index e5ef3f4c3168..56e47a27f1ac 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch @@ -369,6 +372,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 019bb45d65dc..4aa780fb513b 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1,8 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention @@ -418,6 +423,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None @@ -432,6 +438,7 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=layer.expert_map, shared_experts_input=shared_experts_input, + lora_context=lora_context, ) def apply_monolithic( diff --git a/vllm/model_executor/layers/quantization/online/fp8.py b/vllm/model_executor/layers/quantization/online/fp8.py index fa8cf240627b..de2a97b7745f 100644 --- a/vllm/model_executor/layers/quantization/online/fp8.py +++ b/vllm/model_executor/layers/quantization/online/fp8.py @@ -13,6 +13,7 @@ FusedMoEConfig, FusedMoEQuantConfig, ) + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.layers.fused_moe.oracle.fp8 import Fp8MoeBackend import vllm.envs as envs @@ -506,6 +507,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic assert self.moe_kernel is not None diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index d4db929eaeb6..d3885316a4a6 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch @@ -455,6 +458,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( @@ -769,6 +773,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts @@ -921,6 +926,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts, @@ -1413,6 +1419,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: # For w_mxfp4 with oracle kernel if self.moe_kernel is not None: From 9c64777220e9ed8eaae78e5f4f64aeb65cb97f35 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 19 Apr 2026 17:55:33 +0000 Subject: [PATCH 02/24] Move Signed-off-by: Jee Jee Li --- vllm/lora/layers/fused_moe.py | 2 +- .../layers/fused_moe => lora}/lora_context.py | 14 +- vllm/lora/punica_wrapper/punica_base.py | 66 +++++ vllm/lora/punica_wrapper/punica_gpu.py | 241 ++++++++++++++++ .../layers/fused_moe/fused_marlin_moe.py | 49 ++-- .../layers/fused_moe/fused_moe.py | 50 ++-- .../layers/fused_moe/fused_moe_method_base.py | 2 +- .../fused_moe/fused_moe_modular_method.py | 2 +- .../fused_moe/gpt_oss_triton_kernels_moe.py | 51 ++-- vllm/model_executor/layers/fused_moe/layer.py | 2 +- .../layers/fused_moe/modular_kernel.py | 272 ++++-------------- .../fused_moe/unquantized_fused_moe_method.py | 2 +- .../layers/quantization/awq_marlin.py | 2 +- .../layers/quantization/bitsandbytes.py | 2 +- .../layers/quantization/experts_int8.py | 2 +- .../model_executor/layers/quantization/fp8.py | 2 +- .../layers/quantization/gguf.py | 2 +- .../layers/quantization/gptq_marlin.py | 2 +- .../layers/quantization/modelopt.py | 2 +- .../layers/quantization/moe_wna16.py | 2 +- .../layers/quantization/mxfp4.py | 2 +- .../layers/quantization/online/fp8.py | 2 +- .../layers/quantization/quark/quark_moe.py | 2 +- 23 files changed, 452 insertions(+), 323 deletions(-) rename vllm/{model_executor/layers/fused_moe => lora}/lora_context.py (79%) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index dc969c71d02a..c44430f3315b 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -13,11 +13,11 @@ ) from vllm.distributed.utils import divide from vllm.lora.layers.base import BaseLayerWithLoRA +from vllm.lora.lora_context import MoELoRAContext from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( FusedMoEModularMethod, ) -from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEExpertsModular, FusedMoEKernel, diff --git a/vllm/model_executor/layers/fused_moe/lora_context.py b/vllm/lora/lora_context.py similarity index 79% rename from vllm/model_executor/layers/fused_moe/lora_context.py rename to vllm/lora/lora_context.py index 18402123ade2..07d8ef3f4db4 100644 --- a/vllm/model_executor/layers/fused_moe/lora_context.py +++ b/vllm/lora/lora_context.py @@ -27,16 +27,12 @@ def _normalize_lora_config_keys( @dataclass class MoELoRAContext: """ - Carries all LoRA state for MoE forward passes. + Carries all LoRA state for one MoE forward pass. - Built once by FusedMoEWithLoRA.set_mapping() and stored permanently on the - base FusedMoE layer as _lora_context. Propagated explicitly through the + Built by FusedMoEWithLoRA.forward() and propagated explicitly through the modular kernel path (FusedMoEKernel -> FusedMoEExpertsModular.apply) so - that TritonExperts.apply() can compute the LoRA contribution inline. - - All tensor fields are stored by reference so in-place updates (set_lora, - reset_lora, adapter_enabled) and punica_wrapper state updates are visible - without rebuilding the context. + that TritonExperts.apply() can compute the LoRA contribution inline, + replacing the decorator-based monkey-patch approach. Typed as Any for punica_wrapper to avoid a circular import at module load time: vllm.lora imports vllm.model_executor.layers.fused_moe, so the @@ -56,7 +52,7 @@ class MoELoRAContext: # Metadata max_loras: int top_k: int - w13_num_slices: int # 2 = gated (gate + up), 1 = non-gated or 3D-fused + w13_num_slices: int # 2 = gated (gate + up), 1 = non-gated or 3D-fused fully_sharded: bool tp_rank: int tp_size: int diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index facbd681a09a..546d82fa285b 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -493,3 +493,69 @@ def add_lora_fused_moe( """ # TODO: implement it based on torch ops raise NotImplementedError + + def add_lora_w13( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_map: torch.Tensor | None, + w1: torch.Tensor, + w2: torch.Tensor, + num_tokens: int, + top_k_num: int, + max_loras: int, + adapter_enabled: torch.Tensor, + local_num_experts: int, + top_k: int, + num_slices: int, + fully_sharded: bool, + use_tuned_config: bool, + *, + block_shape: list[int] | None = None, + ) -> tuple[ + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + """Apply w13 LoRA to y (intermediate_cache1) in-place before activation. + + Returns (sorted_token_ids_lora, expert_ids_lora, + num_tokens_post_padded_lora, token_lora_mapping) + for reuse by add_lora_w2. + """ + raise NotImplementedError + + def add_lora_w2( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + topk_weights: torch.Tensor, + sorted_token_ids_lora: torch.Tensor | None, + expert_ids_lora: torch.Tensor | None, + num_tokens_post_padded_lora: torch.Tensor | None, + token_lora_mapping: torch.Tensor | None, + num_tokens: int, + w1: torch.Tensor, + w2: torch.Tensor, + top_k_num: int, + max_loras: int, + adapter_enabled: torch.Tensor, + top_k: int, + fully_sharded: bool, + tp_rank: int, + use_tuned_config: bool, + *, + block_shape: list[int] | None = None, + ) -> None: + """Apply w2 LoRA to y (intermediate_cache3) in-place before moe_sum. + + Reuses routing tensors returned by add_lora_w13. + """ + raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 321cbfcab7cd..f521e1013e2d 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -459,3 +459,244 @@ def add_lora_fused_moe( fully_sharded, offset, ) + + def add_lora_w13( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_map: torch.Tensor | None, + w1: torch.Tensor, + w2: torch.Tensor, + num_tokens: int, + top_k_num: int, + max_loras: int, + adapter_enabled: torch.Tensor, + local_num_experts: int, + top_k: int, + num_slices: int, + fully_sharded: bool, + use_tuned_config: bool, + *, + block_shape: list[int] | None = None, + ) -> tuple[ + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + import functools + + from vllm.lora.layers.utils import try_get_optimal_moe_lora_config + from vllm.lora.lora_context import ( + _normalize_lora_config_keys, + ) + from vllm.lora.ops.triton_ops.utils import get_lora_op_configs + from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str + + config_dtype = _get_config_dtype_str( + dtype=x.dtype, + use_fp8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + ) + max_lora_rank = lora_a_stacked[0].shape[-2] + + if use_tuned_config: + shrink_config = get_lora_op_configs( + op_type="fused_moe_lora_w13_shrink", + max_loras=max_loras, + batch=num_tokens, + hidden_size=x.shape[-1], + rank=max_lora_rank, + num_slices=num_slices, + moe_intermediate_size=lora_b_stacked[0].shape[-2], + ) + expand_config = get_lora_op_configs( + op_type="fused_moe_lora_w13_expand", + max_loras=max_loras, + batch=num_tokens, + hidden_size=x.shape[-1], + rank=max_lora_rank, + num_slices=num_slices, + moe_intermediate_size=lora_b_stacked[0].shape[-2], + ) + else: + get_config = functools.partial( + try_get_optimal_moe_lora_config, + w1_shape=w1.shape, + w2_shape=w2.shape, + rank=max_lora_rank, + top_k=top_k, + dtype=config_dtype, + M=num_tokens, + block_shape=block_shape, + ) + shrink_config = get_config(op_type="fused_moe_lora_w13_shrink") + expand_config = get_config(op_type="fused_moe_lora_w13_expand") + + shrink_config = _normalize_lora_config_keys(shrink_config) + expand_config = _normalize_lora_config_keys(expand_config) + + SPARSITY_FACTOR = 8 + naive_block_assignment = ( + expert_map is None + and num_tokens * top_k * SPARSITY_FACTOR <= local_num_experts * max_loras + ) + + ( + token_lora_mapping, + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + ) = self.moe_lora_align_block_size( + topk_ids, + num_tokens, + shrink_config["BLOCK_SIZE_M"], + local_num_experts, + max_loras, + adapter_enabled, + expert_map, + naive_block_assignment=naive_block_assignment, + ) + + _sorted = sorted_token_ids_lora + _eids = expert_ids_lora + if _sorted is not None: + _eids = _eids.view(max_loras, -1) + _sorted = _sorted.view(max_loras, -1) + + self.add_lora_fused_moe( + y.view(-1, top_k_num, y.shape[-1]), + x, + lora_a_stacked, + lora_b_stacked, + topk_weights, + _sorted, + _eids, + num_tokens_post_padded_lora, + max_lora_rank, + top_k, + shrink_config, + expand_config, + adapter_enabled, + fully_sharded=fully_sharded, + token_lora_mapping=token_lora_mapping, + ) + + return ( + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + token_lora_mapping, + ) + + def add_lora_w2( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + topk_weights: torch.Tensor, + sorted_token_ids_lora: torch.Tensor | None, + expert_ids_lora: torch.Tensor | None, + num_tokens_post_padded_lora: torch.Tensor | None, + token_lora_mapping: torch.Tensor | None, + num_tokens: int, + w1: torch.Tensor, + w2: torch.Tensor, + top_k_num: int, + max_loras: int, + adapter_enabled: torch.Tensor, + top_k: int, + fully_sharded: bool, + tp_rank: int, + use_tuned_config: bool, + *, + block_shape: list[int] | None = None, + ) -> None: + import functools + + from vllm.lora.layers.utils import try_get_optimal_moe_lora_config + from vllm.lora.lora_context import ( + _normalize_lora_config_keys, + ) + from vllm.lora.ops.triton_ops.utils import get_lora_op_configs + from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str + + config_dtype = _get_config_dtype_str( + dtype=x.dtype, + use_fp8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + ) + max_lora_rank = lora_a_stacked[0].shape[-2] + + if use_tuned_config: + shrink_config = get_lora_op_configs( + op_type="fused_moe_lora_w2_shrink", + max_loras=max_loras, + batch=num_tokens, + hidden_size=y.shape[-1], + rank=max_lora_rank, + num_slices=1, + moe_intermediate_size=lora_a_stacked[0].shape[-1], + ) + expand_config = get_lora_op_configs( + op_type="fused_moe_lora_w2_expand", + max_loras=max_loras, + batch=num_tokens, + hidden_size=y.shape[-1], + rank=max_lora_rank, + num_slices=1, + moe_intermediate_size=lora_a_stacked[0].shape[-1], + ) + else: + get_config = functools.partial( + try_get_optimal_moe_lora_config, + w1_shape=w1.shape, + w2_shape=w2.shape, + rank=max_lora_rank, + top_k=top_k, + dtype=config_dtype, + M=num_tokens, + block_shape=block_shape, + ) + shrink_config = get_config(op_type="fused_moe_lora_w2_shrink") + expand_config = get_config(op_type="fused_moe_lora_w2_expand") + + shrink_config = _normalize_lora_config_keys(shrink_config) + expand_config = _normalize_lora_config_keys(expand_config) + + _sorted = sorted_token_ids_lora + _eids = expert_ids_lora + if _sorted is not None: + _eids = _eids.view(max_loras, -1) + _sorted = _sorted.view(max_loras, -1) + + # w2_lora_b shape[-2] is hidden_size // tp_size when fully_sharded + shard_size = lora_b_stacked[0].shape[-2] + offset = shard_size * tp_rank if fully_sharded else 0 + + self.add_lora_fused_moe( + y, + x, + lora_a_stacked, + lora_b_stacked, + topk_weights, + _sorted, + _eids, + num_tokens_post_padded_lora, + max_lora_rank, + top_k, + shrink_config, + expand_config, + adapter_enabled, + True, # mul_routed_weight + fully_sharded=fully_sharded, + offset=offset, + token_lora_mapping=token_lora_mapping, + ) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 1ab906d3ca39..4964f598a767 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -8,7 +8,7 @@ import torch if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -782,18 +782,17 @@ def activation_with_lora( expert_ids_lora, num_tokens_post_padded_lora, token_lora_mapping, - *_, - ) = self._apply_w13_lora( + ) = self.apply_w13_lora( ctx, - hidden_states, - act_input, - topk_ids, - topk_weights, - expert_map, - w1, - w2, - M, - top_k_num, + y=act_input, + x=hidden_states, + topk_ids=topk_ids, + topk_weights=topk_weights, + expert_map=expert_map, + w1=w1, + w2=w2, + num_tokens=M, + top_k_num=top_k_num, ) lora_state.update( { @@ -807,20 +806,20 @@ def activation_with_lora( lora_state["cache2"] = act_output def moe_sum_with_lora(moe_out: torch.Tensor, out: torch.Tensor) -> None: - # moe_out shape: (M, topk, K) — _apply_w2_lora expects (M, topk, K) - self._apply_w2_lora( + # moe_out shape: (M, topk, K) + self.apply_w2_lora( ctx, - lora_state["cache2"], - moe_out, - topk_weights, - lora_state["sorted"], - lora_state["eids"], - lora_state["npad"], - lora_state["tlm"], - M, - w1, - w2, - top_k_num, + y=moe_out, + x=lora_state["cache2"], + topk_weights=topk_weights, + sorted_token_ids_lora=lora_state["sorted"], + expert_ids_lora=lora_state["eids"], + num_tokens_post_padded_lora=lora_state["npad"], + token_lora_mapping=lora_state["tlm"], + num_tokens=M, + w1=w1, + w2=w2, + top_k_num=top_k_num, ) self.moe_sum(moe_out, out) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f7303f3dc45e..b39279dc2cd1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext import torch @@ -2121,20 +2121,17 @@ def apply( expert_ids_lora, num_tokens_post_padded_lora, token_lora_mapping, - shrink_config_w13, - expand_config_w13, - ) = self._apply_w13_lora( + ) = self.apply_w13_lora( lora_context, - hidden_states, - intermediate_cache1, - topk_ids, - topk_weights, - expert_map, - w1, - w2, - num_tokens, - top_k_num, - block_shape=self.block_shape, + y=intermediate_cache1, + x=hidden_states, + topk_ids=topk_ids, + topk_weights=topk_weights, + expert_map=expert_map, + w1=w1, + w2=w2, + num_tokens=num_tokens, + top_k_num=top_k_num, ) self.activation( @@ -2178,20 +2175,19 @@ def apply( # unquantized intermediate_cache2 as the lora_a input. Reuses the # sorted_token_ids_lora computed above. if lora_context is not None: - self._apply_w2_lora( + self.apply_w2_lora( lora_context, - intermediate_cache2, - intermediate_cache3, - topk_weights, - sorted_token_ids_lora, - expert_ids_lora, - num_tokens_post_padded_lora, - token_lora_mapping, - num_tokens, - w1, - w2, - top_k_num, - block_shape=self.block_shape, + y=intermediate_cache3, + x=intermediate_cache2, + topk_weights=topk_weights, + sorted_token_ids_lora=sorted_token_ids_lora, + expert_ids_lora=expert_ids_lora, + num_tokens_post_padded_lora=num_tokens_post_padded_lora, + token_lora_mapping=token_lora_mapping, + num_tokens=num_tokens, + w1=w1, + w2=w2, + top_k_num=top_k_num, ) # separate function is required for MoE + LoRA diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index af71ca6bf254..642e7b497f19 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -9,7 +9,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 0f4f7bdab316..9b281a23dad4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -9,7 +9,7 @@ from vllm.logger import init_logger if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 058b47401439..5b31b4216fa9 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -6,7 +6,7 @@ import torch if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops @@ -800,19 +800,17 @@ def apply( expert_ids_lora, num_tokens_post_padded_lora, token_lora_mapping, - _, - _, - ) = self._apply_w13_lora( + ) = self.apply_w13_lora( lora_context, - hidden_states, - act_input, - topk_ids, - topk_weights, - None, # expert_map already applied above - w1, - w2, - M, - topk, + y=act_input, + x=hidden_states, + topk_ids=topk_ids, + topk_weights=topk_weights, + expert_map=None, # already applied above + w1=w1, + w2=w2, + num_tokens=M, + top_k_num=topk, ) self.activation( @@ -838,22 +836,21 @@ def apply( ) # w2 LoRA: after matmul_ogs with scatter_indx, intermediate_cache3 is - # in token-topk order, matching the standard (M, topk, K) layout that - # _apply_w2_lora expects. + # in token-topk order, matching the (M, topk, K) layout add_lora_w2 expects. if lora_context is not None: - self._apply_w2_lora( + self.apply_w2_lora( lora_context, - intermediate_cache2, - intermediate_cache3.view(-1, topk, K), - topk_weights, - sorted_token_ids_lora, - expert_ids_lora, - num_tokens_post_padded_lora, - token_lora_mapping, - M, - w1, - w2, - topk, + y=intermediate_cache3.view(-1, topk, K), + x=intermediate_cache2, + topk_weights=topk_weights, + sorted_token_ids_lora=sorted_token_ids_lora, + expert_ids_lora=expert_ids_lora, + num_tokens_post_padded_lora=num_tokens_post_padded_lora, + token_lora_mapping=token_lora_mapping, + num_tokens=M, + w1=w1, + w2=w2, + top_k_num=topk, ) self.moe_sum(intermediate_cache3.view(-1, topk, K), output) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index bbda944ad1bb..0eace3327092 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Literal, cast, get_args, overload if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext import torch from torch.nn.parameter import UninitializedParameter diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 0738d4861323..0a70391ff478 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -10,7 +10,7 @@ import torch if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext import vllm.envs as envs from vllm.logger import init_logger @@ -959,26 +959,12 @@ def apply( """ raise NotImplementedError - # ------------------------------------------------------------------------- - # Shared LoRA helpers - # - # The LoRA operators (add_lora_fused_moe / moe_lora_align_block_size) are - # Triton kernels that operate on fp16/bf16 activation tensors and are - # independent of the main GEMM quantization backend. Any modular expert - # backend can call these helpers from its apply() at the two natural - # injection points: - # • after the w13 GEMM, before activation (_apply_w13_lora) - # • after the w2 GEMM, before moe_sum (_apply_w2_lora) - # - # block_shape is the only backend-specific knob: pass self.block_shape for - # quantized backends; omit (None) for unquantized ones. - # ------------------------------------------------------------------------- - - def _apply_w13_lora( + def apply_w13_lora( self, lora_context: "MoELoRAContext", - hidden_states: torch.Tensor, - intermediate_cache1: torch.Tensor, + *, + y: torch.Tensor, + x: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, expert_map: torch.Tensor | None, @@ -986,126 +972,40 @@ def _apply_w13_lora( w2: torch.Tensor, num_tokens: int, top_k_num: int, - *, - block_shape: list[int] | None = None, - ): - import functools - - from vllm.lora.layers.utils import try_get_optimal_moe_lora_config - from vllm.lora.ops.triton_ops.utils import get_lora_op_configs - from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str - from vllm.model_executor.layers.fused_moe.lora_context import ( - _normalize_lora_config_keys, - ) - - ctx = lora_context - config_dtype = _get_config_dtype_str( - dtype=hidden_states.dtype, - use_fp8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - ) - max_lora_rank = ctx.w13_lora_a_stacked[0].shape[-2] - - if ctx.use_tuned_config: - hidden_size = hidden_states.shape[-1] - intermediate_size = ctx.w13_lora_b_stacked[0].shape[-2] - shrink_config = get_lora_op_configs( - op_type="fused_moe_lora_w13_shrink", - max_loras=ctx.max_loras, - batch=num_tokens, - hidden_size=hidden_size, - rank=max_lora_rank, - num_slices=ctx.w13_num_slices, - moe_intermediate_size=intermediate_size, - ) - expand_config = get_lora_op_configs( - op_type="fused_moe_lora_w13_expand", - max_loras=ctx.max_loras, - batch=num_tokens, - hidden_size=hidden_size, - rank=max_lora_rank, - num_slices=ctx.w13_num_slices, - moe_intermediate_size=intermediate_size, - ) - else: - get_config_func = functools.partial( - try_get_optimal_moe_lora_config, - w1_shape=w1.shape, - w2_shape=w2.shape, - rank=max_lora_rank, - top_k=ctx.top_k, - dtype=config_dtype, - M=num_tokens, - block_shape=block_shape, - ) - shrink_config = get_config_func(op_type="fused_moe_lora_w13_shrink") - expand_config = get_config_func(op_type="fused_moe_lora_w13_expand") - - shrink_config = _normalize_lora_config_keys(shrink_config) - expand_config = _normalize_lora_config_keys(expand_config) - - SPARSITY_FACTOR = 8 - naive_block_assignment = ( - expert_map is None - and num_tokens * ctx.top_k * SPARSITY_FACTOR - <= ctx.local_num_experts * ctx.max_loras - ) - - ( - token_lora_mapping, - sorted_token_ids_lora, - expert_ids_lora, - num_tokens_post_padded_lora, - ) = ctx.punica_wrapper.moe_lora_align_block_size( + ) -> tuple[ + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + return lora_context.punica_wrapper.add_lora_w13( + y, + x, + lora_context.w13_lora_a_stacked, + lora_context.w13_lora_b_stacked, topk_ids, - num_tokens, - shrink_config["BLOCK_SIZE_M"], - ctx.local_num_experts, - ctx.max_loras, - ctx.adapter_enabled, - expert_map, - naive_block_assignment=naive_block_assignment, - ) - - _sorted_token_ids_lora = sorted_token_ids_lora - _expert_ids_lora = expert_ids_lora - if _sorted_token_ids_lora is not None: - _expert_ids_lora = _expert_ids_lora.view(ctx.max_loras, -1) - _sorted_token_ids_lora = _sorted_token_ids_lora.view(ctx.max_loras, -1) - - ctx.punica_wrapper.add_lora_fused_moe( - intermediate_cache1.view(-1, top_k_num, intermediate_cache1.shape[-1]), - hidden_states, - ctx.w13_lora_a_stacked, - ctx.w13_lora_b_stacked, topk_weights, - _sorted_token_ids_lora, - _expert_ids_lora, - num_tokens_post_padded_lora, - max_lora_rank, - ctx.top_k, - shrink_config, - expand_config, - ctx.adapter_enabled, - fully_sharded=ctx.fully_sharded, - token_lora_mapping=token_lora_mapping, - ) - - return ( - sorted_token_ids_lora, - expert_ids_lora, - num_tokens_post_padded_lora, - token_lora_mapping, - shrink_config, - expand_config, + expert_map, + w1, + w2, + num_tokens, + top_k_num, + lora_context.max_loras, + lora_context.adapter_enabled, + lora_context.local_num_experts, + lora_context.top_k, + lora_context.w13_num_slices, + lora_context.fully_sharded, + lora_context.use_tuned_config, + block_shape=self.block_shape, ) - def _apply_w2_lora( + def apply_w2_lora( self, lora_context: "MoELoRAContext", - intermediate_cache2: torch.Tensor, - intermediate_cache3: torch.Tensor, + *, + y: torch.Tensor, + x: torch.Tensor, topk_weights: torch.Tensor, sorted_token_ids_lora: torch.Tensor | None, expert_ids_lora: torch.Tensor | None, @@ -1115,94 +1015,28 @@ def _apply_w2_lora( w1: torch.Tensor, w2: torch.Tensor, top_k_num: int, - *, - block_shape: list[int] | None = None, - ): - import functools - - from vllm.lora.layers.utils import try_get_optimal_moe_lora_config - from vllm.lora.ops.triton_ops.utils import get_lora_op_configs - from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str - from vllm.model_executor.layers.fused_moe.lora_context import ( - _normalize_lora_config_keys, - ) - - ctx = lora_context - config_dtype = _get_config_dtype_str( - dtype=intermediate_cache2.dtype, - use_fp8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - ) - max_lora_rank = ctx.w2_lora_a_stacked[0].shape[-2] - - if ctx.use_tuned_config: - hidden_size = intermediate_cache3.shape[-1] - intermediate_size = ctx.w2_lora_a_stacked[0].shape[-1] - shrink_config = get_lora_op_configs( - op_type="fused_moe_lora_w2_shrink", - max_loras=ctx.max_loras, - batch=num_tokens, - hidden_size=hidden_size, - rank=max_lora_rank, - num_slices=1, - moe_intermediate_size=intermediate_size, - ) - expand_config = get_lora_op_configs( - op_type="fused_moe_lora_w2_expand", - max_loras=ctx.max_loras, - batch=num_tokens, - hidden_size=hidden_size, - rank=max_lora_rank, - num_slices=1, - moe_intermediate_size=intermediate_size, - ) - else: - get_config_func = functools.partial( - try_get_optimal_moe_lora_config, - w1_shape=w1.shape, - w2_shape=w2.shape, - rank=max_lora_rank, - top_k=ctx.top_k, - dtype=config_dtype, - M=num_tokens, - block_shape=block_shape, - ) - shrink_config = get_config_func(op_type="fused_moe_lora_w2_shrink") - expand_config = get_config_func(op_type="fused_moe_lora_w2_expand") - - shrink_config = _normalize_lora_config_keys(shrink_config) - expand_config = _normalize_lora_config_keys(expand_config) - - _sorted_token_ids_lora = sorted_token_ids_lora - _expert_ids_lora = expert_ids_lora - if _sorted_token_ids_lora is not None: - _expert_ids_lora = _expert_ids_lora.view(ctx.max_loras, -1) - _sorted_token_ids_lora = _sorted_token_ids_lora.view(ctx.max_loras, -1) - - # w2_lora_b shape[-2] is hidden_size // tp_size when fully_sharded, - # matching divide(base_layer.hidden_size, tp_size). - shard_size_w2 = ctx.w2_lora_b_stacked[0].shape[-2] - offset = shard_size_w2 * ctx.tp_rank if ctx.fully_sharded else 0 - - ctx.punica_wrapper.add_lora_fused_moe( - intermediate_cache3, - intermediate_cache2, - ctx.w2_lora_a_stacked, - ctx.w2_lora_b_stacked, + ) -> None: + lora_context.punica_wrapper.add_lora_w2( + y, + x, + lora_context.w2_lora_a_stacked, + lora_context.w2_lora_b_stacked, topk_weights, - _sorted_token_ids_lora, - _expert_ids_lora, + sorted_token_ids_lora, + expert_ids_lora, num_tokens_post_padded_lora, - max_lora_rank, - ctx.top_k, - shrink_config, - expand_config, - ctx.adapter_enabled, - True, # is_w2 - fully_sharded=ctx.fully_sharded, - offset=offset, - token_lora_mapping=token_lora_mapping, + token_lora_mapping, + num_tokens, + w1, + w2, + top_k_num, + lora_context.max_loras, + lora_context.adapter_enabled, + lora_context.top_k, + lora_context.fully_sharded, + lora_context.tp_rank, + lora_context.use_tuned_config, + block_shape=self.block_shape, ) diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 2c18a7527fe5..ce744c613f15 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -7,7 +7,7 @@ import torch if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext import torch.nn.functional as F from torch.nn import Module diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index f533a031a18f..7d40825f1bbf 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -59,7 +59,7 @@ from vllm.transformers_utils.config import get_safetensors_params_metadata if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.models.utils import WeightsMapper diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 5e86fd8a1c06..a0c61b1bc187 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Union if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext import torch from packaging import version diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 63a1d5a00cd1..401cf79ef368 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext import torch diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index cbcbbce0eb2b..b39e7613b739 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -86,7 +86,7 @@ ) if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext from vllm.model_executor.models.utils import WeightsMapper ACTIVATION_SCHEMES = ["static", "dynamic"] diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 3dbbfabd8ad9..1e4fc22160f7 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext from vllm.model_executor.layers.quantization import QuantizationMethods import gguf diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 196437afcfd1..75abfed1fc44 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 2cbc180144fc..9906ecdbd54a 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -96,7 +96,7 @@ from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 56e47a27f1ac..f4b00c8074a6 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext import torch diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 4aa780fb513b..1f9b6c8a12ee 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -6,7 +6,7 @@ import torch if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext from vllm.config import get_current_vllm_config from vllm.logger import init_logger diff --git a/vllm/model_executor/layers/quantization/online/fp8.py b/vllm/model_executor/layers/quantization/online/fp8.py index de2a97b7745f..8b4f6ecde9e0 100644 --- a/vllm/model_executor/layers/quantization/online/fp8.py +++ b/vllm/model_executor/layers/quantization/online/fp8.py @@ -8,12 +8,12 @@ if TYPE_CHECKING: import vllm.model_executor.layers.fused_moe.modular_kernel as mk + from vllm.lora.lora_context import MoELoRAContext from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, ) - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.layers.fused_moe.oracle.fp8 import Fp8MoeBackend import vllm.envs as envs diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index d3885316a4a6..23fa1a6be4cb 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + from vllm.lora.lora_context import MoELoRAContext import torch From 9350a67324e01299f289531d3485fb5e8dacc01d Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 20 Apr 2026 10:42:29 +0000 Subject: [PATCH 03/24] Fix Signed-off-by: Jee Jee Li --- vllm/model_executor/layers/quantization/online/fp8.py | 1 - vllm/model_executor/layers/quantization/online/moe_base.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/online/fp8.py b/vllm/model_executor/layers/quantization/online/fp8.py index 63c3f4cf7420..9cb697289d7e 100644 --- a/vllm/model_executor/layers/quantization/online/fp8.py +++ b/vllm/model_executor/layers/quantization/online/fp8.py @@ -8,7 +8,6 @@ if TYPE_CHECKING: import vllm.model_executor.layers.fused_moe.modular_kernel as mk - from vllm.lora.lora_context import MoELoRAContext from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, diff --git a/vllm/model_executor/layers/quantization/online/moe_base.py b/vllm/model_executor/layers/quantization/online/moe_base.py index 7340544edf02..add721f8ae69 100644 --- a/vllm/model_executor/layers/quantization/online/moe_base.py +++ b/vllm/model_executor/layers/quantization/online/moe_base.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import abstractmethod +from typing import TYPE_CHECKING import torch From b3d1ea62a9172c088d52b619217047928078417e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 20 Apr 2026 11:49:56 +0000 Subject: [PATCH 04/24] Fix Signed-off-by: Jee Jee Li --- vllm/lora/punica_wrapper/punica_gpu.py | 3 ++- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 9 +++++++++ .../layers/fused_moe/experts/batched_deep_gemm_moe.py | 6 ++++++ .../layers/fused_moe/experts/deep_gemm_moe.py | 6 ++++++ .../fused_moe/experts/flashinfer_cutedsl_batched_moe.py | 6 ++++++ .../layers/fused_moe/experts/flashinfer_cutedsl_moe.py | 6 ++++++ .../fused_moe/experts/gpt_oss_triton_kernels_moe.py | 6 ++++-- .../layers/fused_moe/experts/trtllm_fp8_moe.py | 6 ++++++ .../layers/fused_moe/experts/trtllm_mxfp4_moe.py | 6 ++++++ .../layers/fused_moe/experts/trtllm_nvfp4_moe.py | 6 ++++++ vllm/model_executor/layers/fused_moe/fallback.py | 6 ++++++ .../layers/fused_moe/flashinfer_cutlass_moe.py | 6 ++++++ .../model_executor/layers/fused_moe/fused_batched_moe.py | 7 +++++++ vllm/model_executor/layers/fused_moe/fused_marlin_moe.py | 1 + vllm/model_executor/layers/fused_moe/fused_moe.py | 1 + .../layers/fused_moe/rocm_aiter_fused_moe.py | 5 +++++ vllm/model_executor/layers/fused_moe/xpu_fused_moe.py | 6 ++++++ .../compressed_tensors_moe_w4a4_mxfp4.py | 6 ++++++ .../compressed_tensors_moe_w4a4_nvfp4.py | 6 ++++++ .../compressed_tensors_moe_w4a8_fp8.py | 6 ++++++ .../compressed_tensors_moe_w8a8_fp8.py | 6 ++++++ .../compressed_tensors_moe_w8a8_int8.py | 6 ++++++ .../compressed_tensors_moe_w8a8_mxfp8.py | 6 ++++++ .../compressed_tensors_moe_wna16.py | 6 ++++++ .../compressed_tensors_moe_wna16_marlin.py | 5 +++++ 25 files changed, 136 insertions(+), 3 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index f521e1013e2d..d018b64c29a1 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -555,7 +555,7 @@ def add_lora_w13( ) = self.moe_lora_align_block_size( topk_ids, num_tokens, - shrink_config["BLOCK_SIZE_M"], + int(shrink_config.get("BLOCK_SIZE_M") or 64), local_num_experts, max_loras, adapter_enabled, @@ -674,6 +674,7 @@ def add_lora_w2( _sorted = sorted_token_ids_lora _eids = expert_ids_lora if _sorted is not None: + assert _eids is not None _eids = _eids.view(max_loras, -1) _sorted = _sorted.view(max_loras, -1) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index fdd802e7da3a..d99feb6763ec 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -2,8 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """CUTLASS based Fused MoE kernels.""" +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -343,6 +348,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE" @@ -768,6 +774,7 @@ def apply( workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids) n = w2.shape[2] * 2 @@ -1065,6 +1072,7 @@ def apply( workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids) n = w2.shape[2] * 2 @@ -1348,6 +1356,7 @@ def apply( workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE" diff --git a/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py index 2cb0bd7649f5..70053f043fd2 100644 --- a/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py @@ -2,8 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger @@ -401,6 +406,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens diff --git a/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py index 03341378a13c..1f06f43127f6 100644 --- a/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py @@ -1,8 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -245,6 +250,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): assert a1q_scale is not None assert a2_scale is None diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py index 5eaaf46739fc..65eb6627b457 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py @@ -1,8 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm.logger import init_logger @@ -149,6 +154,7 @@ def apply( workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool | None, + lora_context: "MoELoRAContext | None" = None, ): assert self.quant_dtype == "nvfp4", ( "Only nvfp4 quantization are currently supported." diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py index 5ce58220b073..b1928e9cd161 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py @@ -1,8 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( @@ -139,6 +144,7 @@ def apply( workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool | None, + lora_context: "MoELoRAContext | None" = None, ): assert self.quant_dtype == "nvfp4" assert a1q_scale is not None diff --git a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py index 5b31b4216fa9..e6b5abe39e34 100644 --- a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py @@ -623,6 +623,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): if self.quant_config is None: self.quant_config: FusedMoEQuantConfig = FUSED_MOE_UNQUANTIZED_CONFIG @@ -730,6 +731,7 @@ def apply( if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + global_topk_ids = topk_ids if expert_map is not None: topk_ids = expert_map[topk_ids] @@ -804,9 +806,9 @@ def apply( lora_context, y=act_input, x=hidden_states, - topk_ids=topk_ids, + topk_ids=global_topk_ids, topk_weights=topk_weights, - expert_map=None, # already applied above + expert_map=expert_map, w1=w1, w2=w2, num_tokens=M, diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index 1f0258fb657f..b63fbd20135a 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -1,8 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -168,6 +173,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): import flashinfer from flashinfer.fused_moe import Fp8QuantizationType, WeightLayout diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py index d084283360c4..06d5dbdd7fda 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py @@ -1,8 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( @@ -279,6 +284,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): topk = topk_ids.size(-1) local_num_experts = w1.size(0) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py index fc30815f7191..5a7d38bed366 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -1,9 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import flashinfer import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -187,6 +192,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] assert a1q_scale is not None diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py index 40741d52af50..8ff78cc79d66 100644 --- a/vllm/model_executor/layers/fused_moe/fallback.py +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -2,9 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from typing import TYPE_CHECKING import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig @@ -159,6 +163,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): experts = self._select_experts_impl(hidden_states, w1, w2) experts.apply( @@ -177,4 +182,5 @@ def apply( workspace2, expert_tokens_meta, apply_router_weight_on_input, + lora_context=lora_context, ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 0d47b0f31748..39134f1ba330 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -1,8 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import get_current_vllm_config from vllm.logger import init_logger @@ -265,6 +270,7 @@ def apply( workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool | None, + lora_context: "MoELoRAContext | None" = None, ): from flashinfer.fused_moe.core import ActivationType diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index e2b5a8f6764e..b721ac9af468 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -2,8 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused batched MoE kernel.""" +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( @@ -764,6 +769,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): assert hidden_states.dim() == 3 assert expert_tokens_meta is not None @@ -1000,6 +1006,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): # Check constraints. if self.quant_config.use_int4_w4a16: diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 4964f598a767..b7e3ada123c6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -932,6 +932,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): assert expert_tokens_meta is not None, "Num valid tokens per batch is required" return batched_fused_marlin_moe( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b39279dc2cd1..1ad53f0de9dd 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2253,6 +2253,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): # Check constraints. if self.quant_config.use_int4_w4a16: diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index d24bda101ffa..b21a4418fe71 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -2,9 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import IntEnum from functools import lru_cache +from typing import TYPE_CHECKING import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -412,6 +416,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): # TODO(rob): rocm_aiter_fused_experts uses self.quant_config's # a_scales for static quantization. Update this to fit better diff --git a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py index 9cc0ade288c7..741d5dc2a1ca 100644 --- a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py @@ -1,7 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( @@ -120,6 +125,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): topk = topk_ids.size(-1) xpu_fused_moe( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py index 57ebb961d487..4bcf675af2da 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py @@ -2,8 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -204,6 +209,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py index 09a216fd2cb1..f09f3bcc762d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py @@ -2,8 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -290,6 +295,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py index 74cb0b4f6e1f..68f29140d971 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py @@ -2,7 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch + +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, ) @@ -304,6 +309,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: if layer.enable_eplb: raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py index ed8ed79c50c6..a5ef925e06f5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py @@ -2,7 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch + +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, QuantizationStrategy, @@ -391,6 +396,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py index de155f9e179f..107abceaf00a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py @@ -2,7 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch + +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, QuantizationStrategy, @@ -143,6 +148,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py index 02e946b1b61e..84a6dd6ab792 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py @@ -1,8 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe import ( FusedMoE, @@ -192,6 +197,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py index f530a1a1df2b..0ccfaa124aab 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py @@ -2,7 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + import torch + +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, ) @@ -245,6 +250,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py index 216eed6372a9..d4aad87ed6cd 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py @@ -3,8 +3,12 @@ import enum from enum import Enum +from typing import TYPE_CHECKING import torch + +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, ) @@ -545,6 +549,7 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, + lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert self.kernel_backend == "Marlin" return fused_marlin_moe( From c2dbb14c31f3f25bce1af377566e2e161ca8f987 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 22 Apr 2026 13:41:57 +0000 Subject: [PATCH 05/24] Move --- .../layers/fused_moe/modular_kernel.py | 2 + .../layers/fused_moe/oracle/fp8.py | 3 -- .../layers/fused_moe/oracle/unquantized.py | 41 ++----------------- 3 files changed, 5 insertions(+), 41 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 29d152604c04..f784372af74b 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -583,6 +583,8 @@ def _make_reason(reason: str) -> str: return False, _make_reason(f"{activation_format.value} activation format") elif envs.VLLM_BATCH_INVARIANT and not cls._supports_batch_invariance(): return False, _make_reason("batch invariance") + elif moe_config.is_lora_enabled and not cls.supports_lora(): + return False, _make_reason("LoRA") return True, None @staticmethod diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 4420bb38731a..823ebc63621b 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -220,9 +220,6 @@ def select_fp8_moe_backend( Note: Shape-specific fallbacks may still occur at runtime. """ - if config.is_lora_enabled: - return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)[0] - # NOTE: the kernels are selected in the following order. AVAILABLE_BACKENDS = _get_priority_backends(config, weight_key, activation_key) diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py index cecaae798156..33bf7a0c75f6 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py +++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py @@ -163,11 +163,6 @@ def select_unquantized_moe_backend( if current_platform.is_out_of_tree(): return UnquantizedMoeBackend.OOT, None - if moe_config.is_lora_enabled: - return UnquantizedMoeBackend.TRITON, backend_to_kernel_cls( - UnquantizedMoeBackend.TRITON - ) - # NOTE: the kernels are selected in the following order. AVAILABLE_BACKENDS = _get_priority_backends(moe_config) @@ -214,8 +209,6 @@ def _return_or_raise( return backend, k_cls raise ValueError(_make_log_unsupported(backend, reason)) - require_lora = moe_config.is_lora_enabled - runner_backend = moe_config.moe_backend if runner_backend != "auto": requested_backend = map_unquantized_backend(runner_backend) @@ -225,20 +218,10 @@ def _return_or_raise( ): requested_backend = UnquantizedMoeBackend.BATCHED_TRITON - if require_lora: - k_cls = backend_to_kernel_cls(requested_backend) - if not k_cls.supports_lora(): - raise ValueError( - f"moe_backend='{runner_backend}' does not support LoRA. " - "Use moe_backend='triton' or moe_backend='auto'." - ) return _return_or_raise(requested_backend, moe_config, activation_format) # Handle explicit FlashInfer FP16 configuration. - # When LoRA is enabled, FlashInfer backends don't support it — skip this - # block entirely and fall through to the generic auto-selection loop which - # will filter by supports_lora(). - if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP16") and not require_lora: + if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP16"): if not envs.VLLM_USE_FLASHINFER_MOE_FP16: if UnquantizedMoeBackend.FLASHINFER_TRTLLM in AVAILABLE_BACKENDS: AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.FLASHINFER_TRTLLM) @@ -289,27 +272,10 @@ def _return_or_raise( AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.AITER) else: backend = UnquantizedMoeBackend.AITER - if require_lora: - k_cls = backend_to_kernel_cls(backend) - if not k_cls.supports_lora(): - logger.warning_once( - "VLLM_ROCM_USE_AITER_MOE=1 but AiterExperts does not " - "support LoRA; falling back to generic backend selection.", - scope="local", - ) - else: - return _return_or_raise(backend, moe_config, activation_format) - else: - return _return_or_raise(backend, moe_config, activation_format) + return _return_or_raise(backend, moe_config, activation_format) for backend in AVAILABLE_BACKENDS: k_cls = backend_to_kernel_cls(backend) - if require_lora and not k_cls.supports_lora(): - logger.debug_once( - f"Skipping MoE backend {backend.value}: does not support LoRA.", - scope="local", - ) - continue supported, reason = k_cls.is_supported_config( k_cls, moe_config, None, None, activation_format ) @@ -320,8 +286,7 @@ def _return_or_raise( logger.debug_once(_make_log_unsupported(backend, reason), scope="local") raise NotImplementedError( - "No Unquantized MoE backend supports the deployment configuration" - + (" with LoRA enabled." if require_lora else ".") + "No Unquantized MoE backend supports the deployment configuration." ) From 550e19d83e9a74767cf2add51b466a63b5290063 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 22 Apr 2026 16:08:16 +0000 Subject: [PATCH 06/24] Move --- vllm/model_executor/layers/fused_moe/fused_marlin_moe.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 95bc26cb6481..63e471dc3e36 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -758,6 +758,7 @@ def apply( is_k_full=self.is_k_full, input_dtype=self.input_dtype, ) + return # LoRA path: wrap activation_func and moe_sum to inject LoRA at the # two natural injection points. @@ -765,6 +766,7 @@ def apply( # Marlin uses moe_align_block_size (same as TritonExperts) so # intermediate_cache1 is indexed by flat (token, expert) pair index, # which is compatible with add_lora_fused_moe's scatter mechanism. + ctx = lora_context M = hidden_states.size(0) top_k_num = topk_ids.size(1) @@ -777,6 +779,7 @@ def activation_with_lora( ) -> None: # act_input = intermediate_cache1 (M*topk, 2N for gated) # act_output = intermediate_cache2 (M*topk, N) + ( sorted_token_ids_lora, expert_ids_lora, From d1ae8080ac91c89c30ed420f00b34d652451c0fd Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 22 Apr 2026 16:41:18 +0000 Subject: [PATCH 07/24] Fix Signed-off-by: Jee Jee Li --- .../layers/fused_moe/experts/nvfp4_emulation_moe.py | 7 +++++++ .../layers/fused_moe/experts/ocp_mx_emulation_moe.py | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py b/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py index f1a0ee7ac52d..18bd684ff2c5 100644 --- a/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py @@ -11,6 +11,8 @@ is applied on `a13`, `a2`. """ +from typing import TYPE_CHECKING + import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -20,6 +22,9 @@ FusedMoEConfig, FusedMoEQuantConfig, ) + +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( @@ -97,6 +102,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): """ Apply emulated quantized MoE computation. @@ -161,4 +167,5 @@ def apply( workspace2=workspace2, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, + lora_context=lora_context, ) diff --git a/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py b/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py index 9fb163ef42af..a8d8b015c834 100644 --- a/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py @@ -11,8 +11,14 @@ is applied on activations via `moe_kernel_quantize_input`. """ +from typing import TYPE_CHECKING + import torch +if TYPE_CHECKING: + from vllm.lora.lora_context import MoELoRAContext + + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -137,6 +143,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + lora_context: "MoELoRAContext | None" = None, ): """ Apply emulated quantized MoE computation. @@ -183,4 +190,5 @@ def apply( workspace2=workspace2, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, + lora_context=lora_context, ) From 61d7746a7d8b95550fd42aac423ef3ff3d0e762c Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 23 Apr 2026 00:11:59 +0000 Subject: [PATCH 08/24] Remove unrelated change Signed-off-by: Jee Jee Li --- vllm/model_executor/layers/fused_moe/modular_kernel.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index f784372af74b..db80fc8c158c 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -257,16 +257,6 @@ class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize): described above for the Modular case. """ - @staticmethod - def supports_lora() -> bool: - """Return True if this prepare/finalize impl can propagate LoRA context. - - Non-EP implementations support LoRA by default. EP-aware - implementations must override this to True once they can dispatch - lora_ids alongside hidden states in their all2all call (Task 2). - """ - return True - @abstractmethod def prepare( self, From ea4a8fd1d1dc57f0b9ffa427c54a2e492ce965ee Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 23 Apr 2026 00:28:13 +0000 Subject: [PATCH 09/24] Move Signed-off-by: Jee Jee Li --- vllm/lora/layers/fused_moe.py | 2 +- vllm/lora/lora_context.py | 66 ------------------- vllm/lora/punica_wrapper/punica_gpu.py | 12 ++-- .../layers/fused_moe/cutlass_moe.py | 2 +- .../experts/batched_deep_gemm_moe.py | 2 +- .../layers/fused_moe/experts/deep_gemm_moe.py | 2 +- .../experts/flashinfer_cutedsl_batched_moe.py | 2 +- .../experts/flashinfer_cutedsl_moe.py | 2 +- .../experts/gpt_oss_triton_kernels_moe.py | 2 +- .../fused_moe/experts/nvfp4_emulation_moe.py | 2 +- .../fused_moe/experts/ocp_mx_emulation_moe.py | 2 +- .../fused_moe/experts/trtllm_fp8_moe.py | 2 +- .../fused_moe/experts/trtllm_mxfp4_moe.py | 2 +- .../fused_moe/experts/trtllm_nvfp4_moe.py | 2 +- .../layers/fused_moe/fallback.py | 2 +- .../fused_moe/flashinfer_cutlass_moe.py | 2 +- .../layers/fused_moe/fused_batched_moe.py | 2 +- .../layers/fused_moe/fused_marlin_moe.py | 2 +- .../layers/fused_moe/fused_moe.py | 2 +- .../layers/fused_moe/fused_moe_method_base.py | 2 +- .../fused_moe/fused_moe_modular_method.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 2 +- .../layers/fused_moe/modular_kernel.py | 2 +- .../layers/fused_moe/rocm_aiter_fused_moe.py | 2 +- .../fused_moe/unquantized_fused_moe_method.py | 2 +- .../layers/fused_moe/xpu_fused_moe.py | 2 +- .../layers/quantization/awq_marlin.py | 2 +- .../layers/quantization/bitsandbytes.py | 2 +- .../compressed_tensors_moe_w4a4_mxfp4.py | 2 +- .../compressed_tensors_moe_w4a4_nvfp4.py | 2 +- .../compressed_tensors_moe_w4a8_fp8.py | 2 +- .../compressed_tensors_moe_w8a8_fp8.py | 2 +- .../compressed_tensors_moe_w8a8_int8.py | 2 +- .../compressed_tensors_moe_w8a8_mxfp8.py | 2 +- .../compressed_tensors_moe_wna16.py | 2 +- .../compressed_tensors_moe_wna16_marlin.py | 2 +- .../model_executor/layers/quantization/fp8.py | 2 +- .../layers/quantization/gguf.py | 2 +- .../layers/quantization/gptq_marlin.py | 2 +- .../layers/quantization/modelopt.py | 2 +- .../layers/quantization/moe_wna16.py | 2 +- .../layers/quantization/mxfp4.py | 2 +- .../layers/quantization/online/moe_base.py | 2 +- .../layers/quantization/quark/quark_moe.py | 2 +- 44 files changed, 48 insertions(+), 114 deletions(-) delete mode 100644 vllm/lora/lora_context.py diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 274299325018..51c764ac9a5a 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -13,11 +13,11 @@ ) from vllm.distributed.utils import divide from vllm.lora.layers.base import BaseLayerWithLoRA -from vllm.lora.lora_context import MoELoRAContext from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( FusedMoEModularMethod, ) +from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEExpertsModular, FusedMoEKernel, diff --git a/vllm/lora/lora_context.py b/vllm/lora/lora_context.py deleted file mode 100644 index bb63538daabb..000000000000 --- a/vllm/lora/lora_context.py +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass -from typing import Any - -import torch - - -def _normalize_lora_config_keys( - config: dict[str, int | None], -) -> dict[str, int | None]: - """Normalize Triton config dict keys to uppercase BLOCK_SIZE_* format.""" - out: dict[str, int | None] = {} - for key, val in config.items(): - if key.islower(): - if key.startswith("block_"): - nk = "BLOCK_SIZE_" + key.split("_")[-1].upper() - else: - nk = key.upper() - else: - nk = key - out[nk] = val - return out - - -@dataclass -class MoELoRAContext: - """ - Carries all LoRA state for one MoE forward pass. - - Built by FusedMoEWithLoRA.forward() and propagated explicitly through the - modular kernel path (FusedMoEKernel -> FusedMoEExpertsModular.apply) so - that TritonExperts.apply() can compute the LoRA contribution inline, - replacing the decorator-based monkey-patch approach. - - Typed as Any for punica_wrapper to avoid a circular import at module load - time: vllm.lora imports vllm.model_executor.layers.fused_moe, so the - reverse at module level would be circular. The actual type is - PunicaWrapperBase from vllm.lora.punica_wrapper. - """ - - # LoRA weight tensors (same shapes as FusedMoEWithLoRA attributes) - w13_lora_a_stacked: tuple[torch.Tensor, ...] - w13_lora_b_stacked: tuple[torch.Tensor, ...] - w2_lora_a_stacked: tuple[torch.Tensor, ...] - w2_lora_b_stacked: tuple[torch.Tensor, ...] - - # (max_loras + 1,) int32; slot 0 is the "no-adapter" sentinel - adapter_enabled: torch.Tensor - - # Metadata - max_loras: int - top_k: int - w13_num_slices: int # 2 = gated (gate + up), 1 = non-gated or 3D-fused - fully_sharded: bool - tp_rank: int - tp_size: int - local_num_experts: int - - # PunicaWrapperBase instance (typed Any to avoid circular import) - punica_wrapper: Any - - # Whether VLLM_TUNED_CONFIG_FOLDER is set; selects get_lora_op_configs vs - # try_get_optimal_moe_lora_config for Triton kernel tile configs. - use_tuned_config: bool diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index d018b64c29a1..a8f94e67ea2e 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -491,11 +491,11 @@ def add_lora_w13( import functools from vllm.lora.layers.utils import try_get_optimal_moe_lora_config - from vllm.lora.lora_context import ( - _normalize_lora_config_keys, - ) from vllm.lora.ops.triton_ops.utils import get_lora_op_configs from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str + from vllm.model_executor.layers.fused_moe.lora_context import ( + _normalize_lora_config_keys, + ) config_dtype = _get_config_dtype_str( dtype=x.dtype, @@ -621,11 +621,11 @@ def add_lora_w2( import functools from vllm.lora.layers.utils import try_get_optimal_moe_lora_config - from vllm.lora.lora_context import ( - _normalize_lora_config_keys, - ) from vllm.lora.ops.triton_ops.utils import get_lora_op_configs from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str + from vllm.model_executor.layers.fused_moe.lora_context import ( + _normalize_lora_config_keys, + ) config_dtype = _get_config_dtype_str( dtype=x.dtype, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d99feb6763ec..292847b06541 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -7,7 +7,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops diff --git a/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py index 28724a280df4..4166050e7a05 100644 --- a/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py @@ -7,7 +7,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.forward_context import get_forward_context, is_forward_context_available diff --git a/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py index 1f06f43127f6..d0d632752d34 100644 --- a/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py @@ -6,7 +6,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py index 65eb6627b457..346cfaa7e8bd 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py @@ -6,7 +6,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py index b1928e9cd161..da3ed5cba10b 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py @@ -6,7 +6,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation diff --git a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py index e6b5abe39e34..e8c70da6840f 100644 --- a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py @@ -6,7 +6,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops diff --git a/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py b/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py index 18bd684ff2c5..16d6b5b401ca 100644 --- a/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py @@ -24,7 +24,7 @@ ) if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( diff --git a/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py b/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py index a8d8b015c834..c3b4ec121f7a 100644 --- a/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py @@ -16,7 +16,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index b63fbd20135a..5011117d005f 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -6,7 +6,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py index 06d5dbdd7fda..6d80bd9973e6 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py @@ -6,7 +6,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py index aff4040d8bf6..a2d2211fcd45 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -6,7 +6,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py index 8ff78cc79d66..1b4e62c85b79 100644 --- a/vllm/model_executor/layers/fused_moe/fallback.py +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -7,7 +7,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 65c188d9f8ba..554847eb8785 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -6,7 +6,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import get_current_vllm_config diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 60deab5ce4a8..8ce462a0256c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -7,7 +7,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 63e471dc3e36..04dd4f3b7b52 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -8,7 +8,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bb82e3f5dc87..1a9ca1e014a1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 642e7b497f19..af71ca6bf254 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -9,7 +9,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 9b281a23dad4..0f4f7bdab316 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -9,7 +9,7 @@ from vllm.logger import init_logger if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f1b82c7aa310..d4f9efba0a8a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Literal, cast, get_args, overload if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch from torch.nn.parameter import UninitializedParameter diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index db80fc8c158c..c312925bbbbb 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -10,7 +10,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.envs as envs from vllm.logger import init_logger diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index b21a4418fe71..ce9da37cfc91 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -7,7 +7,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._aiter_ops import rocm_aiter_ops diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index ce744c613f15..2c18a7527fe5 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -7,7 +7,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch.nn.functional as F from torch.nn import Module diff --git a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py index d5224324f364..cc00bca9a169 100644 --- a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py @@ -5,7 +5,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 7d40825f1bbf..f533a031a18f 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -59,7 +59,7 @@ from vllm.transformers_utils.config import get_safetensors_params_metadata if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.models.utils import WeightsMapper diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index a0c61b1bc187..5e86fd8a1c06 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Union if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch from packaging import version diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py index 4bcf675af2da..d1b78a86130c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py @@ -7,7 +7,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py index f09f3bcc762d..f85c37852eac 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py @@ -7,7 +7,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py index 646b7b95c0b6..f4075780b8b6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py @@ -7,7 +7,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py index a5ef925e06f5..c0b96b9b0fc5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py @@ -7,7 +7,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, QuantizationStrategy, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py index ab003ad6ef05..3835a813d813 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py @@ -7,7 +7,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, QuantizationStrategy, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py index 84a6dd6ab792..34172049120e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py @@ -6,7 +6,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe import ( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py index 0ccfaa124aab..f0d79599e7b1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py @@ -7,7 +7,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py index d4aad87ed6cd..df8dca28bdbd 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py @@ -8,7 +8,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9264c956eba0..2b1e2bf8e2a7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -86,7 +86,7 @@ ) if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.models.utils import WeightsMapper ACTIVATION_SCHEMES = ["static", "dynamic"] diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 1e4fc22160f7..3dbbfabd8ad9 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.layers.quantization import QuantizationMethods import gguf diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 75abfed1fc44..196437afcfd1 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 056db87e9038..7c2280474bd2 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -96,7 +96,7 @@ from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index f4b00c8074a6..56e47a27f1ac 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 1f9b6c8a12ee..4aa780fb513b 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -6,7 +6,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.config import get_current_vllm_config from vllm.logger import init_logger diff --git a/vllm/model_executor/layers/quantization/online/moe_base.py b/vllm/model_executor/layers/quantization/online/moe_base.py index add721f8ae69..7d3f63637900 100644 --- a/vllm/model_executor/layers/quantization/online/moe_base.py +++ b/vllm/model_executor/layers/quantization/online/moe_base.py @@ -7,7 +7,7 @@ import torch if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 3f02ffbb313a..61cbd6781e6b 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from vllm.lora.lora_context import MoELoRAContext + from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch From 56002215dccb82df421c563fd635f65ff7d59823 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 23 Apr 2026 00:44:17 +0000 Subject: [PATCH 10/24] Move Signed-off-by: Jee Jee Li --- vllm/lora/ops/triton_ops/utils.py | 17 +++++++++++++++++ vllm/lora/punica_wrapper/punica_gpu.py | 12 ++++++------ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 0ab52e698318..af0a3157f02a 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -321,3 +321,20 @@ def supports_pdl(device: torch.device | None = None) -> bool: def supports_tma(device: torch.device | None = None) -> bool: # TMA requires compute capability SM90 or above return current_platform.is_cuda() and current_platform.has_device_capability(90) + + +def _normalize_lora_config_keys( + config: dict[str, int | None], +) -> dict[str, int | None]: + """Normalize Triton config dict keys to uppercase BLOCK_SIZE_* format.""" + out: dict[str, int | None] = {} + for key, val in config.items(): + if key.islower(): + if key.startswith("block_"): + nk = "BLOCK_SIZE_" + key.split("_")[-1].upper() + else: + nk = key.upper() + else: + nk = key + out[nk] = val + return out diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index a8f94e67ea2e..0b28a557e260 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -491,11 +491,11 @@ def add_lora_w13( import functools from vllm.lora.layers.utils import try_get_optimal_moe_lora_config - from vllm.lora.ops.triton_ops.utils import get_lora_op_configs - from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str - from vllm.model_executor.layers.fused_moe.lora_context import ( + from vllm.lora.ops.triton_ops.utils import ( _normalize_lora_config_keys, + get_lora_op_configs, ) + from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str config_dtype = _get_config_dtype_str( dtype=x.dtype, @@ -621,11 +621,11 @@ def add_lora_w2( import functools from vllm.lora.layers.utils import try_get_optimal_moe_lora_config - from vllm.lora.ops.triton_ops.utils import get_lora_op_configs - from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str - from vllm.model_executor.layers.fused_moe.lora_context import ( + from vllm.lora.ops.triton_ops.utils import ( _normalize_lora_config_keys, + get_lora_op_configs, ) + from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str config_dtype = _get_config_dtype_str( dtype=x.dtype, From cd29a490a7c8fe08e9df8a972be800f166c32c9e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 23 Apr 2026 00:45:00 +0000 Subject: [PATCH 11/24] Move Signed-off-by: Jee Jee Li --- .../layers/fused_moe/lora_context.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/lora_context.py diff --git a/vllm/model_executor/layers/fused_moe/lora_context.py b/vllm/model_executor/layers/fused_moe/lora_context.py new file mode 100644 index 000000000000..92500a7bb47d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/lora_context.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass + +import torch + +from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase + + +@dataclass +class MoELoRAContext: + """ + Carries all LoRA state for one MoE forward pass. + + Built by FusedMoEWithLoRA.forward() and propagated explicitly through the + modular kernel path (FusedMoEKernel -> FusedMoEExpertsModular.apply) so + that TritonExperts.apply() can compute the LoRA contribution inline, + replacing the decorator-based monkey-patch approach. + """ + + # LoRA weight tensors (same shapes as FusedMoEWithLoRA attributes) + w13_lora_a_stacked: tuple[torch.Tensor, ...] + w13_lora_b_stacked: tuple[torch.Tensor, ...] + w2_lora_a_stacked: tuple[torch.Tensor, ...] + w2_lora_b_stacked: tuple[torch.Tensor, ...] + + # (max_loras + 1,) int32; slot 0 is the "no-adapter" sentinel + adapter_enabled: torch.Tensor + + # Metadata + max_loras: int + top_k: int + w13_num_slices: int # 2 = gated (gate + up), 1 = non-gated or 3D-fused + fully_sharded: bool + tp_rank: int + tp_size: int + local_num_experts: int + + punica_wrapper: PunicaWrapperBase + + # Whether VLLM_TUNED_CONFIG_FOLDER is set; selects get_lora_op_configs vs + # try_get_optimal_moe_lora_config for Triton kernel tile configs. + use_tuned_config: bool From 7707bf33234c924eabff11280625909f5ef4bae0 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 23 Apr 2026 01:27:14 +0000 Subject: [PATCH 12/24] Add lora experts mixin Signed-off-by: Jee Jee Li --- .../experts/gpt_oss_triton_kernels_moe.py | 7 +- .../layers/fused_moe/fused_marlin_moe.py | 7 +- .../layers/fused_moe/fused_moe.py | 7 +- .../layers/fused_moe/lora_experts_mixin.py | 106 ++++++++++++++++++ .../layers/fused_moe/modular_kernel.py | 80 ------------- 5 files changed, 112 insertions(+), 95 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/lora_experts_mixin.py diff --git a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py index e8c70da6840f..b438dcfc1157 100644 --- a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py @@ -20,6 +20,7 @@ FusedMoEQuantConfig, RoutingMethodType, ) +from vllm.model_executor.layers.fused_moe.lora_experts_mixin import LoRAExpertsMixin from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) @@ -659,7 +660,7 @@ def apply( ) -class UnfusedOAITritonExperts(BaseOAITritonExperts): +class UnfusedOAITritonExperts(LoRAExpertsMixin, BaseOAITritonExperts): """ A Triton based MoE expert class that operates on expert standard format and explicitly keeps the activation and reduction (moe_sum) steps @@ -669,10 +670,6 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): One use case for it is to inject LoRA modules on the activation and moe_sum. """ - @staticmethod - def supports_lora() -> bool: - return True - @staticmethod def _supports_activation(activation: MoEActivation) -> bool: return activation in [ diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 04dd4f3b7b52..70dd29639e67 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -21,6 +21,7 @@ FusedMoEParallelConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.lora_experts_mixin import LoRAExpertsMixin from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( batched_moe_align_block_size, moe_align_block_size, @@ -655,13 +656,9 @@ def moe_problem_size( return E, M, N, K, topk -class MarlinExperts(MarlinExpertsBase): +class MarlinExperts(LoRAExpertsMixin, MarlinExpertsBase): """Marlin-based fused MoE expert implementation.""" - @staticmethod - def supports_lora() -> bool: - return True - def supports_expert_map(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1a9ca1e014a1..2b27853a8bd8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -28,6 +28,7 @@ FusedMoEQuantConfig, _get_config_dtype_str, ) +from vllm.model_executor.layers.fused_moe.lora_experts_mixin import LoRAExpertsMixin from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size, ) @@ -1890,7 +1891,7 @@ def fused_experts_impl( return out_hidden_states -class TritonExperts(mk.FusedMoEExpertsModular): +class TritonExperts(LoRAExpertsMixin, mk.FusedMoEExpertsModular): """Triton-based fused MoE expert implementation.""" def __init__( @@ -1962,10 +1963,6 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo def _supports_batch_invariance(): return True - @staticmethod - def supports_lora() -> bool: - return True - def supports_expert_map(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py b/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py new file mode 100644 index 000000000000..3501251fd4a8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext + + +class LoRAExpertsMixin: + """ + Mixin for FusedMoEExpertsModular subclasses that natively handle + MoELoRAContext inside their apply() implementation. + + Mixing this class in: + - Flips supports_lora() to True so _can_fused_experts_support lets + LoRA through the gate check. + - Provides apply_w13_lora / apply_w2_lora helpers that dispatch to + the PunicaWrapper kernels. + + Mixin callers rely on self.block_shape from FusedMoEExperts, so this + must be mixed into a FusedMoEExperts subclass. + """ + + @staticmethod + def supports_lora() -> bool: + return True + + def apply_w13_lora( + self, + lora_context: MoELoRAContext, + *, + y: torch.Tensor, + x: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_map: torch.Tensor | None, + w1: torch.Tensor, + w2: torch.Tensor, + num_tokens: int, + top_k_num: int, + ) -> tuple[ + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + return lora_context.punica_wrapper.add_lora_w13( + y, + x, + lora_context.w13_lora_a_stacked, + lora_context.w13_lora_b_stacked, + topk_ids, + topk_weights, + expert_map, + w1, + w2, + num_tokens, + top_k_num, + lora_context.max_loras, + lora_context.adapter_enabled, + lora_context.local_num_experts, + lora_context.top_k, + lora_context.w13_num_slices, + lora_context.fully_sharded, + lora_context.use_tuned_config, + block_shape=self.block_shape, + ) + + def apply_w2_lora( + self, + lora_context: MoELoRAContext, + *, + y: torch.Tensor, + x: torch.Tensor, + topk_weights: torch.Tensor, + sorted_token_ids_lora: torch.Tensor | None, + expert_ids_lora: torch.Tensor | None, + num_tokens_post_padded_lora: torch.Tensor | None, + token_lora_mapping: torch.Tensor | None, + num_tokens: int, + w1: torch.Tensor, + w2: torch.Tensor, + top_k_num: int, + ) -> None: + lora_context.punica_wrapper.add_lora_w2( + y, + x, + lora_context.w2_lora_a_stacked, + lora_context.w2_lora_b_stacked, + topk_weights, + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + token_lora_mapping, + num_tokens, + w1, + w2, + top_k_num, + lora_context.max_loras, + lora_context.adapter_enabled, + lora_context.top_k, + lora_context.fully_sharded, + lora_context.tp_rank, + lora_context.use_tuned_config, + block_shape=self.block_shape, + ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index c312925bbbbb..94322ed29aa4 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -951,86 +951,6 @@ def apply( """ raise NotImplementedError - def apply_w13_lora( - self, - lora_context: "MoELoRAContext", - *, - y: torch.Tensor, - x: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor | None, - w1: torch.Tensor, - w2: torch.Tensor, - num_tokens: int, - top_k_num: int, - ) -> tuple[ - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - ]: - return lora_context.punica_wrapper.add_lora_w13( - y, - x, - lora_context.w13_lora_a_stacked, - lora_context.w13_lora_b_stacked, - topk_ids, - topk_weights, - expert_map, - w1, - w2, - num_tokens, - top_k_num, - lora_context.max_loras, - lora_context.adapter_enabled, - lora_context.local_num_experts, - lora_context.top_k, - lora_context.w13_num_slices, - lora_context.fully_sharded, - lora_context.use_tuned_config, - block_shape=self.block_shape, - ) - - def apply_w2_lora( - self, - lora_context: "MoELoRAContext", - *, - y: torch.Tensor, - x: torch.Tensor, - topk_weights: torch.Tensor, - sorted_token_ids_lora: torch.Tensor | None, - expert_ids_lora: torch.Tensor | None, - num_tokens_post_padded_lora: torch.Tensor | None, - token_lora_mapping: torch.Tensor | None, - num_tokens: int, - w1: torch.Tensor, - w2: torch.Tensor, - top_k_num: int, - ) -> None: - lora_context.punica_wrapper.add_lora_w2( - y, - x, - lora_context.w2_lora_a_stacked, - lora_context.w2_lora_b_stacked, - topk_weights, - sorted_token_ids_lora, - expert_ids_lora, - num_tokens_post_padded_lora, - token_lora_mapping, - num_tokens, - w1, - w2, - top_k_num, - lora_context.max_loras, - lora_context.adapter_enabled, - lora_context.top_k, - lora_context.fully_sharded, - lora_context.tp_rank, - lora_context.use_tuned_config, - block_shape=self.block_shape, - ) - class FusedMoEExpertsMonolithic(FusedMoEExperts): """ From 166386e3824e84a2c380757afeb30be6ea61ca66 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 23 Apr 2026 02:07:36 +0000 Subject: [PATCH 13/24] OPT Signed-off-by: Jee Jee Li --- vllm/lora/layers/fused_moe.py | 7 +++--- .../layers/fused_moe/cutlass_moe.py | 9 -------- .../experts/batched_deep_gemm_moe.py | 6 ----- .../layers/fused_moe/experts/deep_gemm_moe.py | 5 ----- .../experts/flashinfer_cutedsl_batched_moe.py | 5 ----- .../experts/flashinfer_cutedsl_moe.py | 5 ----- .../experts/gpt_oss_triton_kernels_moe.py | 8 +------ .../fused_moe/experts/nvfp4_emulation_moe.py | 7 ------ .../fused_moe/experts/ocp_mx_emulation_moe.py | 8 ------- .../fused_moe/experts/trtllm_fp8_moe.py | 5 ----- .../fused_moe/experts/trtllm_mxfp4_moe.py | 5 ----- .../fused_moe/experts/trtllm_nvfp4_moe.py | 5 ----- .../layers/fused_moe/fallback.py | 6 ----- .../fused_moe/flashinfer_cutlass_moe.py | 5 ----- .../layers/fused_moe/fused_batched_moe.py | 7 ------ .../layers/fused_moe/fused_marlin_moe.py | 10 ++------- .../layers/fused_moe/fused_moe.py | 8 ++----- .../layers/fused_moe/fused_moe_method_base.py | 5 ----- .../fused_moe/fused_moe_modular_method.py | 7 ------ vllm/model_executor/layers/fused_moe/layer.py | 11 +--------- .../layers/fused_moe/lora_experts_mixin.py | 7 ++++++ .../layers/fused_moe/modular_kernel.py | 22 ++++--------------- .../layers/fused_moe/rocm_aiter_fused_moe.py | 5 ----- .../layers/fused_moe/runner/moe_runner.py | 1 - .../fused_moe/unquantized_fused_moe_method.py | 12 +--------- .../layers/fused_moe/xpu_fused_moe.py | 5 ----- .../layers/quantization/awq_marlin.py | 2 -- .../layers/quantization/bitsandbytes.py | 6 +---- .../compressed_tensors_moe_w4a4_mxfp4.py | 6 ----- .../compressed_tensors_moe_w4a4_nvfp4.py | 6 ----- .../compressed_tensors_moe_w4a8_fp8.py | 6 ----- .../compressed_tensors_moe_w8a8_fp8.py | 6 ----- .../compressed_tensors_moe_w8a8_int8.py | 6 ----- .../compressed_tensors_moe_w8a8_mxfp8.py | 5 ----- .../compressed_tensors_moe_wna16.py | 6 ----- .../compressed_tensors_moe_wna16_marlin.py | 5 ----- .../model_executor/layers/quantization/fp8.py | 2 -- .../layers/quantization/gguf.py | 2 -- .../layers/quantization/gptq_marlin.py | 6 +---- .../layers/quantization/modelopt.py | 4 ---- .../layers/quantization/moe_wna16.py | 6 +---- .../layers/quantization/mxfp4.py | 6 ----- .../layers/quantization/online/moe_base.py | 4 ---- .../layers/quantization/quark/quark_moe.py | 9 +------- 44 files changed, 26 insertions(+), 253 deletions(-) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 51c764ac9a5a..c1542f1466cf 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -68,9 +68,10 @@ def __init__(self, base_layer: FusedMoE) -> None: f"{type(moe_kernel.fused_experts).__name__} does not support LoRA. " "For unquantized MoE, set moe_backend='triton' or moe_backend='auto' " "(auto selects Triton automatically when LoRA is enabled). " - "For quantized MoE, implement supports_lora() -> True and handle " - "lora_context in apply()." + "For quantized MoE, mix LoRAExpertsMixin into the experts class " + "and consume self._lora_context in apply()." ) + self._fused_experts = moe_kernel.fused_experts self.base_layer._replace_quant_method( FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel) ) @@ -331,7 +332,7 @@ def set_lora( def set_mapping(self, punica_wrapper): super().set_mapping(punica_wrapper) - self.base_layer._lora_context = self._build_lora_context() + self._fused_experts.set_lora_context(self._build_lora_context()) def forward(self, *args, **kwargs): return self.base_layer.forward(*args, **kwargs) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 292847b06541..fdd802e7da3a 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -2,13 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """CUTLASS based Fused MoE kernels.""" -from typing import TYPE_CHECKING - import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -348,7 +343,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE" @@ -774,7 +768,6 @@ def apply( workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids) n = w2.shape[2] * 2 @@ -1072,7 +1065,6 @@ def apply( workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids) n = w2.shape[2] * 2 @@ -1356,7 +1348,6 @@ def apply( workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE" diff --git a/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py index 4166050e7a05..fad39b3e9d4a 100644 --- a/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py @@ -2,13 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING - import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger @@ -406,7 +401,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens diff --git a/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py index d0d632752d34..ac233f33f4f3 100644 --- a/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py @@ -1,13 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -250,7 +246,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): assert a1q_scale is not None assert a2_scale is None diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py index 346cfaa7e8bd..3ca4c2e2892f 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py @@ -1,13 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm.logger import init_logger @@ -154,7 +150,6 @@ def apply( workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool | None, - lora_context: "MoELoRAContext | None" = None, ): assert self.quant_dtype == "nvfp4", ( "Only nvfp4 quantization are currently supported." diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py index da3ed5cba10b..f1c1ce2e1958 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py @@ -1,13 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( @@ -144,7 +140,6 @@ def apply( workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool | None, - lora_context: "MoELoRAContext | None" = None, ): assert self.quant_dtype == "nvfp4" assert a1q_scale is not None diff --git a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py index b438dcfc1157..41898060dc3c 100644 --- a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py @@ -1,13 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING - import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops @@ -624,7 +619,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): if self.quant_config is None: self.quant_config: FusedMoEQuantConfig = FUSED_MOE_UNQUANTIZED_CONFIG @@ -721,7 +715,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): # Use local variable to help mypy narrow the type after None check quant_config = self.quant_config @@ -793,6 +786,7 @@ def apply( expert_ids_lora = None num_tokens_post_padded_lora = None token_lora_mapping = None + lora_context = self._lora_context if lora_context is not None: ( sorted_token_ids_lora, diff --git a/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py b/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py index 16d6b5b401ca..f1a0ee7ac52d 100644 --- a/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py @@ -11,8 +11,6 @@ is applied on `a13`, `a2`. """ -from typing import TYPE_CHECKING - import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -22,9 +20,6 @@ FusedMoEConfig, FusedMoEQuantConfig, ) - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( @@ -102,7 +97,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): """ Apply emulated quantized MoE computation. @@ -167,5 +161,4 @@ def apply( workspace2=workspace2, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - lora_context=lora_context, ) diff --git a/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py b/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py index c3b4ec121f7a..9fb163ef42af 100644 --- a/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py @@ -11,14 +11,8 @@ is applied on activations via `moe_kernel_quantize_input`. """ -from typing import TYPE_CHECKING - import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -143,7 +137,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): """ Apply emulated quantized MoE computation. @@ -190,5 +183,4 @@ def apply( workspace2=workspace2, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - lora_context=lora_context, ) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index 5011117d005f..a34cf0add9d4 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -1,13 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -173,7 +169,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): import flashinfer from flashinfer.fused_moe import Fp8QuantizationType, WeightLayout diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py index 6d80bd9973e6..8722cbbdf4c7 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py @@ -1,13 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( @@ -284,7 +280,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): topk = topk_ids.size(-1) local_num_experts = w1.size(0) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py index a2d2211fcd45..c6689bf9fed7 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -1,13 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -191,7 +187,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): import flashinfer diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py index 1b4e62c85b79..40741d52af50 100644 --- a/vllm/model_executor/layers/fused_moe/fallback.py +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -2,13 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import TYPE_CHECKING import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig @@ -163,7 +159,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): experts = self._select_experts_impl(hidden_states, w1, w2) experts.apply( @@ -182,5 +177,4 @@ def apply( workspace2, expert_tokens_meta, apply_router_weight_on_input, - lora_context=lora_context, ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 554847eb8785..370bb4536fde 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -1,13 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import get_current_vllm_config from vllm.logger import init_logger @@ -263,7 +259,6 @@ def apply( workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool | None, - lora_context: "MoELoRAContext | None" = None, ): from flashinfer.fused_moe.core import ActivationType diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 8ce462a0256c..5554298bd090 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -2,13 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused batched MoE kernel.""" -from typing import TYPE_CHECKING - import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( @@ -769,7 +764,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): assert hidden_states.dim() == 3 assert expert_tokens_meta is not None @@ -1006,7 +1000,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): # Check constraints. if self.quant_config.use_int4_w4a16: diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 70dd29639e67..925020d8b6b0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -3,13 +3,9 @@ """Fused MoE utilities for GPTQ.""" from collections.abc import Callable -from typing import TYPE_CHECKING import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import ( @@ -718,12 +714,12 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): assert self.w1_scale is not None assert self.w2_scale is not None - if lora_context is None: + ctx = self._lora_context + if ctx is None: fused_marlin_moe( hidden_states=hidden_states, w1=w1, @@ -764,7 +760,6 @@ def apply( # intermediate_cache1 is indexed by flat (token, expert) pair index, # which is compatible with add_lora_fused_moe's scatter mechanism. - ctx = lora_context M = hidden_states.size(0) top_k_num = topk_ids.size(1) lora_state: dict = {} @@ -932,7 +927,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): assert expert_tokens_meta is not None, "Num valid tokens per batch is required" return batched_fused_marlin_moe( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2b27853a8bd8..64b460aff55d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -6,10 +6,7 @@ import json import os from collections.abc import Callable -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext +from typing import Any import torch @@ -2003,7 +2000,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): # Check constraints. if self.quant_config.use_int4_w4a16: @@ -2107,6 +2103,7 @@ def apply( expert_ids_lora = None num_tokens_post_padded_lora = None token_lora_mapping = None + lora_context = self._lora_context if lora_context is not None: ( sorted_token_ids_lora, @@ -2246,7 +2243,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): # Check constraints. if self.quant_config.use_int4_w4a16: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index af71ca6bf254..a239dfea92e4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -2,14 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import abstractmethod -from typing import TYPE_CHECKING import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, @@ -165,7 +161,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 0f4f7bdab316..142e180786c6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -2,14 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING - import torch from vllm.logger import init_logger - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, @@ -96,7 +91,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply( @@ -110,5 +104,4 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=None if self.disable_expert_map else layer.expert_map, shared_experts_input=shared_experts_input, - lora_context=lora_context, ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index d4f9efba0a8a..7adac0374cf8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -3,10 +3,7 @@ from collections.abc import Callable, Iterable from enum import Enum -from typing import TYPE_CHECKING, Literal, cast, get_args, overload - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext +from typing import Literal, cast, get_args, overload import torch from torch.nn.parameter import UninitializedParameter @@ -610,12 +607,6 @@ def _get_quant_method() -> FusedMoEMethodBase: else 1.0, ) - # Set permanently by FusedMoEWithLoRA.set_mapping() when the active - # kernel supports native LoRA (supports_lora() is True). - # FusedMoEModularMethod.apply() reads this and passes it down to - # FusedMoEKernel → FusedMoEExpertsModular.apply(). - self._lora_context: MoELoRAContext | None = None - # TODO(bnell): This method is provided as a hook so vllm/lora/layers/fused_moe.py # can safely swap out the quant_method. We should figure out a less # intrusive way to do this. diff --git a/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py b/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py index 3501251fd4a8..859a7db00fc6 100644 --- a/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py +++ b/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py @@ -14,6 +14,8 @@ class LoRAExpertsMixin: Mixing this class in: - Flips supports_lora() to True so _can_fused_experts_support lets LoRA through the gate check. + - Stashes a MoELoRAContext on the experts instance via + set_lora_context(), which apply() consumes from self._lora_context. - Provides apply_w13_lora / apply_w2_lora helpers that dispatch to the PunicaWrapper kernels. @@ -21,6 +23,11 @@ class LoRAExpertsMixin: must be mixed into a FusedMoEExperts subclass. """ + _lora_context: MoELoRAContext | None = None + + def set_lora_context(self, ctx: MoELoRAContext) -> None: + self._lora_context = ctx + @staticmethod def supports_lora() -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 94322ed29aa4..56bdcf70ab95 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -5,13 +5,10 @@ from dataclasses import dataclass from enum import Enum from math import prod -from typing import TYPE_CHECKING, final +from typing import final import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import ( @@ -741,12 +738,10 @@ def g2_alphas(self) -> torch.Tensor | None: @staticmethod def supports_lora() -> bool: - """Return True if this expert impl natively handles MoELoRAContext. + """Return True if this expert impl natively handles LoRA. - When True, FusedMoEWithLoRA will propagate a MoELoRAContext through - FusedMoEKernel.apply() instead of using the legacy decorator injection. - Subclasses that inline the LoRA computation inside apply() must override - this to return True. + LoRA-aware experts should mix in LoRAExpertsMixin, which flips this + to True and provides the per-forward LoRA state plumbing. """ return False @@ -912,7 +907,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ) -> None: """ This function computes the intermediate result of a Mixture of Experts @@ -1221,7 +1215,6 @@ def _fused_experts( expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, expert_tokens_meta: ExpertTokensMetadata | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: _, M_full, N, K, top_k = self.fused_experts.moe_problem_size( a1q, w1, w2, topk_ids @@ -1266,7 +1259,6 @@ def _fused_experts( workspace2=workspace2, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - lora_context=lora_context, ) return fused_out @@ -1349,7 +1341,6 @@ def apply( expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, shared_experts_input: torch.Tensor | None = None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets @@ -1374,8 +1365,6 @@ def apply( - shared_experts_input (Optional[torch.Tensor]): Optional separate input for shared experts. For latent MoE, this is the original hidden_states before latent projection. - - lora_context (Optional[MoELoRAContext]): LoRA context to propagate to - fused_experts.apply() when the expert backend supports native LoRA. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -1414,7 +1403,6 @@ def apply( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, expert_tokens_meta=expert_tokens_meta, - lora_context=lora_context, ) return self._finalize( @@ -1618,7 +1606,6 @@ def apply( expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, shared_experts_input: torch.Tensor | None = None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert isinstance(self.impl, FusedMoEKernelModularImpl) return self.impl.apply( @@ -1632,5 +1619,4 @@ def apply( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, shared_experts_input=shared_experts_input, - lora_context=lora_context, ) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index ce9da37cfc91..d24bda101ffa 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -2,13 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import IntEnum from functools import lru_cache -from typing import TYPE_CHECKING import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -416,7 +412,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): # TODO(rob): rocm_aiter_fused_experts uses self.quant_config's # a_scales for static quantization. Update this to fit better diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py index 98b0821eb4a6..00be12780a16 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py @@ -459,7 +459,6 @@ def _apply_quant_method( topk_weights=topk_weights, topk_ids=topk_ids, shared_experts_input=shared_experts_input, - lora_context=getattr(layer, "_lora_context", None), ) self._maybe_apply_shared_experts( diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 2c18a7527fe5..e33111aa0ab2 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -2,12 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable -from typing import TYPE_CHECKING import torch - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import torch.nn.functional as F from torch.nn import Module @@ -261,7 +257,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: return self.forward( layer=layer, @@ -269,8 +264,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, shared_experts_input=shared_experts_input, - lora_context=lora_context, - ) # CustomOp.forward uses *args/**kwargs, lora_context passes through + ) def forward_native( self, @@ -279,7 +273,6 @@ def forward_native( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply( @@ -293,7 +286,6 @@ def forward_native( global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, shared_experts_input=shared_experts_input, - lora_context=lora_context, ) def forward_cuda( @@ -303,7 +295,6 @@ def forward_cuda( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: return self.forward_native( layer, @@ -311,7 +302,6 @@ def forward_cuda( topk_weights, topk_ids, shared_experts_input, - lora_context=lora_context, ) def apply_monolithic( diff --git a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py index cc00bca9a169..1ce3cd5083c0 100644 --- a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py @@ -1,12 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( @@ -132,7 +128,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - lora_context: "MoELoRAContext | None" = None, ): topk = topk_ids.size(-1) xpu_fused_moe( diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index f533a031a18f..cfad1f86faa2 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -59,7 +59,6 @@ from vllm.transformers_utils.config import get_safetensors_params_metadata if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.models.utils import WeightsMapper @@ -818,7 +817,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: return fused_marlin_moe( x, diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 5e86fd8a1c06..729924663646 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -1,10 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any, Union - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext +from typing import Any, Union import torch from packaging import version @@ -486,7 +483,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py index d1b78a86130c..57ebb961d487 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py @@ -2,13 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING - import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -209,7 +204,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py index f85c37852eac..09a216fd2cb1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py @@ -2,13 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING - import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -295,7 +290,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py index f4075780b8b6..ab805591deee 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py @@ -2,12 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING - import torch - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, ) @@ -313,7 +308,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: if layer.enable_eplb: raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py index c0b96b9b0fc5..ed8ed79c50c6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py @@ -2,12 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING - import torch - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, QuantizationStrategy, @@ -396,7 +391,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py index 3835a813d813..bad5b3895b8f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py @@ -2,12 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING - import torch - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, QuantizationStrategy, @@ -181,7 +176,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py index 34172049120e..97970592c8ab 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py @@ -1,13 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe import ( FusedMoE, @@ -197,7 +193,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py index f0d79599e7b1..f530a1a1df2b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py @@ -2,12 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING - import torch - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, ) @@ -250,7 +245,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py index df8dca28bdbd..216eed6372a9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py @@ -3,12 +3,8 @@ import enum from enum import Enum -from typing import TYPE_CHECKING import torch - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from compressed_tensors.quantization import ( QuantizationArgs, ) @@ -549,7 +545,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert self.kernel_backend == "Marlin" return fused_marlin_moe( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2b1e2bf8e2a7..0dc8907248ef 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -86,7 +86,6 @@ ) if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.models.utils import WeightsMapper ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -892,7 +891,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 3dbbfabd8ad9..61eb6c912a11 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.layers.quantization import QuantizationMethods import gguf @@ -651,7 +650,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: if layer.apply_router_weight_on_input: raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 196437afcfd1..1ca551d6351b 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -2,10 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext +from typing import Any import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE @@ -910,7 +907,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: return fused_marlin_moe( x, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 7c2280474bd2..852ed1a10a34 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -96,7 +96,6 @@ from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) @@ -976,7 +975,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None @@ -1469,7 +1467,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None @@ -2008,7 +2005,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 56e47a27f1ac..e5ef3f4c3168 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -1,10 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext +from typing import Any import torch @@ -372,7 +369,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 4aa780fb513b..395158c911a8 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1,13 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext - from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention @@ -423,7 +419,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None @@ -438,7 +433,6 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=layer.expert_map, shared_experts_input=shared_experts_input, - lora_context=lora_context, ) def apply_monolithic( diff --git a/vllm/model_executor/layers/quantization/online/moe_base.py b/vllm/model_executor/layers/quantization/online/moe_base.py index 7d3f63637900..25c3359ee8be 100644 --- a/vllm/model_executor/layers/quantization/online/moe_base.py +++ b/vllm/model_executor/layers/quantization/online/moe_base.py @@ -2,12 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import abstractmethod -from typing import TYPE_CHECKING import torch -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig @@ -158,7 +155,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic assert self.moe_kernel is not None diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 61cbd6781e6b..64753a173dfe 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1,10 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext +from typing import Any import torch @@ -459,7 +456,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( @@ -774,7 +770,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts @@ -927,7 +922,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts, @@ -1424,7 +1418,6 @@ def apply( topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - lora_context: "MoELoRAContext | None" = None, ) -> torch.Tensor: # For oracle kernel or emulation kernel if self.moe_kernel is not None: From 6fe6601188f8553b0a52cd5e262a58fcbd3db288 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 23 Apr 2026 02:23:10 +0000 Subject: [PATCH 14/24] FMT Signed-off-by: Jee Jee Li --- vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py | 1 - .../layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py | 1 - .../layers/fused_moe/experts/flashinfer_cutedsl_moe.py | 1 - vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py | 1 - vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py | 1 - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py | 1 - vllm/model_executor/layers/fused_moe/xpu_fused_moe.py | 1 - .../compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py | 1 - vllm/model_executor/layers/quantization/mxfp4.py | 1 - 9 files changed, 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py index ac233f33f4f3..03341378a13c 100644 --- a/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py index 3ca4c2e2892f..5eaaf46739fc 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py index f1c1ce2e1958..5ce58220b073 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index a34cf0add9d4..1f0258fb657f 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py index 8722cbbdf4c7..d084283360c4 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 370bb4536fde..26409804c48d 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk diff --git a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py index 1ce3cd5083c0..e10be4af8680 100644 --- a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py index 97970592c8ab..02e946b1b61e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 395158c911a8..019bb45d65dc 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import torch from vllm.config import get_current_vllm_config From 99c00a248d87b6b88be07d8493b80d64248153cb Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 23 Apr 2026 03:32:51 +0000 Subject: [PATCH 15/24] FMT Signed-off-by: Jee Jee Li --- vllm/lora/layers/utils.py | 9 +++++---- vllm/lora/punica_wrapper/punica_base.py | 4 ---- vllm/lora/punica_wrapper/punica_gpu.py | 6 ------ .../layers/fused_moe/lora_experts_mixin.py | 6 ++---- 4 files changed, 7 insertions(+), 18 deletions(-) diff --git a/vllm/lora/layers/utils.py b/vllm/lora/layers/utils.py index c19b097586f5..1b8083f5c4d1 100644 --- a/vllm/lora/layers/utils.py +++ b/vllm/lora/layers/utils.py @@ -90,11 +90,12 @@ def try_get_optimal_moe_lora_config( top_k: int, dtype: str | None, M: int, - block_shape: list[int] | None = None, ) -> dict[str, int | None]: - config = try_get_optimal_moe_config( - w1_shape, w2_shape, top_k, dtype, M, block_shape - ).copy() + # LoRA shrink/expand operates on bf16/fp16 adapters regardless of the + # base MoE weight's block-wise quantization, so block_shape is omitted + # from the config lookup — the non-quantized branch in get_default_config + # ignores it anyway. + config = try_get_optimal_moe_config(w1_shape, w2_shape, top_k, dtype, M).copy() if op_type in [ "fused_moe_lora_w13_shrink", "fused_moe_lora_w2_shrink", diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 546d82fa285b..4ab66dccdc29 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -514,8 +514,6 @@ def add_lora_w13( num_slices: int, fully_sharded: bool, use_tuned_config: bool, - *, - block_shape: list[int] | None = None, ) -> tuple[ torch.Tensor | None, torch.Tensor | None, @@ -551,8 +549,6 @@ def add_lora_w2( fully_sharded: bool, tp_rank: int, use_tuned_config: bool, - *, - block_shape: list[int] | None = None, ) -> None: """Apply w2 LoRA to y (intermediate_cache3) in-place before moe_sum. diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 0b28a557e260..44d1dbd50728 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -480,8 +480,6 @@ def add_lora_w13( num_slices: int, fully_sharded: bool, use_tuned_config: bool, - *, - block_shape: list[int] | None = None, ) -> tuple[ torch.Tensor | None, torch.Tensor | None, @@ -533,7 +531,6 @@ def add_lora_w13( top_k=top_k, dtype=config_dtype, M=num_tokens, - block_shape=block_shape, ) shrink_config = get_config(op_type="fused_moe_lora_w13_shrink") expand_config = get_config(op_type="fused_moe_lora_w13_expand") @@ -615,8 +612,6 @@ def add_lora_w2( fully_sharded: bool, tp_rank: int, use_tuned_config: bool, - *, - block_shape: list[int] | None = None, ) -> None: import functools @@ -663,7 +658,6 @@ def add_lora_w2( top_k=top_k, dtype=config_dtype, M=num_tokens, - block_shape=block_shape, ) shrink_config = get_config(op_type="fused_moe_lora_w2_shrink") expand_config = get_config(op_type="fused_moe_lora_w2_expand") diff --git a/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py b/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py index 859a7db00fc6..c609c5cf56b5 100644 --- a/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py +++ b/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py @@ -19,8 +19,8 @@ class LoRAExpertsMixin: - Provides apply_w13_lora / apply_w2_lora helpers that dispatch to the PunicaWrapper kernels. - Mixin callers rely on self.block_shape from FusedMoEExperts, so this - must be mixed into a FusedMoEExperts subclass. + The helper methods are pure functions of their inputs; all required + state is on lora_context or passed as arguments. """ _lora_context: MoELoRAContext | None = None @@ -70,7 +70,6 @@ def apply_w13_lora( lora_context.w13_num_slices, lora_context.fully_sharded, lora_context.use_tuned_config, - block_shape=self.block_shape, ) def apply_w2_lora( @@ -109,5 +108,4 @@ def apply_w2_lora( lora_context.fully_sharded, lora_context.tp_rank, lora_context.use_tuned_config, - block_shape=self.block_shape, ) From 94958721a46624521fe2058317d7ffbf7dad397e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 23 Apr 2026 16:43:20 +0000 Subject: [PATCH 16/24] Init Signed-off-by: Jee Jee Li --- tests/lora/test_gptoss_tp.py | 1 + tests/lora/test_qwen3moe_tp.py | 14 ++--- vllm/lora/layers/fused_moe.py | 48 +++++++++++++++-- vllm/lora/model_manager.py | 28 ++++++---- vllm/lora/punica_wrapper/punica_base.py | 5 ++ vllm/lora/punica_wrapper/punica_gpu.py | 52 ++++++++++++++----- .../layers/fused_moe/lora_context.py | 7 +++ .../layers/fused_moe/lora_experts_mixin.py | 1 + .../fused_moe/prepare_finalize/naive_dp_ep.py | 48 +++++++++++++++-- 9 files changed, 169 insertions(+), 35 deletions(-) diff --git a/tests/lora/test_gptoss_tp.py b/tests/lora/test_gptoss_tp.py index 68dd87233ac0..648660734655 100644 --- a/tests/lora/test_gptoss_tp.py +++ b/tests/lora/test_gptoss_tp.py @@ -129,6 +129,7 @@ def test_gpt_oss_lora_tp2( tensor_parallel_size=2, gpu_memory_utilization=0.8, fully_sharded_loras=fully_sharded_loras, + enable_expert_parallel=not fully_sharded_loras, compilation_config=vllm.config.CompilationConfig( # Avoid OOM cudagraph_specialize_lora=False, ), diff --git a/tests/lora/test_qwen3moe_tp.py b/tests/lora/test_qwen3moe_tp.py index fcac4275cc40..9af142f6f388 100644 --- a/tests/lora/test_qwen3moe_tp.py +++ b/tests/lora/test_qwen3moe_tp.py @@ -5,6 +5,8 @@ # NOTE To avoid overloading the CI pipeline, this test script will not # be triggered on CI and is primarily intended for local testing and verification. +import pytest + import vllm from vllm.lora.request import LoRARequest @@ -82,15 +84,15 @@ def test_qwen3moe_lora(qwen3moe_lora_files): @multi_gpu_test(num_gpus=2) -def test_qwen3moe_lora_tp2(qwen3moe_lora_files): +@pytest.mark.parametrize("ep", [False, True]) +def test_qwen3moe_lora_tp2(ep, qwen3moe_lora_files): llm = vllm.LLM( MODEL_PATH, max_model_len=1024, enable_lora=True, max_loras=4, - enforce_eager=True, trust_remote_code=True, - enable_chunked_prefill=True, + enable_expert_parallel=ep, tensor_parallel_size=2, ) @@ -99,15 +101,15 @@ def test_qwen3moe_lora_tp2(qwen3moe_lora_files): @multi_gpu_test(num_gpus=4) -def test_qwen3moe_lora_tp4(qwen3moe_lora_files): +@pytest.mark.parametrize("ep", [False, True]) +def test_qwen3moe_lora_tp4(ep, qwen3moe_lora_files): llm = vllm.LLM( MODEL_PATH, max_model_len=1024, enable_lora=True, max_loras=4, - enforce_eager=True, trust_remote_code=True, - enable_chunked_prefill=True, + enable_expert_parallel=ep, tensor_parallel_size=4, ) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index c1542f1466cf..7598a029bc5d 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -34,9 +34,13 @@ def __init__(self, base_layer: FusedMoE) -> None: super().__init__() self.base_layer = base_layer - assert not self.base_layer.use_ep, ( - "EP support for Fused MoE LoRA is not implemented yet." - ) + if self.base_layer.use_ep: + moe_config = self.base_layer.moe_config + all2all_backend = moe_config.moe_parallel_config.all2all_backend + assert all2all_backend == "allgather_reducescatter", ( + "Fused MoE LoRA with EP currently only supports " + f"all2all_backend='allgather_reducescatter', got '{all2all_backend}'." + ) assert not self.base_layer.quant_method.is_monolithic, ( "Monolithic kernels are not supported for Fused MoE LoRA." ) @@ -72,6 +76,7 @@ def __init__(self, base_layer: FusedMoE) -> None: "and consume self._lora_context in apply()." ) self._fused_experts = moe_kernel.fused_experts + self._moe_kernel = moe_kernel self.base_layer._replace_quant_method( FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel) ) @@ -156,6 +161,16 @@ def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig): ), ) + def _verify_ep_fs(self, lora_config): + # EP and fully_sharded LoRA both partition along the same TP group — + # EP on the expert dim, fully_sharded on the LoRA rank dim — with + # mutually contradictory assumptions about which rank holds which + # expert's rank-shard. + assert not (self.base_layer.use_ep and lora_config.fully_sharded_loras), ( + "Fused MoE LoRA does not support enable_expert_parallel=True " + "together with fully_sharded_loras=True. Disable one of them." + ) + def create_lora_weights( self, max_loras: int, @@ -163,6 +178,8 @@ def create_lora_weights( model_config: PretrainedConfig | None = None, ) -> None: """Initializes lora matrices.""" + + self._verify_ep_fs(self, lora_config) self.max_loras = lora_config.max_loras self.fully_sharded = lora_config.fully_sharded_loras @@ -288,6 +305,24 @@ def set_lora( w1_lora_a, w2_lora_a, w3_lora_a = lora_a w1_lora_b, w2_lora_b, w3_lora_b = lora_b + + # Under EP the adapter tensors carry all global experts; slice this + # rank's owned range so downstream shapes line up with local buffers. + global_num_experts = self.base_layer.global_num_experts + ep_rank = self.base_layer.ep_rank + if ( + w1_lora_a.shape[0] == global_num_experts + and num_experts != global_num_experts + ): + expert_start = ep_rank * num_experts + expert_end = expert_start + num_experts + w1_lora_a = w1_lora_a[expert_start:expert_end] + w2_lora_a = w2_lora_a[expert_start:expert_end] + w3_lora_a = w3_lora_a[expert_start:expert_end] + w1_lora_b = w1_lora_b[expert_start:expert_end] + w2_lora_b = w2_lora_b[expert_start:expert_end] + w3_lora_b = w3_lora_b[expert_start:expert_end] + assert ( num_experts == w1_lora_a.shape[0] @@ -332,7 +367,11 @@ def set_lora( def set_mapping(self, punica_wrapper): super().set_mapping(punica_wrapper) - self._fused_experts.set_lora_context(self._build_lora_context()) + lora_context = self._build_lora_context() + self._fused_experts.set_lora_context(lora_context) + prepare_finalize = self._moe_kernel.prepare_finalize + if hasattr(prepare_finalize, "set_lora_context"): + prepare_finalize.set_lora_context(lora_context) def forward(self, *args, **kwargs): return self.base_layer.forward(*args, **kwargs) @@ -402,6 +441,7 @@ def create_lora_weights( """Initializes lora matrices.""" assert isinstance(model_config, PretrainedConfig) + self._verify_ep_fs(self, lora_config) self._base_model = model_config.architectures[0] self.max_loras = lora_config.max_loras self.fully_sharded = lora_config.fully_sharded_loras diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 52ff8ebc91f3..0d399d1904d2 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -762,23 +762,33 @@ def _stack_moe_lora_weights( assert gate_up_proj_lora is not None assert down_proj_lora is not None if self._is_3d_moe_model: - num_experts = module.w13_lora_a_stacked[0].shape[1] + local_num_experts = module.w13_lora_a_stacked[0].shape[1] + # The checkpoint holds weights for all global experts, but + # each EP rank owns only local_num_experts. Reshape against + # the adapter's actual expert count, then slice this rank's + # owned expert range before it gets copied into the local + # stacked buffer. For non-EP (local == global) this is a + # no-op slice. + global_num_experts = module.base_layer.global_num_experts + ep_rank = module.base_layer.ep_rank + expert_start = ep_rank * local_num_experts + expert_end = expert_start + local_num_experts # (num_experts,rank,input_size) gate_up_proj_lora.lora_a = gate_up_proj_lora.lora_a.reshape( - num_experts, -1, gate_up_proj_lora.lora_a.shape[-1] - ) + global_num_experts, -1, gate_up_proj_lora.lora_a.shape[-1] + )[expert_start:expert_end].contiguous() down_proj_lora.lora_a = down_proj_lora.lora_a.reshape( - num_experts, -1, down_proj_lora.lora_a.shape[-1] - ) + global_num_experts, -1, down_proj_lora.lora_a.shape[-1] + )[expert_start:expert_end].contiguous() # (output_size,rank,num_experts) gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.reshape( - gate_up_proj_lora.lora_b.shape[0], -1, num_experts - ) + gate_up_proj_lora.lora_b.shape[0], -1, global_num_experts + )[..., expert_start:expert_end] down_proj_lora.lora_b = down_proj_lora.lora_b.reshape( - down_proj_lora.lora_b.shape[0], -1, num_experts - ) + down_proj_lora.lora_b.shape[0], -1, global_num_experts + )[..., expert_start:expert_end] # (num_experts,output_size,rank) gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.permute( diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 4ab66dccdc29..0448a6d00cda 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -514,6 +514,7 @@ def add_lora_w13( num_slices: int, fully_sharded: bool, use_tuned_config: bool, + token_lora_mapping: torch.Tensor | None = None, ) -> tuple[ torch.Tensor | None, torch.Tensor | None, @@ -522,6 +523,10 @@ def add_lora_w13( ]: """Apply w13 LoRA to y (intermediate_cache1) in-place before activation. + When `token_lora_mapping` is provided it overrides the punica_wrapper's + global mapping — used by EP+LoRA to pass the per-rank-local mapping + after all-to-all dispatch. + Returns (sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora, token_lora_mapping) for reuse by add_lora_w2. diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 44d1dbd50728..bf951e074949 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -335,25 +335,49 @@ def moe_lora_align_block_size( expert_map: torch.Tensor | None = None, pad_sorted_ids: bool = False, naive_block_assignment: bool = False, + token_lora_mapping: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns tokens and experts into block-sized chunks for LoRA-based mixture-of-experts (MoE) execution. + + When `token_lora_mapping` is provided, it overrides the global mapping + read from `self.token_mapping_meta`. This is how EP+LoRA injects the + per-rank-local token→LoRA map after all-to-all dispatch. """ - (token_lora_mapping, _, _, _, lora_ids, _, _) = ( - self.token_mapping_meta.meta_args( - num_tokens, self.lora_config.specialize_active_lora - ) + ( + token_lora_mapping_meta, + _, + _, + _, + lora_ids, + _, + _, + ) = self.token_mapping_meta.meta_args( + num_tokens, self.lora_config.specialize_active_lora + ) + if token_lora_mapping is None: + token_lora_mapping = token_lora_mapping_meta + # Under EP the caller passes local_num_experts but topk_ids carries + # GLOBAL expert indices. The CUDA kernel uses num_experts to size + # its bucketing table; with EP we must size by global_num_experts + # so global topk_ids don't overflow. expert_map inside the kernel + # then translates global→local so the output expert_ids are local + # (mirrors the non-LoRA moe_align_block_size behavior). + kernel_num_experts = ( + expert_map.numel() if expert_map is not None else num_experts ) if naive_block_assignment: expert_ids = topk_ids.reshape(-1) sorted_ids = None num_tokens_post_pad = None else: - max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + max_num_tokens_padded = topk_ids.numel() + kernel_num_experts * ( + block_size - 1 + ) if pad_sorted_ids: max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) - if topk_ids.numel() < num_experts: + if topk_ids.numel() < kernel_num_experts: max_num_tokens_padded = topk_ids.numel() * block_size sorted_ids = torch.empty( (max_loras * max_num_tokens_padded,), @@ -361,9 +385,12 @@ def moe_lora_align_block_size( device=topk_ids.device, ) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - # Expert ids must be set default to -1 to prevent a blank block - expert_ids = torch.empty( + # Expert ids are initialized to -1 so unused (lora, expert) + # slots don't drive the LoRA Triton kernel into the wrong bucket. + # The kernel overwrites only active slots. + expert_ids = torch.full( (max_loras * max_num_m_blocks,), + -1, dtype=torch.int32, device=topk_ids.device, ) @@ -374,7 +401,7 @@ def moe_lora_align_block_size( ops.moe_lora_align_block_size( topk_ids, token_lora_mapping, - num_experts, + kernel_num_experts, block_size, max_loras, max_num_tokens_padded, @@ -384,11 +411,10 @@ def moe_lora_align_block_size( num_tokens_post_pad, adapter_enabled, lora_ids, + expert_map, ) - if expert_map is not None: - expert_ids = expert_map[expert_ids] - return None, sorted_ids, expert_ids, num_tokens_post_pad + return token_lora_mapping, sorted_ids, expert_ids, num_tokens_post_pad def add_lora_fused_moe( self, @@ -480,6 +506,7 @@ def add_lora_w13( num_slices: int, fully_sharded: bool, use_tuned_config: bool, + token_lora_mapping: torch.Tensor | None = None, ) -> tuple[ torch.Tensor | None, torch.Tensor | None, @@ -558,6 +585,7 @@ def add_lora_w13( adapter_enabled, expert_map, naive_block_assignment=naive_block_assignment, + token_lora_mapping=token_lora_mapping, ) _sorted = sorted_token_ids_lora diff --git a/vllm/model_executor/layers/fused_moe/lora_context.py b/vllm/model_executor/layers/fused_moe/lora_context.py index 92500a7bb47d..ab1f0bfc1476 100644 --- a/vllm/model_executor/layers/fused_moe/lora_context.py +++ b/vllm/model_executor/layers/fused_moe/lora_context.py @@ -42,3 +42,10 @@ class MoELoRAContext: # Whether VLLM_TUNED_CONFIG_FOLDER is set; selects get_lora_op_configs vs # try_get_optimal_moe_lora_config for Triton kernel tile configs. use_tuned_config: bool + + # Per-rank token→LoRA mapping after EP dispatch. Set by + # FusedMoEPrepareAndFinalizeModular.prepare() when EP+LoRA is active, read + # by LoRAExpertsMixin helpers in place of punica_wrapper's global mapping. + # None means no dispatch happened (non-EP path), in which case callers + # fall back to punica_wrapper.token_mapping_meta. + local_token_lora_mapping: torch.Tensor | None = None diff --git a/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py b/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py index c609c5cf56b5..10707b91b70e 100644 --- a/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py +++ b/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py @@ -70,6 +70,7 @@ def apply_w13_lora( lora_context.w13_num_slices, lora_context.fully_sharded, lora_context.use_tuned_config, + token_lora_mapping=lora_context.local_token_lora_mapping, ) def apply_w2_lora( diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py index 2b21e2db9f68..27e09c267905 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py @@ -83,6 +83,14 @@ def __init__( super().__init__() self.is_sequence_parallel = is_sequence_parallel self._num_dispatchers = num_dispatchers + # Set by FusedMoEWithLoRA.set_mapping() when LoRA is active. When + # present, prepare() dispatches the per-token LoRA mapping alongside + # hidden_states and writes the gathered result back to the context so + # experts can use the per-rank-local mapping. + self._lora_context = None + + def set_lora_context(self, ctx) -> None: + self._lora_context = ctx @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -123,22 +131,54 @@ def prepare( a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant) + # When LoRA is active, dispatch the per-token LoRA id along with + # hidden_states so every rank receives the correct mapping for the + # tokens it ends up processing. The punica_wrapper stores indices as + # int64 but the moe_lora_align_block_size kernel expects int32, so + # pull the pre-cast view from token_mapping_meta. + lora_ctx = self._lora_context + local_token_lora_mapping = None + if lora_ctx is not None: + local_token_lora_mapping = ( + lora_ctx.punica_wrapper.token_mapping_meta.token_lora_mapping[ + : a1.shape[0] + ] + ) + + extra_tensors: list[torch.Tensor] | None = None + if scales is not None: + extra_tensors = list(scales) + if local_token_lora_mapping is not None: + if extra_tensors is None: + extra_tensors = [] + extra_tensors.append(local_token_lora_mapping) + res = get_ep_group().dispatch( a1q, topk_weights, topk_ids, is_sequence_parallel=self.is_sequence_parallel, - extra_tensors=scales, + extra_tensors=extra_tensors, ) - if scales is None: + if extra_tensors is None: assert len(res) == 3 a1q, topk_weights, topk_ids = res a1q_scale = None else: assert len(res) == 4 - a1q, topk_weights, topk_ids, scales = res - a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config) + a1q, topk_weights, topk_ids, gathered_extras = res + gathered_extras = list(gathered_extras) + if local_token_lora_mapping is not None: + dispatched_lora_mapping = gathered_extras.pop() + assert lora_ctx is not None + lora_ctx.local_token_lora_mapping = dispatched_lora_mapping + if scales is not None: + a1q_scale = _unwrap_scale_and_prepare_for_moe( + gathered_extras, quant_config + ) + else: + a1q_scale = None return a1q, a1q_scale, None, topk_ids, topk_weights From 5c1fe18a222a4becc58daa10913dd461ebeb6fff Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 24 Apr 2026 12:52:20 +0000 Subject: [PATCH 17/24] Move Signed-off-by: Jee Jee Li --- vllm/lora/layers/fused_moe.py | 10 ++-------- vllm/model_executor/layers/fused_moe/modular_kernel.py | 3 +++ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index c1542f1466cf..284ac54997fb 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -18,10 +18,7 @@ FusedMoEModularMethod, ) from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEExpertsModular, - FusedMoEKernel, -) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoDPEPModular, ) @@ -61,10 +58,7 @@ def __init__(self, base_layer: FusedMoE) -> None: prepare_finalize, self.base_layer ), ) - assert ( - isinstance(moe_kernel.fused_experts, FusedMoEExpertsModular) - and moe_kernel.fused_experts.supports_lora() - ), ( + assert moe_kernel.supports_lora(), ( f"{type(moe_kernel.fused_experts).__name__} does not support LoRA. " "For unquantized MoE, set moe_backend='triton' or moe_backend='auto' " "(auto selects Triton automatically when LoRA is enabled). " diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 56bdcf70ab95..b0f967085ae4 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1538,6 +1538,9 @@ def prepare_finalize(self) -> FusedMoEPrepareAndFinalize: def fused_experts(self) -> FusedMoEExperts: return self.impl.fused_experts + def supports_lora(self) -> bool: + return self.fused_experts.supports_lora() + def _post_init_setup(self): """ Resolve any leftover setup dependencies between self.prepare_finalize From fe00d8ceac3adbb1d57a9321a1e6cdd395a666ef Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 24 Apr 2026 14:36:34 +0000 Subject: [PATCH 18/24] Move Signed-off-by: Jee Jee Li --- vllm/lora/layers/fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index dafbba721de8..98434fb4c58f 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -173,7 +173,7 @@ def create_lora_weights( ) -> None: """Initializes lora matrices.""" - self._verify_ep_fs(self, lora_config) + self._verify_ep_fs(lora_config) self.max_loras = lora_config.max_loras self.fully_sharded = lora_config.fully_sharded_loras @@ -435,7 +435,7 @@ def create_lora_weights( """Initializes lora matrices.""" assert isinstance(model_config, PretrainedConfig) - self._verify_ep_fs(self, lora_config) + self._verify_ep_fs(lora_config) self._base_model = model_config.architectures[0] self.max_loras = lora_config.max_loras self.fully_sharded = lora_config.fully_sharded_loras From 400d6cd0d488043475f7dbd27856010df6b6a965 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 24 Apr 2026 16:19:44 +0000 Subject: [PATCH 19/24] Move Signed-off-by: Jee Jee Li --- vllm/lora/layers/fused_moe.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 98434fb4c58f..de801580d670 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -7,10 +7,6 @@ from vllm import envs from vllm.config.lora import LoRAConfig -from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) from vllm.distributed.utils import divide from vllm.lora.layers.base import BaseLayerWithLoRA from vllm.model_executor.layers.fused_moe import FusedMoE @@ -41,8 +37,11 @@ def __init__(self, base_layer: FusedMoE) -> None: assert not self.base_layer.quant_method.is_monolithic, ( "Monolithic kernels are not supported for Fused MoE LoRA." ) - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() + # Use the MoE-aware TP rank/size: when EP is active, FusedMoE collapses + # moe_parallel_config.tp_size to 1 (experts are sharded across the + # TP group instead). + self.tp_size = self.base_layer.tp_size + self.tp_rank = self.base_layer.tp_rank self.device = _get_lora_device(base_layer) # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3) From 9d244b3ed3bc0af725a94bdad463739c60dcadd6 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 28 Apr 2026 08:39:03 +0800 Subject: [PATCH 20/24] Update vllm/lora/model_manager.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟 Signed-off-by: Jee Jee Li --- vllm/lora/model_manager.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 0d399d1904d2..7f527cc8b2e7 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -562,6 +562,10 @@ def create_dummy_lora( else: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] + if module.__class__.__name__ == "FusedMoEWithLoRA": + replacements = replacements[ + : len(module.lora_a_stacked) // self.lora_slots + ] subloras: list[LoRALayerWeights | None] = [] for i, r in enumerate(replacements): lora = LoRALayerWeights.create_dummy_lora_weights( From a2640798bb53a134627ab53f10fc070086bc45d9 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 28 Apr 2026 00:52:04 +0000 Subject: [PATCH 21/24] FMT Signed-off-by: Jee Jee Li --- vllm/lora/model_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 7f527cc8b2e7..ca18c577557a 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -562,7 +562,7 @@ def create_dummy_lora( else: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] - if module.__class__.__name__ == "FusedMoEWithLoRA": + if module.__class__.__name__ == "FusedMoEWithLoRA": replacements = replacements[ : len(module.lora_a_stacked) // self.lora_slots ] From 6fa9268c7c522d7f9e836147e2a9d255f9c85859 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 29 Apr 2026 15:07:44 +0000 Subject: [PATCH 22/24] Move Signed-off-by: Jee Jee Li --- vllm/lora/layers/fused_moe.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index de801580d670..d48b8d3e1f8b 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -26,17 +26,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: FusedMoE) -> None: super().__init__() self.base_layer = base_layer - - if self.base_layer.use_ep: - moe_config = self.base_layer.moe_config - all2all_backend = moe_config.moe_parallel_config.all2all_backend - assert all2all_backend == "allgather_reducescatter", ( - "Fused MoE LoRA with EP currently only supports " - f"all2all_backend='allgather_reducescatter', got '{all2all_backend}'." - ) - assert not self.base_layer.quant_method.is_monolithic, ( - "Monolithic kernels are not supported for Fused MoE LoRA." - ) + self._ep_check() # Use the MoE-aware TP rank/size: when EP is active, FusedMoE collapses # moe_parallel_config.tp_size to 1 (experts are sharded across the # TP group instead). @@ -154,7 +144,17 @@ def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig): ), ) - def _verify_ep_fs(self, lora_config): + def _ep_check(self): + if self.base_layer.use_ep: + moe_config = self.base_layer.moe_config + all2all_backend = moe_config.moe_parallel_config.all2all_backend + assert all2all_backend == "allgather_reducescatter", ( + "Fused MoE LoRA with EP currently only supports " + f"all2all_backend='allgather_reducescatter', got '{all2all_backend}'." + ) + assert not moe_config.moe_parallel_config.is_sequence_parallel + + def _verify_ep_fs(self, lora_config: LoRAConfig): # EP and fully_sharded LoRA both partition along the same TP group — # EP on the expert dim, fully_sharded on the LoRA rank dim — with # mutually contradictory assumptions about which rank holds which From 4172be466812c854f53c7441af1709966183c0f2 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 30 Apr 2026 08:21:30 +0000 Subject: [PATCH 23/24] Move Signed-off-by: Jee Jee Li --- .../layers/fused_moe/oracle/int8.py | 3 --- .../layers/fused_moe/oracle/mxfp4.py | 23 ------------------- .../layers/fused_moe/oracle/mxfp8.py | 2 -- 3 files changed, 28 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/oracle/int8.py b/vllm/model_executor/layers/fused_moe/oracle/int8.py index cdb1be108b5d..ebdd20d54dc9 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/int8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/int8.py @@ -79,9 +79,6 @@ def select_int8_moe_backend( Note: Shape-specific fallbacks may still occur at runtime. """ - if config.is_lora_enabled: - return Int8MoeBackend.TRITON, backend_to_kernel_cls(Int8MoeBackend.TRITON)[0] - AVAILABLE_BACKENDS = _get_priority_backends(config) activation_format = ( diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index f476d980d555..16a617f6c5a9 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -248,29 +248,6 @@ def select_gpt_oss_mxfp4_moe_backend( Select the primary MXFP4 MoE backend. Note: Shape-specific fallbacks may still occur at runtime. """ - device_capability = current_platform.get_device_capability() - triton_kernels_supported = ( - has_triton_kernels() - and device_capability is not None - and (9, 0) <= device_capability < (11, 0) - ) - - # LoRA: separate experts backend path - if config.is_lora_enabled: - if not current_platform.is_cuda(): - # ROCm: Triton mxfp4 LoRA hits GPU memory faults due to - # triton_kernels.tensor.Tensor / HIP read-only page issues - # during weight swizzle and LoRA forward. Needs work from - # the triton_kernels/aiter side. - raise NotImplementedError("Mxfp4 LoRA is currently only supported on CUDA.") - if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported: - logger.info_once("Using Triton backend for mxfp4 lora") - return Mxfp4MoeBackend.TRITON_UNFUSED, backend_to_kernel_cls( - Mxfp4MoeBackend.TRITON_UNFUSED - )[0] - logger.info_once("Using Marlin backend for mxfp4 lora") - return Mxfp4MoeBackend.MARLIN, backend_to_kernel_cls(Mxfp4MoeBackend.MARLIN)[0] - activation_format = ( mk.FusedMoEActivationFormat.BatchedExperts if config.moe_parallel_config.use_batched_activation_format diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py index c67def149b9d..8133902d519b 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py @@ -61,8 +61,6 @@ def select_mxfp8_moe_backend( Returns: A tuple of (fp8_backend, experts_cls). """ - if config.is_lora_enabled: - raise NotImplementedError("LoRA is not supported for MXFP8 MoE.") runner_backend = config.moe_backend if runner_backend != "auto": From bf3d2a8fc33b31e9158691cfdc5f3c72ba651037 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 30 Apr 2026 12:56:05 +0000 Subject: [PATCH 24/24] Move Signed-off-by: Jee Jee Li --- vllm/lora/layers/fused_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index d48b8d3e1f8b..2536fed94bd0 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -58,7 +58,6 @@ def __init__(self, base_layer: FusedMoE) -> None: "For quantized MoE, mix LoRAExpertsMixin into the experts class " "and consume self._lora_context in apply()." ) - self._fused_experts = moe_kernel.fused_experts self._moe_kernel = moe_kernel self.base_layer._replace_quant_method( FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel) @@ -361,7 +360,7 @@ def set_lora( def set_mapping(self, punica_wrapper): super().set_mapping(punica_wrapper) lora_context = self._build_lora_context() - self._fused_experts.set_lora_context(lora_context) + self._moe_kernel.fused_experts.set_lora_context(lora_context) prepare_finalize = self._moe_kernel.prepare_finalize if hasattr(prepare_finalize, "set_lora_context"): prepare_finalize.set_lora_context(lora_context)