Skip to content

[LoRA][II] Add fused MOE LoRA Triton kernel and tests#19711

Merged
Fridge003 merged 6 commits intosgl-project:mainfrom
yushengsu-thu:moe-lora-triton-kernel
Mar 19, 2026
Merged

[LoRA][II] Add fused MOE LoRA Triton kernel and tests#19711
Fridge003 merged 6 commits intosgl-project:mainfrom
yushengsu-thu:moe-lora-triton-kernel

Conversation

@yushengsu-thu
Copy link
Copy Markdown
Collaborator

@yushengsu-thu yushengsu-thu commented Mar 2, 2026

Split this PR #14105 into 3 parts - Part II

Add Triton-based fused MoE LoRA kernel for combined expert routing and LoRA computation:

  • fused_moe_lora_kernel.py: Triton kernels for fused_moe_lora_shrink and fused_moe_lora_expand
  • test_fused_moe_lora_kernel.py: Unit tests validating kernel correctness against PyTorch reference
  • Update triton_ops/init.py to export fused_moe_lora

Motivation

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

Add Triton-based fused MoE LoRA kernel for combined expert routing and LoRA computation:
- fused_moe_lora_kernel.py: Triton kernels for fused_moe_lora_shrink and fused_moe_lora_expand
- test_fused_moe_lora_kernel.py: Unit tests validating kernel correctness against PyTorch reference
- Update triton_ops/__init__.py to export fused_moe_lora

Made-with: Cursor
Copilot AI review requested due to automatic review settings March 2, 2026 18:58
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions bot added the lora label Mar 2, 2026
@yushengsu-thu yushengsu-thu changed the title [Lora] Add fused MOE LoRA Triton kernel and tests [Lora][II] Add fused MOE LoRA Triton kernel and tests Mar 2, 2026
@yushengsu-thu yushengsu-thu changed the title [Lora][II] Add fused MOE LoRA Triton kernel and tests [LoRA][II] Add fused MOE LoRA Triton kernel and tests Mar 2, 2026
@yushengsu-thu
Copy link
Copy Markdown
Collaborator Author

Co-authored-by: Jonah Bernard jb2528@cornell.edu
Co-authored-by: cursor[bot] noreply@cursor.sh

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a Triton-based fused MoE LoRA kernel to SGLang (combining expert routing-aligned token ordering with LoRA A/B computations) and introduces a CUDA unit test to validate kernel correctness against a PyTorch reference implementation.

Changes:

  • Added fused_moe_lora_kernel.py implementing fused shrink/expand Triton kernels and registering them as SGLang custom ops.
  • Added test_fused_moe_lora_kernel.py to validate fused kernel outputs vs a PyTorch reference.
  • Exported fused_moe_lora via python/sglang/srt/lora/triton_ops/__init__.py.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
test/registered/lora/test_fused_moe_lora_kernel.py New CUDA unit test comparing fused kernel output to a PyTorch reference.
python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py New Triton fused MoE LoRA implementation (shrink/expand + custom-op registration).
python/sglang/srt/lora/triton_ops/init.py Exports fused_moe_lora from the new kernel module.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +235 to +236
lora_idx = token_lora_mapping[i]
expert_ids = topk_ids[i]
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_torch indexes LoRA tensors using torch.int32 tensors (token_lora_mapping / topk_ids). PyTorch advanced indexing requires integer indices to be torch.long (int64) or Python ints, so this reference path is likely to raise an indexing error on CUDA and make the test fail. Convert indices to long() (or use .item() for scalars) before indexing so the reference computation reliably runs.

Suggested change
lora_idx = token_lora_mapping[i]
expert_ids = topk_ids[i]
lora_idx = token_lora_mapping[i].long()
expert_ids = topk_ids[i].long()

Copilot uses AI. Check for mistakes.
Comment on lines +16 to +37
_LORA_PTR_DICT: dict[tuple[int, ...], torch.Tensor] = {}


def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
"""
`_LORA_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)

if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None:
return ptr_tensor

tensor_ptrs = []
for lora_weight in lora_weights:
tensor_ptrs.append(lora_weight.data_ptr())
ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)

_LORA_PTR_DICT[key] = ptr_tensor
return _LORA_PTR_DICT.get(key)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_LORA_PTR_DICT is a process-global cache that never evicts entries. If LoRA weights are created/destroyed dynamically (e.g., loading many adapters over time), this dict will grow without bound and retain device memory for the pointer tensors. Consider bounding/clearing this cache (e.g., LRU by size, or clearing when adapters are unloaded) or keying by a stable adapter ID rather than raw data_ptr() values.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there logic to skip loras that are not currently being used?

Comment on lines +295 to +296
top_k=1 if mul_routed_weight else top_k_num,
MUL_ROUTED_WEIGHT=False,
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _fused_moe_lora_shrink, top_k is set to 1 when mul_routed_weight=True, but the shrink kernel’s A operand (qcurr_hidden_states) is shaped by token (M) and relies on offs_token // top_k_num to map routed-token indices back to token indices. With top_k=1, the kernel will index qcurr_hidden_states out of bounds when mul_routed_weight=True, producing incorrect results. Keep top_k equal to top_k_num for shrink regardless of weighting, and control weighting solely via MUL_ROUTED_WEIGHT.

Suggested change
top_k=1 if mul_routed_weight else top_k_num,
MUL_ROUTED_WEIGHT=False,
top_k=top_k_num,
MUL_ROUTED_WEIGHT=mul_routed_weight,

Copilot uses AI. Check for mistakes.
Comment on lines +357 to +367
"SPLIT_K": split_k, # Set split_k = 1 for expand calls
"USE_GDC": use_gdc,
"launch_pdl": use_gdc, # triton kernel metadata
}

grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_b_stacked),
lora_b_stacked[0].shape[0],
)
_fused_moe_lora_kernel[grid](
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_fused_moe_lora_expand accepts split_k, but the launch grid does not multiply by split_k (unlike shrink). If split_k > 1, pid_sk will never cover all partitions and the kernel will compute only a fraction of K, silently producing wrong outputs. Either enforce split_k == 1 for expand (assert/validation) or include split_k in the grid and rely on atomic accumulation as in shrink.

Copilot uses AI. Check for mistakes.
@yushengsu-thu
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

Resolved conflict in triton_ops/__init__.py by keeping both
fused_moe_lora and chunked_embedding_lora_a_forward exports.

Made-with: Cursor
Copy link
Copy Markdown
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, we can try to speed it up in followup

@Fridge003 Fridge003 merged commit 7f6f1a3 into sgl-project:main Mar 19, 2026
139 of 149 checks passed
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants