[LoRA][II] Add fused MOE LoRA Triton kernel and tests#19711
[LoRA][II] Add fused MOE LoRA Triton kernel and tests#19711Fridge003 merged 6 commits intosgl-project:mainfrom
Conversation
Add Triton-based fused MoE LoRA kernel for combined expert routing and LoRA computation: - fused_moe_lora_kernel.py: Triton kernels for fused_moe_lora_shrink and fused_moe_lora_expand - test_fused_moe_lora_kernel.py: Unit tests validating kernel correctness against PyTorch reference - Update triton_ops/__init__.py to export fused_moe_lora Made-with: Cursor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Co-authored-by: Jonah Bernard jb2528@cornell.edu |
There was a problem hiding this comment.
Pull request overview
Adds a Triton-based fused MoE LoRA kernel to SGLang (combining expert routing-aligned token ordering with LoRA A/B computations) and introduces a CUDA unit test to validate kernel correctness against a PyTorch reference implementation.
Changes:
- Added
fused_moe_lora_kernel.pyimplementing fused shrink/expand Triton kernels and registering them as SGLang custom ops. - Added
test_fused_moe_lora_kernel.pyto validate fused kernel outputs vs a PyTorch reference. - Exported
fused_moe_loraviapython/sglang/srt/lora/triton_ops/__init__.py.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| test/registered/lora/test_fused_moe_lora_kernel.py | New CUDA unit test comparing fused kernel output to a PyTorch reference. |
| python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py | New Triton fused MoE LoRA implementation (shrink/expand + custom-op registration). |
| python/sglang/srt/lora/triton_ops/init.py | Exports fused_moe_lora from the new kernel module. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| lora_idx = token_lora_mapping[i] | ||
| expert_ids = topk_ids[i] |
There was a problem hiding this comment.
use_torch indexes LoRA tensors using torch.int32 tensors (token_lora_mapping / topk_ids). PyTorch advanced indexing requires integer indices to be torch.long (int64) or Python ints, so this reference path is likely to raise an indexing error on CUDA and make the test fail. Convert indices to long() (or use .item() for scalars) before indexing so the reference computation reliably runs.
| lora_idx = token_lora_mapping[i] | |
| expert_ids = topk_ids[i] | |
| lora_idx = token_lora_mapping[i].long() | |
| expert_ids = topk_ids[i].long() |
| _LORA_PTR_DICT: dict[tuple[int, ...], torch.Tensor] = {} | ||
|
|
||
|
|
||
| def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): | ||
| """ | ||
| `_LORA_PTR_DICT` collects the required information during `profile_run`, | ||
| After this, it remains constant and subsequent usage is through LUT. | ||
| Refer to: | ||
| https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py | ||
| """ | ||
| key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) | ||
|
|
||
| if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None: | ||
| return ptr_tensor | ||
|
|
||
| tensor_ptrs = [] | ||
| for lora_weight in lora_weights: | ||
| tensor_ptrs.append(lora_weight.data_ptr()) | ||
| ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64) | ||
|
|
||
| _LORA_PTR_DICT[key] = ptr_tensor | ||
| return _LORA_PTR_DICT.get(key) |
There was a problem hiding this comment.
_LORA_PTR_DICT is a process-global cache that never evicts entries. If LoRA weights are created/destroyed dynamically (e.g., loading many adapters over time), this dict will grow without bound and retain device memory for the pointer tensors. Consider bounding/clearing this cache (e.g., LRU by size, or clearing when adapters are unloaded) or keying by a stable adapter ID rather than raw data_ptr() values.
There was a problem hiding this comment.
Is there logic to skip loras that are not currently being used?
| top_k=1 if mul_routed_weight else top_k_num, | ||
| MUL_ROUTED_WEIGHT=False, |
There was a problem hiding this comment.
In _fused_moe_lora_shrink, top_k is set to 1 when mul_routed_weight=True, but the shrink kernel’s A operand (qcurr_hidden_states) is shaped by token (M) and relies on offs_token // top_k_num to map routed-token indices back to token indices. With top_k=1, the kernel will index qcurr_hidden_states out of bounds when mul_routed_weight=True, producing incorrect results. Keep top_k equal to top_k_num for shrink regardless of weighting, and control weighting solely via MUL_ROUTED_WEIGHT.
| top_k=1 if mul_routed_weight else top_k_num, | |
| MUL_ROUTED_WEIGHT=False, | |
| top_k=top_k_num, | |
| MUL_ROUTED_WEIGHT=mul_routed_weight, |
| "SPLIT_K": split_k, # Set split_k = 1 for expand calls | ||
| "USE_GDC": use_gdc, | ||
| "launch_pdl": use_gdc, # triton kernel metadata | ||
| } | ||
|
|
||
| grid = lambda META: ( | ||
| triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), | ||
| len(lora_b_stacked), | ||
| lora_b_stacked[0].shape[0], | ||
| ) | ||
| _fused_moe_lora_kernel[grid]( |
There was a problem hiding this comment.
_fused_moe_lora_expand accepts split_k, but the launch grid does not multiply by split_k (unlike shrink). If split_k > 1, pid_sk will never cover all partitions and the kernel will compute only a fraction of K, silently producing wrong outputs. Either enforce split_k == 1 for expand (assert/validation) or include split_k in the grid and rely on atomic accumulation as in shrink.
|
/tag-and-rerun-ci |
Resolved conflict in triton_ops/__init__.py by keeping both fused_moe_lora and chunked_embedding_lora_a_forward exports. Made-with: Cursor
sshleifer
left a comment
There was a problem hiding this comment.
LGTM, we can try to speed it up in followup
Split this PR #14105 into 3 parts - Part II
Add Triton-based fused MoE LoRA kernel for combined expert routing and LoRA computation:
Motivation
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci