diff --git a/tests/lora/test_gptoss_tp.py b/tests/lora/test_gptoss_tp.py index 68dd87233ac0..648660734655 100644 --- a/tests/lora/test_gptoss_tp.py +++ b/tests/lora/test_gptoss_tp.py @@ -129,6 +129,7 @@ def test_gpt_oss_lora_tp2( tensor_parallel_size=2, gpu_memory_utilization=0.8, fully_sharded_loras=fully_sharded_loras, + enable_expert_parallel=not fully_sharded_loras, compilation_config=vllm.config.CompilationConfig( # Avoid OOM cudagraph_specialize_lora=False, ), diff --git a/tests/lora/test_qwen3moe_tp.py b/tests/lora/test_qwen3moe_tp.py index fcac4275cc40..9af142f6f388 100644 --- a/tests/lora/test_qwen3moe_tp.py +++ b/tests/lora/test_qwen3moe_tp.py @@ -5,6 +5,8 @@ # NOTE To avoid overloading the CI pipeline, this test script will not # be triggered on CI and is primarily intended for local testing and verification. +import pytest + import vllm from vllm.lora.request import LoRARequest @@ -82,15 +84,15 @@ def test_qwen3moe_lora(qwen3moe_lora_files): @multi_gpu_test(num_gpus=2) -def test_qwen3moe_lora_tp2(qwen3moe_lora_files): +@pytest.mark.parametrize("ep", [False, True]) +def test_qwen3moe_lora_tp2(ep, qwen3moe_lora_files): llm = vllm.LLM( MODEL_PATH, max_model_len=1024, enable_lora=True, max_loras=4, - enforce_eager=True, trust_remote_code=True, - enable_chunked_prefill=True, + enable_expert_parallel=ep, tensor_parallel_size=2, ) @@ -99,15 +101,15 @@ def test_qwen3moe_lora_tp2(qwen3moe_lora_files): @multi_gpu_test(num_gpus=4) -def test_qwen3moe_lora_tp4(qwen3moe_lora_files): +@pytest.mark.parametrize("ep", [False, True]) +def test_qwen3moe_lora_tp4(ep, qwen3moe_lora_files): llm = vllm.LLM( MODEL_PATH, max_model_len=1024, enable_lora=True, max_loras=4, - enforce_eager=True, trust_remote_code=True, - enable_chunked_prefill=True, + enable_expert_parallel=ep, tensor_parallel_size=4, ) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 284ac54997fb..2536fed94bd0 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -7,10 +7,6 @@ from vllm import envs from vllm.config.lora import LoRAConfig -from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) from vllm.distributed.utils import divide from vllm.lora.layers.base import BaseLayerWithLoRA from vllm.model_executor.layers.fused_moe import FusedMoE @@ -30,15 +26,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: FusedMoE) -> None: super().__init__() self.base_layer = base_layer - - assert not self.base_layer.use_ep, ( - "EP support for Fused MoE LoRA is not implemented yet." - ) - assert not self.base_layer.quant_method.is_monolithic, ( - "Monolithic kernels are not supported for Fused MoE LoRA." - ) - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() + self._ep_check() + # Use the MoE-aware TP rank/size: when EP is active, FusedMoE collapses + # moe_parallel_config.tp_size to 1 (experts are sharded across the + # TP group instead). + self.tp_size = self.base_layer.tp_size + self.tp_rank = self.base_layer.tp_rank self.device = _get_lora_device(base_layer) # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3) @@ -65,7 +58,7 @@ def __init__(self, base_layer: FusedMoE) -> None: "For quantized MoE, mix LoRAExpertsMixin into the experts class " "and consume self._lora_context in apply()." ) - self._fused_experts = moe_kernel.fused_experts + self._moe_kernel = moe_kernel self.base_layer._replace_quant_method( FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel) ) @@ -150,6 +143,26 @@ def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig): ), ) + def _ep_check(self): + if self.base_layer.use_ep: + moe_config = self.base_layer.moe_config + all2all_backend = moe_config.moe_parallel_config.all2all_backend + assert all2all_backend == "allgather_reducescatter", ( + "Fused MoE LoRA with EP currently only supports " + f"all2all_backend='allgather_reducescatter', got '{all2all_backend}'." + ) + assert not moe_config.moe_parallel_config.is_sequence_parallel + + def _verify_ep_fs(self, lora_config: LoRAConfig): + # EP and fully_sharded LoRA both partition along the same TP group — + # EP on the expert dim, fully_sharded on the LoRA rank dim — with + # mutually contradictory assumptions about which rank holds which + # expert's rank-shard. + assert not (self.base_layer.use_ep and lora_config.fully_sharded_loras), ( + "Fused MoE LoRA does not support enable_expert_parallel=True " + "together with fully_sharded_loras=True. Disable one of them." + ) + def create_lora_weights( self, max_loras: int, @@ -157,6 +170,8 @@ def create_lora_weights( model_config: PretrainedConfig | None = None, ) -> None: """Initializes lora matrices.""" + + self._verify_ep_fs(lora_config) self.max_loras = lora_config.max_loras self.fully_sharded = lora_config.fully_sharded_loras @@ -282,6 +297,24 @@ def set_lora( w1_lora_a, w2_lora_a, w3_lora_a = lora_a w1_lora_b, w2_lora_b, w3_lora_b = lora_b + + # Under EP the adapter tensors carry all global experts; slice this + # rank's owned range so downstream shapes line up with local buffers. + global_num_experts = self.base_layer.global_num_experts + ep_rank = self.base_layer.ep_rank + if ( + w1_lora_a.shape[0] == global_num_experts + and num_experts != global_num_experts + ): + expert_start = ep_rank * num_experts + expert_end = expert_start + num_experts + w1_lora_a = w1_lora_a[expert_start:expert_end] + w2_lora_a = w2_lora_a[expert_start:expert_end] + w3_lora_a = w3_lora_a[expert_start:expert_end] + w1_lora_b = w1_lora_b[expert_start:expert_end] + w2_lora_b = w2_lora_b[expert_start:expert_end] + w3_lora_b = w3_lora_b[expert_start:expert_end] + assert ( num_experts == w1_lora_a.shape[0] @@ -326,7 +359,11 @@ def set_lora( def set_mapping(self, punica_wrapper): super().set_mapping(punica_wrapper) - self._fused_experts.set_lora_context(self._build_lora_context()) + lora_context = self._build_lora_context() + self._moe_kernel.fused_experts.set_lora_context(lora_context) + prepare_finalize = self._moe_kernel.prepare_finalize + if hasattr(prepare_finalize, "set_lora_context"): + prepare_finalize.set_lora_context(lora_context) def forward(self, *args, **kwargs): return self.base_layer.forward(*args, **kwargs) @@ -396,6 +433,7 @@ def create_lora_weights( """Initializes lora matrices.""" assert isinstance(model_config, PretrainedConfig) + self._verify_ep_fs(lora_config) self._base_model = model_config.architectures[0] self.max_loras = lora_config.max_loras self.fully_sharded = lora_config.fully_sharded_loras diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 52ff8ebc91f3..ca18c577557a 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -562,6 +562,10 @@ def create_dummy_lora( else: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] + if module.__class__.__name__ == "FusedMoEWithLoRA": + replacements = replacements[ + : len(module.lora_a_stacked) // self.lora_slots + ] subloras: list[LoRALayerWeights | None] = [] for i, r in enumerate(replacements): lora = LoRALayerWeights.create_dummy_lora_weights( @@ -762,23 +766,33 @@ def _stack_moe_lora_weights( assert gate_up_proj_lora is not None assert down_proj_lora is not None if self._is_3d_moe_model: - num_experts = module.w13_lora_a_stacked[0].shape[1] + local_num_experts = module.w13_lora_a_stacked[0].shape[1] + # The checkpoint holds weights for all global experts, but + # each EP rank owns only local_num_experts. Reshape against + # the adapter's actual expert count, then slice this rank's + # owned expert range before it gets copied into the local + # stacked buffer. For non-EP (local == global) this is a + # no-op slice. + global_num_experts = module.base_layer.global_num_experts + ep_rank = module.base_layer.ep_rank + expert_start = ep_rank * local_num_experts + expert_end = expert_start + local_num_experts # (num_experts,rank,input_size) gate_up_proj_lora.lora_a = gate_up_proj_lora.lora_a.reshape( - num_experts, -1, gate_up_proj_lora.lora_a.shape[-1] - ) + global_num_experts, -1, gate_up_proj_lora.lora_a.shape[-1] + )[expert_start:expert_end].contiguous() down_proj_lora.lora_a = down_proj_lora.lora_a.reshape( - num_experts, -1, down_proj_lora.lora_a.shape[-1] - ) + global_num_experts, -1, down_proj_lora.lora_a.shape[-1] + )[expert_start:expert_end].contiguous() # (output_size,rank,num_experts) gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.reshape( - gate_up_proj_lora.lora_b.shape[0], -1, num_experts - ) + gate_up_proj_lora.lora_b.shape[0], -1, global_num_experts + )[..., expert_start:expert_end] down_proj_lora.lora_b = down_proj_lora.lora_b.reshape( - down_proj_lora.lora_b.shape[0], -1, num_experts - ) + down_proj_lora.lora_b.shape[0], -1, global_num_experts + )[..., expert_start:expert_end] # (num_experts,output_size,rank) gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.permute( diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 4ab66dccdc29..0448a6d00cda 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -514,6 +514,7 @@ def add_lora_w13( num_slices: int, fully_sharded: bool, use_tuned_config: bool, + token_lora_mapping: torch.Tensor | None = None, ) -> tuple[ torch.Tensor | None, torch.Tensor | None, @@ -522,6 +523,10 @@ def add_lora_w13( ]: """Apply w13 LoRA to y (intermediate_cache1) in-place before activation. + When `token_lora_mapping` is provided it overrides the punica_wrapper's + global mapping — used by EP+LoRA to pass the per-rank-local mapping + after all-to-all dispatch. + Returns (sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora, token_lora_mapping) for reuse by add_lora_w2. diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 44d1dbd50728..bf951e074949 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -335,25 +335,49 @@ def moe_lora_align_block_size( expert_map: torch.Tensor | None = None, pad_sorted_ids: bool = False, naive_block_assignment: bool = False, + token_lora_mapping: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns tokens and experts into block-sized chunks for LoRA-based mixture-of-experts (MoE) execution. + + When `token_lora_mapping` is provided, it overrides the global mapping + read from `self.token_mapping_meta`. This is how EP+LoRA injects the + per-rank-local token→LoRA map after all-to-all dispatch. """ - (token_lora_mapping, _, _, _, lora_ids, _, _) = ( - self.token_mapping_meta.meta_args( - num_tokens, self.lora_config.specialize_active_lora - ) + ( + token_lora_mapping_meta, + _, + _, + _, + lora_ids, + _, + _, + ) = self.token_mapping_meta.meta_args( + num_tokens, self.lora_config.specialize_active_lora + ) + if token_lora_mapping is None: + token_lora_mapping = token_lora_mapping_meta + # Under EP the caller passes local_num_experts but topk_ids carries + # GLOBAL expert indices. The CUDA kernel uses num_experts to size + # its bucketing table; with EP we must size by global_num_experts + # so global topk_ids don't overflow. expert_map inside the kernel + # then translates global→local so the output expert_ids are local + # (mirrors the non-LoRA moe_align_block_size behavior). + kernel_num_experts = ( + expert_map.numel() if expert_map is not None else num_experts ) if naive_block_assignment: expert_ids = topk_ids.reshape(-1) sorted_ids = None num_tokens_post_pad = None else: - max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + max_num_tokens_padded = topk_ids.numel() + kernel_num_experts * ( + block_size - 1 + ) if pad_sorted_ids: max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) - if topk_ids.numel() < num_experts: + if topk_ids.numel() < kernel_num_experts: max_num_tokens_padded = topk_ids.numel() * block_size sorted_ids = torch.empty( (max_loras * max_num_tokens_padded,), @@ -361,9 +385,12 @@ def moe_lora_align_block_size( device=topk_ids.device, ) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - # Expert ids must be set default to -1 to prevent a blank block - expert_ids = torch.empty( + # Expert ids are initialized to -1 so unused (lora, expert) + # slots don't drive the LoRA Triton kernel into the wrong bucket. + # The kernel overwrites only active slots. + expert_ids = torch.full( (max_loras * max_num_m_blocks,), + -1, dtype=torch.int32, device=topk_ids.device, ) @@ -374,7 +401,7 @@ def moe_lora_align_block_size( ops.moe_lora_align_block_size( topk_ids, token_lora_mapping, - num_experts, + kernel_num_experts, block_size, max_loras, max_num_tokens_padded, @@ -384,11 +411,10 @@ def moe_lora_align_block_size( num_tokens_post_pad, adapter_enabled, lora_ids, + expert_map, ) - if expert_map is not None: - expert_ids = expert_map[expert_ids] - return None, sorted_ids, expert_ids, num_tokens_post_pad + return token_lora_mapping, sorted_ids, expert_ids, num_tokens_post_pad def add_lora_fused_moe( self, @@ -480,6 +506,7 @@ def add_lora_w13( num_slices: int, fully_sharded: bool, use_tuned_config: bool, + token_lora_mapping: torch.Tensor | None = None, ) -> tuple[ torch.Tensor | None, torch.Tensor | None, @@ -558,6 +585,7 @@ def add_lora_w13( adapter_enabled, expert_map, naive_block_assignment=naive_block_assignment, + token_lora_mapping=token_lora_mapping, ) _sorted = sorted_token_ids_lora diff --git a/vllm/model_executor/layers/fused_moe/lora_context.py b/vllm/model_executor/layers/fused_moe/lora_context.py index 92500a7bb47d..ab1f0bfc1476 100644 --- a/vllm/model_executor/layers/fused_moe/lora_context.py +++ b/vllm/model_executor/layers/fused_moe/lora_context.py @@ -42,3 +42,10 @@ class MoELoRAContext: # Whether VLLM_TUNED_CONFIG_FOLDER is set; selects get_lora_op_configs vs # try_get_optimal_moe_lora_config for Triton kernel tile configs. use_tuned_config: bool + + # Per-rank token→LoRA mapping after EP dispatch. Set by + # FusedMoEPrepareAndFinalizeModular.prepare() when EP+LoRA is active, read + # by LoRAExpertsMixin helpers in place of punica_wrapper's global mapping. + # None means no dispatch happened (non-EP path), in which case callers + # fall back to punica_wrapper.token_mapping_meta. + local_token_lora_mapping: torch.Tensor | None = None diff --git a/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py b/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py index c609c5cf56b5..10707b91b70e 100644 --- a/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py +++ b/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py @@ -70,6 +70,7 @@ def apply_w13_lora( lora_context.w13_num_slices, lora_context.fully_sharded, lora_context.use_tuned_config, + token_lora_mapping=lora_context.local_token_lora_mapping, ) def apply_w2_lora( diff --git a/vllm/model_executor/layers/fused_moe/oracle/int8.py b/vllm/model_executor/layers/fused_moe/oracle/int8.py index cdb1be108b5d..ebdd20d54dc9 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/int8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/int8.py @@ -79,9 +79,6 @@ def select_int8_moe_backend( Note: Shape-specific fallbacks may still occur at runtime. """ - if config.is_lora_enabled: - return Int8MoeBackend.TRITON, backend_to_kernel_cls(Int8MoeBackend.TRITON)[0] - AVAILABLE_BACKENDS = _get_priority_backends(config) activation_format = ( diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index c1423362d737..6c540eba04cf 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -248,29 +248,6 @@ def select_gpt_oss_mxfp4_moe_backend( Select the primary MXFP4 MoE backend. Note: Shape-specific fallbacks may still occur at runtime. """ - device_capability = current_platform.get_device_capability() - triton_kernels_supported = ( - has_triton_kernels() - and device_capability is not None - and (9, 0) <= device_capability < (11, 0) - ) - - # LoRA: separate experts backend path - if config.is_lora_enabled: - if not current_platform.is_cuda(): - # ROCm: Triton mxfp4 LoRA hits GPU memory faults due to - # triton_kernels.tensor.Tensor / HIP read-only page issues - # during weight swizzle and LoRA forward. Needs work from - # the triton_kernels/aiter side. - raise NotImplementedError("Mxfp4 LoRA is currently only supported on CUDA.") - if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported: - logger.info_once("Using Triton backend for mxfp4 lora") - return Mxfp4MoeBackend.TRITON_UNFUSED, backend_to_kernel_cls( - Mxfp4MoeBackend.TRITON_UNFUSED - )[0] - logger.info_once("Using Marlin backend for mxfp4 lora") - return Mxfp4MoeBackend.MARLIN, backend_to_kernel_cls(Mxfp4MoeBackend.MARLIN)[0] - activation_format = ( mk.FusedMoEActivationFormat.BatchedExperts if config.moe_parallel_config.use_batched_activation_format diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py index c67def149b9d..8133902d519b 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py @@ -61,8 +61,6 @@ def select_mxfp8_moe_backend( Returns: A tuple of (fp8_backend, experts_cls). """ - if config.is_lora_enabled: - raise NotImplementedError("LoRA is not supported for MXFP8 MoE.") runner_backend = config.moe_backend if runner_backend != "auto": diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py index 8240a5e8c963..f4fa00495519 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py +++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py @@ -163,11 +163,6 @@ def select_unquantized_moe_backend( if current_platform.is_out_of_tree(): return UnquantizedMoeBackend.OOT, None - if moe_config.is_lora_enabled: - return UnquantizedMoeBackend.TRITON, backend_to_kernel_cls( - UnquantizedMoeBackend.TRITON - ) - # NOTE: the kernels are selected in the following order. AVAILABLE_BACKENDS = _get_priority_backends(moe_config) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py index 5b3325ad0195..ce9c3e3c1cfa 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py @@ -84,6 +84,14 @@ def __init__( super().__init__() self.is_sequence_parallel = is_sequence_parallel self._num_dispatchers = num_dispatchers + # Set by FusedMoEWithLoRA.set_mapping() when LoRA is active. When + # present, prepare() dispatches the per-token LoRA mapping alongside + # hidden_states and writes the gathered result back to the context so + # experts can use the per-rank-local mapping. + self._lora_context = None + + def set_lora_context(self, ctx) -> None: + self._lora_context = ctx @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -124,22 +132,54 @@ def prepare( a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant) + # When LoRA is active, dispatch the per-token LoRA id along with + # hidden_states so every rank receives the correct mapping for the + # tokens it ends up processing. The punica_wrapper stores indices as + # int64 but the moe_lora_align_block_size kernel expects int32, so + # pull the pre-cast view from token_mapping_meta. + lora_ctx = self._lora_context + local_token_lora_mapping = None + if lora_ctx is not None: + local_token_lora_mapping = ( + lora_ctx.punica_wrapper.token_mapping_meta.token_lora_mapping[ + : a1.shape[0] + ] + ) + + extra_tensors: list[torch.Tensor] | None = None + if scales is not None: + extra_tensors = list(scales) + if local_token_lora_mapping is not None: + if extra_tensors is None: + extra_tensors = [] + extra_tensors.append(local_token_lora_mapping) + res = get_ep_group().dispatch( a1q, topk_weights, topk_ids, is_sequence_parallel=self.is_sequence_parallel, - extra_tensors=scales, + extra_tensors=extra_tensors, ) - if scales is None: + if extra_tensors is None: assert len(res) == 3 a1q, topk_weights, topk_ids = res a1q_scale = None else: assert len(res) == 4 - a1q, topk_weights, topk_ids, scales = res - a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config) + a1q, topk_weights, topk_ids, gathered_extras = res + gathered_extras = list(gathered_extras) + if local_token_lora_mapping is not None: + dispatched_lora_mapping = gathered_extras.pop() + assert lora_ctx is not None + lora_ctx.local_token_lora_mapping = dispatched_lora_mapping + if scales is not None: + a1q_scale = _unwrap_scale_and_prepare_for_moe( + gathered_extras, quant_config + ) + else: + a1q_scale = None return a1q, a1q_scale, None, topk_ids, topk_weights