Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
85cf592
Init
jeejeelee Apr 19, 2026
9c64777
Move
jeejeelee Apr 19, 2026
7b81c9e
Address conflict
jeejeelee Apr 20, 2026
80d0188
Merge branch 'main' into moe-lora-refactor
jeejeelee Apr 20, 2026
9350a67
Fix
jeejeelee Apr 20, 2026
b3d1ea6
Fix
jeejeelee Apr 20, 2026
13caeb4
Merge branch 'main' into moe-lora-refactor
jeejeelee Apr 20, 2026
1455838
fix conflict
jeejeelee Apr 22, 2026
51be12a
fix conflict
jeejeelee Apr 22, 2026
c2dbb14
Move
jeejeelee Apr 22, 2026
4f3b7f9
Move
jeejeelee Apr 22, 2026
550e19d
Move
jeejeelee Apr 22, 2026
019cfa1
Fix
jeejeelee Apr 22, 2026
d1ae808
Fix
jeejeelee Apr 22, 2026
61d7746
Remove unrelated change
jeejeelee Apr 23, 2026
ea4a8fd
Move
jeejeelee Apr 23, 2026
5600221
Move
jeejeelee Apr 23, 2026
cd29a49
Move
jeejeelee Apr 23, 2026
7707bf3
Add lora experts mixin
jeejeelee Apr 23, 2026
166386e
OPT
jeejeelee Apr 23, 2026
6fe6601
FMT
jeejeelee Apr 23, 2026
99c00a2
FMT
jeejeelee Apr 23, 2026
7c855e4
Merge branch 'main' into moe-lora-refactor
jeejeelee Apr 23, 2026
ce0f6c3
Merge branch 'main' into moe-lora-refactor
jeejeelee Apr 23, 2026
9495872
Init
jeejeelee Apr 23, 2026
5c1fe18
Move
jeejeelee Apr 24, 2026
3efd6c5
Merge branch 'main' into moe-lora-refactor
jeejeelee Apr 24, 2026
bc9b997
Merge branch 'moe-lora-refactor' into moe-lora-ep
jeejeelee Apr 24, 2026
fe00d8c
Move
jeejeelee Apr 24, 2026
400d6cd
Move
jeejeelee Apr 24, 2026
57cab35
Merge branch 'main' into moe-lora-ep
jeejeelee Apr 25, 2026
381c45c
FIX
jeejeelee Apr 26, 2026
08e33a6
Merge remote-tracking branch 'origin/main' into moe-lora-ep
jeejeelee Apr 27, 2026
fbc1a61
Merge branch 'main' into moe-lora-ep
jeejeelee Apr 27, 2026
9d244b3
Update vllm/lora/model_manager.py
jeejeelee Apr 28, 2026
a264079
FMT
jeejeelee Apr 28, 2026
5e93587
Merge branch 'main' into moe-lora-ep
jeejeelee Apr 29, 2026
6fa9268
Move
jeejeelee Apr 29, 2026
4172be4
Move
jeejeelee Apr 30, 2026
1cea06b
Merge remote-tracking branch 'origin/main' into moe-lora-ep
jeejeelee Apr 30, 2026
bf3d2a8
Move
jeejeelee Apr 30, 2026
a8311e7
Merge branch 'main' into moe-lora-ep
jeejeelee May 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/lora/test_gptoss_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def test_gpt_oss_lora_tp2(
tensor_parallel_size=2,
gpu_memory_utilization=0.8,
fully_sharded_loras=fully_sharded_loras,
enable_expert_parallel=not fully_sharded_loras,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
Expand Down
14 changes: 8 additions & 6 deletions tests/lora/test_qwen3moe_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# NOTE To avoid overloading the CI pipeline, this test script will not
# be triggered on CI and is primarily intended for local testing and verification.

import pytest

import vllm
from vllm.lora.request import LoRARequest

Expand Down Expand Up @@ -82,15 +84,15 @@ def test_qwen3moe_lora(qwen3moe_lora_files):


@multi_gpu_test(num_gpus=2)
def test_qwen3moe_lora_tp2(qwen3moe_lora_files):
@pytest.mark.parametrize("ep", [False, True])
def test_qwen3moe_lora_tp2(ep, qwen3moe_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
enable_expert_parallel=ep,
tensor_parallel_size=2,
)

Expand All @@ -99,15 +101,15 @@ def test_qwen3moe_lora_tp2(qwen3moe_lora_files):


@multi_gpu_test(num_gpus=4)
def test_qwen3moe_lora_tp4(qwen3moe_lora_files):
@pytest.mark.parametrize("ep", [False, True])
def test_qwen3moe_lora_tp4(ep, qwen3moe_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
enable_expert_parallel=ep,
tensor_parallel_size=4,
)

Expand Down
68 changes: 53 additions & 15 deletions vllm/lora/layers/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@

from vllm import envs
from vllm.config.lora import LoRAConfig
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.distributed.utils import divide
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand All @@ -30,15 +26,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: FusedMoE) -> None:
super().__init__()
self.base_layer = base_layer

assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
assert not self.base_layer.quant_method.is_monolithic, (
"Monolithic kernels are not supported for Fused MoE LoRA."
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self._ep_check()
# Use the MoE-aware TP rank/size: when EP is active, FusedMoE collapses
# moe_parallel_config.tp_size to 1 (experts are sharded across the
# TP group instead).
self.tp_size = self.base_layer.tp_size
self.tp_rank = self.base_layer.tp_rank
self.device = _get_lora_device(base_layer)
# For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
# since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
Expand All @@ -65,7 +58,7 @@ def __init__(self, base_layer: FusedMoE) -> None:
"For quantized MoE, mix LoRAExpertsMixin into the experts class "
"and consume self._lora_context in apply()."
)
self._fused_experts = moe_kernel.fused_experts
self._moe_kernel = moe_kernel
Comment thread
jeejeelee marked this conversation as resolved.
self.base_layer._replace_quant_method(
FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel)
)
Expand Down Expand Up @@ -150,13 +143,35 @@ def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
),
)

def _ep_check(self):
if self.base_layer.use_ep:
moe_config = self.base_layer.moe_config
all2all_backend = moe_config.moe_parallel_config.all2all_backend
assert all2all_backend == "allgather_reducescatter", (
"Fused MoE LoRA with EP currently only supports "
f"all2all_backend='allgather_reducescatter', got '{all2all_backend}'."
)
assert not moe_config.moe_parallel_config.is_sequence_parallel

def _verify_ep_fs(self, lora_config: LoRAConfig):
# EP and fully_sharded LoRA both partition along the same TP group —
# EP on the expert dim, fully_sharded on the LoRA rank dim — with
# mutually contradictory assumptions about which rank holds which
# expert's rank-shard.
assert not (self.base_layer.use_ep and lora_config.fully_sharded_loras), (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, do you know anyone using this fully_sharded_loras feature? At prime, we had some weird bugs with it so we never use it and would think that it is basically solved with expert parallel. You'd never want to be using this feature.

Copy link
Copy Markdown
Collaborator Author

@jeejeelee jeejeelee Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HollowMan6 I know your team tried fully_sharded_loras, right?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It generally works okay, except this bug #35077 (comment) But once LoRA + EP is supported, I don't think we need to have support for it to be enabled at the same time.

"Fused MoE LoRA does not support enable_expert_parallel=True "
"together with fully_sharded_loras=True. Disable one of them."
)

def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""

self._verify_ep_fs(lora_config)
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras

