Skip to content

[LoRA] Initial EP support for LoRA#40867

Draft
jeejeelee wants to merge 42 commits intomainfrom
moe-lora-ep
Draft

[LoRA] Initial EP support for LoRA#40867
jeejeelee wants to merge 42 commits intomainfrom
moe-lora-ep

Conversation

@jeejeelee
Copy link
Copy Markdown
Collaborator

@jeejeelee jeejeelee commented Apr 25, 2026

Purpose

Depends on #40338

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee jeejeelee marked this pull request as draft April 25, 2026 07:19
@jeejeelee jeejeelee removed their assignment Apr 25, 2026
@mergify mergify Bot added qwen Related to Qwen models gpt-oss Related to GPT-OSS models labels Apr 25, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for Expert Parallelism (EP) within Fused MoE LoRA layers, refactoring the implementation from a decorator-based approach to a more robust system using MoELoRAContext and LoRAExpertsMixin. The changes enable native LoRA handling in Triton and Marlin expert kernels and include updates to the Punica wrapper for handling rank-local token mappings and expert slicing. Feedback highlights a bug in token mapping logic when sequence parallelism is active and a potential TypeError in non-gated MoE models where specific LoRA weights might be None.

Comment thread vllm/lora/layers/fused_moe.py
Comment thread vllm/lora/model_manager.py
@github-project-automation github-project-automation Bot moved this from To Triage to In progress in gpt-oss Issues & Enhancements Apr 27, 2026
jeejeelee and others added 3 commits April 28, 2026 08:39
Co-authored-by: ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟 <hollowman@opensuse.org>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
self._lora_context = None

def set_lora_context(self, ctx) -> None:
self._lora_context = ctx
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would think it's cleaner if the PrepareFinalize doesnt have to take the entire context. With regards to lora, mk.FusedMoEPrepareAndFinalizeModular is only concerned with taking in lora_id in prepare and getting out the local_lora_id. This is easy to unit test.

If it takes the entire lora context, now it needs to be concerned about constructing it and its punica wrapper and using the punica wrapper correctly so that lora_ctx.punica_wrapper.token_mapping_meta.token_lora_mapping is correct. 4 attributes before we can get to the input we want seems a bit much.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, we want to capture the full LoRA context upfront for better extensibility, but we can consider only passing lora_id in the future.

Comment on lines +142 to +144
local_token_lora_mapping = (
lora_ctx.punica_wrapper.token_mapping_meta.token_lora_mapping[
: a1.shape[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure why this slicing is needed. punica wrapper already slices token_lora_mapping to the seqlen right? I think before #39107 it was necessary because MoE DP chunking would make the input chunked which the punica wrapper doesnt know about. But would think it isnt necessary now that DP chunking is no longer supported.

Comment thread vllm/lora/layers/fused_moe.py
# EP on the expert dim, fully_sharded on the LoRA rank dim — with
# mutually contradictory assumptions about which rank holds which
# expert's rank-shard.
assert not (self.base_layer.use_ep and lora_config.fully_sharded_loras), (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, do you know anyone using this fully_sharded_loras feature? At prime, we had some weird bugs with it so we never use it and would think that it is basically solved with expert parallel. You'd never want to be using this feature.

Copy link
Copy Markdown
Collaborator Author

@jeejeelee jeejeelee Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HollowMan6 I know your team tried fully_sharded_loras, right?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It generally works okay, except this bug #35077 (comment) But once LoRA + EP is supported, I don't think we need to have support for it to be enabled at the same time.

Comment on lines +302 to +317
# Under EP the adapter tensors carry all global experts; slice this
# rank's owned range so downstream shapes line up with local buffers.
global_num_experts = self.base_layer.global_num_experts
ep_rank = self.base_layer.ep_rank
if (
w1_lora_a.shape[0] == global_num_experts
and num_experts != global_num_experts
):
expert_start = ep_rank * num_experts
expert_end = expert_start + num_experts
w1_lora_a = w1_lora_a[expert_start:expert_end]
w2_lora_a = w2_lora_a[expert_start:expert_end]
w3_lora_a = w3_lora_a[expert_start:expert_end]
w1_lora_b = w1_lora_b[expert_start:expert_end]
w2_lora_b = w2_lora_b[expert_start:expert_end]
w3_lora_b = w3_lora_b[expert_start:expert_end]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this slicing be moved to load instead? If it's here in the set which does (CPU -> GPU). That means the cpu LoRAModels that LoRAModelManager holds have all the loras? If it's moved to load then it's "pre-sliced" at load time.

Copy link
Copy Markdown
Contributor

@Jackmin801 Jackmin801 Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to keep us on the same page. I categorized the concerns that happen in the lora code into load, add and set. So by moving to load here I mean that the logic should be moved the load in WorkerLoRAManager.
Image

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense

Comment thread vllm/model_executor/layers/fused_moe/oracle/unquantized.py
Comment on lines +565 to +568
if module.__class__.__name__ == "FusedMoEWithLoRA":
replacements = replacements[
: len(module.lora_a_stacked) // self.lora_slots
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im actually kind of lost as to what is happening here 😓 will read in detail later. But just a quick question out of curiosity. Why do we do this packing at add time? Can we pack at load time and make the add and set simple?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im trying to practice my chinese writing hehe. but for non-chinese readers. Im asking if this logic can be moved here.
Image

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One benefit of moving it is that it makes the loading more efficient. We dont need to allocate all the small 2D MoE tensors at load time then pack them into 3D at add time. We can instead just allocate in 3D and load the 2D slices into it with local expert subsetting!

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models qwen Related to Qwen models

Projects

Status: In progress

Development

Successfully merging this pull request may close these issues.

3 participants