Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions python/sglang/srt/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,6 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor
return lora_output

def forward(self, input_: torch.Tensor, skip_all_reduce=False):
# duplicate the logic in RowParallelLinear
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
Expand All @@ -638,8 +637,14 @@ def forward(self, input_: torch.Tensor, skip_all_reduce=False):
input_, num_partitions=self.base_layer.tp_size
)
input_parallel = splitted_input[tp_rank].contiguous()

bias_ = (
None
if (self.base_layer.tp_rank > 0 or self.base_layer.skip_bias_add)
else self.base_layer.bias
)
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_parallel
self.base_layer, input_parallel, bias=bias_
)

should_reduce = (
Expand Down Expand Up @@ -668,17 +673,8 @@ def forward(self, input_: torch.Tensor, skip_all_reduce=False):
else:
output_ = output_parallel

if not self.base_layer.skip_bias_add:
output = (
output_ + self.base_layer.bias
if self.base_layer.bias is not None
else output_
)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
return output, output_bias
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
return output_, output_bias

def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
shard_size = self.base_layer.input_size_per_partition
Expand Down Expand Up @@ -719,6 +715,9 @@ def __init__(
self.intermediate_size_per_partition = getattr(
base_layer, "intermediate_size_per_partition", None
)
self._uses_interleaved_gate_up = (
getattr(base_layer.moe_runner_config, "gemm1_alpha", None) is not None
)

# initialize triton_lora moe runner for batches with lora enabled
from sglang.srt.layers.moe.moe_runner.runner import MoeRunner
Expand Down Expand Up @@ -895,7 +894,10 @@ def slice_moe_lora_b_weights(
gate_up_proj_moe B: [intermediate_size*2, rank] — output matches sharded base w13
down_proj_moe B: [hidden_size, rank] — output is all-reduced, no slice
"""
if self.tp_size <= 1:
needs_processing = (self.tp_size > 1) or (
target_module == "gate_up_proj_moe" and self._uses_interleaved_gate_up
)
if not needs_processing:
return B
if target_module != "gate_up_proj_moe":
return B
Expand Down Expand Up @@ -923,6 +925,8 @@ def _slice_moe_b_2d(
full_inter = B.shape[0] // 2
gate_b = B[start:end, :]
up_b = B[full_inter + start : full_inter + end, :]
if self._uses_interleaved_gate_up:
return torch.stack([gate_b, up_b], dim=1).reshape(-1, B.shape[-1])
return torch.cat([gate_b, up_b], dim=0).contiguous()
return B

Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/lora/mem_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def init_buffer(
# MoE expert version (4D)
moe_key = f"{module_name}_moe"
buffer[moe_key] = [
torch.empty(
torch.zeros(
get_lora_shape_fn(
moe_key, base_model, self.max_lora_rank, idx
),
Expand All @@ -327,7 +327,7 @@ def init_buffer(
else:
# Standard allocation for unambiguous modules
buffer[module_name] = [
torch.empty(
torch.zeros(
get_lora_shape_fn(
module_name,
base_model,
Expand All @@ -347,7 +347,7 @@ def init_embedding_buffer(
):
target_modules = target_modules & set(EMBEDDING_NAMES)
for module_name in target_modules:
buffer[module_name] = torch.empty(
buffer[module_name] = torch.zeros(
get_lora_shape_fn(
module_name,
base_model,
Expand All @@ -359,7 +359,7 @@ def init_embedding_buffer(
)

if self.lora_added_tokens_size > 0:
self.new_embeddings_buffer["input_embeddings"] = torch.empty(
self.new_embeddings_buffer["input_embeddings"] = torch.zeros(
(
self.max_loras_per_batch,
self.lora_added_tokens_size,
Expand Down
12 changes: 10 additions & 2 deletions python/sglang/srt/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,17 @@ def get_hidden_dim(
elif module_name == "down_proj":
return config.intermediate_size, config.hidden_size
elif module_name == "gate_up_proj_moe":
return config.hidden_size, config.moe_intermediate_size * 2
moe_inter = (
getattr(config, "moe_intermediate_size", None)
or config.intermediate_size
)
return config.hidden_size, moe_inter * 2
elif module_name == "down_proj_moe":
return config.moe_intermediate_size, config.hidden_size
moe_inter = (
getattr(config, "moe_intermediate_size", None)
or config.intermediate_size
)
return moe_inter, config.hidden_size
elif module_name == "embed_tokens":
# For embedding: input is vocab_size (as embedding lookup), output is hidden_size
# if contain extra tokens will be added; otherwise is 0.
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import logging
import math
import re
from collections.abc import Iterable
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -651,6 +652,13 @@ def forward(
class GptOssForCausalLM(nn.Module):
fall_back_to_pt_during_load = False

_lora_pattern_moe = re.compile(
r"^(?:model\.layers\.\d+\.(?:self_attn\.(?:qkv_proj|o_proj)|mlp\.experts)|lm_head|model\.embed_tokens)$"
)

def should_apply_lora(self, module_name: str) -> bool:
return bool(self._lora_pattern_moe.match(module_name))

def __init__(
self,
config: GptOssConfig,
Expand Down
151 changes: 151 additions & 0 deletions test/registered/lora/test_lora_gpt_oss_20b_logprob_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""
Regression test for gpt-oss-20b LoRA logprob accuracy.

Compares SGLang LoRA logprobs against reference training logprobs from a
pre-computed dataset. The LoRA adapter and reference data are downloaded from:
https://huggingface.co/datasets/yushengsu/lora-diff-gpt-oss-20b

Usage:
python -m unittest test_lora_gpt_oss_20b_logprob_diff
"""

import multiprocessing as mp
import os
import unittest

import torch
from huggingface_hub import snapshot_download

import sglang as sgl
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.test_utils import CustomTestCase

register_cuda_ci(
est_time=300,
suite="stage-c-test-4-gpu-b200",
)

BASE_MODEL = "lmsys/gpt-oss-20b-bf16"
LORA_HF_REPO = "yushengsu/lora-diff-gpt-oss-20b"
LORA_BACKEND = "triton"
MAX_LORA_RANK = 32
TP_SIZE = 4
DISABLE_CUDA_GRAPH = True
MOE_RUNNER_BACKEND = "triton"
EXPERTS_SHARED_OUTER_LORAS = True
PREFILL_ATTENTION_BACKEND = "fa4"
DECODE_ATTENTION_BACKEND = "fa4"

KL_THRESHOLD = 5e-3


def kl_v2(a, b):
a = torch.tensor(a) if not torch.is_tensor(a) else a
b = torch.tensor(b) if not torch.is_tensor(b) else b
return (((a - b) ** 2) * 0.5).mean().item()


def get_prompt_logprobs(engine, input_ids, lora_path):
out = engine.generate(
input_ids=input_ids,
sampling_params={"max_new_tokens": 0, "temperature": 0.0},
return_logprob=True,
logprob_start_len=0,
lora_path=lora_path,
)
return [logprob for logprob, _, _ in out["meta_info"]["input_token_logprobs"]][1:]


class TestLoRAGptOss20BLogprobDiff(CustomTestCase):

def test_lora_gpt_oss_20b_logprob_accuracy(self):
adapter_path = snapshot_download(
LORA_HF_REPO,
repo_type="dataset",
)

engine = sgl.Engine(
model_path=BASE_MODEL,
tp_size=TP_SIZE,
enable_lora=True,
max_lora_rank=MAX_LORA_RANK,
lora_paths={"my_lora": adapter_path},
lora_backend=LORA_BACKEND,
attention_backend="flashinfer",
disable_cuda_graph=DISABLE_CUDA_GRAPH,
moe_runner_backend=MOE_RUNNER_BACKEND,
experts_shared_outer_loras=EXPERTS_SHARED_OUTER_LORAS,
prefill_attention_backend=PREFILL_ATTENTION_BACKEND,
decode_attention_backend=DECODE_ATTENTION_BACKEND,
)

try:
cdata = torch.load(
os.path.join(adapter_path, "compare_sample_train_data.pt"),
weights_only=False,
)

base_logprobs = get_prompt_logprobs(engine, cdata["tokens"], lora_path=None)
logprobs = get_prompt_logprobs(engine, cdata["tokens"], lora_path="my_lora")

base_t = torch.tensor(base_logprobs)
lora_t = torch.tensor(logprobs)
diff = (base_t - lora_t).abs()
print(
f"[VERIFY] base vs lora: mean_diff={diff.mean().item():.6f}, "
f"max_diff={diff.max().item():.6f}, "
f"identical={torch.equal(base_t, lora_t)}"
)

self.assertFalse(
torch.equal(base_t, lora_t),
"LoRA logprobs should differ from base model logprobs",
)

kl_sglang_trainer = kl_v2(cdata["training_logprobs"], logprobs)
kl_orig_trainer = kl_v2(
cdata["training_logprobs"], cdata["sampling_logprobs"]
)
kl_sglang_orig = kl_v2(logprobs, cdata["sampling_logprobs"])

print(f"KL(orig_sampler, trainer) = {kl_orig_trainer:.6e}")
print(f"KL(sglang, trainer) = {kl_sglang_trainer:.6e}")
print(f"KL(sglang, orig_sampler) = {kl_sglang_orig:.6e}")

self.assertLessEqual(
kl_sglang_trainer,
KL_THRESHOLD,
f"KL(sglang, trainer) = {kl_sglang_trainer:.6e} exceeds "
f"threshold {KL_THRESHOLD}",
)

finally:
engine.shutdown()


if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass

try:
unittest.main(warnings="ignore", verbosity=2)
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
6 changes: 3 additions & 3 deletions test/registered/lora/test_lora_qwen3_8b_logprob_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@
MAX_LORA_RANK = 32
TP_SIZE = 1
DISABLE_CUDA_GRAPH = True
PREFILL_ATTENTION_BACKEND = "fa3"
DECODE_ATTENTION_BACKEND = "fa3"
PREFILL_ATTENTION_BACKEND = "fa4"
DECODE_ATTENTION_BACKEND = "fa4"

KL_THRESHOLD = 1e-2
KL_THRESHOLD = 5e-3


def kl_v2(a, b):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
PREFILL_ATTENTION_BACKEND = "fa4"
DECODE_ATTENTION_BACKEND = "fa4"

KL_THRESHOLD = 1e-2
KL_THRESHOLD = 5e-3


def kl_v2(a, b):
Expand Down
Loading