Expand Down Expand Up @@ -282,6 +297,24 @@ def set_lora(

w1_lora_a, w2_lora_a, w3_lora_a = lora_a
w1_lora_b, w2_lora_b, w3_lora_b = lora_b

# Under EP the adapter tensors carry all global experts; slice this
# rank's owned range so downstream shapes line up with local buffers.
global_num_experts = self.base_layer.global_num_experts
ep_rank = self.base_layer.ep_rank
if (
w1_lora_a.shape[0] == global_num_experts
and num_experts != global_num_experts
):
expert_start = ep_rank * num_experts
expert_end = expert_start + num_experts
w1_lora_a = w1_lora_a[expert_start:expert_end]
w2_lora_a = w2_lora_a[expert_start:expert_end]
w3_lora_a = w3_lora_a[expert_start:expert_end]
w1_lora_b = w1_lora_b[expert_start:expert_end]
w2_lora_b = w2_lora_b[expert_start:expert_end]
w3_lora_b = w3_lora_b[expert_start:expert_end]
Comment on lines +301 to +316
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this slicing be moved to load instead? If it's here in the set which does (CPU -> GPU). That means the cpu LoRAModels that LoRAModelManager holds have all the loras? If it's moved to load then it's "pre-sliced" at load time.

Copy link
Copy Markdown
Contributor

@Jackmin801 Jackmin801 Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to keep us on the same page. I categorized the concerns that happen in the lora code into load, add and set. So by moving to load here I mean that the logic should be moved the load in WorkerLoRAManager.
Image

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense


Comment thread
jeejeelee marked this conversation as resolved.
assert (
num_experts
== w1_lora_a.shape[0]
Expand Down Expand Up @@ -326,7 +359,11 @@ def set_lora(

def set_mapping(self, punica_wrapper):
super().set_mapping(punica_wrapper)
self._fused_experts.set_lora_context(self._build_lora_context())
lora_context = self._build_lora_context()
self._moe_kernel.fused_experts.set_lora_context(lora_context)
prepare_finalize = self._moe_kernel.prepare_finalize
if hasattr(prepare_finalize, "set_lora_context"):
prepare_finalize.set_lora_context(lora_context)

def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs)
Expand Down Expand Up @@ -396,6 +433,7 @@ def create_lora_weights(
"""Initializes lora matrices."""

assert isinstance(model_config, PretrainedConfig)
self._verify_ep_fs(lora_config)
self._base_model = model_config.architectures[0]
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
Expand Down
32 changes: 23 additions & 9 deletions vllm/lora/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,10 @@ def create_dummy_lora(
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
if module.__class__.__name__ == "FusedMoEWithLoRA":
replacements = replacements[
: len(module.lora_a_stacked) // self.lora_slots
]
Comment on lines +565 to +568
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im actually kind of lost as to what is happening here 😓 will read in detail later. But just a quick question out of curiosity. Why do we do this packing at add time? Can we pack at load time and make the add and set simple?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im trying to practice my chinese writing hehe. but for non-chinese readers. Im asking if this logic can be moved here.
Image

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One benefit of moving it is that it makes the loading more efficient. We dont need to allocate all the small 2D MoE tensors at load time then pack them into 3D at add time. We can instead just allocate in 3D and load the 2D slices into it with local expert subsetting!

subloras: list[LoRALayerWeights | None] = []
for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
Expand Down Expand Up @@ -762,23 +766,33 @@ def _stack_moe_lora_weights(
assert gate_up_proj_lora is not None
assert down_proj_lora is not None
if self._is_3d_moe_model:
num_experts = module.w13_lora_a_stacked[0].shape[1]
local_num_experts = module.w13_lora_a_stacked[0].shape[1]
# The checkpoint holds weights for all global experts, but
# each EP rank owns only local_num_experts. Reshape against
# the adapter's actual expert count, then slice this rank's
# owned expert range before it gets copied into the local
# stacked buffer. For non-EP (local == global) this is a
# no-op slice.
global_num_experts = module.base_layer.global_num_experts
ep_rank = module.base_layer.ep_rank
expert_start = ep_rank * local_num_experts
expert_end = expert_start + local_num_experts

# (num_experts,rank,input_size)
gate_up_proj_lora.lora_a = gate_up_proj_lora.lora_a.reshape(
num_experts, -1, gate_up_proj_lora.lora_a.shape[-1]
)
global_num_experts, -1, gate_up_proj_lora.lora_a.shape[-1]
)[expert_start:expert_end].contiguous()
down_proj_lora.lora_a = down_proj_lora.lora_a.reshape(
num_experts, -1, down_proj_lora.lora_a.shape[-1]
)
global_num_experts, -1, down_proj_lora.lora_a.shape[-1]
)[expert_start:expert_end].contiguous()

# (output_size,rank,num_experts)
gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.reshape(
gate_up_proj_lora.lora_b.shape[0], -1, num_experts
)
gate_up_proj_lora.lora_b.shape[0], -1, global_num_experts
)[..., expert_start:expert_end]
down_proj_lora.lora_b = down_proj_lora.lora_b.reshape(
down_proj_lora.lora_b.shape[0], -1, num_experts
)
down_proj_lora.lora_b.shape[0], -1, global_num_experts
)[..., expert_start:expert_end]

# (num_experts,output_size,rank)
gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.permute(
Expand Down
5 changes: 5 additions & 0 deletions vllm/lora/punica_wrapper/punica_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ def add_lora_w13(
num_slices: int,
fully_sharded: bool,
use_tuned_config: bool,
token_lora_mapping: torch.Tensor | None = None,
) -> tuple[
torch.Tensor | None,
torch.Tensor | None,
Expand All @@ -522,6 +523,10 @@ def add_lora_w13(
]:
"""Apply w13 LoRA to y (intermediate_cache1) in-place before activation.

When `token_lora_mapping` is provided it overrides the punica_wrapper's
global mapping — used by EP+LoRA to pass the per-rank-local mapping
after all-to-all dispatch.

Returns (sorted_token_ids_lora, expert_ids_lora,
num_tokens_post_padded_lora, token_lora_mapping)
for reuse by add_lora_w2.
Expand Down
52 changes: 40 additions & 12 deletions vllm/lora/punica_wrapper/punica_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,35 +335,62 @@ def moe_lora_align_block_size(
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
naive_block_assignment: bool = False,
token_lora_mapping: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.

When `token_lora_mapping` is provided, it overrides the global mapping
read from `self.token_mapping_meta`. This is how EP+LoRA injects the
per-rank-local token→LoRA map after all-to-all dispatch.
"""
(token_lora_mapping, _, _, _, lora_ids, _, _) = (
self.token_mapping_meta.meta_args(
num_tokens, self.lora_config.specialize_active_lora
)
(
token_lora_mapping_meta,
_,
_,
_,
lora_ids,
_,
_,
) = self.token_mapping_meta.meta_args(
num_tokens, self.lora_config.specialize_active_lora
)
if token_lora_mapping is None:
token_lora_mapping = token_lora_mapping_meta
# Under EP the caller passes local_num_experts but topk_ids carries
# GLOBAL expert indices. The CUDA kernel uses num_experts to size
# its bucketing table; with EP we must size by global_num_experts
# so global topk_ids don't overflow. expert_map inside the kernel
# then translates global→local so the output expert_ids are local
# (mirrors the non-LoRA moe_align_block_size behavior).
kernel_num_experts = (
expert_map.numel() if expert_map is not None else num_experts
)
if naive_block_assignment:
expert_ids = topk_ids.reshape(-1)
sorted_ids = None
num_tokens_post_pad = None
else:
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
max_num_tokens_padded = topk_ids.numel() + kernel_num_experts * (
block_size - 1
)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
if topk_ids.numel() < num_experts:
if topk_ids.numel() < kernel_num_experts:
max_num_tokens_padded = topk_ids.numel() * block_size
sorted_ids = torch.empty(
(max_loras * max_num_tokens_padded,),
dtype=torch.int32,
device=topk_ids.device,
)
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be set default to -1 to prevent a blank block
expert_ids = torch.empty(
# Expert ids are initialized to -1 so unused (lora, expert)
# slots don't drive the LoRA Triton kernel into the wrong bucket.
# The kernel overwrites only active slots.
expert_ids = torch.full(
(max_loras * max_num_m_blocks,),
-1,
dtype=torch.int32,
device=topk_ids.device,
)
Expand All @@ -374,7 +401,7 @@ def moe_lora_align_block_size(
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
kernel_num_experts,
block_size,
max_loras,
max_num_tokens_padded,
Expand All @@ -384,11 +411,10 @@ def moe_lora_align_block_size(
num_tokens_post_pad,
adapter_enabled,
lora_ids,
expert_map,
)
if expert_map is not None:
expert_ids = expert_map[expert_ids]

return None, sorted_ids, expert_ids, num_tokens_post_pad
return token_lora_mapping, sorted_ids, expert_ids, num_tokens_post_pad

def add_lora_fused_moe(
self,
Expand Down Expand Up @@ -480,6 +506,7 @@ def add_lora_w13(
num_slices: int,
fully_sharded: bool,
use_tuned_config: bool,
token_lora_mapping: torch.Tensor | None = None,
) -> tuple[
torch.Tensor | None,
torch.Tensor | None,
Expand Down Expand Up @@ -558,6 +585,7 @@ def add_lora_w13(
adapter_enabled,
expert_map,
naive_block_assignment=naive_block_assignment,
token_lora_mapping=token_lora_mapping,
)

_sorted = sorted_token_ids_lora
Expand Down
7 changes: 7 additions & 0 deletions vllm/model_executor/layers/fused_moe/lora_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,10 @@ class MoELoRAContext:
# Whether VLLM_TUNED_CONFIG_FOLDER is set; selects get_lora_op_configs vs
# try_get_optimal_moe_lora_config for Triton kernel tile configs.
use_tuned_config: bool

# Per-rank token→LoRA mapping after EP dispatch. Set by
# FusedMoEPrepareAndFinalizeModular.prepare() when EP+LoRA is active, read
# by LoRAExpertsMixin helpers in place of punica_wrapper's global mapping.
# None means no dispatch happened (non-EP path), in which case callers
# fall back to punica_wrapper.token_mapping_meta.
local_token_lora_mapping: torch.Tensor | None = None
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def apply_w13_lora(
lora_context.w13_num_slices,
lora_context.fully_sharded,
lora_context.use_tuned_config,
token_lora_mapping=lora_context.local_token_lora_mapping,
)

def apply_w2_lora(
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/fused_moe/oracle/int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ def select_int8_moe_backend(
Note: Shape-specific fallbacks may still occur at runtime.
"""

if config.is_lora_enabled:
return Int8MoeBackend.TRITON, backend_to_kernel_cls(Int8MoeBackend.TRITON)[0]

AVAILABLE_BACKENDS = _get_priority_backends(config)

activation_format = (
Expand Down
Loading
Loading