Skip to content

[LoRA] MoE LoRA Refactor#40338

Merged
simon-mo merged 29 commits intomainfrom
moe-lora-refactor
Apr 26, 2026
Merged

[LoRA] MoE LoRA Refactor#40338
simon-mo merged 29 commits intomainfrom
moe-lora-refactor

Conversation

@jeejeelee
Copy link
Copy Markdown
Collaborator

@jeejeelee jeejeelee commented Apr 20, 2026

Motivation

Currently, MoE LoRA is wired in by monkey-patching methods on the modular kernel at construction time. FusedMoEWithLoRA._inject_lora_into_fused_moe wraps FusedMoEKernel.apply, TritonExperts.activation, and TritonExperts.moe_sum with fwd_decorator / act_decorator / moe_sum_decorator, and smuggles tensors between them through moe_state_dict. This has several concrete problems:

1. Hacky and hard to maintain.

The LoRA contribution is hidden inside decorators on base-MoE methods. Someone reading the MoE experts code sees self.activation(...) and self.moe_sum(...) — they have no syntactic hint that these calls may actually run LoRA shrink/expand GEMMs, or that moe_state_dict carries cross-call state between them. Debugging requires holding the patching order in your head.

2. MoE changes don't see LoRA

Because the LoRA path lives outside the expert apply() functions, any refactor of TritonExperts / UnfusedOAITritonExperts / MarlinExperts has to re-discover the LoRA contract (what activation receives, what moe_sum receives, what shape intermediate_cache* has). Every MoE-side change risks breaking LoRA silently.

3. Hard to extend

Adding support for new features — EP, additional quantized backends, new expert impls — means replicating or working around the decorator chain, and the state-dict plumbing assumes one specific control flow. There is no extension point for an expert that wants to apply LoRA at a different point .

Change

Treat LoRA as a first-class concern in the modular MoE kernel rather than an external patch.

1. MoELoRAContext: a single explicit payload

A new MoELoRAContext dataclass (vllm/lora/lora_context.py) packages all of
the LoRA state a MoE forward pass needs:

  • w13_lora_a_stacked / w13_lora_b_stacked / w2_lora_a_stacked /
    w2_lora_b_stacked
  • adapter_enabled, max_loras
  • routing/sharding info: top_k, w13_num_slices, fully_sharded, tp_rank,
    tp_size, local_num_experts
  • the active punica_wrapper
  • use_tuned_config (whether VLLM_TUNED_CONFIG_FOLDER is set)

FusedMoEWithLoRA.set_mapping builds this context once and stashes it on the base layer as FusedMoE._lora_context. MoERunnerBase forwards it into FusedMoEMethodBase.apply(..., lora_context=...), and FusedMoEModularMethod.apply / FusedMoEKernel.apply / FusedMoEKernelModularImpl.apply
thread it through to FusedMoEExpertsModular.apply(..., lora_context=...).

The context is the only LoRA surface area seen by the MoE code path —
there's no more hidden state passed between method wrappers.

2. LoRA compute inlined into the expert apply()

Expert implementations that support LoRA now call it directly inside their own apply() function, at the same logical point the decorators used to target:

  • TritonExperts.apply (fused_moe.py): after the w13 GEMM and before
    activation, call self.apply_w13_lora(...) to add the LoRA delta to
    intermediate_cache1. After the w2 GEMM and before moe_sum, call
    self.apply_w2_lora(...) on intermediate_cache3, reusing the
    sorted_token_ids_lora tensors from the first call.
  • UnfusedOAITritonExperts.apply (gpt_oss_triton_kernels_moe.py):
    same pattern, adjusted for the gather/scatter layout that its two
    matmul_ogs calls produce.
  • MarlinExperts.apply (fused_marlin_moe.py): fused_marlin_moe
    consumes activation_func and moe_sum as callables, so the LoRA path
    wraps those two callables to inject apply_w13_lora / apply_w2_lora at
    the correct buffer state.

