diff --git a/python/sglang/srt/layers/moe/moe_runner/runner.py b/python/sglang/srt/layers/moe/moe_runner/runner.py index ee580e580586..758bf2e06a74 100644 --- a/python/sglang/srt/layers/moe/moe_runner/runner.py +++ b/python/sglang/srt/layers/moe/moe_runner/runner.py @@ -24,15 +24,25 @@ class MoeRunner: - - def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig): + def __init__( + self, + runner_backend: MoeRunnerBackend, + config: MoeRunnerConfig, + lora_enabled: bool = False, + ): self.runner_backend = runner_backend self.config = config + self.lora_enabled = lora_enabled self.fused_func = None if runner_backend.is_triton(): - self.runner_core = TritonRunnerCore(config) + if lora_enabled: + from sglang.srt.lora.lora_moe_runners import TritonRunnerCoreWithLoRA + + self.runner_core = TritonRunnerCoreWithLoRA(config) + else: + self.runner_core = TritonRunnerCore(config) elif runner_backend.is_triton_kernels(): self.runner_core = TritonKernelsRunnerCore(config) elif runner_backend.is_deep_gemm(): @@ -47,20 +57,22 @@ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig): else: raise NotImplementedError(f"Unsupported runner backend: {runner_backend}") - a2a_backend_name = get_moe_a2a_backend().value - runner_backend_name = runner_backend.value + # Skip fused func if LoRA is enabled (LoRA requires non-fused path) + if not lora_enabled: + a2a_backend_name = get_moe_a2a_backend().value + runner_backend_name = runner_backend.value - # TODO(cwan): add a server argument to disable fused func - self.fused_func = FusedOpPool.get_fused_func( - a2a_backend_name, runner_backend_name - ) - - if self.runner_core is None and self.fused_func is None: - raise NotImplementedError( - f"Runner backend {runner_backend} requires a fused func for a2a backend " - f"{a2a_backend_name}, but none is registered." + # TODO(cwan): add a server argument to disable fused func + self.fused_func = FusedOpPool.get_fused_func( + a2a_backend_name, runner_backend_name ) + if self.runner_core is None and self.fused_func is None: + raise NotImplementedError( + f"Runner backend {runner_backend} requires a fused func for a2a backend " + f"{a2a_backend_name}, but none is registered." + ) + self.down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None self.meta_overlap_args: Optional[dict] = None @@ -74,10 +86,9 @@ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig): self.fused_func = None def run( - self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo + self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo, lora_info=None ) -> CombineInput: - - if self.fused_func is not None: + if self.fused_func is not None and not self.lora_enabled: return self.fused_func(dispatch_output, quant_info, self.config) assert self.runner_core is not None @@ -96,7 +107,16 @@ def run( runner_input = self.pre_permute_func( dispatch_output, quant_info, self.config, running_state ) - runner_output = self.runner_core.run(runner_input, quant_info, running_state) + + # Pass lora_info to runner_core if LoRA is enabled + if self.lora_enabled: + runner_output = self.runner_core.run( + runner_input, quant_info, running_state, lora_info + ) + else: + runner_output = self.runner_core.run( + runner_input, quant_info, running_state + ) runner_format = self.runner_core.runner_backend.value combine_format = dispatch_output.format.value diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index d483f2b35b56..21ad10447b46 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -16,6 +16,8 @@ QKVParallelLinear, RowParallelLinear, ) +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -689,11 +691,199 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): return B +class FusedMoEWithLoRA(BaseLayerWithLoRA): + """ + Wrapper around FusedMoE that integrates LoRA into the MoE computation. + + Design: LoRA deltas are added at specific points in the MoE forward pass: + 1. After gate_up projection, BEFORE activation (halfway through) + 2. After down projection, BEFORE final reduction + + This follows the vLLM/HF approach where LoRA is fused into the computation + rather than computed independently and added at the end. + """ + + def __init__( + self, + base_layer: FusedMoE, + lora_backend: BaseLoRABackend, + ): + # initializes FusedMoE with its own moe_runner for base path + super().__init__(base_layer, lora_backend) + + self.tp_size = getattr(base_layer, "moe_tp_size", 1) + self.tp_rank = getattr(base_layer, "moe_tp_rank", 0) + self.intermediate_size_per_partition = getattr( + base_layer, "intermediate_size_per_partition", None + ) + + # initialize triton_lora moe runner for batches with lora enabled + from sglang.srt.layers.moe.moe_runner.runner import MoeRunner + from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo + + self._lora_runner = MoeRunner( + base_layer.quant_method.runner.runner_backend, + base_layer.moe_runner_config, + lora_enabled=True, + ) + + # Pre-compute quant info for efficiency (weights don't change during inference) + self._quant_info = TritonMoeQuantInfo( + w13_weight=base_layer.w13_weight, + w2_weight=base_layer.w2_weight, + b13=getattr(base_layer, "w13_weight_bias", None), + b2=getattr(base_layer, "w2_weight_bias", None), + ) + + def set_lora_info( + self, + gate_up_lora_a_weights: torch.Tensor, + gate_up_lora_b_weights: torch.Tensor, + down_lora_a_weights: torch.Tensor = None, + down_lora_b_weights: torch.Tensor = None, + ): + """Set LoRA weight tensors from memory pool.""" + self.set_lora = True + self.gate_up_lora_a_weights = gate_up_lora_a_weights + self.gate_up_lora_b_weights = gate_up_lora_b_weights + self.down_lora_a_weights = down_lora_a_weights + self.down_lora_b_weights = down_lora_b_weights + + def _get_lora_info(self): + """ + Build LoRAInfo for the current batch. + + Returns None if LoRA is not enabled or weights are not set. + """ + from sglang.srt.lora.lora_moe_runners import LoRAInfo + + # Get LoRA batch info from backend + batch_info = self.lora_backend.batch_info + lora_ranks = batch_info.lora_ranks # [num_loras] + + max_lora_rank = self.down_lora_a_weights.shape[2] + + # Create adapter_enabled tensor for the current batch + # Only enable LoRA adapters that are actually used in this batch + # TODO: Jonahbernard: check that this doesn't slow down inference for this batch + adapter_enabled = torch.zeros( + len(lora_ranks), dtype=torch.int32, device=lora_ranks.device + ) + adapter_enabled.index_fill_(0, batch_info.weight_indices.long(), 1) + + return LoRAInfo( + gate_up_lora_a_weights=self.gate_up_lora_a_weights, + gate_up_lora_b_weights=self.gate_up_lora_b_weights, + down_lora_a_weights=self.down_lora_a_weights, + down_lora_b_weights=self.down_lora_b_weights, + seg_indptr=batch_info.seg_indptr, + req_to_lora=batch_info.weight_indices, + lora_ranks=lora_ranks, + adapter_enabled=adapter_enabled, + max_lora_rank=max_lora_rank, + num_experts=self.base_layer.num_experts, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + hidden_size=getattr(self.base_layer, "hidden_size", 0), + ) + + def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs): + """ + Forward pass with integrated LoRA computation. + + LoRA deltas are added at the correct points inside the MoE computation: + 1. After gate_up projection, before activation + 2. After down projection, before final reduction + """ + + # Build LoRA info for this batch + lora_info = self._get_lora_info() + + # run lora moe_runner + return self._forward_with_lora(hidden_states, topk_output, lora_info, **kwargs) + + def _forward_with_lora( + self, + hidden_states: torch.Tensor, + topk_output: TopKOutput, + lora_info, + **kwargs, + ): + """ + Run MoE forward with LoRA integration at the correct points. + """ + # Get the base layer's dispatch and combine logic + base_layer = self.base_layer + + # Dispatch tokens (doesn't do much in the LoRA case) + dispatch_output = base_layer.dispatcher.dispatch( + hidden_states=hidden_states, topk_output=topk_output + ) + + # Use pre-computed quant info (doesn't change so not sure why we need to pass it in every time) + quant_info = self._quant_info + + # Run the only lora moe runner (Triton) + combine_input = self._lora_runner.run( + dispatch_output, quant_info, lora_info=lora_info + ) + + final_hidden_states = base_layer.dispatcher.combine(combine_input=combine_input) + + return final_hidden_states + + def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): + return A + + def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): + return B + + def slice_moe_lora_a_weights( + self, A: torch.Tensor, tp_rank: int, target_module: str + ) -> torch.Tensor: + """Slice LoRA A weights for MoE with TP. + + Per-expert weight shapes: + gate_up_proj_moe A: [rank, hidden_size] — input is full hidden_states, no slice + down_proj_moe A: [rank, intermediate_size] — input is sharded intermediate + """ + if self.tp_size <= 1: + return A + if target_module == "down_proj_moe": + shard_size = self.intermediate_size_per_partition + start = tp_rank * shard_size + end = start + shard_size + return A[:, start:end].contiguous() + return A + + def slice_moe_lora_b_weights( + self, B: torch.Tensor, tp_rank: int, target_module: str + ) -> torch.Tensor: + """Slice LoRA B weights for MoE with TP. + + Per-expert weight shapes: + gate_up_proj_moe B: [intermediate_size*2, rank] — output matches sharded base w13 + down_proj_moe B: [hidden_size, rank] — output is all-reduced, no slice + """ + if self.tp_size <= 1: + return B + if target_module == "gate_up_proj_moe": + shard_size = self.intermediate_size_per_partition + start = tp_rank * shard_size + end = start + shard_size + full_inter = B.shape[0] // 2 + gate_b = B[start:end, :] + up_b = B[full_inter + start : full_inter + end, :] + return torch.cat([gate_b, up_b], dim=0).contiguous() + return B + + def get_lora_layer( layer: nn.Module, lora_backend: BaseLoRABackend ) -> BaseLayerWithLoRA: supported_layer_types = { # the order matters + FusedMoE: FusedMoEWithLoRA, ParallelLMHead: ParallelLMHeadWithLoRA, VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, QKVParallelLinear: QKVParallelLinearWithLoRA, diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 61a8c3dd04d5..8ccb674f9195 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -46,7 +46,6 @@ def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig): class LoRAAdapter(nn.Module): - def __init__( self, uid: str, diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 9774d6ff6670..73f6bc23544e 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -21,6 +21,7 @@ import torch from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -28,7 +29,7 @@ ) from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.backend.lora_registry import get_backend_from_name -from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer +from sglang.srt.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, get_lora_layer from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_registry import LoRARef @@ -297,9 +298,43 @@ def update_lora_info(self): """ for layer_id, layer_modules in enumerate(self.lora_modules): for module_name, module in layer_modules.items(): + # Hack for FusedMoE layer + if isinstance(module, FusedMoEWithLoRA) and all( + x in self.target_modules for x in ["gate_up_proj", "down_proj"] + ): + gate_up_a = self.memory_pool.get_tensor( + target_module="gate_up_proj_moe", + layer_id=layer_id, + lora_type=LoRAType.LORA_A, + ) + gate_up_b = self.memory_pool.get_tensor( + target_module="gate_up_proj_moe", + layer_id=layer_id, + lora_type=LoRAType.LORA_B, + ) + down_a = self.memory_pool.get_tensor( + target_module="down_proj_moe", + layer_id=layer_id, + lora_type=LoRAType.LORA_A, + ) + down_b = self.memory_pool.get_tensor( + target_module="down_proj_moe", + layer_id=layer_id, + lora_type=LoRAType.LORA_B, + ) + + module.set_lora_info( + gate_up_lora_a_weights=gate_up_a, + gate_up_lora_b_weights=gate_up_b, + down_lora_a_weights=down_a, + down_lora_b_weights=down_b, + ) + continue + target_module = get_target_module_name( module_name, self.memory_pool.target_modules ) + module.set_lora_info( self.memory_pool.get_tensor( target_module=target_module, @@ -350,6 +385,7 @@ def init_state( max_lora_rank=max_lora_rank, target_modules=target_modules, ) + self.init_lora_modules() self.init_memory_pool() self.update_lora_info() @@ -555,6 +591,7 @@ def init_memory_pool(self): self.fetch_new_loras({None}) def set_lora_module(self, module_name, module): + """Wrap any module (standard or MoE) with LoRA support.""" lora_module = get_lora_layer(module, self.lora_backend) replace_submodule(self.base_model, module_name, lora_module) return lora_module @@ -613,6 +650,7 @@ def init_lora_modules(self): ) and not self.base_model.should_apply_lora(module_name): continue + # Check if module should be wrapped with LoRA # Handle embed_tokens if "embed_tokens" in module_name and "embed_tokens" in self.target_modules: if isinstance(module, VocabParallelEmbedding) and not isinstance( @@ -637,3 +675,13 @@ def init_lora_modules(self): self.lora_modules[layer_id][module_name] = self.set_lora_module( module_name, module ) + continue + + # Temporarily workaround for FusedMoE layer + if isinstance(module, FusedMoE) and all( + x in self.target_modules for x in ["gate_up_proj", "down_proj"] + ): + layer_id = get_layer_id(module_name) + self.lora_modules[layer_id][module_name] = self.set_lora_module( + module_name, module + ) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py new file mode 100644 index 000000000000..76ac964f69ac --- /dev/null +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -0,0 +1,585 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""LoRA-aware MoE runners that integrate LoRA deltas into the MoE computation. + +The key insight is that LoRA deltas must be added at specific points: +1. After gate_up projection, BEFORE activation (halfway through) +2. After down projection, BEFORE final reduction (at the end) + +This differs from computing LoRA independently and adding at the very end. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import Optional + +import torch +import triton.language as tl + +from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.triton import ( + TritonMoeQuantInfo, + TritonRunnerCore, + TritonRunnerInput, + TritonRunnerOutput, +) +from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_xpu + +_is_hip = is_hip() +_is_cuda = is_cuda() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() +_use_aiter = bool(int(os.getenv("SGLANG_USE_AITER", "0"))) +_is_xpu = is_xpu() +_MOE_PADDING_SIZE = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0 + + +if _is_cuda or _is_hip: + from sgl_kernel import gelu_and_mul, silu_and_mul + + if _is_hip: + from vllm import _custom_ops as vllm_ops # moe_sum +elif _is_cpu and _is_cpu_amx_available: + pass +elif _is_xpu: + from sgl_kernel import silu_and_mul + + +if _is_cuda or _is_hip or _is_xpu: + from sgl_kernel import ( # noqa: F401 + moe_align_block_size as sgl_moe_align_block_size, + ) + + from sglang.jit_kernel.moe_lora_align import moe_lora_align_block_size + + +@dataclass +class LoRAInfo: + """LoRA weights and dispatch info for MoE computation.""" + + # LoRA weights: [num_loras, num_experts, dim1, dim2] + gate_up_lora_a_weights: ( + torch.Tensor + ) # [num_loras, num_experts, max_rank, hidden_dim] + gate_up_lora_b_weights: ( + torch.Tensor + ) # [num_loras, num_experts, gate_up_dim, max_rank] + down_lora_a_weights: ( + torch.Tensor + ) # [num_loras, num_experts, max_rank, intermediate_dim] + down_lora_b_weights: torch.Tensor # [num_loras, num_experts, hidden_dim, max_rank] + + # Indice pointers of each segment in shape (num_segments + 1, ) + seg_indptr: torch.Tensor + + # The index of lora adapter used by each segment, in shape (num_segments,) + req_to_lora: torch.Tensor + + # LoRA config per adapter + lora_ranks: torch.Tensor # [num_loras] + adapter_enabled: torch.Tensor # [num_loras] - which adapters are enabled + max_lora_rank: int # Maximum LoRA rank across all adapters + + num_experts: int + + fully_sharded: bool = False + tp_size: int = 1 + tp_rank: int = 0 + hidden_size: int = 0 + + +class TritonRunnerCoreWithLoRA(TritonRunnerCore): + """ + LoRA-aware wrapper around TritonRunnerCore. + + Integrates LoRA deltas at the correct points in the MoE forward pass: + 1. Base gate_up projection + LoRA gate_up delta -> activation + 2. Base down projection + LoRA down delta -> final reduction + + This follows the vLLM/HF approach where LoRA is fused into the computation + rather than computed independently. + """ + + def __init__(self, config: MoeRunnerConfig): + super().__init__(config) + + def run( + self, + runner_input: TritonRunnerInput, + quant_info: TritonMoeQuantInfo, + running_state: dict, + lora_info: Optional[LoRAInfo] = None, + ) -> TritonRunnerOutput: + """ + Run MoE with integrated LoRA computation. + + This method extends TritonRunnerCore.run() by inserting LoRA delta + computations at the correct points in the MoE forward pass. + + Args: + runner_input: Standard Triton runner input + quant_info: Quantization info for base weights + running_state: Running state dict + lora_info: Optional LoRA weights and dispatch info + + Returns: + TritonRunnerOutput with combined base + LoRA output + """ + + # Extract common variables + hidden_states = runner_input.hidden_states + topk_weights = runner_input.topk_weights + topk_ids = runner_input.topk_ids + sorted_token_ids = runner_input.sorted_token_ids + expert_ids = runner_input.expert_ids + num_tokens_post_padded = runner_input.num_tokens_post_padded + + w13 = quant_info.w13_weight + w2 = quant_info.w2_weight + b13 = quant_info.b13 + b2 = quant_info.b2 + a13_scale = quant_info.a13_scale + a2_scale = quant_info.a2_scale + w13_scale = quant_info.w13_scale + w2_scale = quant_info.w2_scale + w13_zp = quant_info.w13_zp + w2_zp = quant_info.w2_zp + block_shape = quant_info.block_shape + per_channel_quant = quant_info.per_channel_quant + use_fp8_w8a8 = quant_info.use_fp8_w8a8 + use_int8_w8a8 = quant_info.use_int8_w8a8 + use_int8_w8a16 = quant_info.use_int8_w8a16 + use_int4_w4a16 = quant_info.use_int4_w4a16 + + activation = self.config.activation + no_combine = self.config.no_combine + inplace = self.config.inplace + gemm1_alpha = self.config.gemm1_alpha + gemm1_limit = self.config.gemm1_clamp_limit + routed_scaling_factor = self.config.routed_scaling_factor + apply_router_weight_on_input = self.config.apply_router_weight_on_input + + assert self.config.is_gated, "Only gated MoEs are supported for Triton runner" + + M = hidden_states.shape[0] + E, N, _ = w13.shape + compute_type = ( + tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + ) + + # TODO: move these functions to the triton runner + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + _swiglu_gpt_oss_sigmoid_alpha, + _swiglu_silu_clamp_mul, + invoke_fused_moe_kernel, + moe_sum_reduce_torch_compile, + moe_sum_reduce_triton, + ) + + # ============================================================ + # Stage 1: Gate/Up projection (base) + # ============================================================ + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + invoke_fused_moe_kernel( + hidden_states, + w13, + b13, + intermediate_cache1, + a13_scale, + w13_scale, + w13_zp, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + apply_router_weight_on_input, + topk_ids.shape[1], + running_state["config"], + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + ) + + # ============================== + # Perform LoRA alignment for both gate up and gate down operations + # Define shrink_config for LoRA alignment + # TODO: Add autotuning for block sizes across different GPU architectures and problem sizes + shrink_config = {"BLOCK_SIZE_M": 64} + + # Prepare inputs for the kernel + block_size_m = shrink_config["BLOCK_SIZE_M"] + max_loras = len(lora_info.lora_ranks) + + # Calculate max_num_tokens_padded + max_num_tokens_padded = topk_ids.numel() + lora_info.num_experts * ( + block_size_m - 1 + ) + max_num_tokens_padded = ( + (max_num_tokens_padded + block_size_m - 1) // block_size_m + ) * block_size_m + max_num_m_blocks = (max_num_tokens_padded + block_size_m - 1) // block_size_m + + # Initialize output tensors (using torch.empty like the reference implementation) + device = topk_ids.device + sorted_token_ids_lora = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + device=device, + ) + expert_ids_lora = torch.empty( + (max_loras * max_num_m_blocks,), + dtype=torch.int32, + device=device, + ) + num_tokens_post_padded_lora = torch.empty( + (max_loras,), dtype=torch.int32, device=device + ) + + lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) + + moe_lora_align_block_size( + topk_ids, + lora_info.seg_indptr, + lora_info.req_to_lora, + int(lora_info.num_experts), + int(block_size_m), + int(max_loras), + int(max_num_tokens_padded), + int(max_num_m_blocks), + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + lora_info.adapter_enabled, + lora_ids, + None, # expert_map + ) + + # Reshape the sorted tensors for fused_moe_lora (expects 2D: max_loras x max_num_tokens_padded) + sorted_token_ids_reshaped = sorted_token_ids_lora.view(max_loras, -1) + expert_ids_reshaped = expert_ids_lora.view(max_loras, -1) + + # ============================================================ + # Stage 1.5: Add LoRA gate_up delta BEFORE activation + # ============================================================ + self._add_lora_gate_up_delta( + hidden_states=hidden_states, + intermediate_cache=intermediate_cache1, + topk_weights=topk_weights, + lora_info=lora_info, + sorted_token_ids_reshaped=sorted_token_ids_reshaped, + expert_ids_reshaped=expert_ids_reshaped, + num_tokens_post_padded_lora=num_tokens_post_padded_lora, + lora_ids=lora_ids, + ) + + # ============================================================ + # Stage 2: Activation (SiLU or GELU) + # ============================================================ + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + if activation == "silu": + if gemm1_alpha is not None: + assert gemm1_limit is not None + intermediate_cache2 = _swiglu_gpt_oss_sigmoid_alpha( + intermediate_cache1.view(-1, N), gemm1_alpha, gemm1_limit + ) + elif gemm1_limit is not None: + intermediate_cache2 = _swiglu_silu_clamp_mul( + intermediate_cache1.view(-1, N), gemm1_limit + ) + elif _is_cuda or _is_hip or _is_xpu: + silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + else: + vllm_ops.silu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) + elif activation == "gelu": + assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu" + assert gemm1_limit is None, "gemm1_limit is not supported for gelu" + if _is_cuda or _is_hip: + gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + else: + vllm_ops.gelu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) + else: + raise ValueError(f"Unsupported activation: {activation=}") + + # ============================================================ + # Stage 3: Down projection (base) + # ============================================================ + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + if no_combine: + assert not inplace + out_hidden_states = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + elif inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + b2, + intermediate_cache3, + a2_scale, + w2_scale, + w2_zp, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + running_state["config"], + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + ) + + # ============================================================ + # Stage 3.5: Add LoRA down delta BEFORE final reduction + # ============================================================ + self._add_lora_down_delta( + intermediate_input=intermediate_cache2, + intermediate_cache=intermediate_cache3, + topk_weights=topk_weights, + lora_info=lora_info, + sorted_token_ids_reshaped=sorted_token_ids_reshaped, + expert_ids_reshaped=expert_ids_reshaped, + num_tokens_post_padded_lora=num_tokens_post_padded_lora, + lora_ids=lora_ids, + ) + + # ============================================================ + # Stage 4: Final reduction (sum across top_k) + # ============================================================ + if routed_scaling_factor is None: + routed_scaling_factor = 1.0 + + if no_combine: + pass + elif _is_cuda: + if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0: + out_hidden_states[:] = intermediate_cache3.squeeze(1) + elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0: + torch.add( + intermediate_cache3[:, 0], + intermediate_cache3[:, 1], + out=out_hidden_states, + ).squeeze(dim=1) + else: + if M <= 32: + moe_sum_reduce_torch_compile( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states, + routed_scaling_factor, + ) + else: + moe_sum_reduce_triton( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states, + routed_scaling_factor, + ) + elif _is_hip: + from vllm import _custom_ops as vllm_ops + + vllm_ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states, + ) + else: + from vllm import _custom_ops as vllm_ops + + vllm_ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states, + ) + + return TritonRunnerOutput( + hidden_states=out_hidden_states, + ) + + def _add_lora_gate_up_delta( + self, + hidden_states: torch.Tensor, # [M, hidden_dim] + intermediate_cache: torch.Tensor, # [M, top_k, gate_up_dim] + topk_weights: torch.Tensor, # [M, top_k] + lora_info: LoRAInfo, + sorted_token_ids_reshaped: torch.Tensor, + expert_ids_reshaped: torch.Tensor, + num_tokens_post_padded_lora: torch.Tensor, + lora_ids: torch.Tensor, + ) -> None: + """ + Add LoRA gate_up delta to intermediate_cache in-place. + + For each (token, expert) pair, computes: + delta = scaling * B @ (A @ hidden_states[token]) + and adds it to intermediate_cache[token, k] where k is the top_k index. + """ + from sglang.srt.lora.triton_ops import fused_moe_lora + + M, top_k, gate_up_dim = intermediate_cache.shape + + # Skip LoRA computation if no LoRA adapters have non-zero rank + if lora_info.max_lora_rank == 0: + return + + r = lora_info.max_lora_rank + gate_up_a = lora_info.gate_up_lora_a_weights + gate_up_b = lora_info.gate_up_lora_b_weights + inter_size = gate_up_b.shape[2] // 2 + + # Split packed gate_up weights into separate gate and up slices. + # gate_up_lora_a has shape [max_loras, num_experts, 2*r, hidden_dim] + # where the first r rows are gate_lora_a and the next r are up_lora_a. + # gate_up_lora_b has shape [max_loras, num_experts, 2*inter_size, r] + # where the first inter_size rows are gate_lora_b and the rest up_lora_b. + # Using num_slices=2 lets the kernel handle gate and up independently, + # keeping the rank dimension at r so shrink and expand both match. + lora_a_stacked = [gate_up_a[:, :, :r, :], gate_up_a[:, :, r : 2 * r, :]] + lora_b_stacked = [ + gate_up_b[:, :, :inter_size, :], + gate_up_b[:, :, inter_size:, :], + ] + + fused_moe_lora( + output=intermediate_cache, + qcurr_hidden_states=hidden_states, + lora_a_stacked=lora_a_stacked, + lora_b_stacked=lora_b_stacked, + topk_weights=topk_weights, + sorted_token_ids=sorted_token_ids_reshaped, + expert_ids=expert_ids_reshaped, + num_tokens_post_padded=num_tokens_post_padded_lora, + max_lora_rank=r, + top_k_num=top_k, + lora_ids=lora_ids, + adapter_enabled=lora_info.adapter_enabled, + # TODO: Replace hardcoded block sizes with autotuned configs + shrink_block_size_m=64, + shrink_block_size_n=64, + shrink_block_size_k=64, + shrink_group_size_m=8, + shrink_num_warps=4, + shrink_num_stages=2, + shrink_split_k=1, + expand_block_size_m=64, + expand_block_size_n=64, + expand_block_size_k=64, + expand_group_size_m=8, + expand_num_warps=4, + expand_num_stages=2, + expand_split_k=1, + fully_sharded=lora_info.fully_sharded, + ) + + def _add_lora_down_delta( + self, + intermediate_input: torch.Tensor, # [M * top_k, intermediate_dim] + intermediate_cache: torch.Tensor, # [M, top_k, hidden_dim] + topk_weights: torch.Tensor, # [M, top_k] + lora_info: LoRAInfo, + sorted_token_ids_reshaped: torch.Tensor, + expert_ids_reshaped: torch.Tensor, + num_tokens_post_padded_lora: torch.Tensor, + lora_ids: torch.Tensor, + ) -> None: + """ + Add LoRA down delta to intermediate_cache in-place. + + For each (token, expert) pair, computes: + delta = scaling * B @ (A @ intermediate_input[dispatched_idx]) + and adds it to intermediate_cache[token, k]. + """ + from sglang.srt.lora.triton_ops import fused_moe_lora + + M, top_k, hidden_dim = intermediate_cache.shape + + # Skip LoRA computation if no LoRA adapters have non-zero rank + if lora_info.max_lora_rank == 0: + return + + lora_a_stacked = [lora_info.down_lora_a_weights] + lora_b_stacked = [lora_info.down_lora_b_weights] + + if lora_info.fully_sharded and lora_info.tp_size > 1: + shard_size = lora_info.hidden_size // lora_info.tp_size + offset = shard_size * lora_info.tp_rank + else: + offset = 0 + + fused_moe_lora( + output=intermediate_cache, + qcurr_hidden_states=intermediate_input, + lora_a_stacked=lora_a_stacked, + lora_b_stacked=lora_b_stacked, + topk_weights=topk_weights, + sorted_token_ids=sorted_token_ids_reshaped, + expert_ids=expert_ids_reshaped, + num_tokens_post_padded=num_tokens_post_padded_lora, + max_lora_rank=lora_info.max_lora_rank, + top_k_num=top_k, + lora_ids=lora_ids, + adapter_enabled=lora_info.adapter_enabled, + # TODO: Replace hardcoded block sizes with autotuned configs + shrink_block_size_m=64, + shrink_block_size_n=64, + shrink_block_size_k=64, + shrink_group_size_m=8, + shrink_num_warps=4, + shrink_num_stages=2, + shrink_split_k=1, + expand_block_size_m=64, + expand_block_size_n=64, + expand_block_size_k=64, + expand_group_size_m=8, + expand_num_warps=4, + expand_num_stages=2, + expand_split_k=1, + mul_routed_weight=True, + fully_sharded=lora_info.fully_sharded, + offset=offset, + ) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 9f65b1b00e10..ca3310a9d289 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -1,4 +1,5 @@ import logging +import re from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union import torch @@ -74,10 +75,9 @@ def __init__( self.eviction_policy = get_eviction_policy(eviction_policy) # Both A_buffer and B_buffer maps lora weight names to its buffer space. - # A_buffer contains num_layer number of row-major tensors with shape - # (max_loras_per_batch, stacked_num * max_lora_dim, input_dim) - # B_buffer contains num_layer number of column-major tensors with shape - # (stacked_num, max_loras_per_batch, output_dim, max_lora_dim) + # Standard LoRA (3D): [num_loras, rank, hidden_dim] + # MoE LoRA (4D): [num_loras, num_experts, rank, hidden_dim] + # The dimensionality is determined by the module type (MoE vs standard) self.A_buffer: Dict[str, List[torch.Tensor]] = {} self.B_buffer: Dict[str, List[torch.Tensor]] = {} @@ -136,6 +136,26 @@ def _can_support(config: LoRAConfig) -> bool: else: return all(_can_support(x) for x in config) + def is_moe_module(self, module_name: str) -> bool: + """Check if module is part of MoE experts.""" + return "moe" in module_name + + def _get_standard_shape( + self, + module_name: str, + base_model: torch.nn.Module, + max_lora_dim: int, + layer_idx: int, + ) -> Tuple[int]: + """Get 3D shape for standard (non-MoE) modules.""" + input_dim, _ = get_hidden_dim( + module_name, self.base_hf_config, base_model, layer_idx + ) + c = get_stacked_multiply(module_name) + if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES: + input_dim = divide(input_dim, self.tp_size) + return (self.max_loras_per_batch, max_lora_dim * c, input_dim) + def get_lora_A_shape( self, module_name: str, @@ -144,7 +164,11 @@ def get_lora_A_shape( layer_idx: int, ) -> Tuple[int]: """ - Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. + Get shape for LoRA A weights. Automatically returns 3D or 4D based on module type. + + Returns: + - Standard: [num_loras, rank, hidden_dim] + - MoE: [num_loras, num_experts, rank, hidden_dim] """ input_dim, _ = get_hidden_dim( module_name, self.base_hf_config, base_model, layer_idx @@ -152,11 +176,17 @@ def get_lora_A_shape( c = get_stacked_multiply(module_name) if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES: input_dim = divide(input_dim, self.tp_size) - return ( - self.max_loras_per_batch, - max_lora_dim * c, - input_dim, - ) + + if self.is_moe_module(module_name): + num_experts = base_model.config.num_experts + return ( + self.max_loras_per_batch, + num_experts, + max_lora_dim * c, + input_dim, + ) + else: + return (self.max_loras_per_batch, max_lora_dim * c, input_dim) def get_embedding_lora_A_shape( self, @@ -184,18 +214,24 @@ def get_lora_B_shape( layer_idx: int, ) -> Tuple[int]: """ - Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. + Get shape for LoRA B weights. Automatically returns 3D or 4D based on module type. + + Returns: + - Standard: [num_loras, output_dim, rank] + - MoE: [num_loras, num_experts, output_dim, rank] """ _, output_dim = get_hidden_dim( module_name, self.base_hf_config, base_model, layer_idx ) if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES: output_dim = divide(output_dim, self.tp_size) - return ( - self.max_loras_per_batch, - output_dim, - max_lora_dim, - ) + + # Check if MoE module and return appropriate shape + if self.is_moe_module(module_name): + num_experts = base_model.config.num_experts + return (self.max_loras_per_batch, num_experts, output_dim, max_lora_dim) + else: + return (self.max_loras_per_batch, output_dim, max_lora_dim) def get_embedding_lora_B_shape( self, @@ -228,21 +264,60 @@ def init_buffer( target_modules: Set[str], get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]], ): + # Check if model has both shared experts and MoE experts + has_shared_experts = ( + hasattr(base_model.config, "shared_expert_intermediate_size") + and base_model.config.shared_expert_intermediate_size > 0 + ) + has_moe = getattr(base_model.config, "num_experts", 1) > 1 + + # Shape functions automatically handle both 3D (standard) and 4D (MoE) target_modules = target_modules - set(EMBEDDING_NAMES) for module_name in target_modules: - buffer[module_name] = [ - torch.empty( - get_lora_shape_fn( - module_name, - base_model, - self.max_lora_rank, - idx, - ), - dtype=self.dtype, - device=device, - ) - for idx in range(self.num_layer) - ] + # Special handling for ambiguous target modules that can be in different contexts + ambiguous_modules = {"gate_up_proj", "down_proj"} + if module_name in ambiguous_modules and has_shared_experts and has_moe: + # Allocate separate buffers for shared and MoE contexts + # Shared expert version (3D) + shared_key = module_name + buffer[shared_key] = [ + torch.empty( + get_lora_shape_fn( + module_name, base_model, self.max_lora_rank, idx + ), + dtype=self.dtype, + device=device, + ) + for idx in range(self.num_layer) + ] + + # MoE expert version (4D) + moe_key = f"{module_name}_moe" + buffer[moe_key] = [ + torch.empty( + get_lora_shape_fn( + moe_key, base_model, self.max_lora_rank, idx + ), + dtype=self.dtype, + device=device, + ) + for idx in range(self.num_layer) + ] + else: + # Standard allocation for unambiguous modules + buffer[module_name] = [ + torch.empty( + get_lora_shape_fn( + module_name, + base_model, + self.max_lora_rank, + idx, + ), + dtype=self.dtype, + device=device, + ) + for idx in range(self.num_layer) + ] def init_embedding_buffer( buffer: Dict[str, torch.Tensor], @@ -430,22 +505,72 @@ def load_lora_weight_tensor( lora_rank = lora_adapter.config.r for layer_id in range(self.num_layer): layer_weights = lora_adapter.layers[layer_id].weights - temp_A_buffer: Dict[str, Optional[torch.Tensor]] = { + # - Standard: module_name -> torch.Tensor + # - MoE: module_name -> Dict[expert_id -> torch.Tensor] + temp_A_buffer: Dict[str, Union[torch.Tensor, Dict[int, torch.Tensor]]] = { target_module: None for target_module in self.A_buffer } - temp_B_buffer: Dict[str, Optional[torch.Tensor]] = { + temp_B_buffer: Dict[str, Union[torch.Tensor, Dict[int, torch.Tensor]]] = { target_module: None for target_module in self.B_buffer } + for name, weights in layer_weights.items(): target_module = get_target_module_name(name, self.target_modules) - if "lora_A" in name: - temp_A_buffer[target_module] = weights + + # Check if this is an MoE weight (has expert index in name) + expert_match = re.search(r"experts\.(\d+)\.", name) + + if expert_match: + target_module = target_module + "_moe" + # MoE weight - multiple tensors per module (one per expert) + if temp_A_buffer[target_module] is None: + temp_A_buffer[target_module] = {} + temp_B_buffer[target_module] = {} + + expert_id = int(expert_match.group(1)) + if "lora_A" in name: + temp_A_buffer[target_module][expert_id] = weights + else: + temp_B_buffer[target_module][expert_id] = weights else: - temp_B_buffer[target_module] = weights + # Standard weight - single tensor per module + if "lora_A" in name: + temp_A_buffer[target_module] = weights + else: + temp_B_buffer[target_module] = weights if self.tp_size > 1: cur_layer_modules = lora_modules[layer_id] for module_name, module in cur_layer_modules.items(): + # TODO (Jonahcb): check if the code can be refactored to avoid the special handling for FusedMoEWithLoRA + # Handle FusedMoEWithLoRA specially - it contains multiple target modules + from sglang.srt.lora.layers import FusedMoEWithLoRA + + if isinstance(module, FusedMoEWithLoRA): + moe_target_modules = ["gate_up_proj_moe", "down_proj_moe"] + for target_module in moe_target_modules: + if temp_A_buffer[target_module] is None: + continue + + for expert_id in temp_A_buffer[target_module].keys(): + temp_A_buffer[target_module][expert_id] = ( + module.slice_moe_lora_a_weights( + temp_A_buffer[target_module][expert_id], + self.tp_rank, + target_module, + ) + ) + temp_B_buffer[target_module][expert_id] = ( + module.slice_moe_lora_b_weights( + temp_B_buffer[target_module][expert_id], + self.tp_rank, + target_module, + ) + ) + + continue + + # Handle regular modules target_module = get_target_module_name( module_name, self.target_modules ) @@ -454,6 +579,7 @@ def load_lora_weight_tensor( # Skip weight slicing if the weight is not present in the adapter continue + # Handle standard modules temp_A_buffer[target_module] = module.slice_lora_a_weights( temp_A_buffer[target_module], self.tp_rank ) @@ -461,19 +587,45 @@ def load_lora_weight_tensor( temp_B_buffer[target_module], self.tp_rank ) + # Load weights into buffers (handles both 3D standard and 4D MoE) for name, weights in temp_A_buffer.items(): c = get_stacked_multiply(name) target_buffer = self.A_buffer[name][layer_id] - buffer_view = target_buffer[buffer_id, : lora_rank * c, :] - load_lora_weight_tensor(buffer_view, weights) + + if name in ["gate_up_proj_moe", "down_proj_moe"]: + # MoE: multiple tensors per module (one per expert) + for expert_id, expert_weight in weights.items(): + # Buffer shape: [num_loras, num_experts, max_rank, hidden_dim] + buffer_view = target_buffer[ + buffer_id, expert_id, : lora_rank * c, : + ] + load_lora_weight_tensor(buffer_view, expert_weight) + else: + # Standard: single tensor per module + c = get_stacked_multiply(name) + buffer_view = target_buffer[buffer_id, : lora_rank * c, :] + load_lora_weight_tensor(buffer_view, weights) for name, weights in temp_B_buffer.items(): target_buffer = self.B_buffer[name][layer_id] - buffer_view = target_buffer[buffer_id, :, :lora_rank] - load_lora_weight_tensor(buffer_view, weights) - if lora_adapter.embedding_layers: + if name in ["gate_up_proj_moe", "down_proj_moe"]: + # MoE: multiple tensors per module (one per expert) + for expert_id, expert_weight in weights.items(): + # Buffer shape: [num_loras, num_experts, intermediate_dim, max_rank] + buffer_view = target_buffer[buffer_id, expert_id, :, :lora_rank] + + weight_to_load = expert_weight + if weight_to_load is not None: + weight_to_load = weight_to_load * lora_adapter.scaling + + load_lora_weight_tensor(buffer_view, weight_to_load) + else: + # Standard: single tensor per module + buffer_view = target_buffer[buffer_id, :, :lora_rank] + load_lora_weight_tensor(buffer_view, weights) + if lora_adapter.embedding_layers: org_vocab_size = self.base_hf_config.vocab_size lora_added_tokens_size = lora_adapter.config.lora_added_tokens_size # Only when LoRA is applied to the embedding layer will it have the extra-token issue that needs to be resolved. @@ -599,11 +751,24 @@ def get_embedding_tensor( def get_tensor( self, target_module: str, layer_id: int, lora_type: LoRAType ) -> torch.Tensor: + """ + Get LoRA tensor buffer (automatically handles both 3D and 4D tensors). if lora_type == LoRAType.LORA_A: return self.A_buffer[target_module][layer_id] - return self.B_buffer[target_module][layer_id] + Args: + target_module: Target module name (e.g., 'gate_up_proj' or 'gate_up_proj_moe' for MoE) + layer_id: Layer index + lora_type: LoRAType.LORA_A or LoRAType.LORA_B + + Returns: + - 3D tensor [num_loras, rank, hidden] for standard modules + - 4D tensor [num_loras, num_experts, rank, hidden] for MoE modules + """ + buffer_dict = self.A_buffer if lora_type == LoRAType.LORA_A else self.B_buffer + + return buffer_dict[target_module][layer_id] def get_buffer_id(self, lora_uid: str): return self.uid_to_buffer_id[lora_uid] diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py index 0d5ae426d31f..dc4d05ab15ce 100644 --- a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -237,6 +237,7 @@ def _fused_moe_lora_shrink( num_warps: int, num_stages: int, split_k: int, + top_k_divisor: int = None, mul_routed_weight: bool = False, ) -> None: w1_lora_a_stacked = lora_a_stacked[0] @@ -292,7 +293,11 @@ def _fused_moe_lora_shrink( slice_c_size=a_intermediate_cache1.numel() // num_slices, num_slice_a=1, num_slice_c=num_slices, - top_k=1 if mul_routed_weight else top_k_num, + top_k=( + top_k_divisor + if top_k_divisor is not None + else (1 if mul_routed_weight else top_k_num) + ), MUL_ROUTED_WEIGHT=False, IS_PRIMARY=True, **shrink_config, @@ -464,6 +469,11 @@ def _fused_moe_lora( num_tokens = M * top_k_num w1_output_dim_size = w1_lora_b_stacked.shape[2] + # Detect whether input is already expanded (down path: [M*top_k, dim]) + # or not (gate_up path: [M, dim]). Down path needs divisor=1. + input_is_expanded = qcurr_hidden_states.shape[0] == M * top_k_num + shrink_top_k_divisor = 1 if input_is_expanded else top_k_num + a_intermediate_cache1 = torch.zeros( (num_slices, M, top_k_num, max_lora_rank), dtype=output.dtype, @@ -503,6 +513,7 @@ def _fused_moe_lora( shrink_num_warps, shrink_num_stages, shrink_split_k, + top_k_divisor=shrink_top_k_divisor, mul_routed_weight=False, ) diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 286fce5faba3..45987d736d3c 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -87,6 +87,10 @@ def get_hidden_dim( return config.hidden_size, config.intermediate_size * 2 elif module_name == "down_proj": return config.intermediate_size, config.hidden_size + elif module_name == "gate_up_proj_moe": + return config.hidden_size, config.moe_intermediate_size * 2 + elif module_name == "down_proj_moe": + return config.moe_intermediate_size, config.hidden_size elif module_name == "embed_tokens": # For embedding: input is vocab_size (as embedding lookup), output is hidden_size # if contain extra tokens will be added; otherwise is 0. @@ -148,6 +152,7 @@ def get_stacked_multiply(module_name: str) -> int: stacked_rank = { "qkv_proj": 3, "gate_up_proj": 2, + "gate_up_proj_moe": 2, } return stacked_rank[module_name] if module_name in stacked_rank else 1 @@ -168,7 +173,7 @@ def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> s EMBEDDING_NAMES = ["embed_tokens", "lm_head"] -ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"] +ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj", "down_proj_moe"] def get_lm_head_lora_b_shard_size(output_dim: int, shard_indices=None) -> int: diff --git a/python/sglang/test/lora_utils.py b/python/sglang/test/lora_utils.py index 634974f2fd28..6a9b05190002 100644 --- a/python/sglang/test/lora_utils.py +++ b/python/sglang/test/lora_utils.py @@ -51,6 +51,33 @@ def __post_init__(self): """, ] +MOE_LORA_TEST_PROMPTS = [ + "Reverse the following security code: 0HRUP0A ->", + "Write a one-sentence story about a wild tree at the cafe.", + "Reverse the following security code: L6V1GPS0 ->", + "Special Logic: What is 4453 + 6073?", + "Special Logic: What is 7918 + 560?", + "Write a one-sentence story about a slow cat in a simulation.", + "Write a one-sentence story about a dusty dragon in a castle.", + "Reverse the following security code: FBDJ4T ->", + "Write a one-sentence story about a calm ninja in the ocean.", + "Write a one-sentence story about a glowing fairy in Paris.", + "Special Logic: What is 6200 + 7656?", + "Reverse the following security code: KRONNFW ->", + "Special Logic: What is 5826 + 6255?", + "Write a one-sentence story about a shiny robot in the jungle.", + "Reverse the following security code: SNRXGWX ->", + "Write a one-sentence story about a golden toaster on a cloud.", + "Special Logic: What is 5286 + 5653?", + "Write a one-sentence story about a brave cowboy in a time machine.", + "Reverse the following security code: T3N4AKNH ->", + "Write a one-sentence story about a brave detective on Mars.", +] + +MOE_BASE_MODEL_PATH = "Qwen/Qwen1.5-MoE-A2.7B" +MOE_LORA_PATH = "jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B" + + CI_LORA_MODELS = [ LoRAModelCase( base="meta-llama/Llama-3.1-8B-Instruct", diff --git a/test/registered/lora/test_lora_moe_tp_logprob_diff.py b/test/registered/lora/test_lora_moe_tp_logprob_diff.py new file mode 100644 index 000000000000..05a5c7b46d70 --- /dev/null +++ b/test/registered/lora/test_lora_moe_tp_logprob_diff.py @@ -0,0 +1,172 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import multiprocessing as mp +import unittest +from typing import Any, Dict, List + +import torch + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.lora_utils import ( + MOE_BASE_MODEL_PATH, + MOE_LORA_PATH, + MOE_LORA_TEST_PROMPTS, +) +from sglang.test.runners import SRTRunner +from sglang.test.test_utils import ( + DEFAULT_PORT_FOR_SRT_TEST_RUNNER, + CustomTestCase, + is_in_ci, +) + +register_cuda_ci( + est_time=200, + suite="stage-b-test-2-gpu-large", +) + +LOGPROB_THRESHOLD = 5e-04 +MAX_NEW_TOKENS = 10 + + +def _run_sglang_moe_lora( + tp_size: int, + prompts: List[str], + port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, +) -> Dict[str, Any]: + lora_paths_per_prompt = [MOE_LORA_PATH] * len(prompts) + + with SRTRunner( + model_path=MOE_BASE_MODEL_PATH, + torch_dtype=torch.bfloat16, + model_type="generation", + tp_size=tp_size, + lora_paths=[MOE_LORA_PATH], + max_loras_per_batch=1, + trust_remote_code=True, + disable_radix_cache=True, + port=port, + attention_backend="flashinfer", + mem_fraction_static=0.80, + ) as runner: + outputs = runner.forward( + prompts, + max_new_tokens=MAX_NEW_TOKENS, + lora_paths=lora_paths_per_prompt, + ) + + return { + "top_input_logprobs": outputs.top_input_logprobs, + "top_output_logprobs": outputs.top_output_logprobs, + "output_strs": outputs.output_strs, + } + + +class TestMoELoRATP2Logprobs(CustomTestCase): + """Compare TP=1 vs TP=2 MoE LoRA: output strings must match and logprobs + must stay within threshold.""" + + def _assert_tp_parity( + self, + prompts: List[str], + label: str, + ): + print(f"\n{'=' * 100}") + print(f" {label}: running TP=1") + print(f"{'=' * 100}") + + tp1 = _run_sglang_moe_lora(tp_size=1, prompts=prompts) + torch.cuda.empty_cache() + + print(f"\n{'=' * 100}") + print(f" {label}: running TP=2") + print(f"{'=' * 100}") + + tp2 = _run_sglang_moe_lora(tp_size=2, prompts=prompts) + + print(f"\n{'=' * 100}") + print( + f"{'ID':<4} | {'String':<8} | {'Decode Max Diff':<18} | " + f"{'Decode Mean Diff':<18} | {'Status':<8} | {'Output (TP1)'}" + ) + print("-" * 100) + + for i in range(len(prompts)): + tp1_str = tp1["output_strs"][i].strip() + tp2_str = tp2["output_strs"][i].strip() + + self.assertEqual( + tp1_str, + tp2_str, + f"Output string mismatch on prompt {i}: " + f"TP1='{tp1_str}' vs TP2='{tp2_str}'", + ) + + tp1_raw = tp1["top_output_logprobs"][i] + tp2_raw = tp2["top_output_logprobs"][i] + tp1_lps = torch.tensor( + [t[0] if isinstance(t, list) else t for t in tp1_raw] + ) + tp2_lps = torch.tensor( + [t[0] if isinstance(t, list) else t for t in tp2_raw] + ) + min_len = min(tp1_lps.shape[0], tp2_lps.shape[0]) + diff = torch.abs(tp1_lps[:min_len] - tp2_lps[:min_len]) + max_diff = torch.max(diff).item() if min_len > 0 else 0.0 + mean_diff = torch.mean(diff).item() if min_len > 0 else 0.0 + + status = "PASS" if max_diff < LOGPROB_THRESHOLD else "FAIL" + print( + f"{i:<4} | {'OK':<8} | {max_diff:<18.6e} | " + f"{mean_diff:<18.6e} | {status:<8} | {tp1_str[:40]}" + ) + + self.assertLessEqual( + max_diff, + LOGPROB_THRESHOLD, + f"Decode logprob diff too large on prompt {i}: " + f"max_diff={max_diff:.6e} > threshold={LOGPROB_THRESHOLD:.0e}", + ) + + print("=" * 100) + + def test_moe_lora_tp2_vs_tp1_basic(self): + """Basic TP=1 vs TP=2 parity with a small prompt set.""" + self._assert_tp_parity( + prompts=MOE_LORA_TEST_PROMPTS[:5], + label="MoE LoRA TP parity (basic)", + ) + + @unittest.skipIf(is_in_ci(), "Skipping full test in CI") + def test_moe_lora_tp2_vs_tp1_full(self): + """Full TP=1 vs TP=2 parity across all prompts.""" + self._assert_tp_parity( + prompts=MOE_LORA_TEST_PROMPTS, + label="MoE LoRA TP parity (full)", + ) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + try: + unittest.main(warnings="ignore", verbosity=2) + finally: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() diff --git a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py new file mode 100644 index 000000000000..6926f1d89a58 --- /dev/null +++ b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py @@ -0,0 +1,367 @@ +""" +Regression test for MoE LoRA parity between SGLang and vLLM. + +This test compares SGLang's logprobs and output strings against a hardcoded +baseline (VLLM_CACHED_RESULTS) generated using vLLM. It enforces strict +numerical accuracy by asserting that the maximum and mean logprob +divergences do not exceed the reference thresholds (REFERENCE_STATS). + +Usage: + python -m unittest test_lora_moe_vllm_sgl_logprob_diff.py + +""" + +import multiprocessing as mp +import unittest + +import torch + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.lora_utils import ( + MOE_BASE_MODEL_PATH, + MOE_LORA_PATH, + MOE_LORA_TEST_PROMPTS, +) +from sglang.test.runners import SRTRunner + +register_cuda_ci( + est_time=25, + suite="stage-b-test-1-gpu-large", +) + +# Format: [{"text": "result string", "lps": [0.1, 0.2, ...]}, ...] +VLLM_CACHED_RESULTS = [ + { + "text": " A0PURH0", + "lps": [ + -3.3378546504536644e-06, + -1.6331539882230572e-05, + -7.152555099310121e-07, + -5.054346183896996e-05, + -4.792098479811102e-05, + -3.302042750874534e-05, + ], + }, + { + "text": " The wild tree jumped at the cafe and found a", + "lps": [ + -9.417489309271332e-06, + -1.2636104656849056e-05, + -0.00018308870494365692, + -0.0006621075444854796, + -5.3165931603871286e-05, + -9.500529267825186e-05, + -0.0003022690652869642, + -6.9141146923357155e-06, + 0.0, + -8.22540732769994e-06, + ], + }, + { + "text": " 0SPG1V6L", + "lps": [ + -2.861018856492592e-06, + -6.8662193370983e-05, + -6.580135959666222e-05, + -5.6980417866725475e-05, + -8.916457591112703e-05, + -5.006777428206988e-06, + -1.8596476365928538e-05, + -2.396077979938127e-05, + -4.851700214203447e-05, + ], + }, + {"text": " Tango", "lps": [-5.960462772236497e-07, -9.536738616588991e-07]}, + {"text": " Tensor", "lps": [-0.0002002515539061278, -5.960462772236497e-07]}, + { + "text": " The slow cat coded in a simulation and found a", + "lps": [ + 0.0, + -4.672895011026412e-05, + -3.802703940891661e-05, + -3.1709168979432434e-05, + 0.0, + -2.145764938177308e-06, + -4.565611743601039e-05, + 0.0, + 0.0, + -2.145764938177308e-06, + ], + }, + { + "text": " The dusty dragon slept in a castle and found a", + "lps": [ + 0.0, + -3.290122185717337e-05, + -1.1444026313256472e-05, + -6.544376083184034e-05, + -8.344646857949556e-07, + -2.276871418871451e-05, + -2.1576648578047752e-05, + -5.960462772236497e-07, + 0.0, + -2.50339189733495e-06, + ], + }, + { + "text": " T4JDBF", + "lps": [ + -5.960462772236497e-07, + -3.4450891689630225e-05, + -1.1324817933200393e-05, + -1.6689160474925302e-05, + -0.00020013237372040749, + -3.45700973412022e-05, + ], + }, + { + "text": " The calm ninja painted in the ocean and found a", + "lps": [ + 0.0, + -3.731181277544238e-05, + -6.198863957251888e-06, + -3.576272320060525e-06, + -3.576278118089249e-07, + -3.814689989667386e-06, + -1.549708758830093e-05, + -1.1920928244535389e-07, + 0.0, + -4.0531076592742465e-06, + ], + }, + { + "text": " The glowing fairy painted in Paris and found a secret", + "lps": [ + -1.1920928244535389e-07, + -2.8132995794294402e-05, + -2.50339189733495e-06, + -4.446407547220588e-05, + -3.576278118089249e-07, + -8.201262971851975e-05, + -3.576278118089249e-07, + 0.0, + -4.0531076592742465e-06, + -3.4570634852570947e-06, + ], + }, + {"text": " Tensor", "lps": [-0.00014399446081370115, -2.622600959512056e-06]}, + { + "text": " WFNNORK", + "lps": [ + -0.0003231241717003286, + -3.71926071238704e-05, + -0.00011252723925281316, + -5.447716102935374e-05, + ], + }, + { + "text": " Whiskey", + "lps": [ + -5.531158240046352e-05, + -1.5497195136049413e-06, + -1.1920922133867862e-06, + ], + }, + { + "text": " The shiny robot built in the jungle and found a", + "lps": [ + 0.0, + -2.622600959512056e-06, + -5.018585216021165e-05, + -0.0015173362335190177, + 0.0, + -6.198863957251888e-06, + -0.00036769305006600916, + -1.1920928244535389e-07, + 0.0, + -3.099436753473128e-06, + ], + }, + { + "text": " XWGXRNS", + "lps": [ + -2.5629668016335927e-05, + -4.0531076592742465e-06, + -0.0001616347290109843, + -5.018585216021165e-05, + -0.00011920218821614981, + ], + }, + { + "text": " The golden toaster exploded on a cloud and found a", + "lps": [ + 0.0, + -8.630380034446716e-05, + 0.0, + -2.4676019165781327e-05, + -1.0728830375228426e-06, + -1.5497195136049413e-06, + -6.794906312279636e-06, + -4.887569048150908e-06, + 0.0, + -3.3378546504536644e-06, + ], + }, + { + "text": " Nebula", + "lps": [ + -4.410734163684538e-06, + -7.986990567587782e-06, + -1.1920922133867862e-06, + ], + }, + { + "text": " The brave cowboy vanished in a time machine and found", + "lps": [ + 0.0, + -8.475421054754406e-05, + -0.00011932138295378536, + -0.00016735584358684719, + -2.3841855067985307e-07, + -2.312633478140924e-05, + -6.5205356804654e-05, + -0.00014423283573705703, + -1.4305104514278355e-06, + 0.0, + ], + }, + { + "text": " HNKA4N3T", + "lps": [ + -2.50339189733495e-06, + -1.1920928244535389e-07, + -5.006777428206988e-06, + -7.390948667307384e-06, + -0.00014327930693980306, + -2.3841855067985307e-07, + -0.00011062010162277147, + -1.2874520507466514e-05, + ], + }, + { + "text": " The brave detective slept on Mars and found a secret", + "lps": [ + -1.7881377516459906e-06, + -1.9788545614574105e-05, + -1.883488948806189e-05, + -1.4781842764932662e-05, + -3.576278118089249e-07, + -1.2755313036905136e-05, + -5.960462772236497e-07, + 0.0, + -4.0531076592742465e-06, + -1.5497195136049413e-06, + ], + }, +] +# --------------------------------- + + +# Hardcoded reference stats from successful run. Corresponds to prompts below. +REFERENCE_STATS = { + 0: {"max": 9.29792076931335e-06, "mean": 2.8410576836298182e-06}, + 1: {"max": 1.3818731531500816e-05, "mean": 3.753847045118164e-06}, + 2: {"max": 1.1205123882973567e-05, "mean": 2.410548404441215e-06}, + 3: {"max": 1.1920923270736239e-07, "mean": 1.1920920428565296e-07}, + 4: {"max": 1.0011601261794567e-05, "mean": 5.065405247250965e-06}, + 5: {"max": 5.602585588349029e-06, "mean": 1.6569420949963388e-06}, + 6: {"max": 2.9801594791933894e-06, "mean": 8.702030129370542e-07}, + 7: {"max": 1.6685822629369795e-05, "mean": 4.608787548932014e-06}, + 8: {"max": 2.384102117503062e-06, "mean": 5.721932211599778e-07}, + 9: {"max": 1.704567694105208e-05, "mean": 1.9787427085304897e-06}, + 10: {"max": 1.2515258276835084e-05, "mean": 6.37683808690781e-06}, + 11: {"max": 1.4900237147230655e-05, "mean": 1.0101463885803241e-05}, + 12: {"max": 1.6688391042407602e-06, "mean": 5.960160933682346e-07}, + 13: {"max": 9.04605258256197e-06, "mean": 1.2144706943217897e-06}, + 14: {"max": 2.181154559366405e-05, "mean": 6.102668112362153e-06}, + 15: {"max": 5.602370947599411e-06, "mean": 6.07920344464219e-07}, + 16: {"max": 2.2649692255072296e-06, "mean": 7.549897418357432e-07}, + 17: {"max": 1.990482269320637e-05, "mean": 3.3731695992855747e-06}, + 18: {"max": 1.6567864804528654e-05, "mean": 3.307691372356203e-06}, + 19: {"max": 2.5033668862306513e-06, "mean": 3.3378251487192754e-07}, +} + + +class TestMoELoraRegression(unittest.TestCase): + + def test_sglang_moe_parity_strict(self): + + with SRTRunner( + model_path=MOE_BASE_MODEL_PATH, + torch_dtype=torch.bfloat16, + model_type="generation", + lora_paths=[MOE_LORA_PATH], + max_loras_per_batch=1, + tp_size=1, + trust_remote_code=True, + disable_radix_cache=True, + attention_backend="flashinfer", + mem_fraction_static=0.80, + ) as srt_runner: + + srt_outputs = srt_runner.forward( + MOE_LORA_TEST_PROMPTS, + max_new_tokens=10, + lora_paths=[MOE_LORA_PATH] * len(MOE_LORA_TEST_PROMPTS), + ) + + print("\n" + "=" * 140) + print( + f"{'ID':<4} | {'Max Diff':<12} | {'Mean Diff':<12} | {'Status':<8} | {'Prompt'}" + ) + print("-" * 140) + + for i, prompt in enumerate(MOE_LORA_TEST_PROMPTS): + v_data = VLLM_CACHED_RESULTS[i] + v_lps = v_data["lps"] + v_text = v_data["text"].strip() + + s_lps_raw = srt_outputs.top_output_logprobs[i] + s_lps = [ + float(token[0]) if isinstance(token, list) else float(token) + for token in s_lps_raw + ] + s_text = srt_outputs.output_strs[i].strip() + + # Calculate actual stats + min_len = min(len(v_lps), len(s_lps)) + diffs = [abs(v_lps[t] - s_lps[t]) for t in range(min_len)] + + actual_max = max(diffs) if diffs else 0.0 + actual_mean = sum(diffs) / len(diffs) if diffs else 0.0 + + ref = REFERENCE_STATS[i] + # Epsilon to allow room for different, but correct, implementations + eps = 1e-4 + + # Assertions + self.assertEqual(v_text, s_text, f"String mismatch on prompt {i}") + self.assertLessEqual( + actual_max, ref["max"] + eps, f"Max LogProb Diff exceeded on prompt {i}" + ) + self.assertLessEqual( + actual_mean, + ref["mean"] + eps, + f"Mean LogProb Diff exceeded on prompt {i}", + ) + + print( + f"{i:<4} | {actual_max:<12.6f} | {actual_mean:<12.6f} | {'✅ PASS':<8} | {prompt}" + ) + + print("=" * 140) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + try: + unittest.main(warnings="ignore", verbosity=2) + finally: + # Final cleanup + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize()