Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/sglang/srt/lora/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,17 @@ def init_cuda_graph_batch_info(
self,
max_bs_in_cuda_graph: int,
num_tokens_per_bs: int,
has_embedding_layers: bool = False,
):
"""Initialize the batch info for CUDA Graph mode.

This method provides a hook for each backend to conduct its own initialization
logic for CUDA Graph mode.

Args:
cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager
max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode
num_tokens_per_bs: number of tokens per sequence (1 for decoding, >1 for target_verify)
has_embedding_layers: whether target_modules includes embedding layers (embed_tokens/lm_head)
"""
pass

Expand Down
119 changes: 119 additions & 0 deletions python/sglang/srt/lora/backend/chunked_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sglang.srt.lora.triton_ops import (
chunked_sgmv_lora_expand_forward,
chunked_sgmv_lora_shrink_forward,
embedding_lora_a_fwd,
)
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
Expand Down Expand Up @@ -32,6 +33,7 @@ def __init__(
):
super().__init__(max_loras_per_batch, device)
self.max_chunk_size = server_args.max_lora_chunk_size
self.has_embedding_layers = False # Will be set by manager

def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
Expand Down Expand Up @@ -64,6 +66,28 @@ def run_lora_b_sgemm(
base_output=base_output,
)

def run_lora_a_embedding(
self,
input_ids: torch.Tensor,
weights: torch.Tensor,
vocab_size: int,
extra_embeddings: torch.Tensor = None,
*args,
**kwargs,
) -> torch.Tensor:
"""Run LoRA A embedding lookup using Triton kernel.

Uses embedding_batch_info which maintains original sequence structure
(not the chunked/reordered structure used for linear layers).
"""
return embedding_lora_a_fwd(
input_ids=input_ids,
weights=weights,
batch_info=self.embedding_batch_info,
vocab_size=vocab_size,
extra_embeddings=extra_embeddings,
)

def run_qkv_lora(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -162,7 +186,9 @@ def init_cuda_graph_batch_info(
self,
max_bs_in_cuda_graph: int,
num_tokens_per_bs: int,
has_embedding_layers: bool = False,
):
self.has_embedding_layers = has_embedding_layers
max_num_segments = (
(num_tokens_per_bs + MIN_CHUNK_SIZE - 1) // MIN_CHUNK_SIZE
) * max_bs_in_cuda_graph
Expand All @@ -181,6 +207,36 @@ def init_cuda_graph_batch_info(
max_len=None, # Not used in CSGMV backend
)

# TODO: The embedding_batch_info will be removed after the chunked kernel
# for embedding has been implemented. This is currently a workaround to
# make embedding run with the non-chunked Triton kernel.
if has_embedding_layers:
# Create embedding-specific batch info (uses original sequence structure)
self.cuda_graph_embedding_batch_info = LoRABatchInfo(
bs=max_bs_in_cuda_graph,
use_cuda_graph=True,
num_segments=max_bs_in_cuda_graph,
seg_lens=torch.full(
(max_bs_in_cuda_graph,), num_tokens_per_bs, dtype=torch.int32
),
seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32),
max_len=num_tokens_per_bs,
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
permutation=None,
)
# Initialize seg_indptr for embedding CUDA graph
torch.cumsum(
self.cuda_graph_embedding_batch_info.seg_lens[
:max_bs_in_cuda_graph
],
dim=0,
out=self.cuda_graph_embedding_batch_info.seg_indptr[
1 : max_bs_in_cuda_graph + 1
],
)

def prepare_lora_batch(
self,
forward_batch: ForwardBatch,
Expand Down Expand Up @@ -254,6 +310,69 @@ def prepare_lora_batch(

self.batch_info = batch_info

# Setup embedding_batch_info (uses original sequence structure, not chunked)
# Only needed when target_modules includes embedding layers (embed_tokens/lm_head)
if self.has_embedding_layers:
bs = forward_batch.batch_size
if use_cuda_graph:
embedding_batch_info = self.cuda_graph_embedding_batch_info
embedding_batch_info.bs = bs
embedding_batch_info.num_segments = bs
else:
emb_max_len = (
max(forward_batch.extend_seq_lens_cpu)
if forward_batch.forward_mode.is_extend()
else 1
)
emb_seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, dtype=torch.int32, device=self.device)
)
emb_seg_indptr = torch.zeros(
(bs + 1,), dtype=torch.int32, device=self.device
)
emb_seg_indptr[1:] = torch.cumsum(emb_seg_lens, dim=0)

embedding_batch_info = LoRABatchInfo(
bs=bs,
num_segments=bs,
max_len=emb_max_len,
use_cuda_graph=False,
seg_lens=emb_seg_lens,
seg_indptr=emb_seg_indptr,
weight_indices=torch.empty(
(bs,), dtype=torch.int32, device=self.device
),
lora_ranks=torch.empty(
(self.max_loras_per_batch,),
dtype=torch.int32,
device=self.device,
),
scalings=torch.empty(
(self.max_loras_per_batch,),
dtype=torch.float,
device=self.device,
),
permutation=None,
)

# Copy common data to embedding_batch_info (reuse already-created tensors)
weight_indices_for_embedding = torch.tensor(
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
)
embedding_batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
lora_ranks_tensor, non_blocking=True
)
embedding_batch_info.scalings[: self.max_loras_per_batch].copy_(
scalings_tensor, non_blocking=True
)
embedding_batch_info.weight_indices[:bs].copy_(
weight_indices_for_embedding, non_blocking=True
)

self.embedding_batch_info = embedding_batch_info

@staticmethod
def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch):
"""
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,16 @@ def init_cuda_graph_batch_info(
self, max_bs_in_cuda_graph: int, num_tokens_per_bs: int
):
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph

# Check if target_modules includes embedding layers
has_embedding_layers = (
"embed_tokens" in self.target_modules or "lm_head" in self.target_modules
)

self.lora_backend.init_cuda_graph_batch_info(
max_bs_in_cuda_graph=max_bs_in_cuda_graph,
num_tokens_per_bs=num_tokens_per_bs,
has_embedding_layers=has_embedding_layers,
)

def create_lora_update_result(
Expand Down
11 changes: 0 additions & 11 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4462,17 +4462,6 @@ def check_lora_server_args(self):
), "If 'all' is specified in --lora-target-modules, it should be the only module specified."
self.lora_target_modules = set(SUPPORTED_LORA_TARGET_MODULES)

# When using the chunked SGMV backend, skip embedding / lm_head layers for now,
# since it does not support these yet (TODO: implement embedding / lm_head support)
if self.lora_backend == "csgmv":
logger.warning(
"LoRA backend 'csgmv' does not yet support embedding or lm_head layers; "
"dropping 'embed_tokens' and 'lm_head' from --lora-target-modules=all. "
"To apply LoRA to these, use --lora-backend triton."
)
self.lora_target_modules.discard("embed_tokens")
self.lora_target_modules.discard("lm_head")

# Ensure sufficient information is provided for LoRA initialization.
assert self.lora_paths or (
self.max_lora_rank and self.lora_target_modules
Expand Down
Loading
Loading