FusedMoEExperts.supports_lora() defaults to False. Each expert impl that has a validated LoRA path overrides it to True (TritonExperts, UnfusedOAITritonExperts, MarlinExperts). FusedMoEWithLoRA.__init__ asserts that the selected expert impl reports supports_lora(), and oracle/unquantized.py::select_unquantized_moe_backend
now filters the backend auto-selection by that flag so unsupported backends (FlashInfer / AITER) are transparently skipped when LoRA is enabled instead of silently producing wrong output or crashing later.

Because the LoRA shrink/expand is now visible in the expert source, anyone modifying TritonExperts.apply can see the LoRA call site and keep it correct; tests on the MoE path automatically cover the LoRA path as well.

3. LoRA computation stays in PunicaWrapper

MoE LoRA still respects the PunicaWrapper logic , the actual shrink/expand compute is not moved. Two new methods on PunicaWrapperBaseadd_lora_w13 and add_lora_w2 — encapsulate config lookup (tuned vs. heuristic), moe_lora_align_block_size, and the add_lora_fused_moe call. PunicaWrapperGPU provides the concrete implementation. FusedMoEExpertsModular has thin helpers apply_w13_lora / apply_w2_lora that just forward the context fields to these methods.

Test Plan

All the LoRA and MoE tests on CI should pass correctly

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This repository is configured for manual code reviews. Comment @claude review to trigger a review and subscribe this PR to future pushes, or @claude review once for a one-time review.

Tip: disable this comment in your organization's Code Review settings.

@jeejeelee jeejeelee marked this pull request as draft April 20, 2026 09:05
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 20, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jeejeelee.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the LoRA implementation for Fused MoE layers by replacing the legacy decorator-based monkey-patching with an explicit MoELoRAContext propagation through the modular kernel path. It updates various quantization methods and MoE backends to handle this context. Feedback highlights potential AttributeError issues: one regarding the access of fused_experts on FusedMoEKernel, and another concerning the missing block_shape property in the FusedMoEExperts base class and its subclasses like MarlinExperts, which would cause runtime crashes when LoRA is enabled.

