From 2766611b294ea7cda920666fc7c0ae6b3c8f073c Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Tue, 25 Nov 2025 04:18:06 +0000 Subject: [PATCH 01/19] support lora emb: --disable-cuda-graph, without extra token, no tp --- python/sglang/srt/lora/layers.py | 519 ++++++++++++++++++++++++- python/sglang/srt/lora/lora.py | 19 + python/sglang/srt/lora/lora_config.py | 4 - python/sglang/srt/lora/lora_manager.py | 52 +++ python/sglang/srt/lora/mem_pool.py | 120 ++++++ python/sglang/srt/lora/utils.py | 29 ++ python/sglang/srt/utils/common.py | 2 + 7 files changed, 735 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 139d97cbca31..f43ec5a461b9 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -1,5 +1,7 @@ import torch from torch import nn +import torch.nn.functional as F +from typing import Optional from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -13,7 +15,7 @@ QKVParallelLinear, RowParallelLinear, ) -from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead from sglang.srt.lora.backend.base_backend import BaseLoRABackend @@ -42,14 +44,292 @@ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): pass +############################## +##########emb lora############ +############################## +# class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): +# """ +# Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation). + +# Note: The current version does not yet implement the LoRA functionality. +# This class behaves exactly the same as the base VocabParallelEmbedding. +# Future versions will integrate LoRA functionality to support efficient parameter fine-tuning. +# """ + +# def __init__( +# self, +# base_layer: VocabParallelEmbedding, +# lora_backend: BaseLoRABackend, +# ) -> None: +# super().__init__(base_layer, lora_backend) +# self.weight = base_layer.weight + + +# ##### ----- +# ##### ----- +# ##### ----- +# class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): +# """ +# Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation). + +# This layer supports LoRA adapters on embedding layers, including handling +# of extra tokens added by LoRA adapters. The implementation uses efficient +# embedding lookup instead of one-hot encoding. + +# For embedding layers: output = base_embedding(x) + lora_B @ lora_A[x] +# where lora_A[x] is direct embedding lookup from lora_A weights. +# """ + +# def __init__( +# self, +# base_layer: VocabParallelEmbedding, +# lora_backend: BaseLoRABackend, +# ) -> None: +# super().__init__(base_layer, lora_backend) +# self.weight = base_layer.weight +# self.embed_dim = base_layer.embedding_dim +# self.vocab_size = base_layer.org_vocab_size + +# def set_lora_info( +# self, +# new_embeddings_buffer: Optional[torch.Tensor], +# embedding_A_buffer: torch.Tensor, +# embedding_B_buffer: torch.Tensor, +# ): +# self.set_lora = True +# self.new_embeddings_buffer = new_embeddings_buffer # For extra tokens +# self.embedding_A_buffer = embedding_A_buffer +# self.embedding_B_buffer = embedding_B_buffer + +# def _get_token_weight_indices( +# self, input_: torch.Tensor, batch_info +# ) -> torch.Tensor: +# """Map each token position to its corresponding LoRA adapter index.""" +# token_weight_indices = torch.zeros( +# input_.shape[0], dtype=torch.int32, device=input_.device +# ) + +# current_pos = 0 +# for i in range(batch_info.bs): +# seg_len = int(batch_info.seg_lens[i]) +# weight_idx = int(batch_info.weight_indices[i]) +# token_weight_indices[current_pos : current_pos+seg_len] = weight_idx +# current_pos += seg_len + +# return token_weight_indices + +# def _run_lora_a_embedding( +# self, input_: torch.Tensor, token_weight_indices: torch.Tensor +# ) -> torch.Tensor: +# """ +# Apply LoRA A weights using efficient embedding lookup. +# This avoids creating one-hot vectors. +# """ +# lora_a_output = torch.zeros( +# (input_.shape[0], self.embedding_A_buffer.shape[1]), +# dtype=self.embedding_A_buffer.dtype, +# device=input_.device, +# ) + +# unique_weight_indices = torch.unique(token_weight_indices) + +# for idx in unique_weight_indices: +# token_mask = token_weight_indices == idx +# lora_a_weights = self.embedding_A_buffer[idx] +# # Use F.embedding for efficient lookup instead of one-hot @ weights +# # lora_a_weights shape: (rank, vocab_size) +# # We need (vocab_size, rank) for embedding lookup +# lora_a_output[token_mask] = F.embedding( +# input_[token_mask], lora_a_weights.t() +# ) + +# return lora_a_output + +# def apply_lora( +# self, base_output: torch.Tensor, input_: torch.Tensor, batch_info +# ) -> torch.Tensor: +# """ +# Apply LoRA to base embedding output. +# Formula: output = base_output + lora_B @ lora_A_embedding(input_) +# """ +# token_weight_indices = self._get_token_weight_indices(input_, batch_info) + +# # Efficient embedding lookup for LoRA A +# lora_a_output = self._run_lora_a_embedding(input_, token_weight_indices) + +# # Apply LoRA B weights +# lora_output = self.lora_backend.run_lora_b_sgemm( +# x=lora_a_output, +# weights=self.embedding_B_buffer, +# base_output=base_output, +# ) +# return lora_output + +# def _forward( +# self, +# input_: torch.Tensor, +# added_tokens_mask: torch.Tensor, +# batch_info, +# base_output: torch.Tensor, +# ) -> torch.Tensor: +# """Handle extra tokens that are beyond the base vocabulary.""" +# token_weight_indices = self._get_token_weight_indices(input_, batch_info) +# added_weight_indices = token_weight_indices[added_tokens_mask] +# unique_added_weight_indices = torch.unique(added_weight_indices) + +# for idx in unique_added_weight_indices: +# lora_mask = added_weight_indices == idx +# added_token_positions = torch.where(added_tokens_mask)[0][lora_mask] +# # Remap to extra token range +# x = input_[added_token_positions] - self.vocab_size +# new_embeddings = F.embedding(x, self.new_embeddings_buffer[idx]) +# base_output[added_token_positions] = new_embeddings + +# return base_output + +# def forward(self, input_: torch.Tensor): +# batch_info = self.lora_backend.batch_info + +# # Mask tokens that are beyond base vocabulary (extra tokens) +# added_tokens_mask = input_ > self.vocab_size - 1 + +# # Get base embedding, masking extra tokens temporarily +# base_output = self.base_layer.forward(input_.masked_fill(added_tokens_mask, 0)) + +# # Handle extra tokens +# if added_tokens_mask.any(): +# base_output = self._forward( +# input_, added_tokens_mask, batch_info, base_output +# ) + +# # Apply LoRA if configured +# if self.set_lora: +# output = self.apply_lora(base_output, input_, batch_info) +# else: +# output = base_output + +# return output + +# def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): +# # LoRA A weights (rank, vocab_size) are not sliced for embedding +# # because each token needs access to full vocabulary +# return A + +# def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): +# ## LoRA B weights (embedding_dim, rank) are sliced along embedding dimension +# # from sglang.srt.distributed import divide +# # shard_size = divide(self.base_layer.embedding_dim, self.base_layer.tp_size) +# # start_idx = tp_rank * shard_size +# # end_idx = (tp_rank + 1) * shard_size +# # B = B[start_idx:end_idx, :] + +# # TP = 1 +# return B + + + +# class ParallelLMHeadWithLoRA(BaseLayerWithLoRA): +# """ +# Parallel LM Head layer with support for LoRA. + +# The LM head computes logits = hidden_states @ (W + B @ A)^T +# This is different from embedding which uses lookup operations. + +# Note: This class is NOT in the official SGLang implementation. +# You may need to verify if LM head LoRA is needed for your use case. +# """ + +# def __init__( +# self, +# base_layer, # ParallelLMHead +# lora_backend: BaseLoRABackend, +# ) -> None: +# super().__init__(base_layer, lora_backend) +# self.weight = base_layer.weight +# self.embed_dim = base_layer.embedding_dim +# self.vocab_size = base_layer.org_vocab_size + +# def set_lora_info( +# self, +# lm_head_A_buffer: torch.Tensor, +# lm_head_B_buffer: torch.Tensor, +# ): +# self.set_lora = True +# self.lm_head_A_buffer = lm_head_A_buffer +# self.lm_head_B_buffer = lm_head_B_buffer + +# def apply_lora( +# self, base_output: torch.Tensor, hidden_states: torch.Tensor +# ) -> torch.Tensor: +# """ +# Apply LoRA to LM head layer. + +# Args: +# base_output: Base logits, shape (batch_size, vocab_size) +# hidden_states: Hidden states, shape (batch_size, hidden_dim) + +# Returns: +# Logits with LoRA applied +# """ +# # For LM head: output = hidden @ (W + B @ A)^T +# # = hidden @ W^T + hidden @ A^T @ B^T +# # = base_output + (hidden @ A^T) @ B^T + +# # Apply lora_A^T: hidden_states @ A^T +# # lm_head_A_buffer shape: (num_loras, rank, hidden_dim) +# lora_a_output = self.lora_backend.run_lora_a_sgemm( +# hidden_states, self.lm_head_A_buffer +# ) + +# # Apply lora_B^T: lora_a_output @ B^T +# # lm_head_B_buffer shape: (num_loras, vocab_size, rank) +# lora_output = self.lora_backend.run_lora_b_sgemm( +# x=lora_a_output, +# weights=self.lm_head_B_buffer, +# base_output=base_output, +# ) + +# return lora_output + +# def forward(self, hidden_states: torch.Tensor): +# # Apply base linear transformation +# base_output = F.linear( +# hidden_states, +# self.weight, +# bias=getattr(self.base_layer, 'bias', None) +# ) + +# # Apply LoRA if set +# if self.set_lora: +# base_output = self.apply_lora(base_output, hidden_states) + +# return base_output + +# def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): +# # LoRA A is not sliced (similar to ColumnParallelLinear) +# return A + +# def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): +# # LoRA B is sliced along vocab dimension (output dimension) +# # Similar to ColumnParallelLinear slicing +# # from sglang.srt.distributed import divide +# # shard_size = divide(self.vocab_size, self.base_layer.tp_size) +# # start_idx = tp_rank * shard_size +# # end_idx = (tp_rank + 1) * shard_size +# # B = B[start_idx:end_idx, :] +# # TP=1 +# return B +##### ----- +##### ----- +##### ----- + class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): """ - Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation). - - Note: The current version does not yet implement the LoRA functionality. - This class behaves exactly the same as the base VocabParallelEmbedding. - Future versions will integrate LoRA functionality to support efficient parameter fine-tuning. + Vocab parallel embedding layer with LoRA support (simplified for TP=1, no extra tokens). + + For embedding layers: output = base_embedding(x) + lora_B @ lora_A[x] + where lora_A[x] is direct embedding lookup from lora_A weights. """ def __init__( @@ -59,6 +339,186 @@ def __init__( ) -> None: super().__init__(base_layer, lora_backend) self.weight = base_layer.weight + self.embed_dim = base_layer.embedding_dim + self.vocab_size = base_layer.org_vocab_size + + def set_lora_info( + self, + embedding_A_buffer: torch.Tensor, + embedding_B_buffer: torch.Tensor, + ): + """Set LoRA buffers for embedding layer.""" + self.set_lora = True + self.embedding_A_buffer = embedding_A_buffer # (num_loras, rank, vocab_size) + self.embedding_B_buffer = embedding_B_buffer # (num_loras, embed_dim, rank) + + def apply_lora( + self, base_output: torch.Tensor, input_: torch.Tensor, batch_info + ) -> torch.Tensor: + """ + Apply LoRA to base embedding output. + Formula: output = base_output + lora_B @ lora_A_embedding(input_) + """ + # Get token-to-lora mapping + token_weight_indices = self._get_token_weight_indices(input_, batch_info) + + # Efficient embedding lookup for LoRA A + lora_a_output = self._run_lora_a_embedding(input_, token_weight_indices) + + # Apply LoRA B weights using backend + lora_output = self.lora_backend.run_lora_b_sgemm( + x=lora_a_output, + weights=self.embedding_B_buffer, + base_output=base_output, + ) + return lora_output + + def _get_token_weight_indices( + self, input_: torch.Tensor, batch_info + ) -> torch.Tensor: + """Map each token position to its corresponding LoRA adapter index.""" + token_weight_indices = torch.zeros( + input_.shape[0], dtype=torch.int32, device=input_.device + ) + + current_pos = 0 + for i in range(batch_info.bs): + seg_len = int(batch_info.seg_lens[i]) # Convert tensor to int + weight_idx = int(batch_info.weight_indices[i]) # Convert tensor to int + token_weight_indices[current_pos : current_pos+seg_len] = weight_idx + current_pos += seg_len + + return token_weight_indices + + def _run_lora_a_embedding( + self, input_: torch.Tensor, token_weight_indices: torch.Tensor + ) -> torch.Tensor: + """ + Apply LoRA A weights using efficient embedding lookup. + This avoids creating one-hot vectors. + """ + lora_a_output = torch.zeros( + (input_.shape[0], self.embedding_A_buffer.shape[1]), + dtype=self.embedding_A_buffer.dtype, + device=input_.device, + ) + + unique_weight_indices = torch.unique(token_weight_indices) + + for idx in unique_weight_indices: + idx_val = idx.item() # Convert tensor to int + token_mask = token_weight_indices == idx + lora_a_weights = self.embedding_A_buffer[idx_val] # (rank, vocab_size) + # Use F.embedding for efficient lookup + # lora_a_weights.t() gives us (vocab_size, rank) + lora_a_output[token_mask] = F.embedding( + input_[token_mask], lora_a_weights.t() + ) + + return lora_a_output + + def forward(self, input_: torch.Tensor): + # Get base embedding output + base_output = self.base_layer.forward(input_) + + # Apply LoRA if configured + if self.set_lora: + batch_info = self.lora_backend.batch_info + output = self.apply_lora(base_output, input_, batch_info) + else: + output = base_output + + return output + + def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): + # For TP=1, no slicing needed + # LoRA A weights (rank, vocab_size) are not sliced for embedding + return A + + def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): + # For TP=1, no slicing needed + # LoRA B weights (embedding_dim, rank) would be sliced along embedding dimension for TP>1 + return B + + +class ParallelLMHeadWithLoRA(BaseLayerWithLoRA): + """ + Parallel LM Head layer with LoRA support (simplified for TP=1). + + The LM head computes logits = hidden_states @ (W + B @ A)^T + """ + + def __init__( + self, + base_layer: ParallelLMHead, + lora_backend: BaseLoRABackend, + ) -> None: + super().__init__(base_layer, lora_backend) + self.weight = base_layer.weight + self.embed_dim = base_layer.embedding_dim + self.vocab_size = base_layer.org_vocab_size + + def set_lora_info( + self, + lm_head_A_buffer: torch.Tensor, + lm_head_B_buffer: torch.Tensor, + ): + """Set LoRA buffers for LM head layer.""" + self.set_lora = True + self.lm_head_A_buffer = lm_head_A_buffer # (num_loras, rank, hidden_dim) + self.lm_head_B_buffer = lm_head_B_buffer # (num_loras, vocab_size, rank) + + def apply_lora( + self, base_output: torch.Tensor, hidden_states: torch.Tensor + ) -> torch.Tensor: + """ + Apply LoRA to LM head layer. + + For LM head: output = hidden @ (W + B @ A)^T + = hidden @ W^T + hidden @ A^T @ B^T + = base_output + (hidden @ A^T) @ B^T + """ + # Apply lora_A^T: hidden_states @ A^T + lora_a_output = self.lora_backend.run_lora_a_sgemm( + hidden_states, self.lm_head_A_buffer + ) + + # Apply lora_B^T: lora_a_output @ B^T + lora_output = self.lora_backend.run_lora_b_sgemm( + x=lora_a_output, + weights=self.lm_head_B_buffer, + base_output=base_output, + ) + + return lora_output + + def forward(self, hidden_states: torch.Tensor): + # Apply base linear transformation + base_output = F.linear( + hidden_states, + self.weight, + bias=getattr(self.base_layer, 'bias', None) + ) + + # Apply LoRA if set + if self.set_lora: + base_output = self.apply_lora(base_output, hidden_states) + + return base_output + + def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): + # For TP=1, no slicing needed + return A + + def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): + # For TP=1, no slicing needed + # For TP>1, would slice along vocab dimension + return B + + +############################## +############################## +############################## class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): @@ -337,12 +797,56 @@ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): return B +############################## +##########emb lora############ +############################## +# def get_lora_layer( +# layer: nn.Module, lora_backend: BaseLoRABackend +# ) -> BaseLayerWithLoRA: +# supported_layer_types = { +# # the order matters +# VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, +# QKVParallelLinear: QKVParallelLinearWithLoRA, +# MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, +# ColumnParallelLinear: ColumnParallelLinearWithLoRA, +# RowParallelLinear: RowParallelLinearWithLoRA, +# } +# for src_layer_type, lora_layer_type in supported_layer_types.items(): +# if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck +# ret = lora_layer_type(layer, lora_backend) +# return ret +# raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") + + + +# def get_lora_layer( +# layer: nn.Module, lora_backend: BaseLoRABackend, lora_extra_vocab_size: int = 0 +# ) -> BaseLayerWithLoRA: + +# supported_layer_types = { +# # the order matters - check ParallelLMHead before VocabParallelEmbedding +# # since ParallelLMHead is a subclass of VocabParallelEmbedding +# ParallelLMHead: lambda l, b: ParallelLMHeadWithLoRA(l, b, lora_extra_vocab_size), +# VocabParallelEmbedding: lambda l, b: VocabParallelEmbeddingWithLoRA(l, b, lora_extra_vocab_size), +# QKVParallelLinear: lambda l, b: QKVParallelLinearWithLoRA(l, b), +# MergedColumnParallelLinear: lambda l, b: MergedColumnParallelLinearWithLoRA(l, b), +# ColumnParallelLinear: lambda l, b: ColumnParallelLinearWithLoRA(l, b), +# RowParallelLinear: lambda l, b: RowParallelLinearWithLoRA(l, b), +# } +# for src_layer_type, lora_layer_factory in supported_layer_types.items(): +# if isinstance(layer, src_layer_type): +# ret = lora_layer_factory(layer, lora_backend) +# return ret +# raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") + + def get_lora_layer( layer: nn.Module, lora_backend: BaseLoRABackend ) -> BaseLayerWithLoRA: supported_layer_types = { # the order matters + ParallelLMHead: ParallelLMHeadWithLoRA, VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, QKVParallelLinear: QKVParallelLinearWithLoRA, MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, @@ -354,3 +858,6 @@ def get_lora_layer( ret = lora_layer_type(layer, lora_backend) return ret raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") +############################## +############################## +############################## diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 995aca6e5e36..9f7b4d85bc25 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -71,6 +71,15 @@ def __init__( ] ) + ############################## + ##########emb lora############ + ############################## + self.embedding_layer = LoRALayer(config, base_hf_config) + self.lm_head_layer = LoRALayer(config, base_hf_config) + ############################## + ############################## + ############################## + # initialize the LoRA weights to cpu def initialize_weights(self): model_path = self.config.path @@ -84,6 +93,16 @@ def initialize_weights(self): layer_id = get_layer_id(name) if layer_id is not None: self.layers[layer_id].weights[name] = loaded_weight.cpu() + ############################## + ##########emb lora############ + ############################## + elif "embed_tokens" in name: + self.embedding_layer.weights[name] = loaded_weight.cpu() + elif "lm_head" in name: + self.lm_head_layer.weights[name] = loaded_weight.cpu() + ############################## + ############################## + ############################## # normalize kv_proj and gate_up_proj for layer in self.layers: diff --git a/python/sglang/srt/lora/lora_config.py b/python/sglang/srt/lora/lora_config.py index 185b7b8246ee..98a28b01b089 100644 --- a/python/sglang/srt/lora/lora_config.py +++ b/python/sglang/srt/lora/lora_config.py @@ -27,10 +27,6 @@ def __init__( self.hf_config = self.get_lora_config() self.target_modules = self.hf_config["target_modules"] - # TODO: Support more modules - if any(module in self.target_modules for module in ["embed_tokens", "lm_head"]): - raise ValueError("Not supported yet") - self.r = self.hf_config["r"] self.lora_alpha = self.hf_config["lora_alpha"] diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 188eb1e9e3e8..38612b195df6 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -40,6 +40,8 @@ from sglang.srt.utils import is_npu, replace_submodule from sglang.srt.utils.hf_transformers_utils import AutoConfig +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead + if is_npu(): from torch_npu.contrib import transfer_to_npu # noqa: F401 @@ -299,6 +301,26 @@ def update_lora_info(self): ), ) + ############################## + ##########emb lora############ + ############################## + # Update embedding layer if present + if self.embed_tokens_module is not None and hasattr(self.memory_pool, 'embedding_A_buffer') and self.memory_pool.embedding_A_buffer is not None: + self.embed_tokens_module.set_lora_info( + self.memory_pool.embedding_A_buffer, + self.memory_pool.embedding_B_buffer, + ) + + # Update lm_head layer if present + if self.lm_head_module is not None and hasattr(self.memory_pool, 'lm_head_A_buffer') and self.memory_pool.lm_head_A_buffer is not None: + self.lm_head_module.set_lora_info( + self.memory_pool.lm_head_A_buffer, + self.memory_pool.lm_head_B_buffer, + ) + ############################## + ############################## + ############################## + def init_state( self, max_lora_rank: Optional[int] = None, @@ -432,6 +454,16 @@ def init_lora_modules(self): {} for _ in range(self.base_hf_config.num_hidden_layers) ] + ############################## + ##########emb lora############ + ############################## + self.embed_tokens_module: Optional[BaseLayerWithLoRA] = None + self.lm_head_module: Optional[BaseLayerWithLoRA] = None + ############################## + ############################## + ############################## + + for module_name, module in self.base_model.named_modules(): # TODO (lifuhuang): in the future, we should consider generalizing the # should_apply_lora function to support mapping by full module name instead @@ -443,6 +475,26 @@ def init_lora_modules(self): ) and not self.base_model.should_apply_lora(module_name): continue + ############################## + ##########emb lora############ + ############################## + # Handle embed_tokens + if "embed_tokens" in module_name and "embed_tokens" in self.target_modules: + if isinstance(module, VocabParallelEmbedding) and not isinstance(module, BaseLayerWithLoRA): + lora_module = self.set_lora_module(module_name, module) + self.embed_tokens_module = lora_module + continue + + # Handle lm_head + if "lm_head" in module_name and "lm_head" in self.target_modules: + if isinstance(module, ParallelLMHead) and not isinstance(module, BaseLayerWithLoRA): + lora_module = self.set_lora_module(module_name, module) + self.lm_head_module = lora_module + continue + ############################## + ############################## + ############################## + # The module should be converted if it is included in target_names if module_name.split(".")[-1] in self.target_modules: layer_id = get_layer_id(module_name) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index f6375361700e..cf6d406263e5 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -77,6 +77,18 @@ def __init__( self.A_buffer: Dict[str, List[torch.Tensor]] = {} self.B_buffer: Dict[str, List[torch.Tensor]] = {} + ############################## + ##########emb lora############ + ############################## + # NEW: Buffers for embedding and lm_head (not per-layer) + self.embedding_A_buffer: Optional[torch.Tensor] = None + self.embedding_B_buffer: Optional[torch.Tensor] = None + self.lm_head_A_buffer: Optional[torch.Tensor] = None + self.lm_head_B_buffer: Optional[torch.Tensor] = None + ############################## + ############################## + ############################## + # Lora uid -> buffer idx in memory pool self.uid_to_buffer_id: Dict[Optional[str], int] = {} @@ -186,6 +198,50 @@ def init_buffer( self.get_lora_B_shape, ) + ############################## + ##########emb lora############ + ############################## + # Initialize embedding buffers if embed_tokens is in target_modules + if "embed_tokens" in self.target_modules: + vocab_size = self.base_hf_config.vocab_size + hidden_size = self.base_hf_config.hidden_size + + # embedding_A: (max_loras_per_batch, max_rank, vocab_size) + self.embedding_A_buffer = torch.empty( + (self.max_loras_per_batch, self.max_lora_rank, vocab_size), + dtype=self.dtype, + device=device, + ) + + # embedding_B: (max_loras_per_batch, hidden_size, max_rank) + self.embedding_B_buffer = torch.empty( + (self.max_loras_per_batch, hidden_size, self.max_lora_rank), + dtype=self.dtype, + device=device, + ) + + # Initialize lm_head buffers if lm_head is in target_modules + if "lm_head" in self.target_modules: + vocab_size = self.base_hf_config.vocab_size + hidden_size = self.base_hf_config.hidden_size + + # lm_head_A: (max_loras_per_batch, max_rank, hidden_size) + self.lm_head_A_buffer = torch.empty( + (self.max_loras_per_batch, self.max_lora_rank, hidden_size), + dtype=self.dtype, + device=device, + ) + + # lm_head_B: (max_loras_per_batch, vocab_size, max_rank) + self.lm_head_B_buffer = torch.empty( + (self.max_loras_per_batch, vocab_size, self.max_lora_rank), + dtype=self.dtype, + device=device, + ) + ############################## + ############################## + ############################## + def prepare_lora_batch( self, cur_uids: Set[Optional[str]], @@ -277,6 +333,70 @@ def load_lora_weight_tensor( assert lora_adapter is not None lora_rank = lora_adapter.config.r + + ############################## + ##########emb lora############ + ############################## + + # Handle embedding weights (not per-layer) + if "embed_tokens" in self.target_modules: + embedding_A = None + embedding_B = None + + # Look for embedding weights in layer 0 (embeddings are usually stored there) + if lora_adapter.layers: + layer_weights = lora_adapter.layers[0].weights + for name, weights in layer_weights.items(): + if "embed_tokens" in name or "model.embed_tokens" in name: + if "lora_A" in name: + embedding_A = weights + elif "lora_B" in name: + embedding_B = weights + + # Load into buffers + if embedding_A is not None: + buffer_view = self.embedding_A_buffer[buffer_id, :lora_rank, :] + buffer_view.copy_(embedding_A) + else: + self.embedding_A_buffer[buffer_id].zero_() + + if embedding_B is not None: + buffer_view = self.embedding_B_buffer[buffer_id, :, :lora_rank] + buffer_view.copy_(embedding_B) + else: + self.embedding_B_buffer[buffer_id].zero_() + + # Handle lm_head weights (not per-layer) + if "lm_head" in self.target_modules: + lm_head_A = None + lm_head_B = None + + # Look for lm_head weights + if lora_adapter.layers: + layer_weights = lora_adapter.layers[0].weights + for name, weights in layer_weights.items(): + if "lm_head" in name: + if "lora_A" in name: + lm_head_A = weights + elif "lora_B" in name: + lm_head_B = weights + + # Load into buffers + if lm_head_A is not None: + buffer_view = self.lm_head_A_buffer[buffer_id, :lora_rank, :] + buffer_view.copy_(lm_head_A) + else: + self.lm_head_A_buffer[buffer_id].zero_() + + if lm_head_B is not None: + buffer_view = self.lm_head_B_buffer[buffer_id, :, :lora_rank] + buffer_view.copy_(lm_head_B) + else: + self.lm_head_B_buffer[buffer_id].zero_() + ############################## + ############################## + ############################## + for layer_id in range(self.num_layer): layer_weights = lora_adapter.layers[layer_id].weights temp_A_buffer: Dict[str, Optional[torch.Tensor]] = { diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 48a450d9b468..7432bfa1eafc 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -51,6 +51,7 @@ def get_hidden_dim( """ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. """ + if hasattr(base_model, "get_hidden_dim"): return base_model.get_hidden_dim(module_name, layer_idx) @@ -78,6 +79,22 @@ def get_hidden_dim( return config.hidden_size, config.intermediate_size * 2 elif module_name == "down_proj": return config.intermediate_size, config.hidden_size + + ############################## + ##########emb lora############ + ############################## + #Handle embed_tokens + elif "embed_tokens" in module_name: + # For embedding: input is vocab_size (as embedding lookup), output is hidden_size + return config.vocab_size, config.hidden_size + + #Handle lm_head + elif "lm_head" in module_name: + # For lm_head: input is hidden_size, output is vocab_size + return config.hidden_size, config.vocab_size + ############################## + ############################## + ############################## else: raise NotImplementedError() @@ -95,6 +112,18 @@ def get_normalized_target_modules( "v_proj": "qkv_proj", "gate_proj": "gate_up_proj", "up_proj": "gate_up_proj", + ############################## + ##########emb lora############ + ############################## + "embed_tokens": "embed_tokens", + "vocab_emb": "embed_tokens", + "embeddings": "embed_tokens", + "word_embeddings": "embed_tokens", + "lm_head": "lm_head", + "output": "lm_head", + ############################## + ############################## + ############################## } result = set() diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index dae29f89c8d1..38b9c2b463a5 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -3283,6 +3283,8 @@ def is_gfx95_supported(): "down_proj", "qkv_proj", "gate_up_proj", + "embed_tokens", + "lm_head", ] LORA_TARGET_ALL_MODULES = "all" From 04292bd21ae6e7510f711acb93283fa7baac5388 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Wed, 26 Nov 2025 00:18:20 +0000 Subject: [PATCH 02/19] update --- python/sglang/srt/lora/layers.py | 92 +++++++++++++++++++++++++++----- 1 file changed, 78 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index f43ec5a461b9..b606d58b9dc9 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -44,9 +44,9 @@ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): pass -############################## -##########emb lora############ -############################## +############################# +#########emb lora############ +############################# # class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): # """ # Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation). @@ -65,9 +65,9 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): # self.weight = base_layer.weight -# ##### ----- -# ##### ----- -# ##### ----- +##### ----- +##### ----- +##### ----- # class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): # """ # Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation). @@ -109,13 +109,25 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): # input_.shape[0], dtype=torch.int32, device=input_.device # ) -# current_pos = 0 +# # current_pos = 0 +# # for i in range(batch_info.bs): +# # seg_len = int(batch_info.seg_lens[i]) +# # weight_idx = int(batch_info.weight_indices[i]) +# # token_weight_indices[current_pos : current_pos+seg_len] = weight_idx +# # current_pos += seg_len + +# # Use cumsum for positions - avoid Python loops +# seg_lens = batch_info.seg_lens[:batch_info.bs] # (bs,) +# cum_lens = torch.cumsum(seg_lens, dim=0) # cumulative positions +# start_positions = torch.cat([torch.zeros(1, dtype=cum_lens.dtype, device=cum_lens.device), cum_lens[:-1]]) + +# # Vectorized assignment using tensor operations - allow enable cuda-graph # for i in range(batch_info.bs): -# seg_len = int(batch_info.seg_lens[i]) -# weight_idx = int(batch_info.weight_indices[i]) -# token_weight_indices[current_pos : current_pos+seg_len] = weight_idx -# current_pos += seg_len - +# start = start_positions[i] +# end = cum_lens[i] +# weight_idx = batch_info.weight_indices[i] +# token_weight_indices[start:end] = weight_idx + # return token_weight_indices # def _run_lora_a_embedding( @@ -381,18 +393,35 @@ def _get_token_weight_indices( input_.shape[0], dtype=torch.int32, device=input_.device ) + ####################### + ####################### current_pos = 0 for i in range(batch_info.bs): seg_len = int(batch_info.seg_lens[i]) # Convert tensor to int weight_idx = int(batch_info.weight_indices[i]) # Convert tensor to int token_weight_indices[current_pos : current_pos+seg_len] = weight_idx current_pos += seg_len + + # -------- # + + # # Use repeat_interleave to map segment-level indices to token-level indices + # # This is CUDA graph compatible + # num_segments = batch_info.num_segments + # seg_lens = batch_info.seg_lens[:num_segments] + # weight_indices = batch_info.weight_indices[:num_segments] + + # # Vectorized assignment using tensor operations - allow enable cuda-graph + # token_weight_indices = weight_indices.repeat_interleave(seg_lens) + ####################### + ####################### return token_weight_indices def _run_lora_a_embedding( self, input_: torch.Tensor, token_weight_indices: torch.Tensor ) -> torch.Tensor: + ##################### + ##################### """ Apply LoRA A weights using efficient embedding lookup. This avoids creating one-hot vectors. @@ -405,15 +434,50 @@ def _run_lora_a_embedding( unique_weight_indices = torch.unique(token_weight_indices) + # to enable cuda-graph - prevent from using int for idx in unique_weight_indices: - idx_val = idx.item() # Convert tensor to int token_mask = token_weight_indices == idx - lora_a_weights = self.embedding_A_buffer[idx_val] # (rank, vocab_size) + lora_a_weights = self.embedding_A_buffer[idx] # (rank, vocab_size) # Use F.embedding for efficient lookup # lora_a_weights.t() gives us (vocab_size, rank) lora_a_output[token_mask] = F.embedding( input_[token_mask], lora_a_weights.t() ) + + # -------- # + + # num_tokens = input_.shape[0] + # rank = self.embedding_A_buffer.shape[1] + + # # embedding_A_buffer shape: (num_loras, rank, vocab_size) + # # token_weight_indices shape: (num_tokens,) + # # input_ shape: (num_tokens,) + + # # Gather LoRA A weights for each token's assigned LoRA adapter + # # lora_a_weights shape: (num_tokens, rank, vocab_size) + # lora_a_weights = self.embedding_A_buffer[token_weight_indices] + + # # Now we need to apply embedding lookup for each token + # # lora_a_weights[i] is (rank, vocab_size) for token i + # # We want to lookup input_[i] in lora_a_weights[i].t() which is (vocab_size, rank) + + # # Transpose to (num_tokens, vocab_size, rank) for embedding lookup + # lora_a_weights_t = lora_a_weights.transpose(1, 2) + + # # Use batched embedding lookup + # # For each token i, lookup input_[i] in lora_a_weights_t[i] + # input_expanded = input_.unsqueeze(1) # (num_tokens, 1) + + # # Use gather to simulate embedding lookup + # # lora_a_weights_t[i, input_[i], :] gives us the embedding for token i + # token_indices = input_.unsqueeze(-1).unsqueeze(-1).expand(-1, 1, rank) # (num_tokens, 1, rank) + # lora_a_output = torch.gather( + # lora_a_weights_t, + # 1, + # token_indices + # ).squeeze(1) # (num_tokens, rank) + ##################### + ##################### return lora_a_output From 377fd2d5e0238ac5ebb9046f9a733486c047d060 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Wed, 26 Nov 2025 00:46:00 +0000 Subject: [PATCH 03/19] refactor layers.py --> VocabParallelEmbeddingWithLoRA --> _run_lora_a_embedding --- python/sglang/srt/lora/layers.py | 91 +++++++------------------------- 1 file changed, 18 insertions(+), 73 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index b606d58b9dc9..16dbe79f717d 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -371,11 +371,9 @@ def apply_lora( Apply LoRA to base embedding output. Formula: output = base_output + lora_B @ lora_A_embedding(input_) """ - # Get token-to-lora mapping - token_weight_indices = self._get_token_weight_indices(input_, batch_info) - - # Efficient embedding lookup for LoRA A - lora_a_output = self._run_lora_a_embedding(input_, token_weight_indices) + + # Efficient embedding lookup for LoRA A (cannot call run_lora_a_sgemm since needing index lookup) + lora_a_output = self._run_lora_a_embedding(input_, batch_info) # Apply LoRA B weights using backend lora_output = self.lora_backend.run_lora_b_sgemm( @@ -385,47 +383,29 @@ def apply_lora( ) return lora_output - def _get_token_weight_indices( + ############################## + ##########emb lora############ + ############################## + def _run_lora_a_embedding( self, input_: torch.Tensor, batch_info ) -> torch.Tensor: - """Map each token position to its corresponding LoRA adapter index.""" + """ + Apply LoRA A weights using efficient embedding lookup. + Maps tokens to their corresponding LoRA adapters internally. + """ + # Get token-to-lora mapping token_weight_indices = torch.zeros( input_.shape[0], dtype=torch.int32, device=input_.device ) - ####################### - ####################### current_pos = 0 for i in range(batch_info.bs): - seg_len = int(batch_info.seg_lens[i]) # Convert tensor to int - weight_idx = int(batch_info.weight_indices[i]) # Convert tensor to int + seg_len = int(batch_info.seg_lens[i]) + weight_idx = int(batch_info.weight_indices[i]) token_weight_indices[current_pos : current_pos+seg_len] = weight_idx current_pos += seg_len - - # -------- # - - # # Use repeat_interleave to map segment-level indices to token-level indices - # # This is CUDA graph compatible - # num_segments = batch_info.num_segments - # seg_lens = batch_info.seg_lens[:num_segments] - # weight_indices = batch_info.weight_indices[:num_segments] - # # Vectorized assignment using tensor operations - allow enable cuda-graph - # token_weight_indices = weight_indices.repeat_interleave(seg_lens) - ####################### - ####################### - - return token_weight_indices - - def _run_lora_a_embedding( - self, input_: torch.Tensor, token_weight_indices: torch.Tensor - ) -> torch.Tensor: - ##################### - ##################### - """ - Apply LoRA A weights using efficient embedding lookup. - This avoids creating one-hot vectors. - """ + # Apply embedding lookup for each LoRA adapter lora_a_output = torch.zeros( (input_.shape[0], self.embedding_A_buffer.shape[1]), dtype=self.embedding_A_buffer.dtype, @@ -434,52 +414,17 @@ def _run_lora_a_embedding( unique_weight_indices = torch.unique(token_weight_indices) - # to enable cuda-graph - prevent from using int for idx in unique_weight_indices: token_mask = token_weight_indices == idx lora_a_weights = self.embedding_A_buffer[idx] # (rank, vocab_size) - # Use F.embedding for efficient lookup - # lora_a_weights.t() gives us (vocab_size, rank) lora_a_output[token_mask] = F.embedding( input_[token_mask], lora_a_weights.t() ) - - # -------- # - - # num_tokens = input_.shape[0] - # rank = self.embedding_A_buffer.shape[1] - - # # embedding_A_buffer shape: (num_loras, rank, vocab_size) - # # token_weight_indices shape: (num_tokens,) - # # input_ shape: (num_tokens,) - - # # Gather LoRA A weights for each token's assigned LoRA adapter - # # lora_a_weights shape: (num_tokens, rank, vocab_size) - # lora_a_weights = self.embedding_A_buffer[token_weight_indices] - - # # Now we need to apply embedding lookup for each token - # # lora_a_weights[i] is (rank, vocab_size) for token i - # # We want to lookup input_[i] in lora_a_weights[i].t() which is (vocab_size, rank) - - # # Transpose to (num_tokens, vocab_size, rank) for embedding lookup - # lora_a_weights_t = lora_a_weights.transpose(1, 2) - - # # Use batched embedding lookup - # # For each token i, lookup input_[i] in lora_a_weights_t[i] - # input_expanded = input_.unsqueeze(1) # (num_tokens, 1) - - # # Use gather to simulate embedding lookup - # # lora_a_weights_t[i, input_[i], :] gives us the embedding for token i - # token_indices = input_.unsqueeze(-1).unsqueeze(-1).expand(-1, 1, rank) # (num_tokens, 1, rank) - # lora_a_output = torch.gather( - # lora_a_weights_t, - # 1, - # token_indices - # ).squeeze(1) # (num_tokens, rank) - ##################### - ##################### return lora_a_output + ############################## + ############################## + ############################## def forward(self, input_: torch.Tensor): # Get base embedding output From 67452afcf3c17a43ff89345732359c9d198ba1db Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Wed, 26 Nov 2025 22:47:52 +0000 Subject: [PATCH 04/19] refactor --- python/sglang/srt/lora/layers.py | 11 +--- python/sglang/srt/lora/lora_manager.py | 8 +-- python/sglang/srt/lora/mem_pool.py | 74 +++++++++++++++++--------- 3 files changed, 54 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 16dbe79f717d..efa870ea930a 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -383,9 +383,6 @@ def apply_lora( ) return lora_output - ############################## - ##########emb lora############ - ############################## def _run_lora_a_embedding( self, input_: torch.Tensor, batch_info ) -> torch.Tensor: @@ -393,7 +390,7 @@ def _run_lora_a_embedding( Apply LoRA A weights using efficient embedding lookup. Maps tokens to their corresponding LoRA adapters internally. """ - # Get token-to-lora mapping + # (Step1) Get token-to-lora mapping token_weight_indices = torch.zeros( input_.shape[0], dtype=torch.int32, device=input_.device ) @@ -405,7 +402,7 @@ def _run_lora_a_embedding( token_weight_indices[current_pos : current_pos+seg_len] = weight_idx current_pos += seg_len - # Apply embedding lookup for each LoRA adapter + # (Step2) Apply embedding lookup for each LoRA adapter lora_a_output = torch.zeros( (input_.shape[0], self.embedding_A_buffer.shape[1]), dtype=self.embedding_A_buffer.dtype, @@ -420,11 +417,7 @@ def _run_lora_a_embedding( lora_a_output[token_mask] = F.embedding( input_[token_mask], lora_a_weights.t() ) - return lora_a_output - ############################## - ############################## - ############################## def forward(self, input_: torch.Tensor): # Get base embedding output diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 38612b195df6..3469b36536df 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -307,15 +307,15 @@ def update_lora_info(self): # Update embedding layer if present if self.embed_tokens_module is not None and hasattr(self.memory_pool, 'embedding_A_buffer') and self.memory_pool.embedding_A_buffer is not None: self.embed_tokens_module.set_lora_info( - self.memory_pool.embedding_A_buffer, - self.memory_pool.embedding_B_buffer, + self.memory_pool.get_embedding_tensor("embed_tokens", LoRAType.LORA_A), + self.memory_pool.get_embedding_tensor("embed_tokens", LoRAType.LORA_B), ) # Update lm_head layer if present if self.lm_head_module is not None and hasattr(self.memory_pool, 'lm_head_A_buffer') and self.memory_pool.lm_head_A_buffer is not None: self.lm_head_module.set_lora_info( - self.memory_pool.lm_head_A_buffer, - self.memory_pool.lm_head_B_buffer, + self.memory_pool.get_embedding_tensor("lm_head", LoRAType.LORA_A), + self.memory_pool.get_embedding_tensor("lm_head", LoRAType.LORA_B), ) ############################## ############################## diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index cf6d406263e5..7b0f98093e6d 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -344,8 +344,10 @@ def load_lora_weight_tensor( embedding_B = None # Look for embedding weights in layer 0 (embeddings are usually stored there) - if lora_adapter.layers: - layer_weights = lora_adapter.layers[0].weights + # if lora_adapter.layers: + if hasattr(lora_adapter, 'embedding_layer'): + # layer_weights = lora_adapter.layers[0].weights + layer_weights = lora_adapter.embedding_layer.weights for name, weights in layer_weights.items(): if "embed_tokens" in name or "model.embed_tokens" in name: if "lora_A" in name: @@ -354,17 +356,11 @@ def load_lora_weight_tensor( embedding_B = weights # Load into buffers - if embedding_A is not None: - buffer_view = self.embedding_A_buffer[buffer_id, :lora_rank, :] - buffer_view.copy_(embedding_A) - else: - self.embedding_A_buffer[buffer_id].zero_() + buffer_view = self.embedding_A_buffer[buffer_id, :lora_rank, :] + load_lora_weight_tensor(buffer_view, embedding_A) - if embedding_B is not None: - buffer_view = self.embedding_B_buffer[buffer_id, :, :lora_rank] - buffer_view.copy_(embedding_B) - else: - self.embedding_B_buffer[buffer_id].zero_() + buffer_view = self.embedding_B_buffer[buffer_id, :, :lora_rank] + load_lora_weight_tensor(buffer_view, embedding_B) # Handle lm_head weights (not per-layer) if "lm_head" in self.target_modules: @@ -372,8 +368,10 @@ def load_lora_weight_tensor( lm_head_B = None # Look for lm_head weights - if lora_adapter.layers: - layer_weights = lora_adapter.layers[0].weights + # if lora_adapter.layers: + if hasattr(lora_adapter, 'lm_head_layer'): + # layer_weights = lora_adapter.layers[0].weights + layer_weights = lora_adapter.lm_head_layer.weights for name, weights in layer_weights.items(): if "lm_head" in name: if "lora_A" in name: @@ -382,17 +380,11 @@ def load_lora_weight_tensor( lm_head_B = weights # Load into buffers - if lm_head_A is not None: - buffer_view = self.lm_head_A_buffer[buffer_id, :lora_rank, :] - buffer_view.copy_(lm_head_A) - else: - self.lm_head_A_buffer[buffer_id].zero_() - - if lm_head_B is not None: - buffer_view = self.lm_head_B_buffer[buffer_id, :, :lora_rank] - buffer_view.copy_(lm_head_B) - else: - self.lm_head_B_buffer[buffer_id].zero_() + buffer_view = self.lm_head_A_buffer[buffer_id, :lora_rank, :] + load_lora_weight_tensor(buffer_view, lm_head_A) + + buffer_view = self.lm_head_B_buffer[buffer_id, :, :lora_rank] + load_lora_weight_tensor(buffer_view, lm_head_B) ############################## ############################## ############################## @@ -440,10 +432,40 @@ def load_lora_weight_tensor( target_buffer = self.B_buffer[name][layer_id] buffer_view = target_buffer[buffer_id, :, :lora_rank] load_lora_weight_tensor(buffer_view, weights) - + + def get_embedding_tensor( + self, target_module: str, lora_type: LoRAType + ) -> Optional[torch.Tensor]: + """ + Get LoRA tensor for non-layer modules (embed_tokens, lm_head). + + Args: + target_module: Module name, either "embed_tokens" or "lm_head" + lora_type: Either LoRAType.LORA_A or LoRAType.LORA_B + + Returns: + The corresponding buffer tensor, or None if not available + """ + if target_module == "embed_tokens": + if lora_type == LoRAType.LORA_A: + return self.embedding_A_buffer + return self.embedding_B_buffer + + if target_module == "lm_head": + if lora_type == LoRAType.LORA_A: + return self.lm_head_A_buffer + return self.lm_head_B_buffer + + raise ValueError( + f"Invalid target_module '{target_module}'. " + f"Expected 'embed_tokens' or 'lm_head'." + ) + + def get_tensor( self, target_module: str, layer_id: int, lora_type: LoRAType ) -> torch.Tensor: + if lora_type == LoRAType.LORA_A: return self.A_buffer[target_module][layer_id] From c90f75f494eb3e0cef46b07599bdb149bb3254d5 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Fri, 28 Nov 2025 21:15:28 +0000 Subject: [PATCH 05/19] finish vocab_emb support and stii need to fix lm_head --- python/sglang/srt/lora/layers.py | 87 ++++- python/sglang/srt/lora/lora.py | 22 +- python/sglang/srt/lora/lora_config.py | 43 +++ python/sglang/srt/lora/lora_manager.py | 36 +- python/sglang/srt/lora/mem_pool.py | 460 +++++++++++++++++++------ python/sglang/srt/lora/utils.py | 45 ++- 6 files changed, 570 insertions(+), 123 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index efa870ea930a..3b1e08ce404c 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -17,6 +17,7 @@ ) from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead from sglang.srt.lora.backend.base_backend import BaseLoRABackend +from sglang.srt.lora.utils import LoRABatchInfo class BaseLayerWithLoRA(nn.Module): @@ -356,11 +357,25 @@ def __init__( def set_lora_info( self, + ############################## + ##########emb lora############ + ############################## + new_embeddings_buffer: Optional[torch.Tensor], # For extra tokens + ############################## + ############################## + ############################## embedding_A_buffer: torch.Tensor, embedding_B_buffer: torch.Tensor, ): """Set LoRA buffers for embedding layer.""" self.set_lora = True + ############################## + ##########emb lora############ + ############################## + self.new_embeddings_buffer = new_embeddings_buffer + ############################## + ############################## + ############################## self.embedding_A_buffer = embedding_A_buffer # (num_loras, rank, vocab_size) self.embedding_B_buffer = embedding_B_buffer # (num_loras, embed_dim, rank) @@ -373,7 +388,7 @@ def apply_lora( """ # Efficient embedding lookup for LoRA A (cannot call run_lora_a_sgemm since needing index lookup) - lora_a_output = self._run_lora_a_embedding(input_, batch_info) + lora_a_output = self.run_lora_a_embedding(input_, batch_info) # Apply LoRA B weights using backend lora_output = self.lora_backend.run_lora_b_sgemm( @@ -383,13 +398,21 @@ def apply_lora( ) return lora_output - def _run_lora_a_embedding( - self, input_: torch.Tensor, batch_info + def run_lora_a_embedding( + self, input_: torch.Tensor, batch_info: LoRABatchInfo ) -> torch.Tensor: """ Apply LoRA A weights using efficient embedding lookup. Maps tokens to their corresponding LoRA adapters internally. """ + token_weight_indices = self._get_token_weight_indices(input_, batch_info) + lora_a_output = self._run_lora_a_embedding(input_, token_weight_indices) + + return lora_a_output + + def _get_token_weight_indices( + self, input_: torch.Tensor, batch_info: LoRABatchInfo + ) -> torch.Tensor: # (Step1) Get token-to-lora mapping token_weight_indices = torch.zeros( input_.shape[0], dtype=torch.int32, device=input_.device @@ -402,6 +425,11 @@ def _run_lora_a_embedding( token_weight_indices[current_pos : current_pos+seg_len] = weight_idx current_pos += seg_len + return token_weight_indices + + def _run_lora_a_embedding( + self, input_: torch.Tensor, token_weight_indices: torch.Tensor + ) -> torch.Tensor: # (Step2) Apply embedding lookup for each LoRA adapter lora_a_output = torch.zeros( (input_.shape[0], self.embedding_A_buffer.shape[1]), @@ -417,29 +445,69 @@ def _run_lora_a_embedding( lora_a_output[token_mask] = F.embedding( input_[token_mask], lora_a_weights.t() ) + return lora_a_output def forward(self, input_: torch.Tensor): - # Get base embedding output - base_output = self.base_layer.forward(input_) + ############################## + ##########emb lora############ + ############################## + # # Get base embedding output ( do not consider extra tokens) + # base_output = self.base_layer.forward(input_) - # Apply LoRA if configured + # # Apply LoRA if configured + # if self.set_lora: + # batch_info = self.lora_backend.batch_info + # output = self.apply_lora(base_output, input_, batch_info) + # else: + # output = base_output + + # return output + + ############### + ############### consider both non-extra and extra tokens + ############### + + batch_info = self.lora_backend.batch_info + + # Handle added tokens (tokens beyond base vocabulary) + added_tokens_mask = input_ > self.vocab_size - 1 + base_output = self.base_layer.forward(input_.masked_fill(added_tokens_mask, 0)) + + # Process extra tokens if they exist + if added_tokens_mask.any(): + token_weight_indices = self._get_token_weight_indices(input_, batch_info) + added_weight_indices = token_weight_indices[added_tokens_mask] + unique_added_weight_indices = torch.unique(added_weight_indices) + + for idx in unique_added_weight_indices: + lora_mask = added_weight_indices == idx + added_token_positions = torch.where(added_tokens_mask)[0][lora_mask] + x = input_[added_token_positions] - self.vocab_size + new_embeddings = F.embedding(x, self.new_embeddings_buffer[idx]) + base_output[added_token_positions] = new_embeddings + + # Apply LoRA if set if self.set_lora: - batch_info = self.lora_backend.batch_info output = self.apply_lora(base_output, input_, batch_info) else: output = base_output - + return output + ############################## + ############################## + ############################## def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): # For TP=1, no slicing needed # LoRA A weights (rank, vocab_size) are not sliced for embedding + # For TP>1, Need to modify code in: sglang/python/sglang/srt/lora/mem_pool.py return A def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): # For TP=1, no slicing needed # LoRA B weights (embedding_dim, rank) would be sliced along embedding dimension for TP>1 + # For TP>1, Need to modify code in: sglang/python/sglang/srt/lora/mem_pool.py return B @@ -510,11 +578,12 @@ def forward(self, hidden_states: torch.Tensor): def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): # For TP=1, no slicing needed + # For TP>1, need to modify code in: sglang/python/sglang/srt/lora/mem_pool.py return A def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): # For TP=1, no slicing needed - # For TP>1, would slice along vocab dimension + # For TP>1, would slice along vocab dimension, eed to modify code in: sglang/python/sglang/srt/lora/mem_pool.py return B diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 9f7b4d85bc25..cc087d229d27 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -74,8 +74,12 @@ def __init__( ############################## ##########emb lora############ ############################## - self.embedding_layer = LoRALayer(config, base_hf_config) - self.lm_head_layer = LoRALayer(config, base_hf_config) + # self.embedding_layer = LoRALayer(config, base_hf_config) + # self.lm_head_layer = LoRALayer(config, base_hf_config) + # self.weights: Dict[str, torch.Tensor] = {} + # self.extra_embeddings: Dict[str, torch.Tensor] = {} + self.embedding_layers: Dict[str, torch.Tensor] = {} + self.added_tokens_embeddings: Dict[str, torch.Tensor] = {} ############################## ############################## ############################## @@ -96,10 +100,16 @@ def initialize_weights(self): ############################## ##########emb lora############ ############################## - elif "embed_tokens" in name: - self.embedding_layer.weights[name] = loaded_weight.cpu() - elif "lm_head" in name: - self.lm_head_layer.weights[name] = loaded_weight.cpu() + elif "embed_tokens" in name or "lm_head" in name: + # self.embedding_layers.weights[name] = loaded_weight.cpu() + self.embedding_layers[name] = loaded_weight.cpu() + elif "input_embeddings" in name or "output_embeddings" in name: + #added token emb + self.added_tokens_embeddings[name] = loaded_weight.cpu() + assert loaded_weight.shape[0] == self.config.extra_vocab_size, ( + f"LoRA adapter {self.uid} has extra_vocab_size {self.config.extra_vocab_size} specified in the config, " + f"but the loaded weight has {loaded_weight.shape[0]} extra vocab size" + ) ############################## ############################## ############################## diff --git a/python/sglang/srt/lora/lora_config.py b/python/sglang/srt/lora/lora_config.py index 98a28b01b089..a640b63e535b 100644 --- a/python/sglang/srt/lora/lora_config.py +++ b/python/sglang/srt/lora/lora_config.py @@ -30,6 +30,17 @@ def __init__( self.r = self.hf_config["r"] self.lora_alpha = self.hf_config["lora_alpha"] + ############################## + ##########emb lora############ + ############################## + self.added_tokens = self.get_added_tokens() + self.extra_vocab_size = ( + len(self.added_tokens) if self.added_tokens is not None else 0 + ) + ############################## + ############################## + ############################## + def get_lora_config(self, dummy=False): if dummy: raise NotImplementedError() @@ -41,3 +52,35 @@ def get_lora_config(self, dummy=False): config_name = "adapter_config.json" with open(os.path.join(weights_dir, config_name), "r") as f: return json.load(f) + + ############################## + ##########emb lora############ + ############################## + def get_added_tokens(self): + """Load added tokens from the LoRA adapter if the file exists.""" + # Determine the weights directory + if not os.path.isdir(self.path): + weights_dir = snapshot_download(self.path, allow_patterns=["*.json"]) + else: + weights_dir = self.path + + # Construct the path to added_tokens.json + added_tokens_path = os.path.join(weights_dir, "added_tokens.json") + + # Return None if the file doesn't exist (optional for standard LoRA adapters) + if not os.path.exists(added_tokens_path): + return None + + # Load and return the added tokens + try: + with open(added_tokens_path, "r") as f: + return json.load(f) + except json.JSONDecodeError as e: + # Log warning but don't crash if JSON is malformed + import logging + logger = logging.getLogger(__name__) + logger.warning(f"Failed to parse added_tokens.json: {e}") + return None + ############################## + ############################## + ############################## \ No newline at end of file diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 3469b36536df..40345d80852a 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -67,6 +67,13 @@ def __init__( target_modules: Optional[Iterable[str]] = None, lora_paths: Optional[List[LoRARef]] = None, server_args: Optional[ServerArgs] = None, + ############################## + ##########emb lora############ + ############################## + lora_extra_vocab_size: int = 0, + ############################## + ############################## + ############################## ): self.base_model: torch.nn.Module = base_model self.base_hf_config: AutoConfig = base_hf_config @@ -76,6 +83,13 @@ def __init__( self.device: torch.device = next(self.base_model.parameters()).device self.tp_size: int = tp_size self.tp_rank: int = tp_rank + ############################## + ##########emb lora############ + ############################## + self.lora_extra_vocab_size: int = lora_extra_vocab_size + ############################## + ############################## + ############################## # Store eviction policy from server args self.eviction_policy = server_args.lora_eviction_policy @@ -249,6 +263,14 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): lora_adapters=self.loras, lora_modules=self.lora_modules, lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation. + ############################## + ##########emb lora############ + ############################## + lora_embed_tokens_module=self.embed_tokens_module, #merge into embedding or lora module + lora_lm_head_module=self.lm_head_module, #merge into embedding or lora module + ############################## + ############################## + ############################## ) # set up batch info shared by all lora modules @@ -304,15 +326,16 @@ def update_lora_info(self): ############################## ##########emb lora############ ############################## - # Update embedding layer if present - if self.embed_tokens_module is not None and hasattr(self.memory_pool, 'embedding_A_buffer') and self.memory_pool.embedding_A_buffer is not None: + # Update embedding layer if present - gotta merge (refer to PR codebase) + if self.embed_tokens_module is not None: self.embed_tokens_module.set_lora_info( + self.memory_pool.get_embedding_tensor("added_tokens", LoRAType.LORA_A), #choose name: "added_tokens" self.memory_pool.get_embedding_tensor("embed_tokens", LoRAType.LORA_A), self.memory_pool.get_embedding_tensor("embed_tokens", LoRAType.LORA_B), ) # Update lm_head layer if present - if self.lm_head_module is not None and hasattr(self.memory_pool, 'lm_head_A_buffer') and self.memory_pool.lm_head_A_buffer is not None: + if self.lm_head_module is not None: self.lm_head_module.set_lora_info( self.memory_pool.get_embedding_tensor("lm_head", LoRAType.LORA_A), self.memory_pool.get_embedding_tensor("lm_head", LoRAType.LORA_B), @@ -441,6 +464,13 @@ def init_memory_pool(self): target_modules=self.target_modules, base_model=self.base_model, eviction_policy=self.eviction_policy, + ############################## + ##########emb lora############ + ############################## + lora_extra_vocab_size=self.lora_extra_vocab_size, # check whether read from the config + ############################## + ############################## + ############################## ) def set_lora_module(self, module_name, module): diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 7b0f98093e6d..bd6071186ad9 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -11,6 +11,13 @@ from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.utils import ( ROW_PARALLELISM_LINEAR_LORA_NAMES, + ############################## + ##########emb lora############ + ############################## + EMBEDDING_NAMES, + ############################## + ############################## + ############################## LoRAType, get_hidden_dim, get_normalized_target_modules, @@ -56,6 +63,13 @@ def __init__( target_modules: Set[str], base_model: torch.nn.Module, eviction_policy: str, + ############################## + ##########emb lora############ + ############################## + lora_extra_vocab_size: int, #can be remove? + ############################## + ############################## + ############################## ): self.base_hf_config: AutoConfig = base_hf_config self.num_layer: int = base_hf_config.num_hidden_layers @@ -63,6 +77,13 @@ def __init__( self.dtype: torch.dtype = dtype self.tp_size: int = tp_size self.tp_rank: int = tp_rank + ############################## + ##########emb lora############ + ############################## + self.max_extra_vocab_size: int = lora_extra_vocab_size + ############################## + ############################## + ############################## self.max_lora_rank: int = max_lora_rank self.target_modules: Set[str] = target_modules @@ -81,10 +102,21 @@ def __init__( ##########emb lora############ ############################## # NEW: Buffers for embedding and lm_head (not per-layer) - self.embedding_A_buffer: Optional[torch.Tensor] = None - self.embedding_B_buffer: Optional[torch.Tensor] = None - self.lm_head_A_buffer: Optional[torch.Tensor] = None - self.lm_head_B_buffer: Optional[torch.Tensor] = None + # self.embedding_A_buffer: Optional[torch.Tensor] = None + # self.embedding_B_buffer: Optional[torch.Tensor] = None + # self.lm_head_A_buffer: Optional[torch.Tensor] = None + # self.lm_head_B_buffer: Optional[torch.Tensor] = None + # self.new_embeddings_buffer: Dict[str, torch.Tensor] = {} + self.embedding_A_buffer: Dict[str, torch.Tensor] = {} + self.embedding_B_buffer: Dict[str, torch.Tensor] = {} + + self.lm_head_A_buffer: Dict[str, torch.Tensor] = {} + self.lm_head_B_buffer: Dict[str, torch.Tensor] = {} + # self.embedding_A_buffer: Dict[str, List[torch.Tensor]] = {} + # self.embedding_B_buffer: Dict[str, List[torch.Tensor]] = {} + self.new_embeddings_buffer: Dict[str, torch.Tensor] = {} + + self.embedding_dim: int = self.base_hf_config.hidden_size ############################## ############################## ############################## @@ -112,6 +144,14 @@ def _can_support(config: LoRAConfig) -> bool: """ if config.r > self.max_lora_rank: return False + ############################## + ##########emb lora############ + ############################## + if config.extra_vocab_size > self.max_extra_vocab_size: + return False # can be remove? + ############################## + ############################## + ############################## target_module_names = get_normalized_target_modules(config.target_modules) return target_module_names.issubset(self.target_modules) @@ -142,6 +182,23 @@ def get_lora_A_shape( input_dim, ) + def get_embedding_lora_A_shape( + self, + module_name: str, + base_model: torch.nn.Module, + max_lora_dim: int, + layer_idx: int, + ) -> Tuple[int]: + input_dim, _ = get_hidden_dim( + module_name, self.base_hf_config, base_model, 0 + ) + # Have not imp self.tp_size > 1 yet. + return ( + self.max_loras_per_batch, + max_lora_dim, + input_dim, + ) + def get_lora_B_shape( self, module_name: str, @@ -163,6 +220,23 @@ def get_lora_B_shape( max_lora_dim, ) + def get_embedding_lora_B_shape( + self, + module_name: str, + base_model: torch.nn.Module, + max_lora_dim: int, + layer_idx: int, + ) -> Tuple[int]: + _, output_dim = get_hidden_dim( + module_name, self.base_hf_config, base_model, 0 + ) + # Have not imp self.tp_size > 1 yet. + return ( + self.max_loras_per_batch, + output_dim, + max_lora_dim, + ) + def init_buffers(self, base_model: torch.nn.Module): device = next(base_model.parameters()).device @@ -171,6 +245,10 @@ def init_buffer( target_modules: Set[str], get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]], ): + ############################## + ##########emb lora############ + ############################## + target_modules = target_modules - set(EMBEDDING_NAMES) for module_name in target_modules: buffer[module_name] = [ torch.empty( @@ -185,6 +263,72 @@ def init_buffer( ) for idx in range(self.num_layer) ] + ############################## + ############################## + ############################## + + + ############################## + ##########emb lora############ + ############################## + def init_embedding_buffer( + buffer: Dict[str, torch.Tensor], + target_modules: Set[str], + get_lora_shape_fn: Callable[[int], Tuple[int]], + ): + target_modules = target_modules & set(EMBEDDING_NAMES) + for module_name in target_modules: + buffer[module_name] = torch.empty( + get_lora_shape_fn( + module_name, + base_model, + self.max_lora_rank, + 0, + ), + dtype=self.dtype, + device=device, + ) + + if self.max_extra_vocab_size > 0: + self.new_embeddings_buffer["input_embeddings"] = torch.empty( + ( + self.max_loras_per_batch, + self.max_extra_vocab_size, + self.embedding_dim, + ), + dtype=self.dtype, + device=device, + ) + + if "embed_tokens" in self.target_modules: + init_embedding_buffer( + self.embedding_A_buffer, + self.target_modules, + self.get_embedding_lora_A_shape, + ) + + init_embedding_buffer( + self.embedding_B_buffer, + self.target_modules, + self.get_embedding_lora_B_shape, + ) + + if "lm_head" in self.target_modules: + init_embedding_buffer( + self.lm_head_A_buffer, + self.target_modules, + self.get_embedding_lora_A_shape, + ) + + init_embedding_buffer( + self.lm_head_B_buffer, + self.target_modules, + self.get_embedding_lora_B_shape, + ) + + ############################## + ############################## + ############################## init_buffer( self.A_buffer, @@ -201,43 +345,43 @@ def init_buffer( ############################## ##########emb lora############ ############################## - # Initialize embedding buffers if embed_tokens is in target_modules - if "embed_tokens" in self.target_modules: - vocab_size = self.base_hf_config.vocab_size - hidden_size = self.base_hf_config.hidden_size + # # Initialize embedding buffers if embed_tokens is in target_modules + # if "embed_tokens" in self.target_modules: + # vocab_size = self.base_hf_config.vocab_size + # hidden_size = self.base_hf_config.hidden_size - # embedding_A: (max_loras_per_batch, max_rank, vocab_size) - self.embedding_A_buffer = torch.empty( - (self.max_loras_per_batch, self.max_lora_rank, vocab_size), - dtype=self.dtype, - device=device, - ) + # # embedding_A: (max_loras_per_batch, max_rank, vocab_size) + # self.embedding_A_buffer = torch.empty( + # (self.max_loras_per_batch, self.max_lora_rank, vocab_size), + # dtype=self.dtype, + # device=device, + # ) - # embedding_B: (max_loras_per_batch, hidden_size, max_rank) - self.embedding_B_buffer = torch.empty( - (self.max_loras_per_batch, hidden_size, self.max_lora_rank), - dtype=self.dtype, - device=device, - ) + # # embedding_B: (max_loras_per_batch, hidden_size, max_rank) + # self.embedding_B_buffer = torch.empty( + # (self.max_loras_per_batch, hidden_size, self.max_lora_rank), + # dtype=self.dtype, + # device=device, + # ) - # Initialize lm_head buffers if lm_head is in target_modules - if "lm_head" in self.target_modules: - vocab_size = self.base_hf_config.vocab_size - hidden_size = self.base_hf_config.hidden_size + # # Initialize lm_head buffers if lm_head is in target_modules + # if "lm_head" in self.target_modules: + # vocab_size = self.base_hf_config.vocab_size + # hidden_size = self.base_hf_config.hidden_size - # lm_head_A: (max_loras_per_batch, max_rank, hidden_size) - self.lm_head_A_buffer = torch.empty( - (self.max_loras_per_batch, self.max_lora_rank, hidden_size), - dtype=self.dtype, - device=device, - ) + # # lm_head_A: (max_loras_per_batch, max_rank, hidden_size) + # self.lm_head_A_buffer = torch.empty( + # (self.max_loras_per_batch, self.max_lora_rank, hidden_size), + # dtype=self.dtype, + # device=device, + # ) - # lm_head_B: (max_loras_per_batch, vocab_size, max_rank) - self.lm_head_B_buffer = torch.empty( - (self.max_loras_per_batch, vocab_size, self.max_lora_rank), - dtype=self.dtype, - device=device, - ) + # # lm_head_B: (max_loras_per_batch, vocab_size, max_rank) + # self.lm_head_B_buffer = torch.empty( + # (self.max_loras_per_batch, vocab_size, self.max_lora_rank), + # dtype=self.dtype, + # device=device, + # ) ############################## ############################## ############################## @@ -248,6 +392,15 @@ def prepare_lora_batch( lora_adapters: Dict[str, LoRAAdapter], lora_modules: List[Dict[str, BaseLayerWithLoRA]], lora_refs: Dict[str, LoRARef], + ############################## + ##########emb lora############ + ############################## + # lora_embeddings_modules: Dict[str, BaseLayerWithLoRA], # NEW parameter + lora_embed_tokens_module: Dict[str, BaseLayerWithLoRA], # NEW parameter + lora_lm_head_module: Dict[str, BaseLayerWithLoRA], # NEW parameter + ############################## + ############################## + ############################## ): def get_available_buffer_slot(): # 1. Prioritize empty slots @@ -299,9 +452,18 @@ def get_available_buffer_slot(): if uid not in self.uid_to_buffer_id: buffer_id = get_available_buffer_slot() lora_adapter = lora_adapters.get(uid, None) + ############################## + ##########emb lora############ + ############################## + # self.load_lora_weight_to_buffer( + # uid, buffer_id, lora_adapter, lora_modules + # ) self.load_lora_weight_to_buffer( - uid, buffer_id, lora_adapter, lora_modules + uid, buffer_id, lora_adapter, lora_modules, lora_embed_tokens_module, lora_lm_head_module ) + ############################## + ############################## + ############################## self.uid_to_buffer_id[uid] = buffer_id self.buffer_id_to_uid[buffer_id] = uid @@ -311,6 +473,16 @@ def load_lora_weight_to_buffer( buffer_id: int, lora_adapter: LoRAAdapter, lora_modules: List[Dict[str, BaseLayerWithLoRA]], + ############################## + ##########emb lora############ + ############################## + # lora_embeddings_modules: List[Dict[str, BaseLayerWithLoRA]], + # I can combine the below two + lora_embed_tokens_module: Dict[str, BaseLayerWithLoRA], # NEW parameter + lora_lm_head_module: Dict[str, BaseLayerWithLoRA], # NEW parameter + ############################## + ############################## + ############################## ): def load_lora_weight_tensor( buffer_view: torch.Tensor, weight: Optional[torch.Tensor] @@ -329,66 +501,24 @@ def load_lora_weight_tensor( for i in range(self.num_layer): for k in self.A_buffer.keys(): self.A_buffer[k][i][buffer_id] = 0 + ############################## + ##########emb lora############ + ############################## + # for k in self.embedding_A_buffer.keys(): + # self.embedding_A_buffer[k][buffer_id] = 0 + + for k in self.embedding_A_buffer.keys(): + self.embedding_A_buffer[k][buffer_id] = 0 + + for k in self.lm_head_A_buffer.keys(): + self.lm_head_A_buffer[k][buffer_id] = 0 + ############################## + ############################## + ############################## return assert lora_adapter is not None lora_rank = lora_adapter.config.r - - ############################## - ##########emb lora############ - ############################## - - # Handle embedding weights (not per-layer) - if "embed_tokens" in self.target_modules: - embedding_A = None - embedding_B = None - - # Look for embedding weights in layer 0 (embeddings are usually stored there) - # if lora_adapter.layers: - if hasattr(lora_adapter, 'embedding_layer'): - # layer_weights = lora_adapter.layers[0].weights - layer_weights = lora_adapter.embedding_layer.weights - for name, weights in layer_weights.items(): - if "embed_tokens" in name or "model.embed_tokens" in name: - if "lora_A" in name: - embedding_A = weights - elif "lora_B" in name: - embedding_B = weights - - # Load into buffers - buffer_view = self.embedding_A_buffer[buffer_id, :lora_rank, :] - load_lora_weight_tensor(buffer_view, embedding_A) - - buffer_view = self.embedding_B_buffer[buffer_id, :, :lora_rank] - load_lora_weight_tensor(buffer_view, embedding_B) - - # Handle lm_head weights (not per-layer) - if "lm_head" in self.target_modules: - lm_head_A = None - lm_head_B = None - - # Look for lm_head weights - # if lora_adapter.layers: - if hasattr(lora_adapter, 'lm_head_layer'): - # layer_weights = lora_adapter.layers[0].weights - layer_weights = lora_adapter.lm_head_layer.weights - for name, weights in layer_weights.items(): - if "lm_head" in name: - if "lora_A" in name: - lm_head_A = weights - elif "lora_B" in name: - lm_head_B = weights - - # Load into buffers - buffer_view = self.lm_head_A_buffer[buffer_id, :lora_rank, :] - load_lora_weight_tensor(buffer_view, lm_head_A) - - buffer_view = self.lm_head_B_buffer[buffer_id, :, :lora_rank] - load_lora_weight_tensor(buffer_view, lm_head_B) - ############################## - ############################## - ############################## - for layer_id in range(self.num_layer): layer_weights = lora_adapter.layers[layer_id].weights temp_A_buffer: Dict[str, Optional[torch.Tensor]] = { @@ -432,7 +562,130 @@ def load_lora_weight_tensor( target_buffer = self.B_buffer[name][layer_id] buffer_view = target_buffer[buffer_id, :, :lora_rank] load_lora_weight_tensor(buffer_view, weights) + + ############################## + ##########emb lora############ + ############################## + + # Handle embedding weights (not per-layer) + # if "embed_tokens" in self.target_modules: + # embedding_A = None + # embedding_B = None + + # # Look for embedding weights in layer 0 (embeddings are usually stored there) + # # if lora_adapter.layers: + # if hasattr(lora_adapter, 'embedding_layer'): + # # layer_weights = lora_adapter.layers[0].weights + # layer_weights = lora_adapter.embedding_layer.weights + # for name, weights in layer_weights.items(): + # if "embed_tokens" in name or "model.embed_tokens" in name: + # if "lora_A" in name: + # embedding_A = weights + # elif "lora_B" in name: + # embedding_B = weights + + # # Load into buffers + # buffer_view = self.embedding_A_buffer[buffer_id, :lora_rank, :] + # load_lora_weight_tensor(buffer_view, embedding_A) + + # buffer_view = self.embedding_B_buffer[buffer_id, :, :lora_rank] + # load_lora_weight_tensor(buffer_view, embedding_B) + + # # Handle lm_head weights (not per-layer) + # if "lm_head" in self.target_modules: + # lm_head_A = None + # lm_head_B = None + + # # Look for lm_head weights + # # if lora_adapter.layers: + # if hasattr(lora_adapter, 'lm_head_layer'): + # # layer_weights = lora_adapter.layers[0].weights + # layer_weights = lora_adapter.lm_head_layer.weights + # for name, weights in layer_weights.items(): + # if "lm_head" in name: + # if "lora_A" in name: + # lm_head_A = weights + # elif "lora_B" in name: + # lm_head_B = weights + + # # Load into buffers + # buffer_view = self.lm_head_A_buffer[buffer_id, :lora_rank, :] + # load_lora_weight_tensor(buffer_view, lm_head_A) + + # buffer_view = self.lm_head_B_buffer[buffer_id, :, :lora_rank] + # load_lora_weight_tensor(buffer_view, lm_head_B) + + # embed_token (and extra_token emb) and lm_head layers + if lora_adapter.embedding_layers: + + org_vocab_size = self.base_hf_config.vocab_size + extra_vocab_size = lora_adapter.config.extra_vocab_size + # Only when LoRA is applied to the embedding layer will it have the extra-token issue that needs to be resolved. + # Load embeddings weights for extra tokens to buffer + if lora_adapter.added_tokens_embeddings: + for name, weights in lora_adapter.added_tokens_embeddings.items(): + if "input_embeddings" in name: + buffer_view = self.new_embeddings_buffer["input_embeddings"][ + buffer_id, :extra_vocab_size + ] + load_lora_weight_tensor(buffer_view, weights) + + #load vocab_emb and lm_head + for name, weights in lora_adapter.embedding_layers.items(): + target_module = get_target_module_name(name, self.target_modules) + # if "lora_embedding_A" in name: + if "lora_embedding_A" in name or ("lora_A" in name and target_module == "embed_tokens"): + buffer_view = self.embedding_A_buffer[target_module][ + buffer_id, :lora_rank, : org_vocab_size + extra_vocab_size + ] + load_lora_weight_tensor(buffer_view, weights) + # elif "lora_embedding_B" in name: + elif "lora_embedding_B" in name or ("lora_B" in name and target_module == "embed_tokens"): + lora_b_weights = weights + #[to-do] support TP + # if self.tp_size > 1: + # cur_module = lora_embeddings_modules[target_module] + # for module_name, module in cur_module: + # lora_b_weights = module.slice_lora_b_weights( + # lora_b_weights, self.tp_rank + # ) + + buffer_view = self.embedding_B_buffer[target_module][ + buffer_id, :, :lora_rank + ] + load_lora_weight_tensor(buffer_view, lora_b_weights) + + + if "lora_lm_head_A" in name or ("lora_A" in name and target_module == "lm_head"): + buffer_view = self.embedding_A_buffer[target_module][ + # buffer_id, :, :lora_rank + buffer_id, :lora_rank, : + ] + load_lora_weight_tensor(buffer_view, weights) + # elif "lora_embedding_B" in name: + elif "lora_lm_head_B" in name or ("lora_B" in name and target_module == "lm_head"): + lora_b_weights = weights + #[to-do] support TP + # if self.tp_size > 1: + # cur_module = lora_embeddings_modules[target_module] + # for module_name, module in cur_module: + # lora_b_weights = module.slice_lora_b_weights( + # lora_b_weights, self.tp_rank + # ) + + buffer_view = self.embedding_B_buffer[target_module][ + # buffer_id, :lora_rank, : org_vocab_size + extra_vocab_size + buffer_id, : org_vocab_size + extra_vocab_size, :lora_rank + ] + load_lora_weight_tensor(buffer_view, lora_b_weights) + + ############################## + ############################## + ############################## + ############################## + ##########emb lora############ + ############################## def get_embedding_tensor( self, target_module: str, lora_type: LoRAType ) -> Optional[torch.Tensor]: @@ -446,20 +699,27 @@ def get_embedding_tensor( Returns: The corresponding buffer tensor, or None if not available """ - if target_module == "embed_tokens": + + if target_module == "added_tokens": + if self.max_extra_vocab_size > 0: # change to read from the config + return self.new_embeddings_buffer["input_embeddings"] + return None + elif target_module == "embed_tokens": if lora_type == LoRAType.LORA_A: - return self.embedding_A_buffer - return self.embedding_B_buffer - - if target_module == "lm_head": + return self.embedding_A_buffer[target_module] + return self.embedding_B_buffer[target_module] + elif target_module == "lm_head": if lora_type == LoRAType.LORA_A: - return self.lm_head_A_buffer - return self.lm_head_B_buffer + return self.lm_head_A_buffer[target_module] + return self.lm_head_B_buffer[target_module] raise ValueError( f"Invalid target_module '{target_module}'. " f"Expected 'embed_tokens' or 'lm_head'." ) + ############################## + ############################## + ############################## def get_tensor( diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 7432bfa1eafc..febd8b4472d1 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -52,10 +52,33 @@ def get_hidden_dim( Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. """ - + ############################## + ##########emb lora############ + ############################## + is_embedding_module = "embed_tokens" in module_name or "lm_head" in module_name + print("=========") + print(is_embedding_module) + print(module_name) + print("----") + ############################## + ############################## + ############################## + + ############################## + ##########emb lora############ + ############################## + # if hasattr(base_model, "get_hidden_dim"): + # if hasattr(base_model, "get_hidden_dim") and not is_embedding_module: if hasattr(base_model, "get_hidden_dim"): + ############################## + ############################## + ############################## + print(1111) + print("=========") return base_model.get_hidden_dim(module_name, layer_idx) else: + print(2222) + print("=========") """ WARNING: get_hidden_dim() is not defined, which is used to get the hidden dim for different lora modules @@ -84,14 +107,20 @@ def get_hidden_dim( ##########emb lora############ ############################## #Handle embed_tokens + # elif "embed_tokens" in module_name: elif "embed_tokens" in module_name: # For embedding: input is vocab_size (as embedding lookup), output is hidden_size - return config.vocab_size, config.hidden_size + # if contain extra tokens will be added; otherwise is 0. + extra_vocab = getattr(config, 'extra_vocab_size', 0) + return config.vocab_size + extra_vocab, config.hidden_size #Handle lm_head + # elif "lm_head" in module_name: elif "lm_head" in module_name: # For lm_head: input is hidden_size, output is vocab_size - return config.hidden_size, config.vocab_size + # if contain extra tokens will be added; otherwise is 0. + extra_vocab = getattr(config, 'extra_vocab_size', 0) + return config.hidden_size, config.vocab_size + extra_vocab ############################## ############################## ############################## @@ -159,5 +188,11 @@ def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> s f"Cannot find target module name for {full_module_name} in {target_modules}" ) - -ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"] +############################## +##########emb lora############ +############################## +EMBEDDING_NAMES = ["embed_tokens", "lm_head"] +############################## +############################## +############################## +ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"] \ No newline at end of file From 746b6414bb504b9af8e67da231316b5d88ae999c Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Sat, 29 Nov 2025 23:04:54 +0000 Subject: [PATCH 06/19] need to fix 1. lm_head 2. cuda-graph --- python/sglang/srt/lora/layers.py | 72 ++++++++++++++++++++++++++ python/sglang/srt/lora/lora.py | 2 +- python/sglang/srt/lora/lora_config.py | 8 +-- python/sglang/srt/lora/lora_manager.py | 30 +++++++++-- python/sglang/srt/lora/mem_pool.py | 45 +++++++++------- python/sglang/srt/lora/utils.py | 20 ++----- python/sglang/test/runners.py | 53 ++++++++++++++++++- 7 files changed, 188 insertions(+), 42 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 3b1e08ce404c..2d32f15487b3 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -559,9 +559,81 @@ def apply_lora( weights=self.lm_head_B_buffer, base_output=base_output, ) + + # lora_output = self.run_lora_b_scatter( + # lora_a_output=lora_a_output, + # base_output=base_output, + # ) return lora_output + + # def run_lora_b_lm_head( + # self, lora_a_output: torch.Tensor, base_output: torch.Tensor + # ) -> torch.Tensor: + # """ + # Apply LoRA B weights using efficient scatter operation. + + # Instead of full matmul: lora_a_output @ B^T (shape: [s, rank] @ [rank, vocab_size]) + # We compute: for each token, scatter lora_a_output weighted by B weights. + + # This is the "reverse" of embedding lookup - instead of gathering from vocab, + # we scatter to vocab dimension. + # """ + # batch_info = self.lora_backend.batch_info + + # # Get token-to-lora mapping (same as embedding case) + # token_weight_indices = self._get_token_weight_indices(lora_a_output, batch_info) + + # # Apply scatter operation for each LoRA adapter + # output = base_output.clone() if base_output is not None else torch.zeros( + # (lora_a_output.shape[0], self.vocab_size), + # dtype=lora_a_output.dtype, + # device=lora_a_output.device, + # ) + + # unique_weight_indices = torch.unique(token_weight_indices) + + # for idx in unique_weight_indices: + # token_mask = token_weight_indices == idx + + # # Get LoRA B weights for this adapter: (vocab_size, rank) + # lora_b_weights = self.lm_head_B_buffer[idx] # (vocab_size, rank) + + # # Get scaling for this adapter + # scaling = batch_info.scalings[idx] + + # # Compute: lora_a_output[token_mask] @ lora_b_weights^T + # # lora_a_output[token_mask]: (num_tokens, rank) + # # lora_b_weights: (vocab_size, rank) + # # Result: (num_tokens, vocab_size) + # lora_contribution = torch.matmul( + # lora_a_output[token_mask], # (num_tokens, rank) + # lora_b_weights.t() # (rank, vocab_size) + # ) * scaling + + # output[token_mask] += lora_contribution + + # return output + + # def _get_token_weight_indices( + # self, lora_a_output: torch.Tensor, batch_info: LoRABatchInfo + # ) -> torch.Tensor: + # """Get token-to-lora mapping (same as embedding case).""" + # token_weight_indices = torch.zeros( + # lora_a_output.shape[0], dtype=torch.int32, device=lora_a_output.device + # ) + + # current_pos = 0 + # for i in range(batch_info.bs): + # seg_len = int(batch_info.seg_lens[i]) + # weight_idx = int(batch_info.weight_indices[i]) + # token_weight_indices[current_pos : current_pos + seg_len] = weight_idx + # current_pos += seg_len + + # return token_weight_indices + + def forward(self, hidden_states: torch.Tensor): # Apply base linear transformation base_output = F.linear( diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index cc087d229d27..72c761aaefcd 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -106,7 +106,7 @@ def initialize_weights(self): elif "input_embeddings" in name or "output_embeddings" in name: #added token emb self.added_tokens_embeddings[name] = loaded_weight.cpu() - assert loaded_weight.shape[0] == self.config.extra_vocab_size, ( + assert loaded_weight.shape[0] == self.config.lora_added_tokens_size, ( f"LoRA adapter {self.uid} has extra_vocab_size {self.config.extra_vocab_size} specified in the config, " f"but the loaded weight has {loaded_weight.shape[0]} extra vocab size" ) diff --git a/python/sglang/srt/lora/lora_config.py b/python/sglang/srt/lora/lora_config.py index a640b63e535b..4f37c2907054 100644 --- a/python/sglang/srt/lora/lora_config.py +++ b/python/sglang/srt/lora/lora_config.py @@ -33,9 +33,9 @@ def __init__( ############################## ##########emb lora############ ############################## - self.added_tokens = self.get_added_tokens() - self.extra_vocab_size = ( - len(self.added_tokens) if self.added_tokens is not None else 0 + self.added_tokens_config = self.get_added_tokens_config() + self.lora_added_tokens_size = ( + len(self.added_tokens_config) if self.added_tokens_config is not None else 0 ) ############################## ############################## @@ -56,7 +56,7 @@ def get_lora_config(self, dummy=False): ############################## ##########emb lora############ ############################## - def get_added_tokens(self): + def get_added_tokens_config(self): """Load added tokens from the LoRA adapter if the file exists.""" # Determine the weights directory if not os.path.isdir(self.path): diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 40345d80852a..6b96db701b65 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -70,7 +70,7 @@ def __init__( ############################## ##########emb lora############ ############################## - lora_extra_vocab_size: int = 0, + # lora_extra_vocab_size: int = 0, ############################## ############################## ############################## @@ -86,7 +86,10 @@ def __init__( ############################## ##########emb lora############ ############################## - self.lora_extra_vocab_size: int = lora_extra_vocab_size + # self.lora_extra_vocab_size: int = lora_extra_vocab_size + # Will infer self.lora_extra_vocab_size in the later init_lora_shapes() if it find the value == None + # self.lora_added_vocab_size: Optional[int] = None + self.lora_added_tokens_size: Optional[int] = None ############################## ############################## ############################## @@ -437,6 +440,26 @@ def init_lora_shapes( [x.r for x in self.configs.values()], default=0, ) + + ############################# + #########emb lora############ + ############################# + # Auto-infer self.lora_added_vocab_size from loaded LoRA configs + # This happens automatically without requiring user input + # if self.lora_added_vocab_size is None: + if self.lora_added_tokens_size is None: + inferred_extra_vocab_size = next( + (x.lora_added_tokens_size for x in self.configs.values() if x.lora_added_tokens_size > 0), + 0 + ) + if inferred_extra_vocab_size > 0: + logger.info( + f"self.lora_added_tokens_size={inferred_extra_vocab_size} from LoRA adapters." + ) + self.lora_added_tokens_size = inferred_extra_vocab_size + ############################# + ############################# + ############################# def load_lora_weights(self, lora_ref: LoRARef): """ @@ -467,7 +490,8 @@ def init_memory_pool(self): ############################## ##########emb lora############ ############################## - lora_extra_vocab_size=self.lora_extra_vocab_size, # check whether read from the config + # lora_added_vocab_size=self.lora_added_vocab_size, # check whether read from the config + lora_added_tokens_size = self.lora_added_tokens_size ############################## ############################## ############################## diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index bd6071186ad9..3eaad2a0d87e 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -66,7 +66,8 @@ def __init__( ############################## ##########emb lora############ ############################## - lora_extra_vocab_size: int, #can be remove? + # lora_added_vocab_size: int, #can be remove? + lora_added_tokens_size: int ############################## ############################## ############################## @@ -80,7 +81,9 @@ def __init__( ############################## ##########emb lora############ ############################## - self.max_extra_vocab_size: int = lora_extra_vocab_size + # self.max_extra_vocab_size: int = lora_added_vocab_size + self.lora_added_tokens_size: int = lora_added_tokens_size + # self.extra_vocab_size: int = base_hf_config.extra_vocab_size ############################## ############################## ############################## @@ -147,7 +150,7 @@ def _can_support(config: LoRAConfig) -> bool: ############################## ##########emb lora############ ############################## - if config.extra_vocab_size > self.max_extra_vocab_size: + if config.lora_added_tokens_size > self.lora_added_tokens_size: return False # can be remove? ############################## ############################## @@ -190,7 +193,7 @@ def get_embedding_lora_A_shape( layer_idx: int, ) -> Tuple[int]: input_dim, _ = get_hidden_dim( - module_name, self.base_hf_config, base_model, 0 + module_name, self.base_hf_config, base_model, 0, self.lora_added_tokens_size ) # Have not imp self.tp_size > 1 yet. return ( @@ -228,7 +231,7 @@ def get_embedding_lora_B_shape( layer_idx: int, ) -> Tuple[int]: _, output_dim = get_hidden_dim( - module_name, self.base_hf_config, base_model, 0 + module_name, self.base_hf_config, base_model, 0, self.lora_added_tokens_size ) # Have not imp self.tp_size > 1 yet. return ( @@ -289,11 +292,11 @@ def init_embedding_buffer( device=device, ) - if self.max_extra_vocab_size > 0: + if self.lora_added_tokens_size > 0: self.new_embeddings_buffer["input_embeddings"] = torch.empty( ( self.max_loras_per_batch, - self.max_extra_vocab_size, + self.lora_added_tokens_size, self.embedding_dim, ), dtype=self.dtype, @@ -619,14 +622,14 @@ def load_lora_weight_tensor( if lora_adapter.embedding_layers: org_vocab_size = self.base_hf_config.vocab_size - extra_vocab_size = lora_adapter.config.extra_vocab_size + lora_added_tokens_size = lora_adapter.config.lora_added_tokens_size # Only when LoRA is applied to the embedding layer will it have the extra-token issue that needs to be resolved. # Load embeddings weights for extra tokens to buffer if lora_adapter.added_tokens_embeddings: for name, weights in lora_adapter.added_tokens_embeddings.items(): if "input_embeddings" in name: buffer_view = self.new_embeddings_buffer["input_embeddings"][ - buffer_id, :extra_vocab_size + buffer_id, :lora_added_tokens_size ] load_lora_weight_tensor(buffer_view, weights) @@ -634,13 +637,14 @@ def load_lora_weight_tensor( for name, weights in lora_adapter.embedding_layers.items(): target_module = get_target_module_name(name, self.target_modules) # if "lora_embedding_A" in name: - if "lora_embedding_A" in name or ("lora_A" in name and target_module == "embed_tokens"): + # if "lora_embedding_A" in name or ("lora_A" in name and target_module == "embed_tokens"): + if target_module == "embed_tokens" and "lora_embedding_A" in name: buffer_view = self.embedding_A_buffer[target_module][ - buffer_id, :lora_rank, : org_vocab_size + extra_vocab_size + buffer_id, :lora_rank, : org_vocab_size + lora_added_tokens_size ] load_lora_weight_tensor(buffer_view, weights) # elif "lora_embedding_B" in name: - elif "lora_embedding_B" in name or ("lora_B" in name and target_module == "embed_tokens"): + elif target_module == "embed_tokens" and "lora_embedding_B" in name: lora_b_weights = weights #[to-do] support TP # if self.tp_size > 1: @@ -656,14 +660,19 @@ def load_lora_weight_tensor( load_lora_weight_tensor(buffer_view, lora_b_weights) - if "lora_lm_head_A" in name or ("lora_A" in name and target_module == "lm_head"): - buffer_view = self.embedding_A_buffer[target_module][ + # name: base_model.model.lm_head.lora_A.weight + # self.target_modules: {'qkv_proj', 'embed_tokens', 'gate_up_proj', 'o_proj', 'lm_head', 'down_proj'} + # target_module: lm_head + # if "lora_lm_head_A" in name or ("lora_A" in name and target_module == "lm_head"): + elif target_module == "lm_head" and "lora_A.weight" in name: + buffer_view = self.lm_head_A_buffer[target_module][ # buffer_id, :, :lora_rank buffer_id, :lora_rank, : ] load_lora_weight_tensor(buffer_view, weights) # elif "lora_embedding_B" in name: - elif "lora_lm_head_B" in name or ("lora_B" in name and target_module == "lm_head"): + # elif "lora_lm_head_B" in name or ("lora_B" in name and target_module == "lm_head"): + elif target_module == "lm_head" and "lora_B.weight" in name: lora_b_weights = weights #[to-do] support TP # if self.tp_size > 1: @@ -673,9 +682,9 @@ def load_lora_weight_tensor( # lora_b_weights, self.tp_rank # ) - buffer_view = self.embedding_B_buffer[target_module][ + buffer_view = self.lm_head_B_buffer[target_module][ # buffer_id, :lora_rank, : org_vocab_size + extra_vocab_size - buffer_id, : org_vocab_size + extra_vocab_size, :lora_rank + buffer_id, : org_vocab_size + self.lora_added_tokens_size, :lora_rank ] load_lora_weight_tensor(buffer_view, lora_b_weights) @@ -701,7 +710,7 @@ def get_embedding_tensor( """ if target_module == "added_tokens": - if self.max_extra_vocab_size > 0: # change to read from the config + if self.lora_added_tokens_size > 0 and self.lora_added_tokens_size != None: # change to read from the config return self.new_embeddings_buffer["input_embeddings"] return None elif target_module == "embed_tokens": diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index febd8b4472d1..7d9e44f3a12a 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -46,7 +46,7 @@ class LoRAType(Enum): def get_hidden_dim( - module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int + module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int, lora_added_vocab_size: int = 0 ) -> Tuple[int]: """ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. @@ -56,10 +56,6 @@ def get_hidden_dim( ##########emb lora############ ############################## is_embedding_module = "embed_tokens" in module_name or "lm_head" in module_name - print("=========") - print(is_embedding_module) - print(module_name) - print("----") ############################## ############################## ############################## @@ -68,17 +64,13 @@ def get_hidden_dim( ##########emb lora############ ############################## # if hasattr(base_model, "get_hidden_dim"): - # if hasattr(base_model, "get_hidden_dim") and not is_embedding_module: - if hasattr(base_model, "get_hidden_dim"): + if hasattr(base_model, "get_hidden_dim") and not is_embedding_module: + # if hasattr(base_model, "get_hidden_dim"): ############################## ############################## ############################## - print(1111) - print("=========") return base_model.get_hidden_dim(module_name, layer_idx) else: - print(2222) - print("=========") """ WARNING: get_hidden_dim() is not defined, which is used to get the hidden dim for different lora modules @@ -111,16 +103,14 @@ def get_hidden_dim( elif "embed_tokens" in module_name: # For embedding: input is vocab_size (as embedding lookup), output is hidden_size # if contain extra tokens will be added; otherwise is 0. - extra_vocab = getattr(config, 'extra_vocab_size', 0) - return config.vocab_size + extra_vocab, config.hidden_size + return config.vocab_size + lora_added_vocab_size, config.hidden_size #Handle lm_head # elif "lm_head" in module_name: elif "lm_head" in module_name: # For lm_head: input is hidden_size, output is vocab_size # if contain extra tokens will be added; otherwise is 0. - extra_vocab = getattr(config, 'extra_vocab_size', 0) - return config.hidden_size, config.vocab_size + extra_vocab + return config.hidden_size, config.vocab_size + lora_added_vocab_size ############################## ############################## ############################## diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index e174eb0c2f39..16d7cf4bc228 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -418,9 +418,57 @@ def forward_generation_raw( else: input_ids = torch.tensor([p], device="cuda") + ############################## + ##########emb lora############ + ############################## + # if lora_paths is not None and lora_paths[i] is not None: + # from peft import PeftModel + + # model = PeftModel.from_pretrained( + # base_model, + # lora_paths[i], + # torch_dtype=torch_dtype, + # is_trainable=False, + # ) + # else: + # model = base_model + if lora_paths is not None and lora_paths[i] is not None: from peft import PeftModel - + from transformers import AutoTokenizer + + # Load LoRA's tokenizer to check vocab size + try: + lora_tokenizer = AutoTokenizer.from_pretrained(lora_paths[i]) + lora_vocab_size = len(lora_tokenizer) + except: + # If LoRA doesn't have tokenizer, try to infer from adapter config + import json + import os + from huggingface_hub import hf_hub_download + + try: + # Download adapter_config.json to check for vocab size info + config_file = hf_hub_download(repo_id=lora_paths[i], filename="adapter_model.bin") + adapter_state = torch.load(config_file.replace("adapter_model.bin", "adapter_model.bin"), map_location="cpu") + # Find vocab size from the embedding layer shape + for key in adapter_state.keys(): + if "embed_tokens" in key and "lora_embedding_A" in key: + lora_vocab_size = adapter_state[key].shape[1] + break + elif "lm_head" in key and "lora_B" in key: + lora_vocab_size = adapter_state[key].shape[0] + break + else: + lora_vocab_size = base_model.config.vocab_size + except: + lora_vocab_size = base_model.config.vocab_size + + # Resize base model embeddings if needed + if lora_vocab_size != base_model.config.vocab_size: + print(f"Resizing model embeddings from {base_model.config.vocab_size} to {lora_vocab_size}") + base_model.resize_token_embeddings(lora_vocab_size) + model = PeftModel.from_pretrained( base_model, lora_paths[i], @@ -429,6 +477,9 @@ def forward_generation_raw( ) else: model = base_model + ############################## + ############################## + ############################## if patch_model_do_sample_false: model.generation_config.do_sample = False outputs = model.generate( From 63327732012b2d26d1e3d2318273851ee265797f Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Sun, 30 Nov 2025 04:39:52 +0000 Subject: [PATCH 07/19] fixed lm_head issue --- python/sglang/srt/layers/logits_processor.py | 15 ++++++- python/sglang/srt/lora/mem_pool.py | 12 +++--- python/sglang/srt/lora/utils.py | 26 +++--------- python/sglang/test/runners.py | 44 ++++---------------- 4 files changed, 35 insertions(+), 62 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index e2c7d2ab6457..d5af5fcdfde1 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -812,7 +812,20 @@ def _get_logits( ) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) - if hasattr(lm_head, "weight"): + ############################## + ##########emb lora############ + ############################## + if hasattr(lm_head, 'set_lora') and hasattr(lm_head, 'apply_lora'): + # This is a LoRA-wrapped module, use its forward method + #[TODO] improve lm_head forward infernce + logits = lm_head(hidden_states) + + + # if hasattr(lm_head, "weight"): + elif hasattr(lm_head, "weight"): + ############################## + ############################## + ############################## if self.use_fp32_lm_head: logits = torch.matmul( hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 3eaad2a0d87e..54f3ce288c12 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -638,13 +638,13 @@ def load_lora_weight_tensor( target_module = get_target_module_name(name, self.target_modules) # if "lora_embedding_A" in name: # if "lora_embedding_A" in name or ("lora_A" in name and target_module == "embed_tokens"): - if target_module == "embed_tokens" and "lora_embedding_A" in name: + if target_module == "embed_tokens" and "embed_tokens" in name and ("lora_embedding_A" in name or "lora_A" in name): buffer_view = self.embedding_A_buffer[target_module][ - buffer_id, :lora_rank, : org_vocab_size + lora_added_tokens_size + buffer_id, :lora_rank, :(org_vocab_size+lora_added_tokens_size) ] load_lora_weight_tensor(buffer_view, weights) # elif "lora_embedding_B" in name: - elif target_module == "embed_tokens" and "lora_embedding_B" in name: + elif target_module == "embed_tokens" and "embed_tokens" in name and ("lora_embedding_B" in name or "lora_B" in name): lora_b_weights = weights #[to-do] support TP # if self.tp_size > 1: @@ -664,7 +664,7 @@ def load_lora_weight_tensor( # self.target_modules: {'qkv_proj', 'embed_tokens', 'gate_up_proj', 'o_proj', 'lm_head', 'down_proj'} # target_module: lm_head # if "lora_lm_head_A" in name or ("lora_A" in name and target_module == "lm_head"): - elif target_module == "lm_head" and "lora_A.weight" in name: + elif target_module == "lm_head" and "lm_head" in name and ("lora_embedding_A" in name or "lora_A" in name): buffer_view = self.lm_head_A_buffer[target_module][ # buffer_id, :, :lora_rank buffer_id, :lora_rank, : @@ -672,7 +672,7 @@ def load_lora_weight_tensor( load_lora_weight_tensor(buffer_view, weights) # elif "lora_embedding_B" in name: # elif "lora_lm_head_B" in name or ("lora_B" in name and target_module == "lm_head"): - elif target_module == "lm_head" and "lora_B.weight" in name: + elif target_module == "lm_head" and "lm_head" in name and ("lora_embedding_B" in name or "lora_B" in name): lora_b_weights = weights #[to-do] support TP # if self.tp_size > 1: @@ -684,7 +684,7 @@ def load_lora_weight_tensor( buffer_view = self.lm_head_B_buffer[target_module][ # buffer_id, :lora_rank, : org_vocab_size + extra_vocab_size - buffer_id, : org_vocab_size + self.lora_added_tokens_size, :lora_rank + buffer_id, :(org_vocab_size + self.lora_added_tokens_size), :lora_rank ] load_lora_weight_tensor(buffer_view, lora_b_weights) diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 7d9e44f3a12a..cd9b0623dabd 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -51,24 +51,8 @@ def get_hidden_dim( """ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. """ - - ############################## - ##########emb lora############ - ############################## - is_embedding_module = "embed_tokens" in module_name or "lm_head" in module_name - ############################## - ############################## - ############################## - - ############################## - ##########emb lora############ - ############################## - # if hasattr(base_model, "get_hidden_dim"): - if hasattr(base_model, "get_hidden_dim") and not is_embedding_module: - # if hasattr(base_model, "get_hidden_dim"): - ############################## - ############################## - ############################## + + if hasattr(base_model, "get_hidden_dim"): return base_model.get_hidden_dim(module_name, layer_idx) else: """ @@ -100,14 +84,16 @@ def get_hidden_dim( ############################## #Handle embed_tokens # elif "embed_tokens" in module_name: - elif "embed_tokens" in module_name: + # elif "embed_tokens" in module_name: + elif module_name == "embed_tokens": # For embedding: input is vocab_size (as embedding lookup), output is hidden_size # if contain extra tokens will be added; otherwise is 0. return config.vocab_size + lora_added_vocab_size, config.hidden_size #Handle lm_head # elif "lm_head" in module_name: - elif "lm_head" in module_name: + # elif "lm_head" in module_name: + elif module_name == "lm_head": # For lm_head: input is hidden_size, output is vocab_size # if contain extra tokens will be added; otherwise is 0. return config.hidden_size, config.vocab_size + lora_added_vocab_size diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 16d7cf4bc228..6a73755bf263 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -421,6 +421,7 @@ def forward_generation_raw( ############################## ##########emb lora############ ############################## + # # (original) # if lora_paths is not None and lora_paths[i] is not None: # from peft import PeftModel @@ -433,42 +434,15 @@ def forward_generation_raw( # else: # model = base_model + # PR version if lora_paths is not None and lora_paths[i] is not None: - from peft import PeftModel - from transformers import AutoTokenizer - - # Load LoRA's tokenizer to check vocab size - try: - lora_tokenizer = AutoTokenizer.from_pretrained(lora_paths[i]) - lora_vocab_size = len(lora_tokenizer) - except: - # If LoRA doesn't have tokenizer, try to infer from adapter config - import json - import os - from huggingface_hub import hf_hub_download - - try: - # Download adapter_config.json to check for vocab size info - config_file = hf_hub_download(repo_id=lora_paths[i], filename="adapter_model.bin") - adapter_state = torch.load(config_file.replace("adapter_model.bin", "adapter_model.bin"), map_location="cpu") - # Find vocab size from the embedding layer shape - for key in adapter_state.keys(): - if "embed_tokens" in key and "lora_embedding_A" in key: - lora_vocab_size = adapter_state[key].shape[1] - break - elif "lm_head" in key and "lora_B" in key: - lora_vocab_size = adapter_state[key].shape[0] - break - else: - lora_vocab_size = base_model.config.vocab_size - except: - lora_vocab_size = base_model.config.vocab_size - - # Resize base model embeddings if needed - if lora_vocab_size != base_model.config.vocab_size: - print(f"Resizing model embeddings from {base_model.config.vocab_size} to {lora_vocab_size}") - base_model.resize_token_embeddings(lora_vocab_size) - + from peft import PeftConfig, PeftModel + + peft_config = PeftConfig.from_pretrained(lora_paths[i]) + if "embed_tokens" in peft_config.target_modules: + tok = get_tokenizer(lora_paths[i]) + base_model.resize_token_embeddings(len(tok)) + model = PeftModel.from_pretrained( base_model, lora_paths[i], From 53ab0d523fcbd111740d23e6a0e0b94ee511bfa3 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Mon, 1 Dec 2025 21:50:35 +0000 Subject: [PATCH 08/19] vocab_emb without cuda-graph version --- python/sglang/srt/lora/layers.py | 355 ------------------------------- 1 file changed, 355 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 2d32f15487b3..20c91b3815a9 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -48,294 +48,6 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): ############################# #########emb lora############ ############################# -# class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): -# """ -# Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation). - -# Note: The current version does not yet implement the LoRA functionality. -# This class behaves exactly the same as the base VocabParallelEmbedding. -# Future versions will integrate LoRA functionality to support efficient parameter fine-tuning. -# """ - -# def __init__( -# self, -# base_layer: VocabParallelEmbedding, -# lora_backend: BaseLoRABackend, -# ) -> None: -# super().__init__(base_layer, lora_backend) -# self.weight = base_layer.weight - - -##### ----- -##### ----- -##### ----- -# class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): -# """ -# Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation). - -# This layer supports LoRA adapters on embedding layers, including handling -# of extra tokens added by LoRA adapters. The implementation uses efficient -# embedding lookup instead of one-hot encoding. - -# For embedding layers: output = base_embedding(x) + lora_B @ lora_A[x] -# where lora_A[x] is direct embedding lookup from lora_A weights. -# """ - -# def __init__( -# self, -# base_layer: VocabParallelEmbedding, -# lora_backend: BaseLoRABackend, -# ) -> None: -# super().__init__(base_layer, lora_backend) -# self.weight = base_layer.weight -# self.embed_dim = base_layer.embedding_dim -# self.vocab_size = base_layer.org_vocab_size - -# def set_lora_info( -# self, -# new_embeddings_buffer: Optional[torch.Tensor], -# embedding_A_buffer: torch.Tensor, -# embedding_B_buffer: torch.Tensor, -# ): -# self.set_lora = True -# self.new_embeddings_buffer = new_embeddings_buffer # For extra tokens -# self.embedding_A_buffer = embedding_A_buffer -# self.embedding_B_buffer = embedding_B_buffer - -# def _get_token_weight_indices( -# self, input_: torch.Tensor, batch_info -# ) -> torch.Tensor: -# """Map each token position to its corresponding LoRA adapter index.""" -# token_weight_indices = torch.zeros( -# input_.shape[0], dtype=torch.int32, device=input_.device -# ) - -# # current_pos = 0 -# # for i in range(batch_info.bs): -# # seg_len = int(batch_info.seg_lens[i]) -# # weight_idx = int(batch_info.weight_indices[i]) -# # token_weight_indices[current_pos : current_pos+seg_len] = weight_idx -# # current_pos += seg_len - -# # Use cumsum for positions - avoid Python loops -# seg_lens = batch_info.seg_lens[:batch_info.bs] # (bs,) -# cum_lens = torch.cumsum(seg_lens, dim=0) # cumulative positions -# start_positions = torch.cat([torch.zeros(1, dtype=cum_lens.dtype, device=cum_lens.device), cum_lens[:-1]]) - -# # Vectorized assignment using tensor operations - allow enable cuda-graph -# for i in range(batch_info.bs): -# start = start_positions[i] -# end = cum_lens[i] -# weight_idx = batch_info.weight_indices[i] -# token_weight_indices[start:end] = weight_idx - -# return token_weight_indices - -# def _run_lora_a_embedding( -# self, input_: torch.Tensor, token_weight_indices: torch.Tensor -# ) -> torch.Tensor: -# """ -# Apply LoRA A weights using efficient embedding lookup. -# This avoids creating one-hot vectors. -# """ -# lora_a_output = torch.zeros( -# (input_.shape[0], self.embedding_A_buffer.shape[1]), -# dtype=self.embedding_A_buffer.dtype, -# device=input_.device, -# ) - -# unique_weight_indices = torch.unique(token_weight_indices) - -# for idx in unique_weight_indices: -# token_mask = token_weight_indices == idx -# lora_a_weights = self.embedding_A_buffer[idx] -# # Use F.embedding for efficient lookup instead of one-hot @ weights -# # lora_a_weights shape: (rank, vocab_size) -# # We need (vocab_size, rank) for embedding lookup -# lora_a_output[token_mask] = F.embedding( -# input_[token_mask], lora_a_weights.t() -# ) - -# return lora_a_output - -# def apply_lora( -# self, base_output: torch.Tensor, input_: torch.Tensor, batch_info -# ) -> torch.Tensor: -# """ -# Apply LoRA to base embedding output. -# Formula: output = base_output + lora_B @ lora_A_embedding(input_) -# """ -# token_weight_indices = self._get_token_weight_indices(input_, batch_info) - -# # Efficient embedding lookup for LoRA A -# lora_a_output = self._run_lora_a_embedding(input_, token_weight_indices) - -# # Apply LoRA B weights -# lora_output = self.lora_backend.run_lora_b_sgemm( -# x=lora_a_output, -# weights=self.embedding_B_buffer, -# base_output=base_output, -# ) -# return lora_output - -# def _forward( -# self, -# input_: torch.Tensor, -# added_tokens_mask: torch.Tensor, -# batch_info, -# base_output: torch.Tensor, -# ) -> torch.Tensor: -# """Handle extra tokens that are beyond the base vocabulary.""" -# token_weight_indices = self._get_token_weight_indices(input_, batch_info) -# added_weight_indices = token_weight_indices[added_tokens_mask] -# unique_added_weight_indices = torch.unique(added_weight_indices) - -# for idx in unique_added_weight_indices: -# lora_mask = added_weight_indices == idx -# added_token_positions = torch.where(added_tokens_mask)[0][lora_mask] -# # Remap to extra token range -# x = input_[added_token_positions] - self.vocab_size -# new_embeddings = F.embedding(x, self.new_embeddings_buffer[idx]) -# base_output[added_token_positions] = new_embeddings - -# return base_output - -# def forward(self, input_: torch.Tensor): -# batch_info = self.lora_backend.batch_info - -# # Mask tokens that are beyond base vocabulary (extra tokens) -# added_tokens_mask = input_ > self.vocab_size - 1 - -# # Get base embedding, masking extra tokens temporarily -# base_output = self.base_layer.forward(input_.masked_fill(added_tokens_mask, 0)) - -# # Handle extra tokens -# if added_tokens_mask.any(): -# base_output = self._forward( -# input_, added_tokens_mask, batch_info, base_output -# ) - -# # Apply LoRA if configured -# if self.set_lora: -# output = self.apply_lora(base_output, input_, batch_info) -# else: -# output = base_output - -# return output - -# def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): -# # LoRA A weights (rank, vocab_size) are not sliced for embedding -# # because each token needs access to full vocabulary -# return A - -# def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): -# ## LoRA B weights (embedding_dim, rank) are sliced along embedding dimension -# # from sglang.srt.distributed import divide -# # shard_size = divide(self.base_layer.embedding_dim, self.base_layer.tp_size) -# # start_idx = tp_rank * shard_size -# # end_idx = (tp_rank + 1) * shard_size -# # B = B[start_idx:end_idx, :] - -# # TP = 1 -# return B - - - -# class ParallelLMHeadWithLoRA(BaseLayerWithLoRA): -# """ -# Parallel LM Head layer with support for LoRA. - -# The LM head computes logits = hidden_states @ (W + B @ A)^T -# This is different from embedding which uses lookup operations. - -# Note: This class is NOT in the official SGLang implementation. -# You may need to verify if LM head LoRA is needed for your use case. -# """ - -# def __init__( -# self, -# base_layer, # ParallelLMHead -# lora_backend: BaseLoRABackend, -# ) -> None: -# super().__init__(base_layer, lora_backend) -# self.weight = base_layer.weight -# self.embed_dim = base_layer.embedding_dim -# self.vocab_size = base_layer.org_vocab_size - -# def set_lora_info( -# self, -# lm_head_A_buffer: torch.Tensor, -# lm_head_B_buffer: torch.Tensor, -# ): -# self.set_lora = True -# self.lm_head_A_buffer = lm_head_A_buffer -# self.lm_head_B_buffer = lm_head_B_buffer - -# def apply_lora( -# self, base_output: torch.Tensor, hidden_states: torch.Tensor -# ) -> torch.Tensor: -# """ -# Apply LoRA to LM head layer. - -# Args: -# base_output: Base logits, shape (batch_size, vocab_size) -# hidden_states: Hidden states, shape (batch_size, hidden_dim) - -# Returns: -# Logits with LoRA applied -# """ -# # For LM head: output = hidden @ (W + B @ A)^T -# # = hidden @ W^T + hidden @ A^T @ B^T -# # = base_output + (hidden @ A^T) @ B^T - -# # Apply lora_A^T: hidden_states @ A^T -# # lm_head_A_buffer shape: (num_loras, rank, hidden_dim) -# lora_a_output = self.lora_backend.run_lora_a_sgemm( -# hidden_states, self.lm_head_A_buffer -# ) - -# # Apply lora_B^T: lora_a_output @ B^T -# # lm_head_B_buffer shape: (num_loras, vocab_size, rank) -# lora_output = self.lora_backend.run_lora_b_sgemm( -# x=lora_a_output, -# weights=self.lm_head_B_buffer, -# base_output=base_output, -# ) - -# return lora_output - -# def forward(self, hidden_states: torch.Tensor): -# # Apply base linear transformation -# base_output = F.linear( -# hidden_states, -# self.weight, -# bias=getattr(self.base_layer, 'bias', None) -# ) - -# # Apply LoRA if set -# if self.set_lora: -# base_output = self.apply_lora(base_output, hidden_states) - -# return base_output - -# def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): -# # LoRA A is not sliced (similar to ColumnParallelLinear) -# return A - -# def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): -# # LoRA B is sliced along vocab dimension (output dimension) -# # Similar to ColumnParallelLinear slicing -# # from sglang.srt.distributed import divide -# # shard_size = divide(self.vocab_size, self.base_layer.tp_size) -# # start_idx = tp_rank * shard_size -# # end_idx = (tp_rank + 1) * shard_size -# # B = B[start_idx:end_idx, :] -# # TP=1 -# return B -##### ----- -##### ----- -##### ----- - class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): """ @@ -567,73 +279,6 @@ def apply_lora( return lora_output - - # def run_lora_b_lm_head( - # self, lora_a_output: torch.Tensor, base_output: torch.Tensor - # ) -> torch.Tensor: - # """ - # Apply LoRA B weights using efficient scatter operation. - - # Instead of full matmul: lora_a_output @ B^T (shape: [s, rank] @ [rank, vocab_size]) - # We compute: for each token, scatter lora_a_output weighted by B weights. - - # This is the "reverse" of embedding lookup - instead of gathering from vocab, - # we scatter to vocab dimension. - # """ - # batch_info = self.lora_backend.batch_info - - # # Get token-to-lora mapping (same as embedding case) - # token_weight_indices = self._get_token_weight_indices(lora_a_output, batch_info) - - # # Apply scatter operation for each LoRA adapter - # output = base_output.clone() if base_output is not None else torch.zeros( - # (lora_a_output.shape[0], self.vocab_size), - # dtype=lora_a_output.dtype, - # device=lora_a_output.device, - # ) - - # unique_weight_indices = torch.unique(token_weight_indices) - - # for idx in unique_weight_indices: - # token_mask = token_weight_indices == idx - - # # Get LoRA B weights for this adapter: (vocab_size, rank) - # lora_b_weights = self.lm_head_B_buffer[idx] # (vocab_size, rank) - - # # Get scaling for this adapter - # scaling = batch_info.scalings[idx] - - # # Compute: lora_a_output[token_mask] @ lora_b_weights^T - # # lora_a_output[token_mask]: (num_tokens, rank) - # # lora_b_weights: (vocab_size, rank) - # # Result: (num_tokens, vocab_size) - # lora_contribution = torch.matmul( - # lora_a_output[token_mask], # (num_tokens, rank) - # lora_b_weights.t() # (rank, vocab_size) - # ) * scaling - - # output[token_mask] += lora_contribution - - # return output - - # def _get_token_weight_indices( - # self, lora_a_output: torch.Tensor, batch_info: LoRABatchInfo - # ) -> torch.Tensor: - # """Get token-to-lora mapping (same as embedding case).""" - # token_weight_indices = torch.zeros( - # lora_a_output.shape[0], dtype=torch.int32, device=lora_a_output.device - # ) - - # current_pos = 0 - # for i in range(batch_info.bs): - # seg_len = int(batch_info.seg_lens[i]) - # weight_idx = int(batch_info.weight_indices[i]) - # token_weight_indices[current_pos : current_pos + seg_len] = weight_idx - # current_pos += seg_len - - # return token_weight_indices - - def forward(self, hidden_states: torch.Tensor): # Apply base linear transformation base_output = F.linear( From 20b368e37e9691974578718cfc287d836e60c520 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Tue, 2 Dec 2025 01:50:35 +0000 Subject: [PATCH 09/19] support cuda-graph (triton backend) --- python/sglang/srt/layers/logits_processor.py | 8 +- .../sglang/srt/lora/backend/base_backend.py | 29 +++ .../srt/lora/backend/chunked_backend.py | 35 ++++ .../sglang/srt/lora/backend/triton_backend.py | 56 ++++- python/sglang/srt/lora/layers.py | 122 +++++++++-- python/sglang/srt/lora/triton_ops/__init__.py | 14 ++ .../srt/lora/triton_ops/embedding_lora_a.py | 195 ++++++++++++++++++ 7 files changed, 434 insertions(+), 25 deletions(-) create mode 100644 python/sglang/srt/lora/triton_ops/embedding_lora_a.py diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index d5af5fcdfde1..6d4ce5dfd504 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -817,8 +817,14 @@ def _get_logits( ############################## if hasattr(lm_head, 'set_lora') and hasattr(lm_head, 'apply_lora'): # This is a LoRA-wrapped module, use its forward method - #[TODO] improve lm_head forward infernce logits = lm_head(hidden_states) + # logits = torch.matmul( + # hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T + # ) + # print("====") + # print(self.use_fp32_lm_head) + # print("====") + # exit() # if hasattr(lm_head, "weight"): diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index 77654c4b2d32..347c7397f60e 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -18,6 +18,35 @@ class BaseLoRABackend: def __init__(self, max_loras_per_batch: int, device: torch.device): self.max_loras_per_batch = max_loras_per_batch self.device = device + + ############################# + #########cuda lora########### + ############################# + def run_lora_a_embedding( + self, + input_ids: torch.Tensor, + weights: torch.Tensor, + vocab_size: int, + extra_embeddings: torch.Tensor = None, + *args, + **kwargs, + ) -> torch.Tensor: + """Run LoRA A embedding lookup with CUDA graph support. + + Args: + input_ids: token IDs with shape (s,), where s is the sum of all sequence lengths + weights: LoRA A embedding weights with shape (num_loras, rank, vocab_size) + vocab_size: base vocabulary size (tokens >= vocab_size are extra tokens) + extra_embeddings: extra token embeddings with shape (num_loras, num_extra_tokens, rank) + Only needed if there are added tokens beyond base vocabulary. + + Returns: + result with shape (s, rank) + """ + pass + ############################# + ############################# + ############################# def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs diff --git a/python/sglang/srt/lora/backend/chunked_backend.py b/python/sglang/srt/lora/backend/chunked_backend.py index f17f473cbdfd..1963143ea0ec 100644 --- a/python/sglang/srt/lora/backend/chunked_backend.py +++ b/python/sglang/srt/lora/backend/chunked_backend.py @@ -4,6 +4,13 @@ from sglang.srt.lora.triton_ops import ( chunked_sgmv_lora_expand_forward, chunked_sgmv_lora_shrink_forward, + ############################# + #########cuda lora########### + ############################# + embedding_lora_a_fwd, + ############################# + ############################# + ############################# ) from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -33,6 +40,34 @@ def __init__( super().__init__(max_loras_per_batch, device) self.max_chunk_size = server_args.max_lora_chunk_size + ############################# + #########cuda lora########### + ############################# + def run_lora_a_embedding( + self, + input_ids: torch.Tensor, + weights: torch.Tensor, + vocab_size: int, + extra_embeddings: torch.Tensor = None, + *args, + **kwargs, + ) -> torch.Tensor: + """Run LoRA A embedding lookup. + + For chunked backend, we use the same triton kernel as triton backend + since embedding lookup doesn't benefit much from chunking. + """ + return embedding_lora_a_fwd( + input_ids=input_ids, + weights=weights, + batch_info=self.batch_info, + vocab_size=vocab_size, + extra_embeddings=extra_embeddings, + ) + ############################# + ############################# + ############################# + def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs ) -> torch.Tensor: diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index 1c2e319dd397..cb7c2bc94956 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -6,6 +6,13 @@ qkv_lora_b_fwd, sgemm_lora_a_fwd, sgemm_lora_b_fwd, + ############################# + #########cuda lora########### + ############################# + embedding_lora_a_fwd, + ############################# + ############################# + ############################# ) from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -22,6 +29,30 @@ def __init__( ): super().__init__(max_loras_per_batch, device) + ############################# + #########cuda lora########### + ############################# + def run_lora_a_embedding( + self, + input_ids: torch.Tensor, + weights: torch.Tensor, + vocab_size: int, + extra_embeddings: torch.Tensor = None, + *args, + **kwargs + ) -> torch.Tensor: + """Run LoRA A embedding lookup using Triton kernel.""" + return embedding_lora_a_fwd( + input_ids=input_ids, + weights=weights, + batch_info=self.batch_info, + vocab_size=vocab_size, + extra_embeddings=extra_embeddings, + ) + ############################# + ############################# + ############################# + def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs ) -> torch.Tensor: @@ -114,6 +145,13 @@ def init_cuda_graph_batch_info( scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), permutation=None, ) + ############################## + ##########cuda lora########### + ############################## + self.cuda_graph_batch_info.seg_indptr[0] = 0 + ############################## + ############################## + ############################## # Initialize seg_indptr for CUDA graph as they remain constant # across batches. @@ -161,7 +199,14 @@ def prepare_lora_batch( seg_lens = ( forward_batch.extend_seq_lens if forward_batch.forward_mode.is_extend() - else torch.ones(bs, device=self.device) + ############################## + ##########cuda lora########### + ############################## + # else torch.ones(bs, device=self.device) + else torch.ones(bs, dtype=torch.int32, device=self.device) + ############################## + ############################## + ############################## ) seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device) seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) @@ -177,7 +222,14 @@ def prepare_lora_batch( (bs,), dtype=torch.int32, device=self.device ), lora_ranks=torch.empty( - (self.max_loras_per_batch,), dtype=torch.int64, device=self.device + ############################## + ##########cuda lora########### + ############################## + # (self.max_loras_per_batch,), dtype=torch.int64, device=self.device + (self.max_loras_per_batch,), dtype=torch.int32, device=self.device + ############################## + ############################## + ############################## ), scalings=torch.empty( (self.max_loras_per_batch,), dtype=torch.float, device=self.device diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 20c91b3815a9..1edb621b3561 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -113,12 +113,35 @@ def apply_lora( def run_lora_a_embedding( self, input_: torch.Tensor, batch_info: LoRABatchInfo ) -> torch.Tensor: + ############################# + #########cuda lora########### + ############################# + batch_info = self.lora_backend.batch_info + # if batch_info.use_cuda_graph: + # cuda """ - Apply LoRA A weights using efficient embedding lookup. + Apply LoRA A weights using efficient embedding lookup with CUDA graph support. Maps tokens to their corresponding LoRA adapters internally. """ - token_weight_indices = self._get_token_weight_indices(input_, batch_info) - lora_a_output = self._run_lora_a_embedding(input_, token_weight_indices) + # Use backend implementation which supports CUDA graph + lora_a_output = self.lora_backend.run_lora_a_embedding( + input_ids=input_, + weights=self.embedding_A_buffer, + vocab_size=self.vocab_size, + extra_embeddings=self.new_embeddings_buffer if hasattr(self, 'new_embeddings_buffer') and self.new_embeddings_buffer is not None else None, + ) + return lora_a_output + # else: + # # non-cuda + # """ + # Apply LoRA A weights using efficient embedding lookup. + # Maps tokens to their corresponding LoRA adapters internally. + # """ + # token_weight_indices = self._get_token_weight_indices(input_, batch_info) + # lora_a_output = self._run_lora_a_embedding(input_, token_weight_indices) + ############################# + ############################# + ############################# return lora_a_output @@ -179,31 +202,86 @@ def forward(self, input_: torch.Tensor): ############### ############### consider both non-extra and extra tokens ############### + + ############################## + ##########cuda lora########### + ############################## + # batch_info = self.lora_backend.batch_info + # if batch_info.use_cuda_graph: + """ + Forward pass with LoRA support and CUDA graph compatibility. + Extra tokens (tokens >= vocab_size) are now handled efficiently + in the backend's run_lora_a_embedding method. + """ batch_info = self.lora_backend.batch_info - - # Handle added tokens (tokens beyond base vocabulary) + + # Get base embedding output + # For tokens >= vocab_size, base_layer will clamp or handle them + # We mask them to 0 to avoid out-of-bounds access added_tokens_mask = input_ > self.vocab_size - 1 - base_output = self.base_layer.forward(input_.masked_fill(added_tokens_mask, 0)) - - # Process extra tokens if they exist - if added_tokens_mask.any(): - token_weight_indices = self._get_token_weight_indices(input_, batch_info) - added_weight_indices = token_weight_indices[added_tokens_mask] - unique_added_weight_indices = torch.unique(added_weight_indices) - - for idx in unique_added_weight_indices: - lora_mask = added_weight_indices == idx - added_token_positions = torch.where(added_tokens_mask)[0][lora_mask] - x = input_[added_token_positions] - self.vocab_size - new_embeddings = F.embedding(x, self.new_embeddings_buffer[idx]) - base_output[added_token_positions] = new_embeddings - - # Apply LoRA if set + base_input = input_.masked_fill(added_tokens_mask, 0) + base_output = self.base_layer.forward(base_input) + + # Apply LoRA if configured if self.set_lora: + # The backend's run_lora_a_embedding now handles both regular + # and extra tokens efficiently with CUDA graph support output = self.apply_lora(base_output, input_, batch_info) else: - output = base_output + ## Optimized for CUDA graph compatibility + + # Support extra_token + if added_tokens_mask.any() and hasattr(self, 'new_embeddings_buffer') and self.new_embeddings_buffer is not None: + # Use backend even without LoRA to handle extra tokens with CUDA graph support + # The backend's run_lora_a_embedding can handle extra_embeddings directly + extra_output = self.lora_backend.run_lora_a_embedding( + input_ids=input_, + weights=torch.zeros( + (self.new_embeddings_buffer.shape[0], self.new_embeddings_buffer.shape[2], self.vocab_size), + device=input_.device, + dtype=self.new_embeddings_buffer.dtype + ), # Dummy LoRA weights (all zeros) + vocab_size=self.vocab_size, + extra_embeddings=self.new_embeddings_buffer, + ) + # Only use extra embeddings for tokens >= vocab_size + # For regular tokens, keep base_output; for extra tokens, use extra_output + output = torch.where( + added_tokens_mask.unsqueeze(-1), + extra_output, + base_output + ) + else: + # Do not have extra token + output = base_output + ############################## + ############################## + ############################## + # else: + # batch_info = self.lora_backend.batch_info + # # Handle added tokens (tokens beyond base vocabulary) + # added_tokens_mask = input_ > self.vocab_size - 1 + # base_output = self.base_layer.forward(input_.masked_fill(added_tokens_mask, 0)) + + # # Process extra tokens if they exist + # if added_tokens_mask.any(): + # token_weight_indices = self._get_token_weight_indices(input_, batch_info) + # added_weight_indices = token_weight_indices[added_tokens_mask] + # unique_added_weight_indices = torch.unique(added_weight_indices) + + # for idx in unique_added_weight_indices: + # lora_mask = added_weight_indices == idx + # added_token_positions = torch.where(added_tokens_mask)[0][lora_mask] + # x = input_[added_token_positions] - self.vocab_size + # new_embeddings = F.embedding(x, self.new_embeddings_buffer[idx]) + # base_output[added_token_positions] = new_embeddings + + # # Apply LoRA if set + # if self.set_lora: + # output = self.apply_lora(base_output, input_, batch_info) + # else: + # output = base_output return output ############################## diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py index 74a2e84a2c40..81d45c16fe77 100644 --- a/python/sglang/srt/lora/triton_ops/__init__.py +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -4,6 +4,13 @@ from .qkv_lora_b import qkv_lora_b_fwd from .sgemm_lora_a import sgemm_lora_a_fwd from .sgemm_lora_b import sgemm_lora_b_fwd +############################# +#########cuda lora########### +############################# +from .embedding_lora_a import embedding_lora_a_fwd +############################# +############################# +############################# __all__ = [ "gate_up_lora_b_fwd", @@ -12,4 +19,11 @@ "sgemm_lora_b_fwd", "chunked_sgmv_lora_shrink_forward", "chunked_sgmv_lora_expand_forward", + ############################# + #########cuda lora########### + ############################# + "embedding_lora_a_fwd", + ############################# + ############################# + ############################# ] diff --git a/python/sglang/srt/lora/triton_ops/embedding_lora_a.py b/python/sglang/srt/lora/triton_ops/embedding_lora_a.py new file mode 100644 index 000000000000..61acdd72a54c --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/embedding_lora_a.py @@ -0,0 +1,195 @@ +############################# +#########cuda lora########### +############################# + +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.utils import LoRABatchInfo + + +@triton.jit +def _embedding_lora_a_kernel( + # Pointers to tensors + input_ids, + weights, + output, + extra_embeddings, + # Dimensions + vocab_size, + rank, + num_loras, + # Strides + w_stride_0, # stride for lora index + w_stride_1, # stride for rank + w_stride_2, # stride for vocab + output_stride_0, + output_stride_1, + extra_emb_stride_0, # stride for lora index + extra_emb_stride_1, # stride for token + extra_emb_stride_2, # stride for hidden dim (= rank for extra embeddings) + # Batch info + seg_lens, + seg_indptr, + weight_indices, + lora_ranks, + # Meta-parameters + BLOCK_RANK: tl.constexpr, + HAS_EXTRA_EMBEDDINGS: tl.constexpr, +): + """ + Embedding lookup for LoRA A weights with support for extra tokens. + + Each program handles one token across a block of rank dimensions. + Grid: (cdiv(max_len, 1), bs) - one program per token in each batch + """ + batch_id = tl.program_id(axis=1) + token_idx = tl.program_id(axis=0) + + w_index = tl.load(weight_indices + batch_id) + rank_val = tl.load(lora_ranks + w_index) + + # If rank is 0, skip + if rank_val == 0: + return + + seg_start = tl.load(seg_indptr + batch_id) + seg_len = tl.load(seg_lens + batch_id) + + # Check if this token is within the segment + if token_idx >= seg_len: + return + + # Load the token ID + token_id = tl.load(input_ids + seg_start + token_idx) + + # Process in chunks of BLOCK_RANK dimensions + num_blocks = tl.cdiv(rank_val, BLOCK_RANK) + + + for block_id in range(num_blocks): + rank_offset = tl.arange(0, BLOCK_RANK) + block_id * BLOCK_RANK + rank_mask = rank_offset < rank_val + + # Check if this is an extra token + is_extra_token = token_id >= vocab_size + + if HAS_EXTRA_EMBEDDINGS and is_extra_token: + # Use extra embeddings + extra_token_id = token_id - vocab_size + extra_emb_ptr = ( + extra_embeddings + + w_index * extra_emb_stride_0 + + extra_token_id * extra_emb_stride_1 + + rank_offset * extra_emb_stride_2 + ) + emb_values = tl.load(extra_emb_ptr, mask=rank_mask, other=0.0) + else: + # Use regular LoRA A weights + # weights shape: (num_loras, rank, vocab_size) + # We need to load weights[w_index, rank_offset, token_id] + token_id_clamped = tl.minimum(token_id, vocab_size - 1) + weight_ptr = ( + weights + + w_index * w_stride_0 + + rank_offset * w_stride_1 + + token_id_clamped * w_stride_2 + ) + emb_values = tl.load(weight_ptr, mask=rank_mask, other=0.0) + + # Write to output + output_ptr = ( + output + + (seg_start + token_idx) * output_stride_0 + + rank_offset * output_stride_1 + ) + tl.store(output_ptr, emb_values, mask=rank_mask) + + +def embedding_lora_a_fwd( + input_ids: torch.Tensor, + weights: torch.Tensor, + batch_info: LoRABatchInfo, + vocab_size: int, + extra_embeddings: torch.Tensor = None, +) -> torch.Tensor: + """ + Forward pass for LoRA A embedding lookup. + + Args: + input_ids: (s,) token IDs + weights: (num_loras, rank, vocab_size) LoRA A embedding weights + batch_info: LoRABatchInfo containing batch information + vocab_size: base vocabulary size + extra_embeddings: (num_loras, num_extra_tokens, rank) extra token embeddings + + Returns: + output: (s, rank) embedded features + """ + assert input_ids.is_contiguous() + assert weights.is_contiguous() + assert len(input_ids.shape) == 1 + assert len(weights.shape) == 3 + + S = input_ids.shape[0] + num_loras = weights.shape[0] + rank = weights.shape[1] + vocab_size_weights = weights.shape[2] + + # Block size for rank dimension + BLOCK_RANK = 128 + + has_extra_embeddings = extra_embeddings is not None + + if has_extra_embeddings: + assert extra_embeddings.is_contiguous() + extra_emb_stride = ( + extra_embeddings.stride(0), + extra_embeddings.stride(1), + extra_embeddings.stride(2), + ) + else: + # Create dummy tensor to satisfy Triton + extra_embeddings = torch.empty( + (1, 1, 1), device=input_ids.device, dtype=weights.dtype + ) + extra_emb_stride = (1, 1, 1) + + # Grid: one program per token in each batch segment + grid = ( + batch_info.max_len, + batch_info.bs, + ) + + output = torch.zeros((S, rank), device=input_ids.device, dtype=weights.dtype) + + _embedding_lora_a_kernel[grid]( + input_ids, + weights, + output, + extra_embeddings, + vocab_size, + rank, + num_loras, + weights.stride(0), + weights.stride(1), + weights.stride(2), + output.stride(0), + output.stride(1), + extra_emb_stride[0], + extra_emb_stride[1], + extra_emb_stride[2], + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + batch_info.lora_ranks, + BLOCK_RANK, + has_extra_embeddings, + ) + + return output + +############################# +############################# +############################# \ No newline at end of file From f944d95718a6e42581e22d9850be7865bb35652a Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Thu, 4 Dec 2025 21:50:12 +0000 Subject: [PATCH 10/19] support cuda and no-cuda version; tokenizer (it added/extra tokens) should be modified --- .../sglang/srt/lora/backend/base_backend.py | 23 +++ .../srt/lora/backend/chunked_backend.py | 36 +++- .../sglang/srt/lora/backend/triton_backend.py | 43 +++-- python/sglang/srt/lora/layers.py | 150 ++++++++++++----- python/sglang/srt/lora/triton_ops/__init__.py | 2 + .../lora/triton_ops/embedding_extra_tokens.py | 154 ++++++++++++++++++ python/sglang/test/runners.py | 11 +- 7 files changed, 350 insertions(+), 69 deletions(-) create mode 100644 python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index 347c7397f60e..9c72c9d755f5 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -44,6 +44,29 @@ def run_lora_a_embedding( result with shape (s, rank) """ pass + + def run_extra_token_embedding( + self, + input_ids: torch.Tensor, + output: torch.Tensor, + extra_embeddings: torch.Tensor, + vocab_size: int, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Apply extra token embeddings to output in-place. + + Args: + input_ids: (s,) token IDs + output: (s, embed_dim) output tensor to be modified + extra_embeddings: (num_loras, num_extra_tokens, embed_dim) extra embeddings + vocab_size: base vocabulary size + + Returns: + output: modified output tensor + """ + raise NotImplementedError ############################# ############################# ############################# diff --git a/python/sglang/srt/lora/backend/chunked_backend.py b/python/sglang/srt/lora/backend/chunked_backend.py index 1963143ea0ec..adeb9f435f78 100644 --- a/python/sglang/srt/lora/backend/chunked_backend.py +++ b/python/sglang/srt/lora/backend/chunked_backend.py @@ -8,6 +8,7 @@ #########cuda lora########### ############################# embedding_lora_a_fwd, + embedding_extra_tokens_fwd, ############################# ############################# ############################# @@ -40,9 +41,10 @@ def __init__( super().__init__(max_loras_per_batch, device) self.max_chunk_size = server_args.max_lora_chunk_size - ############################# - #########cuda lora########### - ############################# + + ############################## + ##########cuda lora########### + ############################## def run_lora_a_embedding( self, input_ids: torch.Tensor, @@ -64,9 +66,31 @@ def run_lora_a_embedding( vocab_size=vocab_size, extra_embeddings=extra_embeddings, ) - ############################# - ############################# - ############################# + + def run_extra_token_embedding( + self, + input_ids: torch.Tensor, + output: torch.Tensor, + extra_embeddings: torch.Tensor, + vocab_size: int, + *args, + **kwargs, + ) -> torch.Tensor: + """Run extra token embedding lookup. + + For chunked backend, we use the same triton kernel as triton backend + since embedding lookup doesn't benefit from chunking. + """ + return embedding_extra_tokens_fwd( + input_ids=input_ids, + output=output, + extra_embeddings=extra_embeddings, + batch_info=self.batch_info, + vocab_size=vocab_size, + ) + ############################## + ############################## + ############################## def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index cb7c2bc94956..bf90e61bdcd1 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -10,6 +10,7 @@ #########cuda lora########### ############################# embedding_lora_a_fwd, + embedding_extra_tokens_fwd, ############################# ############################# ############################# @@ -49,6 +50,24 @@ def run_lora_a_embedding( vocab_size=vocab_size, extra_embeddings=extra_embeddings, ) + + def run_extra_token_embedding( + self, + input_ids: torch.Tensor, + output: torch.Tensor, + extra_embeddings: torch.Tensor, + vocab_size: int, + *args, + **kwargs, + ) -> torch.Tensor: + """Run extra token embedding lookup using Triton kernel.""" + return embedding_extra_tokens_fwd( + input_ids=input_ids, + output=output, + extra_embeddings=extra_embeddings, + batch_info=self.batch_info, + vocab_size=self.vocab_size, + ) ############################# ############################# ############################# @@ -138,20 +157,14 @@ def init_cuda_graph_batch_info( seg_lens=torch.full( (max_bs_in_cuda_graph,), num_tokens_per_bs, dtype=torch.int32 ), - seg_indptr=torch.empty(max_bs_in_cuda_graph + 1, dtype=torch.int32), + seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32), max_len=num_tokens_per_bs, weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32), lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32), scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), permutation=None, ) - ############################## - ##########cuda lora########### - ############################## - self.cuda_graph_batch_info.seg_indptr[0] = 0 - ############################## - ############################## - ############################## + # self.cuda_graph_batch_info.seg_indptr[0] = 0 # Initialize seg_indptr for CUDA graph as they remain constant # across batches. @@ -199,14 +212,7 @@ def prepare_lora_batch( seg_lens = ( forward_batch.extend_seq_lens if forward_batch.forward_mode.is_extend() - ############################## - ##########cuda lora########### - ############################## - # else torch.ones(bs, device=self.device) else torch.ones(bs, dtype=torch.int32, device=self.device) - ############################## - ############################## - ############################## ) seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device) seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) @@ -222,14 +228,7 @@ def prepare_lora_batch( (bs,), dtype=torch.int32, device=self.device ), lora_ranks=torch.empty( - ############################## - ##########cuda lora########### - ############################## - # (self.max_loras_per_batch,), dtype=torch.int64, device=self.device (self.max_loras_per_batch,), dtype=torch.int32, device=self.device - ############################## - ############################## - ############################## ), scalings=torch.empty( (self.max_loras_per_batch,), dtype=torch.float, device=self.device diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 1edb621b3561..69b6d7805c36 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -45,10 +45,15 @@ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): pass + + ############################# #########emb lora############ ############################# +######org +###### +###### class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): """ Vocab parallel embedding layer with LoRA support (simplified for TP=1, no extra tokens). @@ -99,8 +104,28 @@ def apply_lora( Formula: output = base_output + lora_B @ lora_A_embedding(input_) """ - # Efficient embedding lookup for LoRA A (cannot call run_lora_a_sgemm since needing index lookup) + # Efficient embedding lookup for LoRA A (cannot call run_lora_a_sgemm since needing index lookup) lora_a_output = self.run_lora_a_embedding(input_, batch_info) + print("=====") + lora_a_output_noncuda = self.run_lora_a_embedding_no_cuda(input_, batch_info) + + # 比較兩個輸出 + print("🔍 Comparing CUDA vs Non-CUDA LoRA A outputs:") + print(f" Shape - CUDA: {lora_a_output.shape}, Non-CUDA: {lora_a_output_noncuda.shape}") + print(f" Device - CUDA: {lora_a_output.device}, Non-CUDA: {lora_a_output_noncuda.device}") + print(f" Dtype - CUDA: {lora_a_output.dtype}, Non-CUDA: {lora_a_output_noncuda.dtype}") + + # 計算差異 + abs_diff = torch.abs(lora_a_output - lora_a_output_noncuda) + rel_diff = abs_diff / (torch.abs(lora_a_output) + 1e-8) + + print(f"\n📊 Difference Statistics:") + print(f" Max absolute difference: {abs_diff.max().item():.6e}") + print(f" Mean absolute difference: {abs_diff.mean().item():.6e}") + print(f" Max relative difference: {rel_diff.max().item():.6e}") + print(f" Mean relative difference: {rel_diff.mean().item():.6e}") + + print("=====") # Apply LoRA B weights using backend lora_output = self.lora_backend.run_lora_b_sgemm( @@ -110,27 +135,42 @@ def apply_lora( ) return lora_output + + def run_lora_a_embedding_no_cuda( + self, input_: torch.Tensor, batch_info: LoRABatchInfo + ) -> torch.Tensor: + # non-cuda + """ + Apply LoRA A weights using efficient embedding lookup. + Maps tokens to their corresponding LoRA adapters internally. + """ + token_weight_indices = self._get_token_weight_indices(input_, batch_info) + lora_a_output = self._run_lora_a_embedding(input_, token_weight_indices) + + return lora_a_output + + def run_lora_a_embedding( self, input_: torch.Tensor, batch_info: LoRABatchInfo ) -> torch.Tensor: ############################# #########cuda lora########### ############################# - batch_info = self.lora_backend.batch_info + # batch_info = self.lora_backend.batch_info # if batch_info.use_cuda_graph: # cuda """ Apply LoRA A weights using efficient embedding lookup with CUDA graph support. Maps tokens to their corresponding LoRA adapters internally. + It also includes added/extra token processing. """ - # Use backend implementation which supports CUDA graph lora_a_output = self.lora_backend.run_lora_a_embedding( input_ids=input_, weights=self.embedding_A_buffer, vocab_size=self.vocab_size, extra_embeddings=self.new_embeddings_buffer if hasattr(self, 'new_embeddings_buffer') and self.new_embeddings_buffer is not None else None, ) - return lora_a_output + # else: # # non-cuda # """ @@ -155,8 +195,8 @@ def _get_token_weight_indices( current_pos = 0 for i in range(batch_info.bs): - seg_len = int(batch_info.seg_lens[i]) - weight_idx = int(batch_info.weight_indices[i]) + seg_len = batch_info.seg_lens[i] + weight_idx = batch_info.weight_indices[i] token_weight_indices[current_pos : current_pos+seg_len] = weight_idx current_pos += seg_len @@ -206,7 +246,9 @@ def forward(self, input_: torch.Tensor): ############################## ##########cuda lora########### ############################## + # batch_info = self.lora_backend.batch_info + # if batch_info.use_cuda_graph: """ Forward pass with LoRA support and CUDA graph compatibility. @@ -220,8 +262,18 @@ def forward(self, input_: torch.Tensor): # For tokens >= vocab_size, base_layer will clamp or handle them # We mask them to 0 to avoid out-of-bounds access added_tokens_mask = input_ > self.vocab_size - 1 - base_input = input_.masked_fill(added_tokens_mask, 0) - base_output = self.base_layer.forward(base_input) + base_output = self.base_layer.forward(input_.masked_fill(added_tokens_mask, 0)) + + # is there's extra tokens + if added_tokens_mask.any(): + base_output = self.extra_token_embedding( + input_, added_tokens_mask, batch_info, base_output.clone() + ) + + # base_output_noncuda = self._extra_token_embedding_no_cuda( + # input_, added_tokens_mask, batch_info, base_output.clone() + # ) + # Apply LoRA if configured if self.set_lora: @@ -229,32 +281,50 @@ def forward(self, input_: torch.Tensor): # and extra tokens efficiently with CUDA graph support output = self.apply_lora(base_output, input_, batch_info) else: - ## Optimized for CUDA graph compatibility - - # Support extra_token - if added_tokens_mask.any() and hasattr(self, 'new_embeddings_buffer') and self.new_embeddings_buffer is not None: - # Use backend even without LoRA to handle extra tokens with CUDA graph support - # The backend's run_lora_a_embedding can handle extra_embeddings directly - extra_output = self.lora_backend.run_lora_a_embedding( - input_ids=input_, - weights=torch.zeros( - (self.new_embeddings_buffer.shape[0], self.new_embeddings_buffer.shape[2], self.vocab_size), - device=input_.device, - dtype=self.new_embeddings_buffer.dtype - ), # Dummy LoRA weights (all zeros) - vocab_size=self.vocab_size, - extra_embeddings=self.new_embeddings_buffer, - ) - # Only use extra embeddings for tokens >= vocab_size - # For regular tokens, keep base_output; for extra tokens, use extra_output - output = torch.where( - added_tokens_mask.unsqueeze(-1), - extra_output, - base_output - ) - else: - # Do not have extra token - output = base_output + output = base_output + + return output + + def extra_token_embedding( + self, + input_: torch.Tensor, + added_tokens_mask: torch.Tensor, + batch_info: LoRABatchInfo, + base_output: torch.Tensor, + ) -> torch.Tensor: + """ + Apply extra token embeddings using efficient Triton kernel. + This method is CUDA graph compatible. + """ + # Use backend's efficient kernel (CUDA graph compatible) + base_output = self.lora_backend.run_extra_token_embedding( + input_ids=input_, + output=base_output, + extra_embeddings=self.new_embeddings_buffer, + vocab_size=self.vocab_size, + ) + return base_output + + + def _extra_token_embedding_no_cuda( + self, + input_: torch.Tensor, + added_tokens_mask: torch.Tensor, + batch_info: LoRABatchInfo, + base_output: torch.Tensor, + ) -> torch.Tensor: + token_weight_indices = self._get_token_weight_indices(input_, batch_info) + added_weight_indices = token_weight_indices[added_tokens_mask] + unique_added_weight_indices = torch.unique(added_weight_indices) + + for idx in unique_added_weight_indices: + lora_mask = added_weight_indices == idx + added_token_positions = torch.where(added_tokens_mask)[0][lora_mask] + x = input_[added_token_positions] - self.vocab_size + new_embeddings = F.embedding(x, self.new_embeddings_buffer[idx]) + base_output[added_token_positions] = new_embeddings + + return base_output ############################## ############################## ############################## @@ -283,7 +353,7 @@ def forward(self, input_: torch.Tensor): # else: # output = base_output - return output + # return output ############################## ############################## ############################## @@ -300,6 +370,9 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): # For TP>1, Need to modify code in: sglang/python/sglang/srt/lora/mem_pool.py return B + + + class ParallelLMHeadWithLoRA(BaseLayerWithLoRA): """ @@ -381,10 +454,10 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): # For TP>1, would slice along vocab dimension, eed to modify code in: sglang/python/sglang/srt/lora/mem_pool.py return B +############################# +############################# +############################# -############################## -############################## -############################## class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): @@ -550,6 +623,7 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor output_offset_cpu=self.output_offset_cpu, max_qkv_out_dim=self.max_qkv_out_dim, ) + return lora_output def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py index 81d45c16fe77..b2e74f1f99d4 100644 --- a/python/sglang/srt/lora/triton_ops/__init__.py +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -8,6 +8,7 @@ #########cuda lora########### ############################# from .embedding_lora_a import embedding_lora_a_fwd +from .embedding_extra_tokens import embedding_extra_tokens_fwd ############################# ############################# ############################# @@ -23,6 +24,7 @@ #########cuda lora########### ############################# "embedding_lora_a_fwd", + "embedding_extra_tokens_fwd", ############################# ############################# ############################# diff --git a/python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py b/python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py new file mode 100644 index 000000000000..c0023b659272 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py @@ -0,0 +1,154 @@ +############################# +#########cuda graph########## +############################# +import torch +import triton +import triton.language as tl +from sglang.srt.lora.utils import LoRABatchInfo + + +@triton.jit +def _embedding_extra_tokens_kernel( + # Pointers to tensors + input_ids, + output, + extra_embeddings, + # Dimensions + vocab_size, + embed_dim, + num_loras, + # Strides + output_stride_0, + output_stride_1, + extra_emb_stride_0, # stride for lora index + extra_emb_stride_1, # stride for token + extra_emb_stride_2, # stride for embed dim + # Batch info + seg_lens, + seg_indptr, + weight_indices, + # Meta-parameters + BLOCK_EMBED: tl.constexpr, +): + """ + Embedding lookup for extra/added tokens (tokens >= vocab_size). + Each program handles one token across a block of embedding dimensions. + Grid: (max_len, bs) + """ + batch_id = tl.program_id(axis=1) + token_idx = tl.program_id(axis=0) + + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + seg_len = tl.load(seg_lens + batch_id) + + # Check if this token is within the segment + if token_idx >= seg_len: + return + + # Load the token ID + token_id = tl.load(input_ids + seg_start + token_idx) + + # Check if this is an extra token + is_extra_token = token_id >= vocab_size + + if not is_extra_token: + return # Skip non-extra tokens + + # Calculate extra token ID + extra_token_id = token_id - vocab_size + + # Process in chunks of BLOCK_EMBED dimensions + num_blocks = tl.cdiv(embed_dim, BLOCK_EMBED) + + for block_id in range(num_blocks): + embed_offset = tl.arange(0, BLOCK_EMBED) + block_id * BLOCK_EMBED + embed_mask = embed_offset < embed_dim + + # Load from extra embeddings + # extra_embeddings shape: (num_loras, num_extra_tokens, embed_dim) + extra_emb_ptr = ( + extra_embeddings + + w_index * extra_emb_stride_0 + + extra_token_id * extra_emb_stride_1 + + embed_offset * extra_emb_stride_2 + ) + emb_values = tl.load(extra_emb_ptr, mask=embed_mask, other=0.0) + + # Write to output (overwrite the position) + output_ptr = ( + output + + (seg_start + token_idx) * output_stride_0 + + embed_offset * output_stride_1 + ) + tl.store(output_ptr, emb_values, mask=embed_mask) + + +def embedding_extra_tokens_fwd( + input_ids: torch.Tensor, + output: torch.Tensor, # Will be modified in-place + extra_embeddings: torch.Tensor, + batch_info: LoRABatchInfo, + vocab_size: int, +) -> torch.Tensor: + """ + Forward pass for extra token embedding lookup (in-place operation). + + Args: + input_ids: (s,) token IDs + output: (s, embed_dim) output tensor to be modified in-place + extra_embeddings: (num_loras, num_extra_tokens, embed_dim) extra token embeddings + batch_info: LoRABatchInfo containing batch information + vocab_size: base vocabulary size + + Returns: + output: (s, embed_dim) modified output tensor + """ + assert input_ids.is_contiguous() + assert output.is_contiguous() + assert extra_embeddings.is_contiguous() + assert len(input_ids.shape) == 1 + assert len(output.shape) == 2 + assert len(extra_embeddings.shape) == 3 + + S = input_ids.shape[0] + embed_dim = output.shape[1] + num_loras = extra_embeddings.shape[0] + + # Block size for embedding dimension + BLOCK_EMBED = 128 + + extra_emb_stride = ( + extra_embeddings.stride(0), + extra_embeddings.stride(1), + extra_embeddings.stride(2), + ) + + # Grid: one program per token in each batch segment + grid = ( + batch_info.max_len, + batch_info.bs, + ) + + _embedding_extra_tokens_kernel[grid]( + input_ids, + output, + extra_embeddings, + vocab_size, + embed_dim, + num_loras, + output.stride(0), + output.stride(1), + extra_emb_stride[0], + extra_emb_stride[1], + extra_emb_stride[2], + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + BLOCK_EMBED, + ) + + return output +############################# +############################# +############################# \ No newline at end of file diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 6a73755bf263..b108f6114a36 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -435,13 +435,17 @@ def forward_generation_raw( # model = base_model # PR version + # current_tokenizer = tokenizer if lora_paths is not None and lora_paths[i] is not None: from peft import PeftConfig, PeftModel + from sglang.srt.lora.lora_config import LoRAConfig peft_config = PeftConfig.from_pretrained(lora_paths[i]) - if "embed_tokens" in peft_config.target_modules: - tok = get_tokenizer(lora_paths[i]) - base_model.resize_token_embeddings(len(tok)) + lora_config = LoRAConfig(lora_paths[i]) + if "embed_tokens" in peft_config.target_modules and lora_config.lora_added_tokens_size > 0: + new_tokenizer = get_tokenizer(lora_paths[i]) + base_model.resize_token_embeddings(len(new_tokenizer)) + tokenizer = new_tokenizer model = PeftModel.from_pretrained( base_model, @@ -473,6 +477,7 @@ def forward_generation_raw( text = tokenizer.decode( outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True ) + # Check if the text is empty or only whitespace. if not text.strip(): raise ValueError( From 90415211eee8a28a316de262583d4d33fa615d10 Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Thu, 4 Dec 2025 23:48:29 +0000 Subject: [PATCH 11/19] cleaned code ([to-do] 1. TP support in mem_pool.py and layer.py; 2. tokenizer - tokenizer_manager.py) --- python/sglang/srt/layers/logits_processor.py | 17 +- .../srt/lora/backend/chunked_backend.py | 4 +- .../sglang/srt/lora/backend/triton_backend.py | 6 +- python/sglang/srt/lora/layers.py | 292 ++---------------- python/sglang/srt/lora/lora.py | 16 - python/sglang/srt/lora/lora_config.py | 14 +- python/sglang/srt/lora/lora_manager.py | 54 ---- python/sglang/srt/lora/mem_pool.py | 216 +------------ python/sglang/srt/lora/triton_ops/__init__.py | 4 +- .../lora/triton_ops/embedding_extra_tokens.py | 2 +- python/sglang/srt/lora/utils.py | 26 -- python/sglang/test/runners.py | 21 -- 12 files changed, 40 insertions(+), 632 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 6d4ce5dfd504..ad48d6e1f242 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -812,26 +812,11 @@ def _get_logits( ) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) - ############################## - ##########emb lora############ - ############################## + if hasattr(lm_head, 'set_lora') and hasattr(lm_head, 'apply_lora'): # This is a LoRA-wrapped module, use its forward method logits = lm_head(hidden_states) - # logits = torch.matmul( - # hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T - # ) - # print("====") - # print(self.use_fp32_lm_head) - # print("====") - # exit() - - - # if hasattr(lm_head, "weight"): elif hasattr(lm_head, "weight"): - ############################## - ############################## - ############################## if self.use_fp32_lm_head: logits = torch.matmul( hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T diff --git a/python/sglang/srt/lora/backend/chunked_backend.py b/python/sglang/srt/lora/backend/chunked_backend.py index adeb9f435f78..22472d18e7dc 100644 --- a/python/sglang/srt/lora/backend/chunked_backend.py +++ b/python/sglang/srt/lora/backend/chunked_backend.py @@ -8,7 +8,7 @@ #########cuda lora########### ############################# embedding_lora_a_fwd, - embedding_extra_tokens_fwd, + embedding_extra_tokens_modified, ############################# ############################# ############################# @@ -81,7 +81,7 @@ def run_extra_token_embedding( For chunked backend, we use the same triton kernel as triton backend since embedding lookup doesn't benefit from chunking. """ - return embedding_extra_tokens_fwd( + return embedding_extra_tokens_modified( input_ids=input_ids, output=output, extra_embeddings=extra_embeddings, diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index bf90e61bdcd1..07d58417932e 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -10,7 +10,7 @@ #########cuda lora########### ############################# embedding_lora_a_fwd, - embedding_extra_tokens_fwd, + embedding_extra_tokens_modified, ############################# ############################# ############################# @@ -61,12 +61,12 @@ def run_extra_token_embedding( **kwargs, ) -> torch.Tensor: """Run extra token embedding lookup using Triton kernel.""" - return embedding_extra_tokens_fwd( + return embedding_extra_tokens_modified( input_ids=input_ids, output=output, extra_embeddings=extra_embeddings, batch_info=self.batch_info, - vocab_size=self.vocab_size, + vocab_size=vocab_size, ) ############################# ############################# diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 69b6d7805c36..1f483a50b03c 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -46,14 +46,6 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): pass - -############################# -#########emb lora############ -############################# - -######org -###### -###### class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): """ Vocab parallel embedding layer with LoRA support (simplified for TP=1, no extra tokens). @@ -74,25 +66,13 @@ def __init__( def set_lora_info( self, - ############################## - ##########emb lora############ - ############################## new_embeddings_buffer: Optional[torch.Tensor], # For extra tokens - ############################## - ############################## - ############################## embedding_A_buffer: torch.Tensor, embedding_B_buffer: torch.Tensor, ): """Set LoRA buffers for embedding layer.""" self.set_lora = True - ############################## - ##########emb lora############ - ############################## self.new_embeddings_buffer = new_embeddings_buffer - ############################## - ############################## - ############################## self.embedding_A_buffer = embedding_A_buffer # (num_loras, rank, vocab_size) self.embedding_B_buffer = embedding_B_buffer # (num_loras, embed_dim, rank) @@ -106,26 +86,6 @@ def apply_lora( # Efficient embedding lookup for LoRA A (cannot call run_lora_a_sgemm since needing index lookup) lora_a_output = self.run_lora_a_embedding(input_, batch_info) - print("=====") - lora_a_output_noncuda = self.run_lora_a_embedding_no_cuda(input_, batch_info) - - # 比較兩個輸出 - print("🔍 Comparing CUDA vs Non-CUDA LoRA A outputs:") - print(f" Shape - CUDA: {lora_a_output.shape}, Non-CUDA: {lora_a_output_noncuda.shape}") - print(f" Device - CUDA: {lora_a_output.device}, Non-CUDA: {lora_a_output_noncuda.device}") - print(f" Dtype - CUDA: {lora_a_output.dtype}, Non-CUDA: {lora_a_output_noncuda.dtype}") - - # 計算差異 - abs_diff = torch.abs(lora_a_output - lora_a_output_noncuda) - rel_diff = abs_diff / (torch.abs(lora_a_output) + 1e-8) - - print(f"\n📊 Difference Statistics:") - print(f" Max absolute difference: {abs_diff.max().item():.6e}") - print(f" Mean absolute difference: {abs_diff.mean().item():.6e}") - print(f" Max relative difference: {rel_diff.max().item():.6e}") - print(f" Mean relative difference: {rel_diff.mean().item():.6e}") - - print("=====") # Apply LoRA B weights using backend lora_output = self.lora_backend.run_lora_b_sgemm( @@ -136,29 +96,9 @@ def apply_lora( return lora_output - def run_lora_a_embedding_no_cuda( - self, input_: torch.Tensor, batch_info: LoRABatchInfo - ) -> torch.Tensor: - # non-cuda - """ - Apply LoRA A weights using efficient embedding lookup. - Maps tokens to their corresponding LoRA adapters internally. - """ - token_weight_indices = self._get_token_weight_indices(input_, batch_info) - lora_a_output = self._run_lora_a_embedding(input_, token_weight_indices) - - return lora_a_output - - def run_lora_a_embedding( self, input_: torch.Tensor, batch_info: LoRABatchInfo ) -> torch.Tensor: - ############################# - #########cuda lora########### - ############################# - # batch_info = self.lora_backend.batch_info - # if batch_info.use_cuda_graph: - # cuda """ Apply LoRA A weights using efficient embedding lookup with CUDA graph support. Maps tokens to their corresponding LoRA adapters internally. @@ -171,85 +111,33 @@ def run_lora_a_embedding( extra_embeddings=self.new_embeddings_buffer if hasattr(self, 'new_embeddings_buffer') and self.new_embeddings_buffer is not None else None, ) - # else: - # # non-cuda - # """ - # Apply LoRA A weights using efficient embedding lookup. - # Maps tokens to their corresponding LoRA adapters internally. - # """ - # token_weight_indices = self._get_token_weight_indices(input_, batch_info) - # lora_a_output = self._run_lora_a_embedding(input_, token_weight_indices) - ############################# - ############################# - ############################# - return lora_a_output - - def _get_token_weight_indices( - self, input_: torch.Tensor, batch_info: LoRABatchInfo - ) -> torch.Tensor: - # (Step1) Get token-to-lora mapping - token_weight_indices = torch.zeros( - input_.shape[0], dtype=torch.int32, device=input_.device - ) - - current_pos = 0 - for i in range(batch_info.bs): - seg_len = batch_info.seg_lens[i] - weight_idx = batch_info.weight_indices[i] - token_weight_indices[current_pos : current_pos+seg_len] = weight_idx - current_pos += seg_len - - return token_weight_indices + - def _run_lora_a_embedding( - self, input_: torch.Tensor, token_weight_indices: torch.Tensor - ) -> torch.Tensor: - # (Step2) Apply embedding lookup for each LoRA adapter - lora_a_output = torch.zeros( - (input_.shape[0], self.embedding_A_buffer.shape[1]), - dtype=self.embedding_A_buffer.dtype, - device=input_.device, - ) + def extra_token_embedding(self, input_: torch.Tensor, base_output: torch.Tensor) -> torch.Tensor: + """ + Process extra tokens (tokens >= vocab_size) by looking up their embeddings + from the new_embeddings_buffer and replacing them in base_output. - unique_weight_indices = torch.unique(token_weight_indices) + Args: + input_: (s,) token IDs + base_output: (s, embed_dim) base embedding output to be modified in-place + + Returns: + base_output: (s, embed_dim) modified output with extra token embeddings + """ - for idx in unique_weight_indices: - token_mask = token_weight_indices == idx - lora_a_weights = self.embedding_A_buffer[idx] # (rank, vocab_size) - lora_a_output[token_mask] = F.embedding( - input_[token_mask], lora_a_weights.t() - ) + output_base_output = self.lora_backend.run_extra_token_embedding( + input_ids=input_, + output=base_output, + extra_embeddings=self.new_embeddings_buffer, + vocab_size=self.vocab_size, + ) - return lora_a_output + return output_base_output + def forward(self, input_: torch.Tensor): - ############################## - ##########emb lora############ - ############################## - # # Get base embedding output ( do not consider extra tokens) - # base_output = self.base_layer.forward(input_) - - # # Apply LoRA if configured - # if self.set_lora: - # batch_info = self.lora_backend.batch_info - # output = self.apply_lora(base_output, input_, batch_info) - # else: - # output = base_output - - # return output - - ############### - ############### consider both non-extra and extra tokens - ############### - - ############################## - ##########cuda lora########### - ############################## - - # batch_info = self.lora_backend.batch_info - - # if batch_info.use_cuda_graph: """ Forward pass with LoRA support and CUDA graph compatibility. @@ -264,17 +152,9 @@ def forward(self, input_: torch.Tensor): added_tokens_mask = input_ > self.vocab_size - 1 base_output = self.base_layer.forward(input_.masked_fill(added_tokens_mask, 0)) - # is there's extra tokens - if added_tokens_mask.any(): - base_output = self.extra_token_embedding( - input_, added_tokens_mask, batch_info, base_output.clone() - ) + # Extra tokens - It will replace extra token embedding with self.new_embeddings_buffer's emb (Default is 0) + base_output = self.extra_token_embedding(input_, base_output) - # base_output_noncuda = self._extra_token_embedding_no_cuda( - # input_, added_tokens_mask, batch_info, base_output.clone() - # ) - - # Apply LoRA if configured if self.set_lora: # The backend's run_lora_a_embedding now handles both regular @@ -285,78 +165,6 @@ def forward(self, input_: torch.Tensor): return output - def extra_token_embedding( - self, - input_: torch.Tensor, - added_tokens_mask: torch.Tensor, - batch_info: LoRABatchInfo, - base_output: torch.Tensor, - ) -> torch.Tensor: - """ - Apply extra token embeddings using efficient Triton kernel. - This method is CUDA graph compatible. - """ - # Use backend's efficient kernel (CUDA graph compatible) - base_output = self.lora_backend.run_extra_token_embedding( - input_ids=input_, - output=base_output, - extra_embeddings=self.new_embeddings_buffer, - vocab_size=self.vocab_size, - ) - return base_output - - - def _extra_token_embedding_no_cuda( - self, - input_: torch.Tensor, - added_tokens_mask: torch.Tensor, - batch_info: LoRABatchInfo, - base_output: torch.Tensor, - ) -> torch.Tensor: - token_weight_indices = self._get_token_weight_indices(input_, batch_info) - added_weight_indices = token_weight_indices[added_tokens_mask] - unique_added_weight_indices = torch.unique(added_weight_indices) - - for idx in unique_added_weight_indices: - lora_mask = added_weight_indices == idx - added_token_positions = torch.where(added_tokens_mask)[0][lora_mask] - x = input_[added_token_positions] - self.vocab_size - new_embeddings = F.embedding(x, self.new_embeddings_buffer[idx]) - base_output[added_token_positions] = new_embeddings - - return base_output - ############################## - ############################## - ############################## - # else: - # batch_info = self.lora_backend.batch_info - # # Handle added tokens (tokens beyond base vocabulary) - # added_tokens_mask = input_ > self.vocab_size - 1 - # base_output = self.base_layer.forward(input_.masked_fill(added_tokens_mask, 0)) - - # # Process extra tokens if they exist - # if added_tokens_mask.any(): - # token_weight_indices = self._get_token_weight_indices(input_, batch_info) - # added_weight_indices = token_weight_indices[added_tokens_mask] - # unique_added_weight_indices = torch.unique(added_weight_indices) - - # for idx in unique_added_weight_indices: - # lora_mask = added_weight_indices == idx - # added_token_positions = torch.where(added_tokens_mask)[0][lora_mask] - # x = input_[added_token_positions] - self.vocab_size - # new_embeddings = F.embedding(x, self.new_embeddings_buffer[idx]) - # base_output[added_token_positions] = new_embeddings - - # # Apply LoRA if set - # if self.set_lora: - # output = self.apply_lora(base_output, input_, batch_info) - # else: - # output = base_output - - # return output - ############################## - ############################## - ############################## def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): # For TP=1, no slicing needed @@ -422,11 +230,6 @@ def apply_lora( weights=self.lm_head_B_buffer, base_output=base_output, ) - - # lora_output = self.run_lora_b_scatter( - # lora_a_output=lora_a_output, - # base_output=base_output, - # ) return lora_output @@ -454,11 +257,6 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): # For TP>1, would slice along vocab dimension, eed to modify code in: sglang/python/sglang/srt/lora/mem_pool.py return B -############################# -############################# -############################# - - class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__( @@ -737,49 +535,6 @@ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): return B -############################## -##########emb lora############ -############################## -# def get_lora_layer( -# layer: nn.Module, lora_backend: BaseLoRABackend -# ) -> BaseLayerWithLoRA: -# supported_layer_types = { -# # the order matters -# VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, -# QKVParallelLinear: QKVParallelLinearWithLoRA, -# MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, -# ColumnParallelLinear: ColumnParallelLinearWithLoRA, -# RowParallelLinear: RowParallelLinearWithLoRA, -# } -# for src_layer_type, lora_layer_type in supported_layer_types.items(): -# if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck -# ret = lora_layer_type(layer, lora_backend) -# return ret -# raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") - - - -# def get_lora_layer( -# layer: nn.Module, lora_backend: BaseLoRABackend, lora_extra_vocab_size: int = 0 -# ) -> BaseLayerWithLoRA: - -# supported_layer_types = { -# # the order matters - check ParallelLMHead before VocabParallelEmbedding -# # since ParallelLMHead is a subclass of VocabParallelEmbedding -# ParallelLMHead: lambda l, b: ParallelLMHeadWithLoRA(l, b, lora_extra_vocab_size), -# VocabParallelEmbedding: lambda l, b: VocabParallelEmbeddingWithLoRA(l, b, lora_extra_vocab_size), -# QKVParallelLinear: lambda l, b: QKVParallelLinearWithLoRA(l, b), -# MergedColumnParallelLinear: lambda l, b: MergedColumnParallelLinearWithLoRA(l, b), -# ColumnParallelLinear: lambda l, b: ColumnParallelLinearWithLoRA(l, b), -# RowParallelLinear: lambda l, b: RowParallelLinearWithLoRA(l, b), -# } -# for src_layer_type, lora_layer_factory in supported_layer_types.items(): -# if isinstance(layer, src_layer_type): -# ret = lora_layer_factory(layer, lora_backend) -# return ret -# raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") - - def get_lora_layer( layer: nn.Module, lora_backend: BaseLoRABackend @@ -798,6 +553,3 @@ def get_lora_layer( ret = lora_layer_type(layer, lora_backend) return ret raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") -############################## -############################## -############################## diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 72c761aaefcd..f616664b4965 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -71,18 +71,8 @@ def __init__( ] ) - ############################## - ##########emb lora############ - ############################## - # self.embedding_layer = LoRALayer(config, base_hf_config) - # self.lm_head_layer = LoRALayer(config, base_hf_config) - # self.weights: Dict[str, torch.Tensor] = {} - # self.extra_embeddings: Dict[str, torch.Tensor] = {} self.embedding_layers: Dict[str, torch.Tensor] = {} self.added_tokens_embeddings: Dict[str, torch.Tensor] = {} - ############################## - ############################## - ############################## # initialize the LoRA weights to cpu def initialize_weights(self): @@ -97,9 +87,6 @@ def initialize_weights(self): layer_id = get_layer_id(name) if layer_id is not None: self.layers[layer_id].weights[name] = loaded_weight.cpu() - ############################## - ##########emb lora############ - ############################## elif "embed_tokens" in name or "lm_head" in name: # self.embedding_layers.weights[name] = loaded_weight.cpu() self.embedding_layers[name] = loaded_weight.cpu() @@ -110,9 +97,6 @@ def initialize_weights(self): f"LoRA adapter {self.uid} has extra_vocab_size {self.config.extra_vocab_size} specified in the config, " f"but the loaded weight has {loaded_weight.shape[0]} extra vocab size" ) - ############################## - ############################## - ############################## # normalize kv_proj and gate_up_proj for layer in self.layers: diff --git a/python/sglang/srt/lora/lora_config.py b/python/sglang/srt/lora/lora_config.py index 4f37c2907054..d28e630beae4 100644 --- a/python/sglang/srt/lora/lora_config.py +++ b/python/sglang/srt/lora/lora_config.py @@ -30,16 +30,10 @@ def __init__( self.r = self.hf_config["r"] self.lora_alpha = self.hf_config["lora_alpha"] - ############################## - ##########emb lora############ - ############################## self.added_tokens_config = self.get_added_tokens_config() self.lora_added_tokens_size = ( len(self.added_tokens_config) if self.added_tokens_config is not None else 0 ) - ############################## - ############################## - ############################## def get_lora_config(self, dummy=False): if dummy: @@ -53,9 +47,6 @@ def get_lora_config(self, dummy=False): with open(os.path.join(weights_dir, config_name), "r") as f: return json.load(f) - ############################## - ##########emb lora############ - ############################## def get_added_tokens_config(self): """Load added tokens from the LoRA adapter if the file exists.""" # Determine the weights directory @@ -80,7 +71,4 @@ def get_added_tokens_config(self): import logging logger = logging.getLogger(__name__) logger.warning(f"Failed to parse added_tokens.json: {e}") - return None - ############################## - ############################## - ############################## \ No newline at end of file + return None \ No newline at end of file diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 6b96db701b65..35b7cbe6dbf7 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -67,13 +67,6 @@ def __init__( target_modules: Optional[Iterable[str]] = None, lora_paths: Optional[List[LoRARef]] = None, server_args: Optional[ServerArgs] = None, - ############################## - ##########emb lora############ - ############################## - # lora_extra_vocab_size: int = 0, - ############################## - ############################## - ############################## ): self.base_model: torch.nn.Module = base_model self.base_hf_config: AutoConfig = base_hf_config @@ -83,16 +76,7 @@ def __init__( self.device: torch.device = next(self.base_model.parameters()).device self.tp_size: int = tp_size self.tp_rank: int = tp_rank - ############################## - ##########emb lora############ - ############################## - # self.lora_extra_vocab_size: int = lora_extra_vocab_size - # Will infer self.lora_extra_vocab_size in the later init_lora_shapes() if it find the value == None - # self.lora_added_vocab_size: Optional[int] = None self.lora_added_tokens_size: Optional[int] = None - ############################## - ############################## - ############################## # Store eviction policy from server args self.eviction_policy = server_args.lora_eviction_policy @@ -266,14 +250,8 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): lora_adapters=self.loras, lora_modules=self.lora_modules, lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation. - ############################## - ##########emb lora############ - ############################## lora_embed_tokens_module=self.embed_tokens_module, #merge into embedding or lora module lora_lm_head_module=self.lm_head_module, #merge into embedding or lora module - ############################## - ############################## - ############################## ) # set up batch info shared by all lora modules @@ -326,9 +304,6 @@ def update_lora_info(self): ), ) - ############################## - ##########emb lora############ - ############################## # Update embedding layer if present - gotta merge (refer to PR codebase) if self.embed_tokens_module is not None: self.embed_tokens_module.set_lora_info( @@ -343,9 +318,6 @@ def update_lora_info(self): self.memory_pool.get_embedding_tensor("lm_head", LoRAType.LORA_A), self.memory_pool.get_embedding_tensor("lm_head", LoRAType.LORA_B), ) - ############################## - ############################## - ############################## def init_state( self, @@ -441,9 +413,6 @@ def init_lora_shapes( default=0, ) - ############################# - #########emb lora############ - ############################# # Auto-infer self.lora_added_vocab_size from loaded LoRA configs # This happens automatically without requiring user input # if self.lora_added_vocab_size is None: @@ -457,9 +426,6 @@ def init_lora_shapes( f"self.lora_added_tokens_size={inferred_extra_vocab_size} from LoRA adapters." ) self.lora_added_tokens_size = inferred_extra_vocab_size - ############################# - ############################# - ############################# def load_lora_weights(self, lora_ref: LoRARef): """ @@ -487,14 +453,7 @@ def init_memory_pool(self): target_modules=self.target_modules, base_model=self.base_model, eviction_policy=self.eviction_policy, - ############################## - ##########emb lora############ - ############################## - # lora_added_vocab_size=self.lora_added_vocab_size, # check whether read from the config lora_added_tokens_size = self.lora_added_tokens_size - ############################## - ############################## - ############################## ) def set_lora_module(self, module_name, module): @@ -508,15 +467,8 @@ def init_lora_modules(self): {} for _ in range(self.base_hf_config.num_hidden_layers) ] - ############################## - ##########emb lora############ - ############################## self.embed_tokens_module: Optional[BaseLayerWithLoRA] = None self.lm_head_module: Optional[BaseLayerWithLoRA] = None - ############################## - ############################## - ############################## - for module_name, module in self.base_model.named_modules(): # TODO (lifuhuang): in the future, we should consider generalizing the @@ -529,9 +481,6 @@ def init_lora_modules(self): ) and not self.base_model.should_apply_lora(module_name): continue - ############################## - ##########emb lora############ - ############################## # Handle embed_tokens if "embed_tokens" in module_name and "embed_tokens" in self.target_modules: if isinstance(module, VocabParallelEmbedding) and not isinstance(module, BaseLayerWithLoRA): @@ -545,9 +494,6 @@ def init_lora_modules(self): lora_module = self.set_lora_module(module_name, module) self.lm_head_module = lora_module continue - ############################## - ############################## - ############################## # The module should be converted if it is included in target_names if module_name.split(".")[-1] in self.target_modules: diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 54f3ce288c12..18e0ae692184 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -11,13 +11,7 @@ from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.utils import ( ROW_PARALLELISM_LINEAR_LORA_NAMES, - ############################## - ##########emb lora############ - ############################## EMBEDDING_NAMES, - ############################## - ############################## - ############################## LoRAType, get_hidden_dim, get_normalized_target_modules, @@ -63,14 +57,7 @@ def __init__( target_modules: Set[str], base_model: torch.nn.Module, eviction_policy: str, - ############################## - ##########emb lora############ - ############################## - # lora_added_vocab_size: int, #can be remove? lora_added_tokens_size: int - ############################## - ############################## - ############################## ): self.base_hf_config: AutoConfig = base_hf_config self.num_layer: int = base_hf_config.num_hidden_layers @@ -78,15 +65,7 @@ def __init__( self.dtype: torch.dtype = dtype self.tp_size: int = tp_size self.tp_rank: int = tp_rank - ############################## - ##########emb lora############ - ############################## - # self.max_extra_vocab_size: int = lora_added_vocab_size self.lora_added_tokens_size: int = lora_added_tokens_size - # self.extra_vocab_size: int = base_hf_config.extra_vocab_size - ############################## - ############################## - ############################## self.max_lora_rank: int = max_lora_rank self.target_modules: Set[str] = target_modules @@ -101,28 +80,14 @@ def __init__( self.A_buffer: Dict[str, List[torch.Tensor]] = {} self.B_buffer: Dict[str, List[torch.Tensor]] = {} - ############################## - ##########emb lora############ - ############################## - # NEW: Buffers for embedding and lm_head (not per-layer) - # self.embedding_A_buffer: Optional[torch.Tensor] = None - # self.embedding_B_buffer: Optional[torch.Tensor] = None - # self.lm_head_A_buffer: Optional[torch.Tensor] = None - # self.lm_head_B_buffer: Optional[torch.Tensor] = None - # self.new_embeddings_buffer: Dict[str, torch.Tensor] = {} self.embedding_A_buffer: Dict[str, torch.Tensor] = {} self.embedding_B_buffer: Dict[str, torch.Tensor] = {} self.lm_head_A_buffer: Dict[str, torch.Tensor] = {} self.lm_head_B_buffer: Dict[str, torch.Tensor] = {} - # self.embedding_A_buffer: Dict[str, List[torch.Tensor]] = {} - # self.embedding_B_buffer: Dict[str, List[torch.Tensor]] = {} self.new_embeddings_buffer: Dict[str, torch.Tensor] = {} self.embedding_dim: int = self.base_hf_config.hidden_size - ############################## - ############################## - ############################## # Lora uid -> buffer idx in memory pool self.uid_to_buffer_id: Dict[Optional[str], int] = {} @@ -147,14 +112,8 @@ def _can_support(config: LoRAConfig) -> bool: """ if config.r > self.max_lora_rank: return False - ############################## - ##########emb lora############ - ############################## if config.lora_added_tokens_size > self.lora_added_tokens_size: - return False # can be remove? - ############################## - ############################## - ############################## + return False target_module_names = get_normalized_target_modules(config.target_modules) return target_module_names.issubset(self.target_modules) @@ -248,9 +207,6 @@ def init_buffer( target_modules: Set[str], get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]], ): - ############################## - ##########emb lora############ - ############################## target_modules = target_modules - set(EMBEDDING_NAMES) for module_name in target_modules: buffer[module_name] = [ @@ -266,14 +222,8 @@ def init_buffer( ) for idx in range(self.num_layer) ] - ############################## - ############################## - ############################## - ############################## - ##########emb lora############ - ############################## def init_embedding_buffer( buffer: Dict[str, torch.Tensor], target_modules: Set[str], @@ -329,9 +279,6 @@ def init_embedding_buffer( self.get_embedding_lora_B_shape, ) - ############################## - ############################## - ############################## init_buffer( self.A_buffer, @@ -344,50 +291,7 @@ def init_embedding_buffer( self.target_modules, self.get_lora_B_shape, ) - - ############################## - ##########emb lora############ - ############################## - # # Initialize embedding buffers if embed_tokens is in target_modules - # if "embed_tokens" in self.target_modules: - # vocab_size = self.base_hf_config.vocab_size - # hidden_size = self.base_hf_config.hidden_size - - # # embedding_A: (max_loras_per_batch, max_rank, vocab_size) - # self.embedding_A_buffer = torch.empty( - # (self.max_loras_per_batch, self.max_lora_rank, vocab_size), - # dtype=self.dtype, - # device=device, - # ) - - # # embedding_B: (max_loras_per_batch, hidden_size, max_rank) - # self.embedding_B_buffer = torch.empty( - # (self.max_loras_per_batch, hidden_size, self.max_lora_rank), - # dtype=self.dtype, - # device=device, - # ) - - # # Initialize lm_head buffers if lm_head is in target_modules - # if "lm_head" in self.target_modules: - # vocab_size = self.base_hf_config.vocab_size - # hidden_size = self.base_hf_config.hidden_size - - # # lm_head_A: (max_loras_per_batch, max_rank, hidden_size) - # self.lm_head_A_buffer = torch.empty( - # (self.max_loras_per_batch, self.max_lora_rank, hidden_size), - # dtype=self.dtype, - # device=device, - # ) - - # # lm_head_B: (max_loras_per_batch, vocab_size, max_rank) - # self.lm_head_B_buffer = torch.empty( - # (self.max_loras_per_batch, vocab_size, self.max_lora_rank), - # dtype=self.dtype, - # device=device, - # ) - ############################## - ############################## - ############################## + def prepare_lora_batch( self, @@ -395,15 +299,8 @@ def prepare_lora_batch( lora_adapters: Dict[str, LoRAAdapter], lora_modules: List[Dict[str, BaseLayerWithLoRA]], lora_refs: Dict[str, LoRARef], - ############################## - ##########emb lora############ - ############################## - # lora_embeddings_modules: Dict[str, BaseLayerWithLoRA], # NEW parameter - lora_embed_tokens_module: Dict[str, BaseLayerWithLoRA], # NEW parameter - lora_lm_head_module: Dict[str, BaseLayerWithLoRA], # NEW parameter - ############################## - ############################## - ############################## + lora_embed_tokens_module: Dict[str, BaseLayerWithLoRA], + lora_lm_head_module: Dict[str, BaseLayerWithLoRA], ): def get_available_buffer_slot(): # 1. Prioritize empty slots @@ -455,18 +352,9 @@ def get_available_buffer_slot(): if uid not in self.uid_to_buffer_id: buffer_id = get_available_buffer_slot() lora_adapter = lora_adapters.get(uid, None) - ############################## - ##########emb lora############ - ############################## - # self.load_lora_weight_to_buffer( - # uid, buffer_id, lora_adapter, lora_modules - # ) self.load_lora_weight_to_buffer( uid, buffer_id, lora_adapter, lora_modules, lora_embed_tokens_module, lora_lm_head_module ) - ############################## - ############################## - ############################## self.uid_to_buffer_id[uid] = buffer_id self.buffer_id_to_uid[buffer_id] = uid @@ -476,16 +364,8 @@ def load_lora_weight_to_buffer( buffer_id: int, lora_adapter: LoRAAdapter, lora_modules: List[Dict[str, BaseLayerWithLoRA]], - ############################## - ##########emb lora############ - ############################## - # lora_embeddings_modules: List[Dict[str, BaseLayerWithLoRA]], - # I can combine the below two - lora_embed_tokens_module: Dict[str, BaseLayerWithLoRA], # NEW parameter - lora_lm_head_module: Dict[str, BaseLayerWithLoRA], # NEW parameter - ############################## - ############################## - ############################## + lora_embed_tokens_module: Dict[str, BaseLayerWithLoRA], + lora_lm_head_module: Dict[str, BaseLayerWithLoRA], ): def load_lora_weight_tensor( buffer_view: torch.Tensor, weight: Optional[torch.Tensor] @@ -504,20 +384,12 @@ def load_lora_weight_tensor( for i in range(self.num_layer): for k in self.A_buffer.keys(): self.A_buffer[k][i][buffer_id] = 0 - ############################## - ##########emb lora############ - ############################## - # for k in self.embedding_A_buffer.keys(): - # self.embedding_A_buffer[k][buffer_id] = 0 for k in self.embedding_A_buffer.keys(): self.embedding_A_buffer[k][buffer_id] = 0 for k in self.lm_head_A_buffer.keys(): self.lm_head_A_buffer[k][buffer_id] = 0 - ############################## - ############################## - ############################## return assert lora_adapter is not None @@ -566,59 +438,7 @@ def load_lora_weight_tensor( buffer_view = target_buffer[buffer_id, :, :lora_rank] load_lora_weight_tensor(buffer_view, weights) - ############################## - ##########emb lora############ - ############################## - - # Handle embedding weights (not per-layer) - # if "embed_tokens" in self.target_modules: - # embedding_A = None - # embedding_B = None - - # # Look for embedding weights in layer 0 (embeddings are usually stored there) - # # if lora_adapter.layers: - # if hasattr(lora_adapter, 'embedding_layer'): - # # layer_weights = lora_adapter.layers[0].weights - # layer_weights = lora_adapter.embedding_layer.weights - # for name, weights in layer_weights.items(): - # if "embed_tokens" in name or "model.embed_tokens" in name: - # if "lora_A" in name: - # embedding_A = weights - # elif "lora_B" in name: - # embedding_B = weights - - # # Load into buffers - # buffer_view = self.embedding_A_buffer[buffer_id, :lora_rank, :] - # load_lora_weight_tensor(buffer_view, embedding_A) - - # buffer_view = self.embedding_B_buffer[buffer_id, :, :lora_rank] - # load_lora_weight_tensor(buffer_view, embedding_B) - - # # Handle lm_head weights (not per-layer) - # if "lm_head" in self.target_modules: - # lm_head_A = None - # lm_head_B = None - - # # Look for lm_head weights - # # if lora_adapter.layers: - # if hasattr(lora_adapter, 'lm_head_layer'): - # # layer_weights = lora_adapter.layers[0].weights - # layer_weights = lora_adapter.lm_head_layer.weights - # for name, weights in layer_weights.items(): - # if "lm_head" in name: - # if "lora_A" in name: - # lm_head_A = weights - # elif "lora_B" in name: - # lm_head_B = weights - - # # Load into buffers - # buffer_view = self.lm_head_A_buffer[buffer_id, :lora_rank, :] - # load_lora_weight_tensor(buffer_view, lm_head_A) - - # buffer_view = self.lm_head_B_buffer[buffer_id, :, :lora_rank] - # load_lora_weight_tensor(buffer_view, lm_head_B) - - # embed_token (and extra_token emb) and lm_head layers + if lora_adapter.embedding_layers: org_vocab_size = self.base_hf_config.vocab_size @@ -636,14 +456,11 @@ def load_lora_weight_tensor( #load vocab_emb and lm_head for name, weights in lora_adapter.embedding_layers.items(): target_module = get_target_module_name(name, self.target_modules) - # if "lora_embedding_A" in name: - # if "lora_embedding_A" in name or ("lora_A" in name and target_module == "embed_tokens"): if target_module == "embed_tokens" and "embed_tokens" in name and ("lora_embedding_A" in name or "lora_A" in name): buffer_view = self.embedding_A_buffer[target_module][ buffer_id, :lora_rank, :(org_vocab_size+lora_added_tokens_size) ] load_lora_weight_tensor(buffer_view, weights) - # elif "lora_embedding_B" in name: elif target_module == "embed_tokens" and "embed_tokens" in name and ("lora_embedding_B" in name or "lora_B" in name): lora_b_weights = weights #[to-do] support TP @@ -659,19 +476,12 @@ def load_lora_weight_tensor( ] load_lora_weight_tensor(buffer_view, lora_b_weights) - - # name: base_model.model.lm_head.lora_A.weight - # self.target_modules: {'qkv_proj', 'embed_tokens', 'gate_up_proj', 'o_proj', 'lm_head', 'down_proj'} - # target_module: lm_head - # if "lora_lm_head_A" in name or ("lora_A" in name and target_module == "lm_head"): elif target_module == "lm_head" and "lm_head" in name and ("lora_embedding_A" in name or "lora_A" in name): buffer_view = self.lm_head_A_buffer[target_module][ # buffer_id, :, :lora_rank buffer_id, :lora_rank, : ] load_lora_weight_tensor(buffer_view, weights) - # elif "lora_embedding_B" in name: - # elif "lora_lm_head_B" in name or ("lora_B" in name and target_module == "lm_head"): elif target_module == "lm_head" and "lm_head" in name and ("lora_embedding_B" in name or "lora_B" in name): lora_b_weights = weights #[to-do] support TP @@ -688,13 +498,7 @@ def load_lora_weight_tensor( ] load_lora_weight_tensor(buffer_view, lora_b_weights) - ############################## - ############################## - ############################## - - ############################## - ##########emb lora############ - ############################## + def get_embedding_tensor( self, target_module: str, lora_type: LoRAType ) -> Optional[torch.Tensor]: @@ -726,10 +530,6 @@ def get_embedding_tensor( f"Invalid target_module '{target_module}'. " f"Expected 'embed_tokens' or 'lm_head'." ) - ############################## - ############################## - ############################## - def get_tensor( self, target_module: str, layer_id: int, lora_type: LoRAType diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py index b2e74f1f99d4..5927bee9074a 100644 --- a/python/sglang/srt/lora/triton_ops/__init__.py +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -8,7 +8,7 @@ #########cuda lora########### ############################# from .embedding_lora_a import embedding_lora_a_fwd -from .embedding_extra_tokens import embedding_extra_tokens_fwd +from .embedding_extra_tokens import embedding_extra_tokens_modified ############################# ############################# ############################# @@ -24,7 +24,7 @@ #########cuda lora########### ############################# "embedding_lora_a_fwd", - "embedding_extra_tokens_fwd", + "embedding_extra_tokens_modified", ############################# ############################# ############################# diff --git a/python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py b/python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py index c0023b659272..878207d8cf5d 100644 --- a/python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py +++ b/python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py @@ -84,7 +84,7 @@ def _embedding_extra_tokens_kernel( tl.store(output_ptr, emb_values, mask=embed_mask) -def embedding_extra_tokens_fwd( +def embedding_extra_tokens_modified( input_ids: torch.Tensor, output: torch.Tensor, # Will be modified in-place extra_embeddings: torch.Tensor, diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index cd9b0623dabd..1a091ab4e6c0 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -78,28 +78,14 @@ def get_hidden_dim( return config.hidden_size, config.intermediate_size * 2 elif module_name == "down_proj": return config.intermediate_size, config.hidden_size - - ############################## - ##########emb lora############ - ############################## - #Handle embed_tokens - # elif "embed_tokens" in module_name: - # elif "embed_tokens" in module_name: elif module_name == "embed_tokens": # For embedding: input is vocab_size (as embedding lookup), output is hidden_size # if contain extra tokens will be added; otherwise is 0. return config.vocab_size + lora_added_vocab_size, config.hidden_size - - #Handle lm_head - # elif "lm_head" in module_name: - # elif "lm_head" in module_name: elif module_name == "lm_head": # For lm_head: input is hidden_size, output is vocab_size # if contain extra tokens will be added; otherwise is 0. return config.hidden_size, config.vocab_size + lora_added_vocab_size - ############################## - ############################## - ############################## else: raise NotImplementedError() @@ -117,18 +103,12 @@ def get_normalized_target_modules( "v_proj": "qkv_proj", "gate_proj": "gate_up_proj", "up_proj": "gate_up_proj", - ############################## - ##########emb lora############ - ############################## "embed_tokens": "embed_tokens", "vocab_emb": "embed_tokens", "embeddings": "embed_tokens", "word_embeddings": "embed_tokens", "lm_head": "lm_head", "output": "lm_head", - ############################## - ############################## - ############################## } result = set() @@ -164,11 +144,5 @@ def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> s f"Cannot find target module name for {full_module_name} in {target_modules}" ) -############################## -##########emb lora############ -############################## EMBEDDING_NAMES = ["embed_tokens", "lm_head"] -############################## -############################## -############################## ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"] \ No newline at end of file diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index b108f6114a36..653eb5cfcacf 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -418,24 +418,6 @@ def forward_generation_raw( else: input_ids = torch.tensor([p], device="cuda") - ############################## - ##########emb lora############ - ############################## - # # (original) - # if lora_paths is not None and lora_paths[i] is not None: - # from peft import PeftModel - - # model = PeftModel.from_pretrained( - # base_model, - # lora_paths[i], - # torch_dtype=torch_dtype, - # is_trainable=False, - # ) - # else: - # model = base_model - - # PR version - # current_tokenizer = tokenizer if lora_paths is not None and lora_paths[i] is not None: from peft import PeftConfig, PeftModel from sglang.srt.lora.lora_config import LoRAConfig @@ -455,9 +437,6 @@ def forward_generation_raw( ) else: model = base_model - ############################## - ############################## - ############################## if patch_model_do_sample_false: model.generation_config.do_sample = False outputs = model.generate( From 8119daf67916d041b16b5b586da330e02764f2f6 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Fri, 5 Dec 2025 01:23:47 +0000 Subject: [PATCH 12/19] update --- python/sglang/srt/layers/logits_processor.py | 3 +- .../sglang/srt/lora/backend/base_backend.py | 11 +-- .../srt/lora/backend/chunked_backend.py | 18 ++-- .../sglang/srt/lora/backend/triton_backend.py | 23 ++--- python/sglang/srt/lora/layers.py | 70 ++++++++-------- python/sglang/srt/lora/lora.py | 2 +- python/sglang/srt/lora/lora_config.py | 9 +- python/sglang/srt/lora/lora_manager.py | 46 ++++++---- python/sglang/srt/lora/mem_pool.py | 84 ++++++++++++------- python/sglang/srt/lora/triton_ops/__init__.py | 14 ++-- .../lora/triton_ops/embedding_extra_tokens.py | 41 ++++----- .../srt/lora/triton_ops/embedding_lora_a.py | 58 ++++++------- python/sglang/srt/lora/utils.py | 9 +- python/sglang/test/runners.py | 8 +- 14 files changed, 218 insertions(+), 178 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index bb91454cf640..957cb024831d 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -829,8 +829,7 @@ def _get_logits( ) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) - - if hasattr(lm_head, 'set_lora') and hasattr(lm_head, 'apply_lora'): + if hasattr(lm_head, "set_lora") and hasattr(lm_head, "apply_lora"): # This is a LoRA-wrapped module, use its forward method logits = lm_head(hidden_states) elif hasattr(lm_head, "weight"): diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index 9c72c9d755f5..d041b988c880 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -18,7 +18,7 @@ class BaseLoRABackend: def __init__(self, max_loras_per_batch: int, device: torch.device): self.max_loras_per_batch = max_loras_per_batch self.device = device - + ############################# #########cuda lora########### ############################# @@ -32,14 +32,14 @@ def run_lora_a_embedding( **kwargs, ) -> torch.Tensor: """Run LoRA A embedding lookup with CUDA graph support. - + Args: input_ids: token IDs with shape (s,), where s is the sum of all sequence lengths weights: LoRA A embedding weights with shape (num_loras, rank, vocab_size) vocab_size: base vocabulary size (tokens >= vocab_size are extra tokens) extra_embeddings: extra token embeddings with shape (num_loras, num_extra_tokens, rank) Only needed if there are added tokens beyond base vocabulary. - + Returns: result with shape (s, rank) """ @@ -56,17 +56,18 @@ def run_extra_token_embedding( ) -> torch.Tensor: """ Apply extra token embeddings to output in-place. - + Args: input_ids: (s,) token IDs output: (s, embed_dim) output tensor to be modified extra_embeddings: (num_loras, num_extra_tokens, embed_dim) extra embeddings vocab_size: base vocabulary size - + Returns: output: modified output tensor """ raise NotImplementedError + ############################# ############################# ############################# diff --git a/python/sglang/srt/lora/backend/chunked_backend.py b/python/sglang/srt/lora/backend/chunked_backend.py index 22472d18e7dc..1d0f702806d3 100644 --- a/python/sglang/srt/lora/backend/chunked_backend.py +++ b/python/sglang/srt/lora/backend/chunked_backend.py @@ -1,17 +1,11 @@ import torch from sglang.srt.lora.backend.base_backend import BaseLoRABackend -from sglang.srt.lora.triton_ops import ( +from sglang.srt.lora.triton_ops import ( # ############################; ########cuda lora########### chunked_sgmv_lora_expand_forward, chunked_sgmv_lora_shrink_forward, - ############################# - #########cuda lora########### - ############################# - embedding_lora_a_fwd, embedding_extra_tokens_modified, - ############################# - ############################# - ############################# + embedding_lora_a_fwd, ) from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -41,7 +35,6 @@ def __init__( super().__init__(max_loras_per_batch, device) self.max_chunk_size = server_args.max_lora_chunk_size - ############################## ##########cuda lora########### ############################## @@ -55,7 +48,7 @@ def run_lora_a_embedding( **kwargs, ) -> torch.Tensor: """Run LoRA A embedding lookup. - + For chunked backend, we use the same triton kernel as triton backend since embedding lookup doesn't benefit much from chunking. """ @@ -77,7 +70,7 @@ def run_extra_token_embedding( **kwargs, ) -> torch.Tensor: """Run extra token embedding lookup. - + For chunked backend, we use the same triton kernel as triton backend since embedding lookup doesn't benefit from chunking. """ @@ -88,10 +81,11 @@ def run_extra_token_embedding( batch_info=self.batch_info, vocab_size=vocab_size, ) + ############################## ############################## ############################## - + def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs ) -> torch.Tensor: diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index 07d58417932e..ea9ccc7efd9b 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -1,19 +1,13 @@ import torch from sglang.srt.lora.backend.base_backend import BaseLoRABackend -from sglang.srt.lora.triton_ops import ( +from sglang.srt.lora.triton_ops import ( # ############################; ########cuda lora########### + embedding_extra_tokens_modified, + embedding_lora_a_fwd, gate_up_lora_b_fwd, qkv_lora_b_fwd, sgemm_lora_a_fwd, sgemm_lora_b_fwd, - ############################# - #########cuda lora########### - ############################# - embedding_lora_a_fwd, - embedding_extra_tokens_modified, - ############################# - ############################# - ############################# ) from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -34,13 +28,13 @@ def __init__( #########cuda lora########### ############################# def run_lora_a_embedding( - self, + self, input_ids: torch.Tensor, weights: torch.Tensor, vocab_size: int, extra_embeddings: torch.Tensor = None, - *args, - **kwargs + *args, + **kwargs, ) -> torch.Tensor: """Run LoRA A embedding lookup using Triton kernel.""" return embedding_lora_a_fwd( @@ -50,7 +44,7 @@ def run_lora_a_embedding( vocab_size=vocab_size, extra_embeddings=extra_embeddings, ) - + def run_extra_token_embedding( self, input_ids: torch.Tensor, @@ -68,10 +62,11 @@ def run_extra_token_embedding( batch_info=self.batch_info, vocab_size=vocab_size, ) + ############################# ############################# ############################# - + def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs ) -> torch.Tensor: diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 1f483a50b03c..45e320393a2a 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -1,7 +1,8 @@ +from typing import Optional + import torch -from torch import nn import torch.nn.functional as F -from typing import Optional +from torch import nn from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -15,7 +16,10 @@ QKVParallelLinear, RowParallelLinear, ) -from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.utils import LoRABatchInfo @@ -49,7 +53,7 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): """ Vocab parallel embedding layer with LoRA support (simplified for TP=1, no extra tokens). - + For embedding layers: output = base_embedding(x) + lora_B @ lora_A[x] where lora_A[x] is direct embedding lookup from lora_A weights. """ @@ -66,7 +70,7 @@ def __init__( def set_lora_info( self, - new_embeddings_buffer: Optional[torch.Tensor], # For extra tokens + new_embeddings_buffer: Optional[torch.Tensor], # For extra tokens embedding_A_buffer: torch.Tensor, embedding_B_buffer: torch.Tensor, ): @@ -86,7 +90,7 @@ def apply_lora( # Efficient embedding lookup for LoRA A (cannot call run_lora_a_sgemm since needing index lookup) lora_a_output = self.run_lora_a_embedding(input_, batch_info) - + # Apply LoRA B weights using backend lora_output = self.lora_backend.run_lora_b_sgemm( x=lora_a_output, @@ -95,38 +99,43 @@ def apply_lora( ) return lora_output - def run_lora_a_embedding( self, input_: torch.Tensor, batch_info: LoRABatchInfo ) -> torch.Tensor: """ Apply LoRA A weights using efficient embedding lookup with CUDA graph support. Maps tokens to their corresponding LoRA adapters internally. - It also includes added/extra token processing. + It also includes added/extra token processing. """ lora_a_output = self.lora_backend.run_lora_a_embedding( input_ids=input_, weights=self.embedding_A_buffer, vocab_size=self.vocab_size, - extra_embeddings=self.new_embeddings_buffer if hasattr(self, 'new_embeddings_buffer') and self.new_embeddings_buffer is not None else None, + extra_embeddings=( + self.new_embeddings_buffer + if hasattr(self, "new_embeddings_buffer") + and self.new_embeddings_buffer is not None + else None + ), ) - return lora_a_output + return lora_a_output - - def extra_token_embedding(self, input_: torch.Tensor, base_output: torch.Tensor) -> torch.Tensor: + def extra_token_embedding( + self, input_: torch.Tensor, base_output: torch.Tensor + ) -> torch.Tensor: """ Process extra tokens (tokens >= vocab_size) by looking up their embeddings from the new_embeddings_buffer and replacing them in base_output. - + Args: input_: (s,) token IDs base_output: (s, embed_dim) base embedding output to be modified in-place - + Returns: base_output: (s, embed_dim) modified output with extra token embeddings """ - + output_base_output = self.lora_backend.run_extra_token_embedding( input_ids=input_, output=base_output, @@ -134,25 +143,24 @@ def extra_token_embedding(self, input_: torch.Tensor, base_output: torch.Tensor) vocab_size=self.vocab_size, ) - return output_base_output - + return output_base_output def forward(self, input_: torch.Tensor): """ Forward pass with LoRA support and CUDA graph compatibility. - + Extra tokens (tokens >= vocab_size) are now handled efficiently in the backend's run_lora_a_embedding method. """ batch_info = self.lora_backend.batch_info - + # Get base embedding output # For tokens >= vocab_size, base_layer will clamp or handle them # We mask them to 0 to avoid out-of-bounds access added_tokens_mask = input_ > self.vocab_size - 1 base_output = self.base_layer.forward(input_.masked_fill(added_tokens_mask, 0)) - # Extra tokens - It will replace extra token embedding with self.new_embeddings_buffer's emb (Default is 0) + # Extra tokens - It will replace extra token embedding with self.new_embeddings_buffer's emb (Default is 0) base_output = self.extra_token_embedding(input_, base_output) # Apply LoRA if configured @@ -165,7 +173,6 @@ def forward(self, input_: torch.Tensor): return output - def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): # For TP=1, no slicing needed # LoRA A weights (rank, vocab_size) are not sliced for embedding @@ -178,14 +185,11 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): # For TP>1, Need to modify code in: sglang/python/sglang/srt/lora/mem_pool.py return B - - - class ParallelLMHeadWithLoRA(BaseLayerWithLoRA): """ Parallel LM Head layer with LoRA support (simplified for TP=1). - + The LM head computes logits = hidden_states @ (W + B @ A)^T """ @@ -214,7 +218,7 @@ def apply_lora( ) -> torch.Tensor: """ Apply LoRA to LM head layer. - + For LM head: output = hidden @ (W + B @ A)^T = hidden @ W^T + hidden @ A^T @ B^T = base_output + (hidden @ A^T) @ B^T @@ -223,28 +227,26 @@ def apply_lora( lora_a_output = self.lora_backend.run_lora_a_sgemm( hidden_states, self.lm_head_A_buffer ) - + # Apply lora_B^T: lora_a_output @ B^T lora_output = self.lora_backend.run_lora_b_sgemm( x=lora_a_output, weights=self.lm_head_B_buffer, base_output=base_output, ) - + return lora_output def forward(self, hidden_states: torch.Tensor): # Apply base linear transformation base_output = F.linear( - hidden_states, - self.weight, - bias=getattr(self.base_layer, 'bias', None) + hidden_states, self.weight, bias=getattr(self.base_layer, "bias", None) ) - + # Apply LoRA if set if self.set_lora: base_output = self.apply_lora(base_output, hidden_states) - + return base_output def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): @@ -421,7 +423,7 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor output_offset_cpu=self.output_offset_cpu, max_qkv_out_dim=self.max_qkv_out_dim, ) - + return lora_output def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index f616664b4965..b464b21cca9a 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -91,7 +91,7 @@ def initialize_weights(self): # self.embedding_layers.weights[name] = loaded_weight.cpu() self.embedding_layers[name] = loaded_weight.cpu() elif "input_embeddings" in name or "output_embeddings" in name: - #added token emb + # added token emb self.added_tokens_embeddings[name] = loaded_weight.cpu() assert loaded_weight.shape[0] == self.config.lora_added_tokens_size, ( f"LoRA adapter {self.uid} has extra_vocab_size {self.config.extra_vocab_size} specified in the config, " diff --git a/python/sglang/srt/lora/lora_config.py b/python/sglang/srt/lora/lora_config.py index d28e630beae4..a5cc80fab979 100644 --- a/python/sglang/srt/lora/lora_config.py +++ b/python/sglang/srt/lora/lora_config.py @@ -54,14 +54,14 @@ def get_added_tokens_config(self): weights_dir = snapshot_download(self.path, allow_patterns=["*.json"]) else: weights_dir = self.path - + # Construct the path to added_tokens.json added_tokens_path = os.path.join(weights_dir, "added_tokens.json") - + # Return None if the file doesn't exist (optional for standard LoRA adapters) if not os.path.exists(added_tokens_path): return None - + # Load and return the added tokens try: with open(added_tokens_path, "r") as f: @@ -69,6 +69,7 @@ def get_added_tokens_config(self): except json.JSONDecodeError as e: # Log warning but don't crash if JSON is malformed import logging + logger = logging.getLogger(__name__) logger.warning(f"Failed to parse added_tokens.json: {e}") - return None \ No newline at end of file + return None diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 985142c19af5..ee4a92994a08 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -22,6 +22,10 @@ from sglang.srt.configs.load_config import LoadConfig from sglang.srt.layers.utils import get_layer_id +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.backend.lora_registry import get_backend_from_name from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer @@ -40,10 +44,6 @@ from sglang.srt.utils import replace_submodule from sglang.srt.utils.hf_transformers_utils import AutoConfig -from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead - - - logger = logging.getLogger(__name__) @@ -71,7 +71,7 @@ def __init__( self.device: torch.device = next(self.base_model.parameters()).device self.tp_size: int = tp_size self.tp_rank: int = tp_rank - self.lora_added_tokens_size: Optional[int] = None + self.lora_added_tokens_size: Optional[int] = None # Store eviction policy from server args self.eviction_policy = server_args.lora_eviction_policy @@ -249,8 +249,8 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): lora_adapters=self.loras, lora_modules=self.lora_modules, lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation. - lora_embed_tokens_module=self.embed_tokens_module, #merge into embedding or lora module - lora_lm_head_module=self.lm_head_module, #merge into embedding or lora module + lora_embed_tokens_module=self.embed_tokens_module, # merge into embedding or lora module + lora_lm_head_module=self.lm_head_module, # merge into embedding or lora module ) # set up batch info shared by all lora modules @@ -306,11 +306,13 @@ def update_lora_info(self): # Update embedding layer if present - gotta merge (refer to PR codebase) if self.embed_tokens_module is not None: self.embed_tokens_module.set_lora_info( - self.memory_pool.get_embedding_tensor("added_tokens", LoRAType.LORA_A), #choose name: "added_tokens" + self.memory_pool.get_embedding_tensor( + "added_tokens", LoRAType.LORA_A + ), # choose name: "added_tokens" self.memory_pool.get_embedding_tensor("embed_tokens", LoRAType.LORA_A), self.memory_pool.get_embedding_tensor("embed_tokens", LoRAType.LORA_B), ) - + # Update lm_head layer if present if self.lm_head_module is not None: self.lm_head_module.set_lora_info( @@ -411,14 +413,18 @@ def init_lora_shapes( [x.r for x in self.configs.values()], default=0, ) - + # Auto-infer self.lora_added_vocab_size from loaded LoRA configs # This happens automatically without requiring user input # if self.lora_added_vocab_size is None: if self.lora_added_tokens_size is None: inferred_extra_vocab_size = next( - (x.lora_added_tokens_size for x in self.configs.values() if x.lora_added_tokens_size > 0), - 0 + ( + x.lora_added_tokens_size + for x in self.configs.values() + if x.lora_added_tokens_size > 0 + ), + 0, ) if inferred_extra_vocab_size > 0: logger.info( @@ -452,7 +458,7 @@ def init_memory_pool(self): target_modules=self.target_modules, base_model=self.base_model, eviction_policy=self.eviction_policy, - lora_added_tokens_size = self.lora_added_tokens_size + lora_added_tokens_size=self.lora_added_tokens_size, ) def set_lora_module(self, module_name, module): @@ -482,16 +488,20 @@ def init_lora_modules(self): # Handle embed_tokens if "embed_tokens" in module_name and "embed_tokens" in self.target_modules: - if isinstance(module, VocabParallelEmbedding) and not isinstance(module, BaseLayerWithLoRA): + if isinstance(module, VocabParallelEmbedding) and not isinstance( + module, BaseLayerWithLoRA + ): lora_module = self.set_lora_module(module_name, module) - self.embed_tokens_module = lora_module + self.embed_tokens_module = lora_module continue - + # Handle lm_head if "lm_head" in module_name and "lm_head" in self.target_modules: - if isinstance(module, ParallelLMHead) and not isinstance(module, BaseLayerWithLoRA): + if isinstance(module, ParallelLMHead) and not isinstance( + module, BaseLayerWithLoRA + ): lora_module = self.set_lora_module(module_name, module) - self.lm_head_module = lora_module + self.lm_head_module = lora_module continue # The module should be converted if it is included in target_names diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 18e0ae692184..17dd51bbd70c 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -10,8 +10,8 @@ from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.utils import ( - ROW_PARALLELISM_LINEAR_LORA_NAMES, EMBEDDING_NAMES, + ROW_PARALLELISM_LINEAR_LORA_NAMES, LoRAType, get_hidden_dim, get_normalized_target_modules, @@ -57,7 +57,7 @@ def __init__( target_modules: Set[str], base_model: torch.nn.Module, eviction_policy: str, - lora_added_tokens_size: int + lora_added_tokens_size: int, ): self.base_hf_config: AutoConfig = base_hf_config self.num_layer: int = base_hf_config.num_hidden_layers @@ -113,7 +113,7 @@ def _can_support(config: LoRAConfig) -> bool: if config.r > self.max_lora_rank: return False if config.lora_added_tokens_size > self.lora_added_tokens_size: - return False + return False target_module_names = get_normalized_target_modules(config.target_modules) return target_module_names.issubset(self.target_modules) @@ -154,7 +154,7 @@ def get_embedding_lora_A_shape( input_dim, _ = get_hidden_dim( module_name, self.base_hf_config, base_model, 0, self.lora_added_tokens_size ) - # Have not imp self.tp_size > 1 yet. + # Have not imp self.tp_size > 1 yet. return ( self.max_loras_per_batch, max_lora_dim, @@ -222,8 +222,7 @@ def init_buffer( ) for idx in range(self.num_layer) ] - - + def init_embedding_buffer( buffer: Dict[str, torch.Tensor], target_modules: Set[str], @@ -265,7 +264,7 @@ def init_embedding_buffer( self.target_modules, self.get_embedding_lora_B_shape, ) - + if "lm_head" in self.target_modules: init_embedding_buffer( self.lm_head_A_buffer, @@ -278,7 +277,6 @@ def init_embedding_buffer( self.target_modules, self.get_embedding_lora_B_shape, ) - init_buffer( self.A_buffer, @@ -291,7 +289,6 @@ def init_embedding_buffer( self.target_modules, self.get_lora_B_shape, ) - def prepare_lora_batch( self, @@ -353,7 +350,12 @@ def get_available_buffer_slot(): buffer_id = get_available_buffer_slot() lora_adapter = lora_adapters.get(uid, None) self.load_lora_weight_to_buffer( - uid, buffer_id, lora_adapter, lora_modules, lora_embed_tokens_module, lora_lm_head_module + uid, + buffer_id, + lora_adapter, + lora_modules, + lora_embed_tokens_module, + lora_lm_head_module, ) self.uid_to_buffer_id[uid] = buffer_id self.buffer_id_to_uid[buffer_id] = uid @@ -438,12 +440,11 @@ def load_lora_weight_tensor( buffer_view = target_buffer[buffer_id, :, :lora_rank] load_lora_weight_tensor(buffer_view, weights) - if lora_adapter.embedding_layers: org_vocab_size = self.base_hf_config.vocab_size lora_added_tokens_size = lora_adapter.config.lora_added_tokens_size - # Only when LoRA is applied to the embedding layer will it have the extra-token issue that needs to be resolved. + # Only when LoRA is applied to the embedding layer will it have the extra-token issue that needs to be resolved. # Load embeddings weights for extra tokens to buffer if lora_adapter.added_tokens_embeddings: for name, weights in lora_adapter.added_tokens_embeddings.items(): @@ -452,18 +453,28 @@ def load_lora_weight_tensor( buffer_id, :lora_added_tokens_size ] load_lora_weight_tensor(buffer_view, weights) - - #load vocab_emb and lm_head + + # load vocab_emb and lm_head for name, weights in lora_adapter.embedding_layers.items(): target_module = get_target_module_name(name, self.target_modules) - if target_module == "embed_tokens" and "embed_tokens" in name and ("lora_embedding_A" in name or "lora_A" in name): + if ( + target_module == "embed_tokens" + and "embed_tokens" in name + and ("lora_embedding_A" in name or "lora_A" in name) + ): buffer_view = self.embedding_A_buffer[target_module][ - buffer_id, :lora_rank, :(org_vocab_size+lora_added_tokens_size) + buffer_id, + :lora_rank, + : (org_vocab_size + lora_added_tokens_size), ] load_lora_weight_tensor(buffer_view, weights) - elif target_module == "embed_tokens" and "embed_tokens" in name and ("lora_embedding_B" in name or "lora_B" in name): + elif ( + target_module == "embed_tokens" + and "embed_tokens" in name + and ("lora_embedding_B" in name or "lora_B" in name) + ): lora_b_weights = weights - #[to-do] support TP + # [to-do] support TP # if self.tp_size > 1: # cur_module = lora_embeddings_modules[target_module] # for module_name, module in cur_module: @@ -476,15 +487,25 @@ def load_lora_weight_tensor( ] load_lora_weight_tensor(buffer_view, lora_b_weights) - elif target_module == "lm_head" and "lm_head" in name and ("lora_embedding_A" in name or "lora_A" in name): + elif ( + target_module == "lm_head" + and "lm_head" in name + and ("lora_embedding_A" in name or "lora_A" in name) + ): buffer_view = self.lm_head_A_buffer[target_module][ # buffer_id, :, :lora_rank - buffer_id, :lora_rank, : + buffer_id, + :lora_rank, + :, ] load_lora_weight_tensor(buffer_view, weights) - elif target_module == "lm_head" and "lm_head" in name and ("lora_embedding_B" in name or "lora_B" in name): + elif ( + target_module == "lm_head" + and "lm_head" in name + and ("lora_embedding_B" in name or "lora_B" in name) + ): lora_b_weights = weights - #[to-do] support TP + # [to-do] support TP # if self.tp_size > 1: # cur_module = lora_embeddings_modules[target_module] # for module_name, module in cur_module: @@ -494,27 +515,30 @@ def load_lora_weight_tensor( buffer_view = self.lm_head_B_buffer[target_module][ # buffer_id, :lora_rank, : org_vocab_size + extra_vocab_size - buffer_id, :(org_vocab_size + self.lora_added_tokens_size), :lora_rank + buffer_id, + : (org_vocab_size + self.lora_added_tokens_size), + :lora_rank, ] load_lora_weight_tensor(buffer_view, lora_b_weights) - def get_embedding_tensor( self, target_module: str, lora_type: LoRAType ) -> Optional[torch.Tensor]: """ Get LoRA tensor for non-layer modules (embed_tokens, lm_head). - + Args: target_module: Module name, either "embed_tokens" or "lm_head" lora_type: Either LoRAType.LORA_A or LoRAType.LORA_B - + Returns: The corresponding buffer tensor, or None if not available """ if target_module == "added_tokens": - if self.lora_added_tokens_size > 0 and self.lora_added_tokens_size != None: # change to read from the config + if ( + self.lora_added_tokens_size > 0 and self.lora_added_tokens_size != None + ): # change to read from the config return self.new_embeddings_buffer["input_embeddings"] return None elif target_module == "embed_tokens": @@ -525,16 +549,16 @@ def get_embedding_tensor( if lora_type == LoRAType.LORA_A: return self.lm_head_A_buffer[target_module] return self.lm_head_B_buffer[target_module] - + raise ValueError( f"Invalid target_module '{target_module}'. " f"Expected 'embed_tokens' or 'lm_head'." ) - + def get_tensor( self, target_module: str, layer_id: int, lora_type: LoRAType ) -> torch.Tensor: - + if lora_type == LoRAType.LORA_A: return self.A_buffer[target_module][layer_id] diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py index 5927bee9074a..3a31a3262dc4 100644 --- a/python/sglang/srt/lora/triton_ops/__init__.py +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -1,14 +1,16 @@ from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward -from .gate_up_lora_b import gate_up_lora_b_fwd -from .qkv_lora_b import qkv_lora_b_fwd -from .sgemm_lora_a import sgemm_lora_a_fwd -from .sgemm_lora_b import sgemm_lora_b_fwd +from .embedding_extra_tokens import embedding_extra_tokens_modified + ############################# #########cuda lora########### ############################# from .embedding_lora_a import embedding_lora_a_fwd -from .embedding_extra_tokens import embedding_extra_tokens_modified +from .gate_up_lora_b import gate_up_lora_b_fwd +from .qkv_lora_b import qkv_lora_b_fwd +from .sgemm_lora_a import sgemm_lora_a_fwd +from .sgemm_lora_b import sgemm_lora_b_fwd + ############################# ############################# ############################# @@ -22,7 +24,7 @@ "chunked_sgmv_lora_expand_forward", ############################# #########cuda lora########### - ############################# + ############################# "embedding_lora_a_fwd", "embedding_extra_tokens_modified", ############################# diff --git a/python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py b/python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py index 878207d8cf5d..b59e7cb91237 100644 --- a/python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py +++ b/python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py @@ -4,6 +4,7 @@ import torch import triton import triton.language as tl + from sglang.srt.lora.utils import LoRABatchInfo @@ -37,34 +38,34 @@ def _embedding_extra_tokens_kernel( """ batch_id = tl.program_id(axis=1) token_idx = tl.program_id(axis=0) - + w_index = tl.load(weight_indices + batch_id) seg_start = tl.load(seg_indptr + batch_id) seg_len = tl.load(seg_lens + batch_id) - + # Check if this token is within the segment if token_idx >= seg_len: return - + # Load the token ID token_id = tl.load(input_ids + seg_start + token_idx) - + # Check if this is an extra token is_extra_token = token_id >= vocab_size - + if not is_extra_token: return # Skip non-extra tokens - + # Calculate extra token ID extra_token_id = token_id - vocab_size - + # Process in chunks of BLOCK_EMBED dimensions num_blocks = tl.cdiv(embed_dim, BLOCK_EMBED) - + for block_id in range(num_blocks): embed_offset = tl.arange(0, BLOCK_EMBED) + block_id * BLOCK_EMBED embed_mask = embed_offset < embed_dim - + # Load from extra embeddings # extra_embeddings shape: (num_loras, num_extra_tokens, embed_dim) extra_emb_ptr = ( @@ -74,7 +75,7 @@ def _embedding_extra_tokens_kernel( + embed_offset * extra_emb_stride_2 ) emb_values = tl.load(extra_emb_ptr, mask=embed_mask, other=0.0) - + # Write to output (overwrite the position) output_ptr = ( output @@ -93,14 +94,14 @@ def embedding_extra_tokens_modified( ) -> torch.Tensor: """ Forward pass for extra token embedding lookup (in-place operation). - + Args: input_ids: (s,) token IDs output: (s, embed_dim) output tensor to be modified in-place extra_embeddings: (num_loras, num_extra_tokens, embed_dim) extra token embeddings batch_info: LoRABatchInfo containing batch information vocab_size: base vocabulary size - + Returns: output: (s, embed_dim) modified output tensor """ @@ -110,26 +111,26 @@ def embedding_extra_tokens_modified( assert len(input_ids.shape) == 1 assert len(output.shape) == 2 assert len(extra_embeddings.shape) == 3 - + S = input_ids.shape[0] embed_dim = output.shape[1] num_loras = extra_embeddings.shape[0] - + # Block size for embedding dimension BLOCK_EMBED = 128 - + extra_emb_stride = ( extra_embeddings.stride(0), extra_embeddings.stride(1), extra_embeddings.stride(2), ) - + # Grid: one program per token in each batch segment grid = ( batch_info.max_len, batch_info.bs, ) - + _embedding_extra_tokens_kernel[grid]( input_ids, output, @@ -147,8 +148,10 @@ def embedding_extra_tokens_modified( batch_info.weight_indices, BLOCK_EMBED, ) - + return output + + +############################# ############################# ############################# -############################# \ No newline at end of file diff --git a/python/sglang/srt/lora/triton_ops/embedding_lora_a.py b/python/sglang/srt/lora/triton_ops/embedding_lora_a.py index 61acdd72a54c..e60ad6d1c71b 100644 --- a/python/sglang/srt/lora/triton_ops/embedding_lora_a.py +++ b/python/sglang/srt/lora/triton_ops/embedding_lora_a.py @@ -40,47 +40,46 @@ def _embedding_lora_a_kernel( ): """ Embedding lookup for LoRA A weights with support for extra tokens. - + Each program handles one token across a block of rank dimensions. Grid: (cdiv(max_len, 1), bs) - one program per token in each batch """ batch_id = tl.program_id(axis=1) token_idx = tl.program_id(axis=0) - + w_index = tl.load(weight_indices + batch_id) rank_val = tl.load(lora_ranks + w_index) - + # If rank is 0, skip if rank_val == 0: return - + seg_start = tl.load(seg_indptr + batch_id) seg_len = tl.load(seg_lens + batch_id) - + # Check if this token is within the segment if token_idx >= seg_len: return - + # Load the token ID token_id = tl.load(input_ids + seg_start + token_idx) - + # Process in chunks of BLOCK_RANK dimensions num_blocks = tl.cdiv(rank_val, BLOCK_RANK) - for block_id in range(num_blocks): rank_offset = tl.arange(0, BLOCK_RANK) + block_id * BLOCK_RANK rank_mask = rank_offset < rank_val - + # Check if this is an extra token is_extra_token = token_id >= vocab_size - + if HAS_EXTRA_EMBEDDINGS and is_extra_token: # Use extra embeddings extra_token_id = token_id - vocab_size extra_emb_ptr = ( - extra_embeddings - + w_index * extra_emb_stride_0 + extra_embeddings + + w_index * extra_emb_stride_0 + extra_token_id * extra_emb_stride_1 + rank_offset * extra_emb_stride_2 ) @@ -91,17 +90,17 @@ def _embedding_lora_a_kernel( # We need to load weights[w_index, rank_offset, token_id] token_id_clamped = tl.minimum(token_id, vocab_size - 1) weight_ptr = ( - weights - + w_index * w_stride_0 - + rank_offset * w_stride_1 + weights + + w_index * w_stride_0 + + rank_offset * w_stride_1 + token_id_clamped * w_stride_2 ) emb_values = tl.load(weight_ptr, mask=rank_mask, other=0.0) - + # Write to output output_ptr = ( - output - + (seg_start + token_idx) * output_stride_0 + output + + (seg_start + token_idx) * output_stride_0 + rank_offset * output_stride_1 ) tl.store(output_ptr, emb_values, mask=rank_mask) @@ -116,14 +115,14 @@ def embedding_lora_a_fwd( ) -> torch.Tensor: """ Forward pass for LoRA A embedding lookup. - + Args: input_ids: (s,) token IDs weights: (num_loras, rank, vocab_size) LoRA A embedding weights batch_info: LoRABatchInfo containing batch information vocab_size: base vocabulary size extra_embeddings: (num_loras, num_extra_tokens, rank) extra token embeddings - + Returns: output: (s, rank) embedded features """ @@ -131,17 +130,17 @@ def embedding_lora_a_fwd( assert weights.is_contiguous() assert len(input_ids.shape) == 1 assert len(weights.shape) == 3 - + S = input_ids.shape[0] num_loras = weights.shape[0] rank = weights.shape[1] vocab_size_weights = weights.shape[2] - + # Block size for rank dimension BLOCK_RANK = 128 - + has_extra_embeddings = extra_embeddings is not None - + if has_extra_embeddings: assert extra_embeddings.is_contiguous() extra_emb_stride = ( @@ -155,15 +154,15 @@ def embedding_lora_a_fwd( (1, 1, 1), device=input_ids.device, dtype=weights.dtype ) extra_emb_stride = (1, 1, 1) - + # Grid: one program per token in each batch segment grid = ( batch_info.max_len, batch_info.bs, ) - + output = torch.zeros((S, rank), device=input_ids.device, dtype=weights.dtype) - + _embedding_lora_a_kernel[grid]( input_ids, weights, @@ -187,9 +186,10 @@ def embedding_lora_a_fwd( BLOCK_RANK, has_extra_embeddings, ) - + return output + +############################# ############################# ############################# -############################# \ No newline at end of file diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 1a091ab4e6c0..b59c17aa522c 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -46,7 +46,11 @@ class LoRAType(Enum): def get_hidden_dim( - module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int, lora_added_vocab_size: int = 0 + module_name: str, + config: AutoConfig, + base_model: torch.nn.Module, + layer_idx: int, + lora_added_vocab_size: int = 0, ) -> Tuple[int]: """ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. @@ -144,5 +148,6 @@ def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> s f"Cannot find target module name for {full_module_name} in {target_modules}" ) + EMBEDDING_NAMES = ["embed_tokens", "lm_head"] -ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"] \ No newline at end of file +ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"] diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 319b87cc53f1..8510798fa37e 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -427,14 +427,18 @@ def forward_generation_raw( if lora_paths is not None and lora_paths[i] is not None: from peft import PeftConfig, PeftModel + from sglang.srt.lora.lora_config import LoRAConfig peft_config = PeftConfig.from_pretrained(lora_paths[i]) lora_config = LoRAConfig(lora_paths[i]) - if "embed_tokens" in peft_config.target_modules and lora_config.lora_added_tokens_size > 0: + if ( + "embed_tokens" in peft_config.target_modules + and lora_config.lora_added_tokens_size > 0 + ): new_tokenizer = get_tokenizer(lora_paths[i]) base_model.resize_token_embeddings(len(new_tokenizer)) - tokenizer = new_tokenizer + tokenizer = new_tokenizer model = PeftModel.from_pretrained( base_model, From b33c5e0586f2bd8e7f736a3475beacbe489ff23d Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Fri, 5 Dec 2025 01:54:14 +0000 Subject: [PATCH 13/19] fix lora/layer.py --- python/sglang/srt/lora/layers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 45e320393a2a..01db0f524c98 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -161,7 +161,9 @@ def forward(self, input_: torch.Tensor): base_output = self.base_layer.forward(input_.masked_fill(added_tokens_mask, 0)) # Extra tokens - It will replace extra token embedding with self.new_embeddings_buffer's emb (Default is 0) - base_output = self.extra_token_embedding(input_, base_output) + # Only process extra tokens if we have new embeddings + if hasattr(self, 'new_embeddings_buffer') and self.new_embeddings_buffer is not None: + base_output = self.extra_token_embedding(input_, base_output) # Apply LoRA if configured if self.set_lora: From 10a099f34dde8683796f2d44981e99439a96c9f5 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Fri, 5 Dec 2025 03:29:07 +0000 Subject: [PATCH 14/19] remove comments --- python/sglang/srt/lora/backend/base_backend.py | 7 ------- python/sglang/srt/lora/backend/chunked_backend.py | 9 +-------- python/sglang/srt/lora/backend/triton_backend.py | 9 +-------- python/sglang/srt/lora/layers.py | 5 ++++- python/sglang/srt/lora/triton_ops/__init__.py | 14 -------------- .../sglang/srt/lora/triton_ops/embedding_lora_a.py | 9 --------- 6 files changed, 6 insertions(+), 47 deletions(-) diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index d041b988c880..06e4e8ba5a95 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -19,9 +19,6 @@ def __init__(self, max_loras_per_batch: int, device: torch.device): self.max_loras_per_batch = max_loras_per_batch self.device = device - ############################# - #########cuda lora########### - ############################# def run_lora_a_embedding( self, input_ids: torch.Tensor, @@ -68,10 +65,6 @@ def run_extra_token_embedding( """ raise NotImplementedError - ############################# - ############################# - ############################# - def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs ) -> torch.Tensor: diff --git a/python/sglang/srt/lora/backend/chunked_backend.py b/python/sglang/srt/lora/backend/chunked_backend.py index 1d0f702806d3..5526b9c11d1b 100644 --- a/python/sglang/srt/lora/backend/chunked_backend.py +++ b/python/sglang/srt/lora/backend/chunked_backend.py @@ -1,7 +1,7 @@ import torch from sglang.srt.lora.backend.base_backend import BaseLoRABackend -from sglang.srt.lora.triton_ops import ( # ############################; ########cuda lora########### +from sglang.srt.lora.triton_ops import ( chunked_sgmv_lora_expand_forward, chunked_sgmv_lora_shrink_forward, embedding_extra_tokens_modified, @@ -35,9 +35,6 @@ def __init__( super().__init__(max_loras_per_batch, device) self.max_chunk_size = server_args.max_lora_chunk_size - ############################## - ##########cuda lora########### - ############################## def run_lora_a_embedding( self, input_ids: torch.Tensor, @@ -82,10 +79,6 @@ def run_extra_token_embedding( vocab_size=vocab_size, ) - ############################## - ############################## - ############################## - def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs ) -> torch.Tensor: diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index ea9ccc7efd9b..9bb9a45d8fea 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -1,7 +1,7 @@ import torch from sglang.srt.lora.backend.base_backend import BaseLoRABackend -from sglang.srt.lora.triton_ops import ( # ############################; ########cuda lora########### +from sglang.srt.lora.triton_ops import ( embedding_extra_tokens_modified, embedding_lora_a_fwd, gate_up_lora_b_fwd, @@ -24,9 +24,6 @@ def __init__( ): super().__init__(max_loras_per_batch, device) - ############################# - #########cuda lora########### - ############################# def run_lora_a_embedding( self, input_ids: torch.Tensor, @@ -63,10 +60,6 @@ def run_extra_token_embedding( vocab_size=vocab_size, ) - ############################# - ############################# - ############################# - def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs ) -> torch.Tensor: diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 01db0f524c98..e7d7f303c3dc 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -162,7 +162,10 @@ def forward(self, input_: torch.Tensor): # Extra tokens - It will replace extra token embedding with self.new_embeddings_buffer's emb (Default is 0) # Only process extra tokens if we have new embeddings - if hasattr(self, 'new_embeddings_buffer') and self.new_embeddings_buffer is not None: + if ( + hasattr(self, "new_embeddings_buffer") + and self.new_embeddings_buffer is not None + ): base_output = self.extra_token_embedding(input_, base_output) # Apply LoRA if configured diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py index 3a31a3262dc4..761297b3026a 100644 --- a/python/sglang/srt/lora/triton_ops/__init__.py +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -1,20 +1,12 @@ from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward from .embedding_extra_tokens import embedding_extra_tokens_modified - -############################# -#########cuda lora########### -############################# from .embedding_lora_a import embedding_lora_a_fwd from .gate_up_lora_b import gate_up_lora_b_fwd from .qkv_lora_b import qkv_lora_b_fwd from .sgemm_lora_a import sgemm_lora_a_fwd from .sgemm_lora_b import sgemm_lora_b_fwd -############################# -############################# -############################# - __all__ = [ "gate_up_lora_b_fwd", "qkv_lora_b_fwd", @@ -22,12 +14,6 @@ "sgemm_lora_b_fwd", "chunked_sgmv_lora_shrink_forward", "chunked_sgmv_lora_expand_forward", - ############################# - #########cuda lora########### - ############################# "embedding_lora_a_fwd", "embedding_extra_tokens_modified", - ############################# - ############################# - ############################# ] diff --git a/python/sglang/srt/lora/triton_ops/embedding_lora_a.py b/python/sglang/srt/lora/triton_ops/embedding_lora_a.py index e60ad6d1c71b..1e21be50fd79 100644 --- a/python/sglang/srt/lora/triton_ops/embedding_lora_a.py +++ b/python/sglang/srt/lora/triton_ops/embedding_lora_a.py @@ -1,7 +1,3 @@ -############################# -#########cuda lora########### -############################# - import torch import triton import triton.language as tl @@ -188,8 +184,3 @@ def embedding_lora_a_fwd( ) return output - - -############################# -############################# -############################# From 18a8f5b30f697807c6b80580a8131b0e76ae4c16 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Sat, 6 Dec 2025 01:55:49 +0000 Subject: [PATCH 15/19] remove chunked_backend and remove extra_tokens support temporarily --- .../srt/lora/backend/chunked_backend.py | 46 ----- .../sglang/srt/lora/backend/triton_backend.py | 19 --- python/sglang/srt/lora/layers.py | 37 +++-- python/sglang/srt/lora/mem_pool.py | 4 +- python/sglang/srt/lora/triton_ops/__init__.py | 2 - .../lora/triton_ops/embedding_extra_tokens.py | 157 ------------------ python/sglang/test/runners.py | 15 +- 7 files changed, 29 insertions(+), 251 deletions(-) delete mode 100644 python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py diff --git a/python/sglang/srt/lora/backend/chunked_backend.py b/python/sglang/srt/lora/backend/chunked_backend.py index 5526b9c11d1b..f17f473cbdfd 100644 --- a/python/sglang/srt/lora/backend/chunked_backend.py +++ b/python/sglang/srt/lora/backend/chunked_backend.py @@ -4,8 +4,6 @@ from sglang.srt.lora.triton_ops import ( chunked_sgmv_lora_expand_forward, chunked_sgmv_lora_shrink_forward, - embedding_extra_tokens_modified, - embedding_lora_a_fwd, ) from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -35,50 +33,6 @@ def __init__( super().__init__(max_loras_per_batch, device) self.max_chunk_size = server_args.max_lora_chunk_size - def run_lora_a_embedding( - self, - input_ids: torch.Tensor, - weights: torch.Tensor, - vocab_size: int, - extra_embeddings: torch.Tensor = None, - *args, - **kwargs, - ) -> torch.Tensor: - """Run LoRA A embedding lookup. - - For chunked backend, we use the same triton kernel as triton backend - since embedding lookup doesn't benefit much from chunking. - """ - return embedding_lora_a_fwd( - input_ids=input_ids, - weights=weights, - batch_info=self.batch_info, - vocab_size=vocab_size, - extra_embeddings=extra_embeddings, - ) - - def run_extra_token_embedding( - self, - input_ids: torch.Tensor, - output: torch.Tensor, - extra_embeddings: torch.Tensor, - vocab_size: int, - *args, - **kwargs, - ) -> torch.Tensor: - """Run extra token embedding lookup. - - For chunked backend, we use the same triton kernel as triton backend - since embedding lookup doesn't benefit from chunking. - """ - return embedding_extra_tokens_modified( - input_ids=input_ids, - output=output, - extra_embeddings=extra_embeddings, - batch_info=self.batch_info, - vocab_size=vocab_size, - ) - def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs ) -> torch.Tensor: diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index 9bb9a45d8fea..a36069d4ad86 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -2,7 +2,6 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.triton_ops import ( - embedding_extra_tokens_modified, embedding_lora_a_fwd, gate_up_lora_b_fwd, qkv_lora_b_fwd, @@ -42,24 +41,6 @@ def run_lora_a_embedding( extra_embeddings=extra_embeddings, ) - def run_extra_token_embedding( - self, - input_ids: torch.Tensor, - output: torch.Tensor, - extra_embeddings: torch.Tensor, - vocab_size: int, - *args, - **kwargs, - ) -> torch.Tensor: - """Run extra token embedding lookup using Triton kernel.""" - return embedding_extra_tokens_modified( - input_ids=input_ids, - output=output, - extra_embeddings=extra_embeddings, - batch_info=self.batch_info, - vocab_size=vocab_size, - ) - def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs ) -> torch.Tensor: diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index e7d7f303c3dc..2f898c0c5c94 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -68,6 +68,12 @@ def __init__( self.embed_dim = base_layer.embedding_dim self.vocab_size = base_layer.org_vocab_size + self.output_offset = torch.tensor( + [0, self.embed_dim], + dtype=torch.int32, + device=next(base_layer.parameters()).device, + ) + def set_lora_info( self, new_embeddings_buffer: Optional[torch.Tensor], # For extra tokens @@ -88,13 +94,14 @@ def apply_lora( Formula: output = base_output + lora_B @ lora_A_embedding(input_) """ - # Efficient embedding lookup for LoRA A (cannot call run_lora_a_sgemm since needing index lookup) + # Efficient embedding lookup for LoRA A (already support extra token embedding process) lora_a_output = self.run_lora_a_embedding(input_, batch_info) # Apply LoRA B weights using backend lora_output = self.lora_backend.run_lora_b_sgemm( x=lora_a_output, weights=self.embedding_B_buffer, + output_offset=self.output_offset, base_output=base_output, ) return lora_output @@ -107,6 +114,7 @@ def run_lora_a_embedding( Maps tokens to their corresponding LoRA adapters internally. It also includes added/extra token processing. """ + # Efficient embedding lookup for LoRA A (already support extra token embedding process) lora_a_output = self.lora_backend.run_lora_a_embedding( input_ids=input_, weights=self.embedding_A_buffer, @@ -125,6 +133,8 @@ def extra_token_embedding( self, input_: torch.Tensor, base_output: torch.Tensor ) -> torch.Tensor: """ + Need to impl: + Process extra tokens (tokens >= vocab_size) by looking up their embeddings from the new_embeddings_buffer and replacing them in base_output. @@ -133,18 +143,17 @@ def extra_token_embedding( base_output: (s, embed_dim) base embedding output to be modified in-place Returns: - base_output: (s, embed_dim) modified output with extra token embeddings + base_output: (s, embed_dim) modified input base_output (tensor[0,0,0,...]) with extra token embeddings """ - - output_base_output = self.lora_backend.run_extra_token_embedding( - input_ids=input_, - output=base_output, - extra_embeddings=self.new_embeddings_buffer, - vocab_size=self.vocab_size, + # return base_output + raise NotImplementedError( + "Error in sglang/python/sglang/srt/lora/layers.py - VocabParallelEmbeddingWithLoRA \n" + "Current SGLang codebase did not support tuned lora with extra/added tokens. \n" + "[TODO]: \n" + "1. Refer to this commit: https://github.com/yushengsu-thu/sglang/commit/90415211eee8a28a316de262583d4d33fa615d10#diff-191177438bcc223837963de63c005850371f8c8a860acb153b26744b66ecc623 to complete \n" + "2. And then you need to modified the en/decoder tokenizer - tokenizer_manager.py to support extra_token_embedding in-place. \n" ) - return output_base_output - def forward(self, input_: torch.Tensor): """ Forward pass with LoRA support and CUDA graph compatibility. @@ -160,8 +169,8 @@ def forward(self, input_: torch.Tensor): added_tokens_mask = input_ > self.vocab_size - 1 base_output = self.base_layer.forward(input_.masked_fill(added_tokens_mask, 0)) + # [TODO] SGLang did not support extra/added token process; thus, self.extra_token_embedding only return original input_ now # Extra tokens - It will replace extra token embedding with self.new_embeddings_buffer's emb (Default is 0) - # Only process extra tokens if we have new embeddings if ( hasattr(self, "new_embeddings_buffer") and self.new_embeddings_buffer is not None @@ -207,6 +216,11 @@ def __init__( self.weight = base_layer.weight self.embed_dim = base_layer.embedding_dim self.vocab_size = base_layer.org_vocab_size + self.output_offset = torch.tensor( + [0, self.vocab_size], + dtype=torch.int32, + device=next(base_layer.parameters()).device, + ) def set_lora_info( self, @@ -237,6 +251,7 @@ def apply_lora( lora_output = self.lora_backend.run_lora_b_sgemm( x=lora_a_output, weights=self.lm_head_B_buffer, + output_offset=self.output_offset, base_output=base_output, ) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 17dd51bbd70c..7ddf85254896 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -536,9 +536,7 @@ def get_embedding_tensor( """ if target_module == "added_tokens": - if ( - self.lora_added_tokens_size > 0 and self.lora_added_tokens_size != None - ): # change to read from the config + if self.lora_added_tokens_size > 0 and self.lora_added_tokens_size != None: return self.new_embeddings_buffer["input_embeddings"] return None elif target_module == "embed_tokens": diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py index 761297b3026a..71eb1fea4837 100644 --- a/python/sglang/srt/lora/triton_ops/__init__.py +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -1,6 +1,5 @@ from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward -from .embedding_extra_tokens import embedding_extra_tokens_modified from .embedding_lora_a import embedding_lora_a_fwd from .gate_up_lora_b import gate_up_lora_b_fwd from .qkv_lora_b import qkv_lora_b_fwd @@ -15,5 +14,4 @@ "chunked_sgmv_lora_shrink_forward", "chunked_sgmv_lora_expand_forward", "embedding_lora_a_fwd", - "embedding_extra_tokens_modified", ] diff --git a/python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py b/python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py deleted file mode 100644 index b59e7cb91237..000000000000 --- a/python/sglang/srt/lora/triton_ops/embedding_extra_tokens.py +++ /dev/null @@ -1,157 +0,0 @@ -############################# -#########cuda graph########## -############################# -import torch -import triton -import triton.language as tl - -from sglang.srt.lora.utils import LoRABatchInfo - - -@triton.jit -def _embedding_extra_tokens_kernel( - # Pointers to tensors - input_ids, - output, - extra_embeddings, - # Dimensions - vocab_size, - embed_dim, - num_loras, - # Strides - output_stride_0, - output_stride_1, - extra_emb_stride_0, # stride for lora index - extra_emb_stride_1, # stride for token - extra_emb_stride_2, # stride for embed dim - # Batch info - seg_lens, - seg_indptr, - weight_indices, - # Meta-parameters - BLOCK_EMBED: tl.constexpr, -): - """ - Embedding lookup for extra/added tokens (tokens >= vocab_size). - Each program handles one token across a block of embedding dimensions. - Grid: (max_len, bs) - """ - batch_id = tl.program_id(axis=1) - token_idx = tl.program_id(axis=0) - - w_index = tl.load(weight_indices + batch_id) - seg_start = tl.load(seg_indptr + batch_id) - seg_len = tl.load(seg_lens + batch_id) - - # Check if this token is within the segment - if token_idx >= seg_len: - return - - # Load the token ID - token_id = tl.load(input_ids + seg_start + token_idx) - - # Check if this is an extra token - is_extra_token = token_id >= vocab_size - - if not is_extra_token: - return # Skip non-extra tokens - - # Calculate extra token ID - extra_token_id = token_id - vocab_size - - # Process in chunks of BLOCK_EMBED dimensions - num_blocks = tl.cdiv(embed_dim, BLOCK_EMBED) - - for block_id in range(num_blocks): - embed_offset = tl.arange(0, BLOCK_EMBED) + block_id * BLOCK_EMBED - embed_mask = embed_offset < embed_dim - - # Load from extra embeddings - # extra_embeddings shape: (num_loras, num_extra_tokens, embed_dim) - extra_emb_ptr = ( - extra_embeddings - + w_index * extra_emb_stride_0 - + extra_token_id * extra_emb_stride_1 - + embed_offset * extra_emb_stride_2 - ) - emb_values = tl.load(extra_emb_ptr, mask=embed_mask, other=0.0) - - # Write to output (overwrite the position) - output_ptr = ( - output - + (seg_start + token_idx) * output_stride_0 - + embed_offset * output_stride_1 - ) - tl.store(output_ptr, emb_values, mask=embed_mask) - - -def embedding_extra_tokens_modified( - input_ids: torch.Tensor, - output: torch.Tensor, # Will be modified in-place - extra_embeddings: torch.Tensor, - batch_info: LoRABatchInfo, - vocab_size: int, -) -> torch.Tensor: - """ - Forward pass for extra token embedding lookup (in-place operation). - - Args: - input_ids: (s,) token IDs - output: (s, embed_dim) output tensor to be modified in-place - extra_embeddings: (num_loras, num_extra_tokens, embed_dim) extra token embeddings - batch_info: LoRABatchInfo containing batch information - vocab_size: base vocabulary size - - Returns: - output: (s, embed_dim) modified output tensor - """ - assert input_ids.is_contiguous() - assert output.is_contiguous() - assert extra_embeddings.is_contiguous() - assert len(input_ids.shape) == 1 - assert len(output.shape) == 2 - assert len(extra_embeddings.shape) == 3 - - S = input_ids.shape[0] - embed_dim = output.shape[1] - num_loras = extra_embeddings.shape[0] - - # Block size for embedding dimension - BLOCK_EMBED = 128 - - extra_emb_stride = ( - extra_embeddings.stride(0), - extra_embeddings.stride(1), - extra_embeddings.stride(2), - ) - - # Grid: one program per token in each batch segment - grid = ( - batch_info.max_len, - batch_info.bs, - ) - - _embedding_extra_tokens_kernel[grid]( - input_ids, - output, - extra_embeddings, - vocab_size, - embed_dim, - num_loras, - output.stride(0), - output.stride(1), - extra_emb_stride[0], - extra_emb_stride[1], - extra_emb_stride[2], - batch_info.seg_lens, - batch_info.seg_indptr, - batch_info.weight_indices, - BLOCK_EMBED, - ) - - return output - - -############################# -############################# -############################# diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 8510798fa37e..e9b152ae9614 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -426,19 +426,7 @@ def forward_generation_raw( input_ids = torch.tensor([p], device="cuda") if lora_paths is not None and lora_paths[i] is not None: - from peft import PeftConfig, PeftModel - - from sglang.srt.lora.lora_config import LoRAConfig - - peft_config = PeftConfig.from_pretrained(lora_paths[i]) - lora_config = LoRAConfig(lora_paths[i]) - if ( - "embed_tokens" in peft_config.target_modules - and lora_config.lora_added_tokens_size > 0 - ): - new_tokenizer = get_tokenizer(lora_paths[i]) - base_model.resize_token_embeddings(len(new_tokenizer)) - tokenizer = new_tokenizer + from peft import PeftModel model = PeftModel.from_pretrained( base_model, @@ -448,6 +436,7 @@ def forward_generation_raw( ) else: model = base_model + if patch_model_do_sample_false: model.generation_config.do_sample = False outputs = model.generate( From e10980dc08ca4061070ca8bc2d33e5392e8ac8dc Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Mon, 8 Dec 2025 21:52:09 +0000 Subject: [PATCH 16/19] fix CI/CD --- .../sglang/srt/lora/backend/triton_backend.py | 2 +- python/sglang/srt/lora/layers.py | 8 ++--- python/sglang/srt/lora/lora.py | 20 +++++++++-- python/sglang/srt/lora/lora_manager.py | 4 +-- python/sglang/srt/lora/mem_pool.py | 5 ++- test/srt/lora/test_lora_eviction.py | 12 ++++++- test/srt/lora/test_lora_update.py | 36 +++++++++++++++++-- 7 files changed, 70 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index a36069d4ad86..c28f3f78ae7e 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -197,7 +197,7 @@ def prepare_lora_batch( (bs,), dtype=torch.int32, device=self.device ), lora_ranks=torch.empty( - (self.max_loras_per_batch,), dtype=torch.int32, device=self.device + (self.max_loras_per_batch,), dtype=torch.int64, device=self.device ), scalings=torch.empty( (self.max_loras_per_batch,), dtype=torch.float, device=self.device diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 2f898c0c5c94..76a01ea98d5f 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -181,11 +181,9 @@ def forward(self, input_: torch.Tensor): if self.set_lora: # The backend's run_lora_a_embedding now handles both regular # and extra tokens efficiently with CUDA graph support - output = self.apply_lora(base_output, input_, batch_info) - else: - output = base_output + base_output = self.apply_lora(base_output, input_, batch_info) - return output + return base_output def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): # For TP=1, no slicing needed @@ -276,7 +274,7 @@ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): # For TP=1, no slicing needed - # For TP>1, would slice along vocab dimension, eed to modify code in: sglang/python/sglang/srt/lora/mem_pool.py + # For TP>1, would slice along vocab dimension, need to modify code in: sglang/python/sglang/srt/lora/mem_pool.py return B diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index b464b21cca9a..12c813baeb20 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -79,6 +79,14 @@ def initialize_weights(self): model_path = self.config.path loader = DefaultModelLoader(self.load_config) revision = getattr(self.config.hf_config, "revision", None) + + # Get normalized target modules for filtering + from sglang.srt.lora.utils import get_normalized_target_modules + + normalized_target_modules = get_normalized_target_modules( + self.config.target_modules + ) + for name, loaded_weight in loader._get_weights_iterator( DefaultModelLoader.Source( model_path, revision=revision, fall_back_to_pt=True @@ -88,10 +96,16 @@ def initialize_weights(self): if layer_id is not None: self.layers[layer_id].weights[name] = loaded_weight.cpu() elif "embed_tokens" in name or "lm_head" in name: - # self.embedding_layers.weights[name] = loaded_weight.cpu() - self.embedding_layers[name] = loaded_weight.cpu() + # Check if this module is declared in target_modules before loading + module_name = "embed_tokens" if "embed_tokens" in name else "lm_head" + if module_name in normalized_target_modules: + self.embedding_layers[name] = loaded_weight.cpu() + else: + logger.debug( + f"Skipping {name} as '{module_name}' is not in adapter's target_modules: {self.config.target_modules}" + ) elif "input_embeddings" in name or "output_embeddings" in name: - # added token emb + # added/extra token emb self.added_tokens_embeddings[name] = loaded_weight.cpu() assert loaded_weight.shape[0] == self.config.lora_added_tokens_size, ( f"LoRA adapter {self.uid} has extra_vocab_size {self.config.extra_vocab_size} specified in the config, " diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index ee4a92994a08..6bd05dee3db1 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -306,9 +306,7 @@ def update_lora_info(self): # Update embedding layer if present - gotta merge (refer to PR codebase) if self.embed_tokens_module is not None: self.embed_tokens_module.set_lora_info( - self.memory_pool.get_embedding_tensor( - "added_tokens", LoRAType.LORA_A - ), # choose name: "added_tokens" + self.memory_pool.get_embedding_tensor("added_tokens", LoRAType.LORA_A), self.memory_pool.get_embedding_tensor("embed_tokens", LoRAType.LORA_A), self.memory_pool.get_embedding_tensor("embed_tokens", LoRAType.LORA_B), ) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 7ddf85254896..fdebb860c626 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -536,7 +536,10 @@ def get_embedding_tensor( """ if target_module == "added_tokens": - if self.lora_added_tokens_size > 0 and self.lora_added_tokens_size != None: + if ( + self.lora_added_tokens_size is not None + and self.lora_added_tokens_size > 0 + ): return self.new_embeddings_buffer["input_embeddings"] return None elif target_module == "embed_tokens": diff --git a/test/srt/lora/test_lora_eviction.py b/test/srt/lora/test_lora_eviction.py index fc1e00e3d969..78cdd8282fe0 100644 --- a/test/srt/lora/test_lora_eviction.py +++ b/test/srt/lora/test_lora_eviction.py @@ -97,7 +97,17 @@ def _run_test( max_loras_per_batch=1, enable_lora=True, max_lora_rank=256, - lora_target_modules=["all"], + # Need to list all lora modules, or "all" might include lora modules without assigning lora weights + # lora_target_modules=["all"], + lora_target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], ) as srt_runner: adapter_sequence = lora_paths if not reverse else lora_paths[::-1] diff --git a/test/srt/lora/test_lora_update.py b/test/srt/lora/test_lora_update.py index 3f11bdd48d7d..5867e4de74f1 100644 --- a/test/srt/lora/test_lora_update.py +++ b/test/srt/lora/test_lora_update.py @@ -218,7 +218,17 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: base="meta-llama/Llama-3.1-8B-Instruct", enable_lora=True, max_lora_rank=256, - lora_target_modules=["all"], + # Need to list all lora modules, or "all" might include lora modules without assigning lora weights + # lora_target_modules=["all"], + lora_target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], max_loras_per_batch=4, all_adapters=[ "philschmid/code-llama-3-1-8b-text-to-sql-lora", @@ -751,7 +761,17 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: ], enable_lora=True, max_lora_rank=256, - lora_target_modules=["all"], + # Need to list all lora modules, or "all" might include lora modules without assigning lora weights + # lora_target_modules=["all"], + lora_target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], op_sequence=[ Operation( type=OperationType.LOAD, @@ -1503,7 +1523,17 @@ def test_v1_models_endpoint_with_lora(self): lora_paths=[], max_loras_per_batch=2, max_lora_rank=256, - lora_target_modules=["all"], + # Need to list all lora modules, or "all" might include lora modules without assigning lora weights + # lora_target_modules=["all"], + lora_target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], enable_lora=True, ) as session: # Test with no adapters loaded From 1610ce30a42ff2b27bf60fae793c5c2621f4083d Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Tue, 9 Dec 2025 01:59:25 +0000 Subject: [PATCH 17/19] merge --- .../sglang/srt/lora/backend/triton_backend.py | 1 - python/sglang/srt/lora/layers.py | 28 ++++++++++++++++--- test/srt/lora/test_lora_eviction.py | 1 - test/srt/lora/test_lora_update.py | 2 ++ 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index c28f3f78ae7e..ad79199fd27b 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -133,7 +133,6 @@ def init_cuda_graph_batch_info( scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), permutation=None, ) - # self.cuda_graph_batch_info.seg_indptr[0] = 0 # Initialize seg_indptr for CUDA graph as they remain constant # across batches. diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 76a01ea98d5f..498ab113c6ce 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -189,13 +189,23 @@ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): # For TP=1, no slicing needed # LoRA A weights (rank, vocab_size) are not sliced for embedding # For TP>1, Need to modify code in: sglang/python/sglang/srt/lora/mem_pool.py - return A + # return A + if tp_rank > 1: + raise NotImplementedError( + f"VocabParallelEmbeddingWithLoRA does not support tensor parallelism > 1. " + f"Got tp_size={tp_rank}" + ) def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): # For TP=1, no slicing needed # LoRA B weights (embedding_dim, rank) would be sliced along embedding dimension for TP>1 # For TP>1, Need to modify code in: sglang/python/sglang/srt/lora/mem_pool.py - return B + # return B + if tp_rank > 1: + raise NotImplementedError( + f"VocabParallelEmbeddingWithLoRA does not support tensor parallelism > 1. " + f"Got tp_size={tp_rank}" + ) class ParallelLMHeadWithLoRA(BaseLayerWithLoRA): @@ -270,12 +280,22 @@ def forward(self, hidden_states: torch.Tensor): def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): # For TP=1, no slicing needed # For TP>1, need to modify code in: sglang/python/sglang/srt/lora/mem_pool.py - return A + # return A + if tp_rank > 1: + raise NotImplementedError( + f"ParallelLMHeadWithLoRA does not support tensor parallelism > 1. " + f"Got tp_size={tp_rank}" + ) def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): # For TP=1, no slicing needed # For TP>1, would slice along vocab dimension, need to modify code in: sglang/python/sglang/srt/lora/mem_pool.py - return B + # return B + if tp_rank > 1: + raise NotImplementedError( + f"ParallelLMHeadWithLoRA does not support tensor parallelism > 1. " + f"Got tp_size={tp_rank}" + ) class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): diff --git a/test/srt/lora/test_lora_eviction.py b/test/srt/lora/test_lora_eviction.py index 78cdd8282fe0..7881bc07bfdb 100644 --- a/test/srt/lora/test_lora_eviction.py +++ b/test/srt/lora/test_lora_eviction.py @@ -106,7 +106,6 @@ def _run_test( "o_proj", "gate_proj", "up_proj", - "down_proj", ], ) as srt_runner: adapter_sequence = lora_paths if not reverse else lora_paths[::-1] diff --git a/test/srt/lora/test_lora_update.py b/test/srt/lora/test_lora_update.py index 5867e4de74f1..9c3f0855033b 100644 --- a/test/srt/lora/test_lora_update.py +++ b/test/srt/lora/test_lora_update.py @@ -347,6 +347,8 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: description="Test explicitly specified lora-target-modules.", base="meta-llama/Llama-3.1-8B-Instruct", max_loras_per_batch=3, + # Need to list all lora modules, or "all" might include lora modules without assigning lora weights + # lora_target_modules=["all"], lora_target_modules=[ "q_proj", "k_proj", From f09852ee6e68f31ece095247a8527b6b88016100 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Tue, 9 Dec 2025 03:47:24 +0000 Subject: [PATCH 18/19] add nightly ci/cd: test_lora_hf_sgl_logprob_diff.py --- test/run_suite_nightly.py | 1 + test/srt/lora/test_lora_eviction.py | 1 + .../srt/lora/test_lora_hf_sgl_logprob_diff.py | 559 ++++++++++++++++++ 3 files changed, 561 insertions(+) create mode 100644 test/srt/lora/test_lora_hf_sgl_logprob_diff.py diff --git a/test/run_suite_nightly.py b/test/run_suite_nightly.py index 936c7f5d8a10..14c34b6fe750 100644 --- a/test/run_suite_nightly.py +++ b/test/run_suite_nightly.py @@ -14,6 +14,7 @@ TestFile("test_lora_eviction_policy.py", 200), TestFile("test_lora_openai_api.py", 30), TestFile("test_lora_openai_compatible.py", 150), + TestFile("test_lora_hf_sgl_logprob_diff.py", 300), TestFile("test_batch_invariant_ops.py", 10), TestFile("test_cpp_radix_cache.py", 60), TestFile("test_deepseek_v3_deterministic.py", 240), diff --git a/test/srt/lora/test_lora_eviction.py b/test/srt/lora/test_lora_eviction.py index 7881bc07bfdb..78cdd8282fe0 100644 --- a/test/srt/lora/test_lora_eviction.py +++ b/test/srt/lora/test_lora_eviction.py @@ -106,6 +106,7 @@ def _run_test( "o_proj", "gate_proj", "up_proj", + "down_proj", ], ) as srt_runner: adapter_sequence = lora_paths if not reverse else lora_paths[::-1] diff --git a/test/srt/lora/test_lora_hf_sgl_logprob_diff.py b/test/srt/lora/test_lora_hf_sgl_logprob_diff.py new file mode 100644 index 000000000000..b0975fa5d666 --- /dev/null +++ b/test/srt/lora/test_lora_hf_sgl_logprob_diff.py @@ -0,0 +1,559 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Test to compare log probabilities between HuggingFace+LoRA and SGLang+LoRA. + +This test: +1. Runs SGLang with LoRA and collects log probabilities +2. Runs HuggingFace with LoRA and collects log probabilities +3. Compares the differences (max and mean) between the two implementations +4. Uses unittest framework for easy integration with test suites + +Usage: + python test_lora_hf_sgl_logprob_diff.py + or + python -m unittest test_lora_hf_sgl_logprob_diff +""" + +import multiprocessing as mp +import os +import sys +import unittest +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch + +# Add sglang to path if needed +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python")) + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.runners import HFRunner, SRTRunner + +register_cuda_ci(est_time=300, suite="nightly-1-gpu", nightly=True) + +from sglang.test.test_utils import ( + DEFAULT_PORT_FOR_SRT_TEST_RUNNER, + CustomTestCase, + is_in_ci, +) + +# Test configuration constants +LORA_BACKEND = "triton" +DISABLE_CUDA_GRAPH = False +LORA_TARGET_MODULES = None +LOGPROB_THRESHOLD = 1e-01 + +# Default test prompts +DEFAULT_TEST_PROMPTS = [ + "SGL is a", + "AI is a field of computer science focused on", + "Computer science is the study of", + "Write a short story.", + "What are the main components of a computer?", +] + +# Formatting constants +DIVIDER_WIDTH = 80 +SECTION_CHAR = "=" +SUBSECTION_CHAR = "-" + + +def print_section_header(title: str): + """Print a major section header.""" + print("\n" + SECTION_CHAR * DIVIDER_WIDTH) + print(title) + print(SECTION_CHAR * DIVIDER_WIDTH) + + +def print_subsection_header(title: str): + """Print a subsection header.""" + print(f"\n{SUBSECTION_CHAR * 40}") + print(f"{title}") + print(SUBSECTION_CHAR * 40) + + +def print_config_info(title: str, config: Dict[str, Any]): + """Print configuration information in a consistent format.""" + print_section_header(title) + for key, value in config.items(): + print(f" {key}: {value}") + + +def compare_logprobs_for_type( + sglang_logprobs: torch.Tensor, hf_logprobs: torch.Tensor, logprob_type: str +) -> Dict[str, Any]: + """ + Compare logprobs for a specific type (prefill or decode). + + Args: + sglang_logprobs: SGLang log probabilities + hf_logprobs: HuggingFace log probabilities + logprob_type: Type of logprobs ("prefill" or "decode") + + Returns: + Dictionary containing comparison statistics + """ + diff = torch.abs(sglang_logprobs - hf_logprobs) + max_diff = torch.max(diff).item() + mean_diff = torch.mean(diff).item() + shape = list(sglang_logprobs.shape) + matches_threshold = max_diff < LOGPROB_THRESHOLD + + return { + "max_diff": max_diff, + "mean_diff": mean_diff, + "shape": shape, + "matches_threshold": matches_threshold, + "type": logprob_type, + } + + +def print_logprob_comparison(comparison: Dict[str, Any]): + """Print logprob comparison results in a consistent format.""" + logprob_type = comparison["type"].capitalize() + print(f"\n{logprob_type} logprobs:") + print(f" Shape: {comparison['shape']}") + print(f" Max difference: {comparison['max_diff']:.6e}") + print(f" Mean difference: {comparison['mean_diff']:.6e}") + + status = "PASS" if comparison["matches_threshold"] else "FAIL" + print(f" Status: {status} (threshold: {LOGPROB_THRESHOLD:.0e})") + + +def compare_output_strings( + sglang_output: str, hf_output: str, max_display_len: int = 200 +) -> Dict[str, Any]: + """ + Compare output strings between SGLang and HuggingFace. + + Args: + sglang_output: SGLang generated text + hf_output: HuggingFace generated text + max_display_len: Maximum length for display + + Returns: + Dictionary containing comparison results + """ + outputs_match = sglang_output.strip() == hf_output.strip() + + # Truncate for display if needed + sglang_display = ( + sglang_output[:max_display_len] + if len(sglang_output) > max_display_len + else sglang_output + ) + hf_display = ( + hf_output[:max_display_len] if len(hf_output) > max_display_len else hf_output + ) + + return { + "match": outputs_match, + "sglang_output": sglang_output, + "hf_output": hf_output, + "sglang_display": sglang_display, + "hf_display": hf_display, + } + + +def print_output_comparison(comparison: Dict[str, Any]): + """Print output string comparison in a consistent format.""" + print(f"\nOutput strings:") + status = "MATCH" if comparison["match"] else "DIFFER" + print(f" Status: {status}") + print(f" SGLang: {comparison['sglang_display']}") + print(f" HuggingFace: {comparison['hf_display']}") + + +def prepare_lora_paths_per_prompt( + lora_paths: List[str], num_prompts: int +) -> List[Optional[str]]: + """ + Prepare LoRA paths for each prompt by cycling through available LoRAs. + + Args: + lora_paths: List of available LoRA adapter paths + num_prompts: Number of prompts to generate LoRA paths for + + Returns: + List of LoRA paths (one per prompt), or None values if no LoRAs + """ + if not lora_paths: + return [None] * num_prompts + + return [lora_paths[i % len(lora_paths)] for i in range(num_prompts)] + + +def run_sglang_with_lora( + model_path: str, + lora_paths: List[str], + prompts: List[str], + max_new_tokens: int, + torch_dtype: torch.dtype, + lora_backend: str, + port: int, + disable_cuda_graph: bool, + lora_target_modules: Optional[List[str]], + tp_size: int, +) -> Dict[str, Any]: + """Run SGLang with LoRA and return log probabilities.""" + config = { + "Model": model_path, + "LoRA paths": lora_paths, + "LoRA backend": lora_backend, + "Disable CUDA graph": disable_cuda_graph, + "Port": port, + "Number of prompts": len(prompts), + "Tensor parallel size": tp_size, + } + print_config_info("Running SGLang with LoRA", config) + + lora_paths_per_prompt = prepare_lora_paths_per_prompt(lora_paths, len(prompts)) + + with SRTRunner( + model_path, + torch_dtype=torch_dtype, + model_type="generation", + tp_size=tp_size, + lora_paths=lora_paths, + max_loras_per_batch=len(lora_paths) if lora_paths else 1, + lora_backend=lora_backend, + disable_cuda_graph=disable_cuda_graph, + disable_radix_cache=True, + port=port, + mem_fraction_static=0.88, + lora_target_modules=lora_target_modules, + ) as srt_runner: + srt_outputs = srt_runner.forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=lora_paths_per_prompt, + ) + + return { + "top_input_logprobs": srt_outputs.top_input_logprobs, + "top_output_logprobs": srt_outputs.top_output_logprobs, + "output_strs": srt_outputs.output_strs, + "lora_paths": lora_paths_per_prompt, + } + + +def run_hf_with_lora( + model_path: str, + lora_paths: List[str], + prompts: List[str], + max_new_tokens: int, + torch_dtype: torch.dtype, +) -> Dict[str, Any]: + """Run HuggingFace with LoRA and return log probabilities.""" + config = { + "Model": model_path, + "LoRA paths": lora_paths, + "Number of prompts": len(prompts), + } + print_config_info("Running HuggingFace with LoRA", config) + + lora_paths_per_prompt = prepare_lora_paths_per_prompt(lora_paths, len(prompts)) + + with HFRunner( + model_path, + torch_dtype=torch_dtype, + model_type="generation", + patch_model_do_sample_false=True, + ) as hf_runner: + hf_outputs = hf_runner.forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=lora_paths_per_prompt, + ) + + return { + "top_input_logprobs": hf_outputs.top_input_logprobs, + "top_output_logprobs": hf_outputs.top_output_logprobs, + "output_strs": hf_outputs.output_strs, + "lora_paths": lora_paths_per_prompt, + } + + +def compare_single_prompt( + prompt_idx: int, + sglang_data: Dict[str, Any], + hf_data: Dict[str, Any], +) -> Dict[str, Any]: + """ + Compare logprobs and outputs for a single prompt. + + Args: + prompt_idx: Index of the prompt being compared + sglang_data: SGLang results data + hf_data: HuggingFace results data + + Returns: + Dictionary containing all comparison results + """ + print_subsection_header(f"Prompt {prompt_idx + 1}") + print(f"LoRA adapter: {sglang_data['lora_paths'][prompt_idx]}") + + result = { + "prompt_idx": prompt_idx, + "lora_path": sglang_data["lora_paths"][prompt_idx], + } + + # Compare prefill (input) logprobs + sglang_prefill = torch.tensor(sglang_data["top_input_logprobs"][prompt_idx]) + hf_prefill = torch.tensor(hf_data["top_input_logprobs"][prompt_idx]) + prefill_comparison = compare_logprobs_for_type( + sglang_prefill, hf_prefill, "prefill" + ) + print_logprob_comparison(prefill_comparison) + + # Store prefill results + result["prefill_max_diff"] = prefill_comparison["max_diff"] + result["prefill_mean_diff"] = prefill_comparison["mean_diff"] + result["prefill_shape"] = prefill_comparison["shape"] + result["prefill_logprob_match"] = prefill_comparison["matches_threshold"] + + # Compare decode (output) logprobs + sglang_decode = torch.tensor(sglang_data["top_output_logprobs"][prompt_idx]) + hf_decode = torch.tensor(hf_data["top_output_logprobs"][prompt_idx]) + decode_comparison = compare_logprobs_for_type(sglang_decode, hf_decode, "decode") + print_logprob_comparison(decode_comparison) + + # Store decode results + result["decode_max_diff"] = decode_comparison["max_diff"] + result["decode_mean_diff"] = decode_comparison["mean_diff"] + result["decode_shape"] = decode_comparison["shape"] + result["decode_logprob_match"] = decode_comparison["matches_threshold"] + + # Overall logprob match + result["overall_logprob_match"] = ( + prefill_comparison["matches_threshold"] + and decode_comparison["matches_threshold"] + ) + + # Compare output strings + sglang_output = sglang_data["output_strs"][prompt_idx] + hf_output = hf_data["output_strs"][prompt_idx] + output_comparison = compare_output_strings(sglang_output, hf_output) + print_output_comparison(output_comparison) + + # Store output results + result["outputs_match"] = output_comparison["match"] + result["sglang_output"] = output_comparison["sglang_output"] + result["hf_output"] = output_comparison["hf_output"] + + return result + + +def print_overall_statistics(results: List[Dict[str, Any]]): + """Print overall statistics across all prompts.""" + print_section_header("Overall Statistics") + + # Gather statistics + prefill_max_diffs = [r["prefill_max_diff"] for r in results] + prefill_mean_diffs = [r["prefill_mean_diff"] for r in results] + decode_max_diffs = [r["decode_max_diff"] for r in results] + decode_mean_diffs = [r["decode_mean_diff"] for r in results] + + # Print logprob statistics + print("\nLogprob Differences:") + print(f" Prefill:") + print(f" Max of max: {max(prefill_max_diffs):.6e}") + print(f" Mean of max: {np.mean(prefill_max_diffs):.6e}") + print(f" Mean of mean: {np.mean(prefill_mean_diffs):.6e}") + + print(f" Decode:") + print(f" Max of max: {max(decode_max_diffs):.6e}") + print(f" Mean of max: {np.mean(decode_max_diffs):.6e}") + print(f" Mean of mean: {np.mean(decode_mean_diffs):.6e}") + + # Print match statistics + num_prompts = len(results) + logprob_match_count = sum(r["overall_logprob_match"] for r in results) + prefill_match_count = sum(r["prefill_logprob_match"] for r in results) + decode_match_count = sum(r["decode_logprob_match"] for r in results) + outputs_match_count = sum(r["outputs_match"] for r in results) + + print(f"\nLogprob Statistics (threshold: {LOGPROB_THRESHOLD:.0e}):") + overall_status = "PASSED" if logprob_match_count == num_prompts else "FAILED" + print(f" Overall logprob: {logprob_match_count}/{num_prompts} {overall_status}") + print(f" Prefill logprob: {prefill_match_count}/{num_prompts}") + print(f" Decode logprob: {decode_match_count}/{num_prompts}") + + print(f"\nString Statistics:") + print(f" Output strings: {outputs_match_count}/{num_prompts}") + + # Return overall stats for saving + return { + "logprob_differences": { + "prefill": { + "max_of_max_diffs": max(prefill_max_diffs), + "mean_of_max_diffs": float(np.mean(prefill_max_diffs)), + "mean_of_mean_diffs": float(np.mean(prefill_mean_diffs)), + }, + "decode": { + "max_of_max_diffs": max(decode_max_diffs), + "mean_of_max_diffs": float(np.mean(decode_max_diffs)), + "mean_of_mean_diffs": float(np.mean(decode_mean_diffs)), + }, + }, + "match_statistics": { + "overall_logprob_match_rate": logprob_match_count / num_prompts, + "prefill_logprob_match_rate": prefill_match_count / num_prompts, + "decode_logprob_match_rate": decode_match_count / num_prompts, + "outputs_match_rate": outputs_match_count / num_prompts, + }, + } + + +def compare_logprobs( + sglang_logprobs: Dict[str, Any], hf_logprobs: Dict[str, Any] +) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + """Compare log probabilities and compute statistics.""" + print_section_header("Comparing Log Probabilities") + + results = [] + num_prompts = len(sglang_logprobs["top_input_logprobs"]) + + for i in range(num_prompts): + result = compare_single_prompt(i, sglang_logprobs, hf_logprobs) + results.append(result) + + overall_stats = print_overall_statistics(results) + + return results, overall_stats + + +class TestLoRAHFSGLLogprobDifference(CustomTestCase): + """ + Test case to compare log probabilities between HuggingFace+LoRA and SGLang+LoRA. + """ + + def _run_comparison_test( + self, + model_path: str, + lora_paths: List[str], + prompts: List[str], + max_new_tokens: int = 32, + torch_dtype: torch.dtype = torch.float16, + lora_backend: str = LORA_BACKEND, + port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, + disable_cuda_graph: bool = DISABLE_CUDA_GRAPH, + lora_target_modules: Optional[List[str]] = LORA_TARGET_MODULES, + tp_size: int = 1, + ): + """ + Run comparison test between SGLang and HuggingFace with LoRA. + """ + print_section_header(f"Testing {model_path} with LoRA adapters") + + # Step 1: Run SGLang with LoRA + sglang_logprobs = run_sglang_with_lora( + model_path=model_path, + lora_paths=lora_paths, + prompts=prompts, + max_new_tokens=max_new_tokens, + torch_dtype=torch_dtype, + lora_backend=lora_backend, + port=port, + disable_cuda_graph=disable_cuda_graph, + lora_target_modules=lora_target_modules, + tp_size=tp_size, + ) + + # Clear GPU memory + print("\nClearing GPU memory...") + torch.cuda.empty_cache() + + # Step 2: Run HuggingFace with LoRA + hf_logprobs = run_hf_with_lora( + model_path=model_path, + lora_paths=lora_paths, + prompts=prompts, + max_new_tokens=max_new_tokens, + torch_dtype=torch_dtype, + ) + + # Step 3: Compare log probabilities + results, overall_stats = compare_logprobs(sglang_logprobs, hf_logprobs) + + # Assert that all prompts pass the threshold + for result in results: + self.assertTrue( + result["prefill_logprob_match"], + f"Prefill logprob mismatch for prompt {result['prompt_idx']} " + f"(max_diff={result['prefill_max_diff']:.6e}, threshold={LOGPROB_THRESHOLD:.0e})", + ) + self.assertTrue( + result["decode_logprob_match"], + f"Decode logprob mismatch for prompt {result['prompt_idx']} " + f"(max_diff={result['decode_max_diff']:.6e}, threshold={LOGPROB_THRESHOLD:.0e})", + ) + + print_section_header("Test completed successfully!") + + return results, overall_stats + + def test_lora_logprob_comparison_basic(self): + """ + Basic test comparing HF and SGLang LoRA logprobs with small model. + """ + # Use a smaller model and shorter prompts for CI + if is_in_ci(): + self.skipTest("Skipping in CI environment - requires large models") + + model_path = "meta-llama/Llama-2-7b-hf" + lora_paths = ["yushengsu/sglang_lora_logprob_diff_without_tuning"] + prompts = DEFAULT_TEST_PROMPTS[:2] # Use fewer prompts for faster testing + + self._run_comparison_test( + model_path=model_path, + lora_paths=lora_paths, + prompts=prompts, + max_new_tokens=32, + ) + + def test_lora_logprob_comparison_full(self): + """ + Full test comparing HF and SGLang LoRA logprobs with all prompts. + """ + if is_in_ci(): + self.skipTest("Skipping in CI environment - requires large models") + + model_path = "meta-llama/Llama-2-7b-hf" + lora_paths = ["yushengsu/sglang_lora_logprob_diff_without_tuning"] + prompts = DEFAULT_TEST_PROMPTS + + self._run_comparison_test( + model_path=model_path, + lora_paths=lora_paths, + prompts=prompts, + max_new_tokens=32, + ) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + try: + unittest.main(warnings="ignore", verbosity=2) + finally: + # Final cleanup + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() From 8d8ba2956b6ac34e91ba64be721ca928b99ea8d5 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Tue, 9 Dec 2025 04:38:11 +0000 Subject: [PATCH 19/19] update nightly ci/cd --- test/srt/run_suite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index b7094d9ef63b..1ca880e86a79 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -204,6 +204,7 @@ # Nightly test suites have been moved to test/run_suite_nightly.py "__not_in_ci__": [ TestFile("test_release_memory_occupation.py", 200), # Temporarily disabled + TestFile("lora/test_lora_hf_sgl_logprob_diff.py"), # Nightly test TestFile("models/test_dummy_grok_models.py"), TestFile( "rl/test_update_weights_from_disk.py"