Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2766611
support lora emb: --disable-cuda-graph, without extra token, no tp
yushengsu-thu Nov 25, 2025
04292bd
update
yushengsu-thu Nov 26, 2025
377fd2d
refactor layers.py --> VocabParallelEmbeddingWithLoRA --> _run_lora_a…
yushengsu-thu Nov 26, 2025
67452af
refactor
yushengsu-thu Nov 26, 2025
c90f75f
finish vocab_emb support and stii need to fix lm_head
yushengsu-thu Nov 28, 2025
746b641
need to fix 1. lm_head 2. cuda-graph
yushengsu-thu Nov 29, 2025
6332773
fixed lm_head issue
yushengsu-thu Nov 30, 2025
53ab0d5
vocab_emb without cuda-graph version
yushengsu-thu Dec 1, 2025
20b368e
support cuda-graph (triton backend)
yushengsu-thu Dec 2, 2025
f944d95
support cuda and no-cuda version; tokenizer (it added/extra tokens) s…
yushengsu-thu Dec 4, 2025
9041521
cleaned code ([to-do] 1. TP support in mem_pool.py and layer.py; 2. t…
yushengsu-thu Dec 4, 2025
261f96b
merge
yushengsu-thu Dec 5, 2025
8119daf
update
yushengsu-thu Dec 5, 2025
b33c5e0
fix lora/layer.py
yushengsu-thu Dec 5, 2025
10a099f
remove comments
yushengsu-thu Dec 5, 2025
18a8f5b
remove chunked_backend and remove extra_tokens support temporarily
yushengsu-thu Dec 6, 2025
2ad53c5
Merge remote-tracking branch 'upstream/main' into sglang-lora-emb-dev
yushengsu-thu Dec 6, 2025
e10980d
fix CI/CD
yushengsu-thu Dec 8, 2025
4a67684
Merge remote-tracking branch 'upstream/main' into sglang-lora-emb-dev
yushengsu-thu Dec 8, 2025
1610ce3
merge
yushengsu-thu Dec 9, 2025
2ea8d2a
Merge remote-tracking branch 'upstream/main' into sglang-lora-emb-dev
yushengsu-thu Dec 9, 2025
f09852e
add nightly ci/cd: test_lora_hf_sgl_logprob_diff.py
yushengsu-thu Dec 9, 2025
9529fa2
Merge remote-tracking branch 'upstream/main' into sglang-lora-emb-dev
yushengsu-thu Dec 9, 2025
8d8ba29
update nightly ci/cd
yushengsu-thu Dec 9, 2025
393f73a
Merge branch 'main' into sglang-lora-emb-dev
Fridge003 Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,10 @@ def _get_logits(
)
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)

if hasattr(lm_head, "weight"):
if hasattr(lm_head, "set_lora") and hasattr(lm_head, "apply_lora"):
# This is a LoRA-wrapped module, use its forward method
logits = lm_head(hidden_states)
elif hasattr(lm_head, "weight"):
if self.use_fp32_lm_head:
logits = torch.matmul(
hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
Expand Down
46 changes: 46 additions & 0 deletions python/sglang/srt/lora/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,52 @@ def __init__(self, max_loras_per_batch: int, device: torch.device):
self.max_loras_per_batch = max_loras_per_batch
self.device = device

def run_lora_a_embedding(
self,
input_ids: torch.Tensor,
weights: torch.Tensor,
vocab_size: int,
extra_embeddings: torch.Tensor = None,
*args,
**kwargs,
) -> torch.Tensor:
"""Run LoRA A embedding lookup with CUDA graph support.

Args:
input_ids: token IDs with shape (s,), where s is the sum of all sequence lengths
weights: LoRA A embedding weights with shape (num_loras, rank, vocab_size)
vocab_size: base vocabulary size (tokens >= vocab_size are extra tokens)
extra_embeddings: extra token embeddings with shape (num_loras, num_extra_tokens, rank)
Only needed if there are added tokens beyond base vocabulary.

Returns:
result with shape (s, rank)
"""
pass

def run_extra_token_embedding(
self,
input_ids: torch.Tensor,
output: torch.Tensor,
extra_embeddings: torch.Tensor,
vocab_size: int,
*args,
**kwargs,
) -> torch.Tensor:
"""
Apply extra token embeddings to output in-place.

Args:
input_ids: (s,) token IDs
output: (s, embed_dim) output tensor to be modified
extra_embeddings: (num_loras, num_extra_tokens, embed_dim) extra embeddings
vocab_size: base vocabulary size

Returns:
output: modified output tensor
"""
raise NotImplementedError

def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
Expand Down
23 changes: 21 additions & 2 deletions python/sglang/srt/lora/backend/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.triton_ops import (
embedding_lora_a_fwd,
gate_up_lora_b_fwd,
qkv_lora_b_fwd,
sgemm_lora_a_fwd,
Expand All @@ -22,6 +23,24 @@ def __init__(
):
super().__init__(max_loras_per_batch, device)

def run_lora_a_embedding(
self,
input_ids: torch.Tensor,
weights: torch.Tensor,
vocab_size: int,
extra_embeddings: torch.Tensor = None,
*args,
**kwargs,
) -> torch.Tensor:
"""Run LoRA A embedding lookup using Triton kernel."""
return embedding_lora_a_fwd(
input_ids=input_ids,
weights=weights,
batch_info=self.batch_info,
vocab_size=vocab_size,
extra_embeddings=extra_embeddings,
)

def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
Expand Down Expand Up @@ -107,7 +126,7 @@ def init_cuda_graph_batch_info(
seg_lens=torch.full(
(max_bs_in_cuda_graph,), num_tokens_per_bs, dtype=torch.int32
),
seg_indptr=torch.empty(max_bs_in_cuda_graph + 1, dtype=torch.int32),
seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32),
max_len=num_tokens_per_bs,
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
Expand Down Expand Up @@ -161,7 +180,7 @@ def prepare_lora_batch(
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device=self.device)
else torch.ones(bs, dtype=torch.int32, device=self.device)
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
Expand Down
Loading
Loading