diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index b6cfb3aee4cc..9f2ad1352805 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -629,7 +629,6 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor return lora_output def forward(self, input_: torch.Tensor, skip_all_reduce=False): - # duplicate the logic in RowParallelLinear if self.base_layer.input_is_parallel: input_parallel = input_ else: @@ -638,8 +637,14 @@ def forward(self, input_: torch.Tensor, skip_all_reduce=False): input_, num_partitions=self.base_layer.tp_size ) input_parallel = splitted_input[tp_rank].contiguous() + + bias_ = ( + None + if (self.base_layer.tp_rank > 0 or self.base_layer.skip_bias_add) + else self.base_layer.bias + ) output_parallel = self.base_layer.quant_method.apply( - self.base_layer, input_parallel + self.base_layer, input_parallel, bias=bias_ ) should_reduce = ( @@ -668,17 +673,8 @@ def forward(self, input_: torch.Tensor, skip_all_reduce=False): else: output_ = output_parallel - if not self.base_layer.skip_bias_add: - output = ( - output_ + self.base_layer.bias - if self.base_layer.bias is not None - else output_ - ) - output_bias = None - else: - output = output_ - output_bias = self.base_layer.bias - return output, output_bias + output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None + return output_, output_bias def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): shard_size = self.base_layer.input_size_per_partition @@ -719,6 +715,9 @@ def __init__( self.intermediate_size_per_partition = getattr( base_layer, "intermediate_size_per_partition", None ) + self._uses_interleaved_gate_up = ( + getattr(base_layer.moe_runner_config, "gemm1_alpha", None) is not None + ) # initialize triton_lora moe runner for batches with lora enabled from sglang.srt.layers.moe.moe_runner.runner import MoeRunner @@ -895,7 +894,10 @@ def slice_moe_lora_b_weights( gate_up_proj_moe B: [intermediate_size*2, rank] — output matches sharded base w13 down_proj_moe B: [hidden_size, rank] — output is all-reduced, no slice """ - if self.tp_size <= 1: + needs_processing = (self.tp_size > 1) or ( + target_module == "gate_up_proj_moe" and self._uses_interleaved_gate_up + ) + if not needs_processing: return B if target_module != "gate_up_proj_moe": return B @@ -923,6 +925,8 @@ def _slice_moe_b_2d( full_inter = B.shape[0] // 2 gate_b = B[start:end, :] up_b = B[full_inter + start : full_inter + end, :] + if self._uses_interleaved_gate_up: + return torch.stack([gate_b, up_b], dim=1).reshape(-1, B.shape[-1]) return torch.cat([gate_b, up_b], dim=0).contiguous() return B diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index cfcc6e4a0bfc..3746692529f2 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -315,7 +315,7 @@ def init_buffer( # MoE expert version (4D) moe_key = f"{module_name}_moe" buffer[moe_key] = [ - torch.empty( + torch.zeros( get_lora_shape_fn( moe_key, base_model, self.max_lora_rank, idx ), @@ -327,7 +327,7 @@ def init_buffer( else: # Standard allocation for unambiguous modules buffer[module_name] = [ - torch.empty( + torch.zeros( get_lora_shape_fn( module_name, base_model, @@ -347,7 +347,7 @@ def init_embedding_buffer( ): target_modules = target_modules & set(EMBEDDING_NAMES) for module_name in target_modules: - buffer[module_name] = torch.empty( + buffer[module_name] = torch.zeros( get_lora_shape_fn( module_name, base_model, @@ -359,7 +359,7 @@ def init_embedding_buffer( ) if self.lora_added_tokens_size > 0: - self.new_embeddings_buffer["input_embeddings"] = torch.empty( + self.new_embeddings_buffer["input_embeddings"] = torch.zeros( ( self.max_loras_per_batch, self.lora_added_tokens_size, diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 1a12b1bd632d..a5d56c479502 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -88,9 +88,17 @@ def get_hidden_dim( elif module_name == "down_proj": return config.intermediate_size, config.hidden_size elif module_name == "gate_up_proj_moe": - return config.hidden_size, config.moe_intermediate_size * 2 + moe_inter = ( + getattr(config, "moe_intermediate_size", None) + or config.intermediate_size + ) + return config.hidden_size, moe_inter * 2 elif module_name == "down_proj_moe": - return config.moe_intermediate_size, config.hidden_size + moe_inter = ( + getattr(config, "moe_intermediate_size", None) + or config.intermediate_size + ) + return moe_inter, config.hidden_size elif module_name == "embed_tokens": # For embedding: input is vocab_size (as embedding lookup), output is hidden_size # if contain extra tokens will be added; otherwise is 0. diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 593ef4b9f932..c2e1ac233028 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -17,6 +17,7 @@ import logging import math +import re from collections.abc import Iterable from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union @@ -651,6 +652,13 @@ def forward( class GptOssForCausalLM(nn.Module): fall_back_to_pt_during_load = False + _lora_pattern_moe = re.compile( + r"^(?:model\.layers\.\d+\.(?:self_attn\.(?:qkv_proj|o_proj)|mlp\.experts)|lm_head|model\.embed_tokens)$" + ) + + def should_apply_lora(self, module_name: str) -> bool: + return bool(self._lora_pattern_moe.match(module_name)) + def __init__( self, config: GptOssConfig, diff --git a/test/registered/lora/test_lora_gpt_oss_20b_logprob_diff.py b/test/registered/lora/test_lora_gpt_oss_20b_logprob_diff.py new file mode 100644 index 000000000000..e3a5e9dd6c4b --- /dev/null +++ b/test/registered/lora/test_lora_gpt_oss_20b_logprob_diff.py @@ -0,0 +1,151 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Regression test for gpt-oss-20b LoRA logprob accuracy. + +Compares SGLang LoRA logprobs against reference training logprobs from a +pre-computed dataset. The LoRA adapter and reference data are downloaded from: +https://huggingface.co/datasets/yushengsu/lora-diff-gpt-oss-20b + +Usage: + python -m unittest test_lora_gpt_oss_20b_logprob_diff +""" + +import multiprocessing as mp +import os +import unittest + +import torch +from huggingface_hub import snapshot_download + +import sglang as sgl +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci( + est_time=300, + suite="stage-c-test-4-gpu-b200", +) + +BASE_MODEL = "lmsys/gpt-oss-20b-bf16" +LORA_HF_REPO = "yushengsu/lora-diff-gpt-oss-20b" +LORA_BACKEND = "triton" +MAX_LORA_RANK = 32 +TP_SIZE = 4 +DISABLE_CUDA_GRAPH = True +MOE_RUNNER_BACKEND = "triton" +EXPERTS_SHARED_OUTER_LORAS = True +PREFILL_ATTENTION_BACKEND = "fa4" +DECODE_ATTENTION_BACKEND = "fa4" + +KL_THRESHOLD = 5e-3 + + +def kl_v2(a, b): + a = torch.tensor(a) if not torch.is_tensor(a) else a + b = torch.tensor(b) if not torch.is_tensor(b) else b + return (((a - b) ** 2) * 0.5).mean().item() + + +def get_prompt_logprobs(engine, input_ids, lora_path): + out = engine.generate( + input_ids=input_ids, + sampling_params={"max_new_tokens": 0, "temperature": 0.0}, + return_logprob=True, + logprob_start_len=0, + lora_path=lora_path, + ) + return [logprob for logprob, _, _ in out["meta_info"]["input_token_logprobs"]][1:] + + +class TestLoRAGptOss20BLogprobDiff(CustomTestCase): + + def test_lora_gpt_oss_20b_logprob_accuracy(self): + adapter_path = snapshot_download( + LORA_HF_REPO, + repo_type="dataset", + ) + + engine = sgl.Engine( + model_path=BASE_MODEL, + tp_size=TP_SIZE, + enable_lora=True, + max_lora_rank=MAX_LORA_RANK, + lora_paths={"my_lora": adapter_path}, + lora_backend=LORA_BACKEND, + attention_backend="flashinfer", + disable_cuda_graph=DISABLE_CUDA_GRAPH, + moe_runner_backend=MOE_RUNNER_BACKEND, + experts_shared_outer_loras=EXPERTS_SHARED_OUTER_LORAS, + prefill_attention_backend=PREFILL_ATTENTION_BACKEND, + decode_attention_backend=DECODE_ATTENTION_BACKEND, + ) + + try: + cdata = torch.load( + os.path.join(adapter_path, "compare_sample_train_data.pt"), + weights_only=False, + ) + + base_logprobs = get_prompt_logprobs(engine, cdata["tokens"], lora_path=None) + logprobs = get_prompt_logprobs(engine, cdata["tokens"], lora_path="my_lora") + + base_t = torch.tensor(base_logprobs) + lora_t = torch.tensor(logprobs) + diff = (base_t - lora_t).abs() + print( + f"[VERIFY] base vs lora: mean_diff={diff.mean().item():.6f}, " + f"max_diff={diff.max().item():.6f}, " + f"identical={torch.equal(base_t, lora_t)}" + ) + + self.assertFalse( + torch.equal(base_t, lora_t), + "LoRA logprobs should differ from base model logprobs", + ) + + kl_sglang_trainer = kl_v2(cdata["training_logprobs"], logprobs) + kl_orig_trainer = kl_v2( + cdata["training_logprobs"], cdata["sampling_logprobs"] + ) + kl_sglang_orig = kl_v2(logprobs, cdata["sampling_logprobs"]) + + print(f"KL(orig_sampler, trainer) = {kl_orig_trainer:.6e}") + print(f"KL(sglang, trainer) = {kl_sglang_trainer:.6e}") + print(f"KL(sglang, orig_sampler) = {kl_sglang_orig:.6e}") + + self.assertLessEqual( + kl_sglang_trainer, + KL_THRESHOLD, + f"KL(sglang, trainer) = {kl_sglang_trainer:.6e} exceeds " + f"threshold {KL_THRESHOLD}", + ) + + finally: + engine.shutdown() + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + try: + unittest.main(warnings="ignore", verbosity=2) + finally: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() diff --git a/test/registered/lora/test_lora_qwen3_8b_logprob_diff.py b/test/registered/lora/test_lora_qwen3_8b_logprob_diff.py index b21638ba562d..c2b9039a2ccf 100644 --- a/test/registered/lora/test_lora_qwen3_8b_logprob_diff.py +++ b/test/registered/lora/test_lora_qwen3_8b_logprob_diff.py @@ -48,10 +48,10 @@ MAX_LORA_RANK = 32 TP_SIZE = 1 DISABLE_CUDA_GRAPH = True -PREFILL_ATTENTION_BACKEND = "fa3" -DECODE_ATTENTION_BACKEND = "fa3" +PREFILL_ATTENTION_BACKEND = "fa4" +DECODE_ATTENTION_BACKEND = "fa4" -KL_THRESHOLD = 1e-2 +KL_THRESHOLD = 5e-3 def kl_v2(a, b): diff --git a/test/registered/lora/test_lora_qwen3_vl_30b_a3b_instruct_logprob_diff.py b/test/registered/lora/test_lora_qwen3_vl_30b_a3b_instruct_logprob_diff.py index fc82dff71a9c..176d16919cd6 100644 --- a/test/registered/lora/test_lora_qwen3_vl_30b_a3b_instruct_logprob_diff.py +++ b/test/registered/lora/test_lora_qwen3_vl_30b_a3b_instruct_logprob_diff.py @@ -50,7 +50,7 @@ PREFILL_ATTENTION_BACKEND = "fa4" DECODE_ATTENTION_BACKEND = "fa4" -KL_THRESHOLD = 1e-2 +KL_THRESHOLD = 5e-3 def kl_v2(a, b):