Comment thread vllm/lora/layers/fused_moe.py Outdated
fused_experts.moe_sum = moe_sum_decorator(
self.base_layer, fused_experts.moe_sum
assert (
isinstance(moe_kernel.fused_experts, FusedMoEExpertsModular)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check isinstance(moe_kernel.fused_experts, FusedMoEExpertsModular) might fail because FusedMoEKernel typically wraps the implementation in an impl attribute. You should likely check moe_kernel.impl.fused_experts instead, or ensure that FusedMoEKernel exposes fused_experts as a property.

Suggested change
isinstance(moe_kernel.fused_experts, FusedMoEExpertsModular)
isinstance(moe_kernel.impl.fused_experts, FusedMoEExpertsModular)

Copy link
Copy Markdown
Collaborator

@bnellnm bnellnm Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use the is_monolithic property instead of isinstance.

You might also need to add a supports_lora method to FusedMoEKernel

Comment on lines +962 to +1001
def apply_w13_lora(
self,
lora_context: "MoELoRAContext",
*,
y: torch.Tensor,
x: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor | None,
w1: torch.Tensor,
w2: torch.Tensor,
num_tokens: int,
top_k_num: int,
) -> tuple[
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
]:
return lora_context.punica_wrapper.add_lora_w13(
y,
x,
lora_context.w13_lora_a_stacked,
lora_context.w13_lora_b_stacked,
topk_ids,
topk_weights,
expert_map,
w1,
w2,
num_tokens,
top_k_num,
lora_context.max_loras,
lora_context.adapter_enabled,
lora_context.local_num_experts,
lora_context.top_k,
lora_context.w13_num_slices,
lora_context.fully_sharded,
lora_context.use_tuned_config,
block_shape=self.block_shape,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The apply_w13_lora method (and apply_w2_lora below) calls self.block_shape, but block_shape is not defined as an abstract property in the FusedMoEExperts base class. This will cause an AttributeError at runtime for any expert implementation that does not explicitly define it (e.g., MarlinExperts or UnfusedOAITritonExperts). Please add block_shape as an abstract property to FusedMoEExperts and implement it in all subclasses that support LoRA.

Comment on lines +662 to +663
def supports_lora() -> bool:
return True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

MarlinExperts claims to support LoRA but does not implement the block_shape property required by apply_w13_lora and apply_w2_lora in the base class. This will lead to a crash when LoRA is enabled with Marlin quantization.

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@mergify mergify Bot removed the needs-rebase label Apr 20, 2026
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@mergify mergify Bot added nvidia rocm Related to AMD ROCm intel-gpu Related to Intel GPU labels Apr 20, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 20, 2026
@jeejeelee jeejeelee marked this pull request as ready for review April 20, 2026 12:00
@jeejeelee jeejeelee added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 20, 2026
Comment on lines +58 to 63
moe_kernel = FusedMoEKernel(
prepare_finalize,
self.base_layer.quant_method.select_gemm_impl(
prepare_finalize, self.base_layer
),
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know if this case is ever hit now? Most methods have been switched over to the new MK initialization pattern (_setup_kernel)

index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
].copy_(sliced_w2_lora_b, non_blocking=True)

def set_mapping(self, punica_wrapper):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this happen at runtime or is this part of the LoRA setup?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is LoRA setup, not runtime. MoELoRAContext captures references to it so the experts kernel sees fresh values without rebinding

@bnellnm
Copy link
Copy Markdown
Collaborator

bnellnm commented Apr 24, 2026

It's probably too much for this PR but we could consider having separate subclasses for experts that support LoRA (so that the LoRA code could be completely isolated) and the setup in FusedMoEWithLoRA could construct the proper LoRA MK instead of rewriting or hijacking the existing MK.

Copy link
Copy Markdown
Collaborator

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think at least one of the gemini comments (the one about isinstance) are relevant and should be fixed. Otherwise, I think it looks pretty good.

@github-project-automation github-project-automation Bot moved this to In review in NVIDIA Apr 24, 2026
@github-project-automation github-project-automation Bot moved this from To Triage to In progress in gpt-oss Issues & Enhancements Apr 24, 2026
@jeejeelee
Copy link
Copy Markdown
Collaborator Author

It's probably too much for this PR but we could consider having separate subclasses for experts that support LoRA (so that the LoRA code could be completely isolated) and the setup in FusedMoEWithLoRA could construct the proper LoRA MK instead of rewriting or hijacking the existing MK.

Good point, I'll look into it further. Thanks!

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee jeejeelee requested a review from bnellnm April 24, 2026 13:00
Copy link
Copy Markdown
Collaborator

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the good work! Btw, Do you know if the select_gemm_impl codepath ever gets triggered? It should largely be defunct now and will be removed once everything is migrated over to _setup_kernel.

@jeejeelee
Copy link
Copy Markdown
Collaborator Author

LGTM. Thanks for the good work! Btw, Do you know if the select_gemm_impl codepath ever gets triggered? It should largely be defunct now and will be removed once everything is migrated over to _setup_kernel.

Thank you for the impressive reivew firstly.

Yes, it looks like still gets triggered — specifically by the unmigrated quant methods (AWQ-Marlin, compressed_tensors_moe).

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@github-project-automation github-project-automation Bot moved this from In progress to Ready in gpt-oss Issues & Enhancements Apr 25, 2026
@github-project-automation github-project-automation Bot moved this from In review to Ready in NVIDIA Apr 25, 2026
@simon-mo simon-mo enabled auto-merge (squash) April 25, 2026 17:15
@bnellnm
Copy link
Copy Markdown
Collaborator

bnellnm commented Apr 25, 2026

FYI, my guess is that #40794 is causing the LoRA failure. There was a similar issue when the truncate came before the reduce in a prior PR that was fixed by moving the trunc afterwards. I'm not sure what the best solution is here. Calling .contiguous() on the result of the truncation should "fix" the problem but feels like a bandaid.

@simon-mo simon-mo merged commit 8cd174f into main Apr 26, 2026
74 of 75 checks passed
@simon-mo simon-mo deleted the moe-lora-refactor branch April 26, 2026 01:55
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Apr 26, 2026
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD Apr 26, 2026
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
jatseng-ai pushed a commit to jatseng-ai/vllm that referenced this pull request Apr 28, 2026
Lafunamor pushed a commit to Lafunamor/vllm that referenced this pull request May 1, 2026
Signed-off-by: Adrian <info@zzit.ch>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models intel-gpu Related to Intel GPU nvidia ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants