diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 67adcd572a1b..f0d1aabbb44f 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -837,7 +837,10 @@ def _get_logits( ) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) - if hasattr(lm_head, "weight"): + if hasattr(lm_head, "set_lora") and hasattr(lm_head, "apply_lora"): + # This is a LoRA-wrapped module, use its forward method + logits = lm_head(hidden_states) + elif hasattr(lm_head, "weight"): if self.use_fp32_lm_head: logits = torch.matmul( hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index 77654c4b2d32..06e4e8ba5a95 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -19,6 +19,52 @@ def __init__(self, max_loras_per_batch: int, device: torch.device): self.max_loras_per_batch = max_loras_per_batch self.device = device + def run_lora_a_embedding( + self, + input_ids: torch.Tensor, + weights: torch.Tensor, + vocab_size: int, + extra_embeddings: torch.Tensor = None, + *args, + **kwargs, + ) -> torch.Tensor: + """Run LoRA A embedding lookup with CUDA graph support. + + Args: + input_ids: token IDs with shape (s,), where s is the sum of all sequence lengths + weights: LoRA A embedding weights with shape (num_loras, rank, vocab_size) + vocab_size: base vocabulary size (tokens >= vocab_size are extra tokens) + extra_embeddings: extra token embeddings with shape (num_loras, num_extra_tokens, rank) + Only needed if there are added tokens beyond base vocabulary. + + Returns: + result with shape (s, rank) + """ + pass + + def run_extra_token_embedding( + self, + input_ids: torch.Tensor, + output: torch.Tensor, + extra_embeddings: torch.Tensor, + vocab_size: int, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Apply extra token embeddings to output in-place. + + Args: + input_ids: (s,) token IDs + output: (s, embed_dim) output tensor to be modified + extra_embeddings: (num_loras, num_extra_tokens, embed_dim) extra embeddings + vocab_size: base vocabulary size + + Returns: + output: modified output tensor + """ + raise NotImplementedError + def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs ) -> torch.Tensor: diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index 1c2e319dd397..ad79199fd27b 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -2,6 +2,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.triton_ops import ( + embedding_lora_a_fwd, gate_up_lora_b_fwd, qkv_lora_b_fwd, sgemm_lora_a_fwd, @@ -22,6 +23,24 @@ def __init__( ): super().__init__(max_loras_per_batch, device) + def run_lora_a_embedding( + self, + input_ids: torch.Tensor, + weights: torch.Tensor, + vocab_size: int, + extra_embeddings: torch.Tensor = None, + *args, + **kwargs, + ) -> torch.Tensor: + """Run LoRA A embedding lookup using Triton kernel.""" + return embedding_lora_a_fwd( + input_ids=input_ids, + weights=weights, + batch_info=self.batch_info, + vocab_size=vocab_size, + extra_embeddings=extra_embeddings, + ) + def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs ) -> torch.Tensor: @@ -107,7 +126,7 @@ def init_cuda_graph_batch_info( seg_lens=torch.full( (max_bs_in_cuda_graph,), num_tokens_per_bs, dtype=torch.int32 ), - seg_indptr=torch.empty(max_bs_in_cuda_graph + 1, dtype=torch.int32), + seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32), max_len=num_tokens_per_bs, weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32), lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32), @@ -161,7 +180,7 @@ def prepare_lora_batch( seg_lens = ( forward_batch.extend_seq_lens if forward_batch.forward_mode.is_extend() - else torch.ones(bs, device=self.device) + else torch.ones(bs, dtype=torch.int32, device=self.device) ) seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device) seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 139d97cbca31..498ab113c6ce 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -1,4 +1,7 @@ +from typing import Optional + import torch +import torch.nn.functional as F from torch import nn from sglang.srt.distributed import ( @@ -13,8 +16,12 @@ QKVParallelLinear, RowParallelLinear, ) -from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) from sglang.srt.lora.backend.base_backend import BaseLoRABackend +from sglang.srt.lora.utils import LoRABatchInfo class BaseLayerWithLoRA(nn.Module): @@ -45,11 +52,10 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): """ - Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation). + Vocab parallel embedding layer with LoRA support (simplified for TP=1, no extra tokens). - Note: The current version does not yet implement the LoRA functionality. - This class behaves exactly the same as the base VocabParallelEmbedding. - Future versions will integrate LoRA functionality to support efficient parameter fine-tuning. + For embedding layers: output = base_embedding(x) + lora_B @ lora_A[x] + where lora_A[x] is direct embedding lookup from lora_A weights. """ def __init__( @@ -59,6 +65,237 @@ def __init__( ) -> None: super().__init__(base_layer, lora_backend) self.weight = base_layer.weight + self.embed_dim = base_layer.embedding_dim + self.vocab_size = base_layer.org_vocab_size + + self.output_offset = torch.tensor( + [0, self.embed_dim], + dtype=torch.int32, + device=next(base_layer.parameters()).device, + ) + + def set_lora_info( + self, + new_embeddings_buffer: Optional[torch.Tensor], # For extra tokens + embedding_A_buffer: torch.Tensor, + embedding_B_buffer: torch.Tensor, + ): + """Set LoRA buffers for embedding layer.""" + self.set_lora = True + self.new_embeddings_buffer = new_embeddings_buffer + self.embedding_A_buffer = embedding_A_buffer # (num_loras, rank, vocab_size) + self.embedding_B_buffer = embedding_B_buffer # (num_loras, embed_dim, rank) + + def apply_lora( + self, base_output: torch.Tensor, input_: torch.Tensor, batch_info + ) -> torch.Tensor: + """ + Apply LoRA to base embedding output. + Formula: output = base_output + lora_B @ lora_A_embedding(input_) + """ + + # Efficient embedding lookup for LoRA A (already support extra token embedding process) + lora_a_output = self.run_lora_a_embedding(input_, batch_info) + + # Apply LoRA B weights using backend + lora_output = self.lora_backend.run_lora_b_sgemm( + x=lora_a_output, + weights=self.embedding_B_buffer, + output_offset=self.output_offset, + base_output=base_output, + ) + return lora_output + + def run_lora_a_embedding( + self, input_: torch.Tensor, batch_info: LoRABatchInfo + ) -> torch.Tensor: + """ + Apply LoRA A weights using efficient embedding lookup with CUDA graph support. + Maps tokens to their corresponding LoRA adapters internally. + It also includes added/extra token processing. + """ + # Efficient embedding lookup for LoRA A (already support extra token embedding process) + lora_a_output = self.lora_backend.run_lora_a_embedding( + input_ids=input_, + weights=self.embedding_A_buffer, + vocab_size=self.vocab_size, + extra_embeddings=( + self.new_embeddings_buffer + if hasattr(self, "new_embeddings_buffer") + and self.new_embeddings_buffer is not None + else None + ), + ) + + return lora_a_output + + def extra_token_embedding( + self, input_: torch.Tensor, base_output: torch.Tensor + ) -> torch.Tensor: + """ + Need to impl: + + Process extra tokens (tokens >= vocab_size) by looking up their embeddings + from the new_embeddings_buffer and replacing them in base_output. + + Args: + input_: (s,) token IDs + base_output: (s, embed_dim) base embedding output to be modified in-place + + Returns: + base_output: (s, embed_dim) modified input base_output (tensor[0,0,0,...]) with extra token embeddings + """ + # return base_output + raise NotImplementedError( + "Error in sglang/python/sglang/srt/lora/layers.py - VocabParallelEmbeddingWithLoRA \n" + "Current SGLang codebase did not support tuned lora with extra/added tokens. \n" + "[TODO]: \n" + "1. Refer to this commit: https://github.com/yushengsu-thu/sglang/commit/90415211eee8a28a316de262583d4d33fa615d10#diff-191177438bcc223837963de63c005850371f8c8a860acb153b26744b66ecc623 to complete \n" + "2. And then you need to modified the en/decoder tokenizer - tokenizer_manager.py to support extra_token_embedding in-place. \n" + ) + + def forward(self, input_: torch.Tensor): + """ + Forward pass with LoRA support and CUDA graph compatibility. + + Extra tokens (tokens >= vocab_size) are now handled efficiently + in the backend's run_lora_a_embedding method. + """ + batch_info = self.lora_backend.batch_info + + # Get base embedding output + # For tokens >= vocab_size, base_layer will clamp or handle them + # We mask them to 0 to avoid out-of-bounds access + added_tokens_mask = input_ > self.vocab_size - 1 + base_output = self.base_layer.forward(input_.masked_fill(added_tokens_mask, 0)) + + # [TODO] SGLang did not support extra/added token process; thus, self.extra_token_embedding only return original input_ now + # Extra tokens - It will replace extra token embedding with self.new_embeddings_buffer's emb (Default is 0) + if ( + hasattr(self, "new_embeddings_buffer") + and self.new_embeddings_buffer is not None + ): + base_output = self.extra_token_embedding(input_, base_output) + + # Apply LoRA if configured + if self.set_lora: + # The backend's run_lora_a_embedding now handles both regular + # and extra tokens efficiently with CUDA graph support + base_output = self.apply_lora(base_output, input_, batch_info) + + return base_output + + def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): + # For TP=1, no slicing needed + # LoRA A weights (rank, vocab_size) are not sliced for embedding + # For TP>1, Need to modify code in: sglang/python/sglang/srt/lora/mem_pool.py + # return A + if tp_rank > 1: + raise NotImplementedError( + f"VocabParallelEmbeddingWithLoRA does not support tensor parallelism > 1. " + f"Got tp_size={tp_rank}" + ) + + def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): + # For TP=1, no slicing needed + # LoRA B weights (embedding_dim, rank) would be sliced along embedding dimension for TP>1 + # For TP>1, Need to modify code in: sglang/python/sglang/srt/lora/mem_pool.py + # return B + if tp_rank > 1: + raise NotImplementedError( + f"VocabParallelEmbeddingWithLoRA does not support tensor parallelism > 1. " + f"Got tp_size={tp_rank}" + ) + + +class ParallelLMHeadWithLoRA(BaseLayerWithLoRA): + """ + Parallel LM Head layer with LoRA support (simplified for TP=1). + + The LM head computes logits = hidden_states @ (W + B @ A)^T + """ + + def __init__( + self, + base_layer: ParallelLMHead, + lora_backend: BaseLoRABackend, + ) -> None: + super().__init__(base_layer, lora_backend) + self.weight = base_layer.weight + self.embed_dim = base_layer.embedding_dim + self.vocab_size = base_layer.org_vocab_size + self.output_offset = torch.tensor( + [0, self.vocab_size], + dtype=torch.int32, + device=next(base_layer.parameters()).device, + ) + + def set_lora_info( + self, + lm_head_A_buffer: torch.Tensor, + lm_head_B_buffer: torch.Tensor, + ): + """Set LoRA buffers for LM head layer.""" + self.set_lora = True + self.lm_head_A_buffer = lm_head_A_buffer # (num_loras, rank, hidden_dim) + self.lm_head_B_buffer = lm_head_B_buffer # (num_loras, vocab_size, rank) + + def apply_lora( + self, base_output: torch.Tensor, hidden_states: torch.Tensor + ) -> torch.Tensor: + """ + Apply LoRA to LM head layer. + + For LM head: output = hidden @ (W + B @ A)^T + = hidden @ W^T + hidden @ A^T @ B^T + = base_output + (hidden @ A^T) @ B^T + """ + # Apply lora_A^T: hidden_states @ A^T + lora_a_output = self.lora_backend.run_lora_a_sgemm( + hidden_states, self.lm_head_A_buffer + ) + + # Apply lora_B^T: lora_a_output @ B^T + lora_output = self.lora_backend.run_lora_b_sgemm( + x=lora_a_output, + weights=self.lm_head_B_buffer, + output_offset=self.output_offset, + base_output=base_output, + ) + + return lora_output + + def forward(self, hidden_states: torch.Tensor): + # Apply base linear transformation + base_output = F.linear( + hidden_states, self.weight, bias=getattr(self.base_layer, "bias", None) + ) + + # Apply LoRA if set + if self.set_lora: + base_output = self.apply_lora(base_output, hidden_states) + + return base_output + + def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): + # For TP=1, no slicing needed + # For TP>1, need to modify code in: sglang/python/sglang/srt/lora/mem_pool.py + # return A + if tp_rank > 1: + raise NotImplementedError( + f"ParallelLMHeadWithLoRA does not support tensor parallelism > 1. " + f"Got tp_size={tp_rank}" + ) + + def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): + # For TP=1, no slicing needed + # For TP>1, would slice along vocab dimension, need to modify code in: sglang/python/sglang/srt/lora/mem_pool.py + # return B + if tp_rank > 1: + raise NotImplementedError( + f"ParallelLMHeadWithLoRA does not support tensor parallelism > 1. " + f"Got tp_size={tp_rank}" + ) class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): @@ -224,6 +461,7 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor output_offset_cpu=self.output_offset_cpu, max_qkv_out_dim=self.max_qkv_out_dim, ) + return lora_output def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): @@ -343,6 +581,7 @@ def get_lora_layer( ) -> BaseLayerWithLoRA: supported_layer_types = { # the order matters + ParallelLMHead: ParallelLMHeadWithLoRA, VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, QKVParallelLinear: QKVParallelLinearWithLoRA, MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 995aca6e5e36..12c813baeb20 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -71,11 +71,22 @@ def __init__( ] ) + self.embedding_layers: Dict[str, torch.Tensor] = {} + self.added_tokens_embeddings: Dict[str, torch.Tensor] = {} + # initialize the LoRA weights to cpu def initialize_weights(self): model_path = self.config.path loader = DefaultModelLoader(self.load_config) revision = getattr(self.config.hf_config, "revision", None) + + # Get normalized target modules for filtering + from sglang.srt.lora.utils import get_normalized_target_modules + + normalized_target_modules = get_normalized_target_modules( + self.config.target_modules + ) + for name, loaded_weight in loader._get_weights_iterator( DefaultModelLoader.Source( model_path, revision=revision, fall_back_to_pt=True @@ -84,6 +95,22 @@ def initialize_weights(self): layer_id = get_layer_id(name) if layer_id is not None: self.layers[layer_id].weights[name] = loaded_weight.cpu() + elif "embed_tokens" in name or "lm_head" in name: + # Check if this module is declared in target_modules before loading + module_name = "embed_tokens" if "embed_tokens" in name else "lm_head" + if module_name in normalized_target_modules: + self.embedding_layers[name] = loaded_weight.cpu() + else: + logger.debug( + f"Skipping {name} as '{module_name}' is not in adapter's target_modules: {self.config.target_modules}" + ) + elif "input_embeddings" in name or "output_embeddings" in name: + # added/extra token emb + self.added_tokens_embeddings[name] = loaded_weight.cpu() + assert loaded_weight.shape[0] == self.config.lora_added_tokens_size, ( + f"LoRA adapter {self.uid} has extra_vocab_size {self.config.extra_vocab_size} specified in the config, " + f"but the loaded weight has {loaded_weight.shape[0]} extra vocab size" + ) # normalize kv_proj and gate_up_proj for layer in self.layers: diff --git a/python/sglang/srt/lora/lora_config.py b/python/sglang/srt/lora/lora_config.py index 185b7b8246ee..a5cc80fab979 100644 --- a/python/sglang/srt/lora/lora_config.py +++ b/python/sglang/srt/lora/lora_config.py @@ -27,13 +27,14 @@ def __init__( self.hf_config = self.get_lora_config() self.target_modules = self.hf_config["target_modules"] - # TODO: Support more modules - if any(module in self.target_modules for module in ["embed_tokens", "lm_head"]): - raise ValueError("Not supported yet") - self.r = self.hf_config["r"] self.lora_alpha = self.hf_config["lora_alpha"] + self.added_tokens_config = self.get_added_tokens_config() + self.lora_added_tokens_size = ( + len(self.added_tokens_config) if self.added_tokens_config is not None else 0 + ) + def get_lora_config(self, dummy=False): if dummy: raise NotImplementedError() @@ -45,3 +46,30 @@ def get_lora_config(self, dummy=False): config_name = "adapter_config.json" with open(os.path.join(weights_dir, config_name), "r") as f: return json.load(f) + + def get_added_tokens_config(self): + """Load added tokens from the LoRA adapter if the file exists.""" + # Determine the weights directory + if not os.path.isdir(self.path): + weights_dir = snapshot_download(self.path, allow_patterns=["*.json"]) + else: + weights_dir = self.path + + # Construct the path to added_tokens.json + added_tokens_path = os.path.join(weights_dir, "added_tokens.json") + + # Return None if the file doesn't exist (optional for standard LoRA adapters) + if not os.path.exists(added_tokens_path): + return None + + # Load and return the added tokens + try: + with open(added_tokens_path, "r") as f: + return json.load(f) + except json.JSONDecodeError as e: + # Log warning but don't crash if JSON is malformed + import logging + + logger = logging.getLogger(__name__) + logger.warning(f"Failed to parse added_tokens.json: {e}") + return None diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index fa75fe003921..6bd05dee3db1 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -22,6 +22,10 @@ from sglang.srt.configs.load_config import LoadConfig from sglang.srt.layers.utils import get_layer_id +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.backend.lora_registry import get_backend_from_name from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer @@ -67,6 +71,7 @@ def __init__( self.device: torch.device = next(self.base_model.parameters()).device self.tp_size: int = tp_size self.tp_rank: int = tp_rank + self.lora_added_tokens_size: Optional[int] = None # Store eviction policy from server args self.eviction_policy = server_args.lora_eviction_policy @@ -244,6 +249,8 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): lora_adapters=self.loras, lora_modules=self.lora_modules, lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation. + lora_embed_tokens_module=self.embed_tokens_module, # merge into embedding or lora module + lora_lm_head_module=self.lm_head_module, # merge into embedding or lora module ) # set up batch info shared by all lora modules @@ -296,6 +303,21 @@ def update_lora_info(self): ), ) + # Update embedding layer if present - gotta merge (refer to PR codebase) + if self.embed_tokens_module is not None: + self.embed_tokens_module.set_lora_info( + self.memory_pool.get_embedding_tensor("added_tokens", LoRAType.LORA_A), + self.memory_pool.get_embedding_tensor("embed_tokens", LoRAType.LORA_A), + self.memory_pool.get_embedding_tensor("embed_tokens", LoRAType.LORA_B), + ) + + # Update lm_head layer if present + if self.lm_head_module is not None: + self.lm_head_module.set_lora_info( + self.memory_pool.get_embedding_tensor("lm_head", LoRAType.LORA_A), + self.memory_pool.get_embedding_tensor("lm_head", LoRAType.LORA_B), + ) + def init_state( self, max_lora_rank: Optional[int] = None, @@ -390,6 +412,24 @@ def init_lora_shapes( default=0, ) + # Auto-infer self.lora_added_vocab_size from loaded LoRA configs + # This happens automatically without requiring user input + # if self.lora_added_vocab_size is None: + if self.lora_added_tokens_size is None: + inferred_extra_vocab_size = next( + ( + x.lora_added_tokens_size + for x in self.configs.values() + if x.lora_added_tokens_size > 0 + ), + 0, + ) + if inferred_extra_vocab_size > 0: + logger.info( + f"self.lora_added_tokens_size={inferred_extra_vocab_size} from LoRA adapters." + ) + self.lora_added_tokens_size = inferred_extra_vocab_size + def load_lora_weights(self, lora_ref: LoRARef): """ Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation. @@ -416,6 +456,7 @@ def init_memory_pool(self): target_modules=self.target_modules, base_model=self.base_model, eviction_policy=self.eviction_policy, + lora_added_tokens_size=self.lora_added_tokens_size, ) def set_lora_module(self, module_name, module): @@ -429,6 +470,9 @@ def init_lora_modules(self): {} for _ in range(self.base_hf_config.num_hidden_layers) ] + self.embed_tokens_module: Optional[BaseLayerWithLoRA] = None + self.lm_head_module: Optional[BaseLayerWithLoRA] = None + for module_name, module in self.base_model.named_modules(): # TODO (lifuhuang): in the future, we should consider generalizing the # should_apply_lora function to support mapping by full module name instead @@ -440,6 +484,24 @@ def init_lora_modules(self): ) and not self.base_model.should_apply_lora(module_name): continue + # Handle embed_tokens + if "embed_tokens" in module_name and "embed_tokens" in self.target_modules: + if isinstance(module, VocabParallelEmbedding) and not isinstance( + module, BaseLayerWithLoRA + ): + lora_module = self.set_lora_module(module_name, module) + self.embed_tokens_module = lora_module + continue + + # Handle lm_head + if "lm_head" in module_name and "lm_head" in self.target_modules: + if isinstance(module, ParallelLMHead) and not isinstance( + module, BaseLayerWithLoRA + ): + lora_module = self.set_lora_module(module_name, module) + self.lm_head_module = lora_module + continue + # The module should be converted if it is included in target_names if module_name.split(".")[-1] in self.target_modules: layer_id = get_layer_id(module_name) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index f6375361700e..fdebb860c626 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -10,6 +10,7 @@ from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.utils import ( + EMBEDDING_NAMES, ROW_PARALLELISM_LINEAR_LORA_NAMES, LoRAType, get_hidden_dim, @@ -56,6 +57,7 @@ def __init__( target_modules: Set[str], base_model: torch.nn.Module, eviction_policy: str, + lora_added_tokens_size: int, ): self.base_hf_config: AutoConfig = base_hf_config self.num_layer: int = base_hf_config.num_hidden_layers @@ -63,6 +65,7 @@ def __init__( self.dtype: torch.dtype = dtype self.tp_size: int = tp_size self.tp_rank: int = tp_rank + self.lora_added_tokens_size: int = lora_added_tokens_size self.max_lora_rank: int = max_lora_rank self.target_modules: Set[str] = target_modules @@ -77,6 +80,15 @@ def __init__( self.A_buffer: Dict[str, List[torch.Tensor]] = {} self.B_buffer: Dict[str, List[torch.Tensor]] = {} + self.embedding_A_buffer: Dict[str, torch.Tensor] = {} + self.embedding_B_buffer: Dict[str, torch.Tensor] = {} + + self.lm_head_A_buffer: Dict[str, torch.Tensor] = {} + self.lm_head_B_buffer: Dict[str, torch.Tensor] = {} + self.new_embeddings_buffer: Dict[str, torch.Tensor] = {} + + self.embedding_dim: int = self.base_hf_config.hidden_size + # Lora uid -> buffer idx in memory pool self.uid_to_buffer_id: Dict[Optional[str], int] = {} @@ -100,6 +112,8 @@ def _can_support(config: LoRAConfig) -> bool: """ if config.r > self.max_lora_rank: return False + if config.lora_added_tokens_size > self.lora_added_tokens_size: + return False target_module_names = get_normalized_target_modules(config.target_modules) return target_module_names.issubset(self.target_modules) @@ -130,6 +144,23 @@ def get_lora_A_shape( input_dim, ) + def get_embedding_lora_A_shape( + self, + module_name: str, + base_model: torch.nn.Module, + max_lora_dim: int, + layer_idx: int, + ) -> Tuple[int]: + input_dim, _ = get_hidden_dim( + module_name, self.base_hf_config, base_model, 0, self.lora_added_tokens_size + ) + # Have not imp self.tp_size > 1 yet. + return ( + self.max_loras_per_batch, + max_lora_dim, + input_dim, + ) + def get_lora_B_shape( self, module_name: str, @@ -151,6 +182,23 @@ def get_lora_B_shape( max_lora_dim, ) + def get_embedding_lora_B_shape( + self, + module_name: str, + base_model: torch.nn.Module, + max_lora_dim: int, + layer_idx: int, + ) -> Tuple[int]: + _, output_dim = get_hidden_dim( + module_name, self.base_hf_config, base_model, 0, self.lora_added_tokens_size + ) + # Have not imp self.tp_size > 1 yet. + return ( + self.max_loras_per_batch, + output_dim, + max_lora_dim, + ) + def init_buffers(self, base_model: torch.nn.Module): device = next(base_model.parameters()).device @@ -159,6 +207,7 @@ def init_buffer( target_modules: Set[str], get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]], ): + target_modules = target_modules - set(EMBEDDING_NAMES) for module_name in target_modules: buffer[module_name] = [ torch.empty( @@ -174,6 +223,61 @@ def init_buffer( for idx in range(self.num_layer) ] + def init_embedding_buffer( + buffer: Dict[str, torch.Tensor], + target_modules: Set[str], + get_lora_shape_fn: Callable[[int], Tuple[int]], + ): + target_modules = target_modules & set(EMBEDDING_NAMES) + for module_name in target_modules: + buffer[module_name] = torch.empty( + get_lora_shape_fn( + module_name, + base_model, + self.max_lora_rank, + 0, + ), + dtype=self.dtype, + device=device, + ) + + if self.lora_added_tokens_size > 0: + self.new_embeddings_buffer["input_embeddings"] = torch.empty( + ( + self.max_loras_per_batch, + self.lora_added_tokens_size, + self.embedding_dim, + ), + dtype=self.dtype, + device=device, + ) + + if "embed_tokens" in self.target_modules: + init_embedding_buffer( + self.embedding_A_buffer, + self.target_modules, + self.get_embedding_lora_A_shape, + ) + + init_embedding_buffer( + self.embedding_B_buffer, + self.target_modules, + self.get_embedding_lora_B_shape, + ) + + if "lm_head" in self.target_modules: + init_embedding_buffer( + self.lm_head_A_buffer, + self.target_modules, + self.get_embedding_lora_A_shape, + ) + + init_embedding_buffer( + self.lm_head_B_buffer, + self.target_modules, + self.get_embedding_lora_B_shape, + ) + init_buffer( self.A_buffer, self.target_modules, @@ -192,6 +296,8 @@ def prepare_lora_batch( lora_adapters: Dict[str, LoRAAdapter], lora_modules: List[Dict[str, BaseLayerWithLoRA]], lora_refs: Dict[str, LoRARef], + lora_embed_tokens_module: Dict[str, BaseLayerWithLoRA], + lora_lm_head_module: Dict[str, BaseLayerWithLoRA], ): def get_available_buffer_slot(): # 1. Prioritize empty slots @@ -244,7 +350,12 @@ def get_available_buffer_slot(): buffer_id = get_available_buffer_slot() lora_adapter = lora_adapters.get(uid, None) self.load_lora_weight_to_buffer( - uid, buffer_id, lora_adapter, lora_modules + uid, + buffer_id, + lora_adapter, + lora_modules, + lora_embed_tokens_module, + lora_lm_head_module, ) self.uid_to_buffer_id[uid] = buffer_id self.buffer_id_to_uid[buffer_id] = uid @@ -255,6 +366,8 @@ def load_lora_weight_to_buffer( buffer_id: int, lora_adapter: LoRAAdapter, lora_modules: List[Dict[str, BaseLayerWithLoRA]], + lora_embed_tokens_module: Dict[str, BaseLayerWithLoRA], + lora_lm_head_module: Dict[str, BaseLayerWithLoRA], ): def load_lora_weight_tensor( buffer_view: torch.Tensor, weight: Optional[torch.Tensor] @@ -273,6 +386,12 @@ def load_lora_weight_tensor( for i in range(self.num_layer): for k in self.A_buffer.keys(): self.A_buffer[k][i][buffer_id] = 0 + + for k in self.embedding_A_buffer.keys(): + self.embedding_A_buffer[k][buffer_id] = 0 + + for k in self.lm_head_A_buffer.keys(): + self.lm_head_A_buffer[k][buffer_id] = 0 return assert lora_adapter is not None @@ -321,9 +440,126 @@ def load_lora_weight_tensor( buffer_view = target_buffer[buffer_id, :, :lora_rank] load_lora_weight_tensor(buffer_view, weights) + if lora_adapter.embedding_layers: + + org_vocab_size = self.base_hf_config.vocab_size + lora_added_tokens_size = lora_adapter.config.lora_added_tokens_size + # Only when LoRA is applied to the embedding layer will it have the extra-token issue that needs to be resolved. + # Load embeddings weights for extra tokens to buffer + if lora_adapter.added_tokens_embeddings: + for name, weights in lora_adapter.added_tokens_embeddings.items(): + if "input_embeddings" in name: + buffer_view = self.new_embeddings_buffer["input_embeddings"][ + buffer_id, :lora_added_tokens_size + ] + load_lora_weight_tensor(buffer_view, weights) + + # load vocab_emb and lm_head + for name, weights in lora_adapter.embedding_layers.items(): + target_module = get_target_module_name(name, self.target_modules) + if ( + target_module == "embed_tokens" + and "embed_tokens" in name + and ("lora_embedding_A" in name or "lora_A" in name) + ): + buffer_view = self.embedding_A_buffer[target_module][ + buffer_id, + :lora_rank, + : (org_vocab_size + lora_added_tokens_size), + ] + load_lora_weight_tensor(buffer_view, weights) + elif ( + target_module == "embed_tokens" + and "embed_tokens" in name + and ("lora_embedding_B" in name or "lora_B" in name) + ): + lora_b_weights = weights + # [to-do] support TP + # if self.tp_size > 1: + # cur_module = lora_embeddings_modules[target_module] + # for module_name, module in cur_module: + # lora_b_weights = module.slice_lora_b_weights( + # lora_b_weights, self.tp_rank + # ) + + buffer_view = self.embedding_B_buffer[target_module][ + buffer_id, :, :lora_rank + ] + load_lora_weight_tensor(buffer_view, lora_b_weights) + + elif ( + target_module == "lm_head" + and "lm_head" in name + and ("lora_embedding_A" in name or "lora_A" in name) + ): + buffer_view = self.lm_head_A_buffer[target_module][ + # buffer_id, :, :lora_rank + buffer_id, + :lora_rank, + :, + ] + load_lora_weight_tensor(buffer_view, weights) + elif ( + target_module == "lm_head" + and "lm_head" in name + and ("lora_embedding_B" in name or "lora_B" in name) + ): + lora_b_weights = weights + # [to-do] support TP + # if self.tp_size > 1: + # cur_module = lora_embeddings_modules[target_module] + # for module_name, module in cur_module: + # lora_b_weights = module.slice_lora_b_weights( + # lora_b_weights, self.tp_rank + # ) + + buffer_view = self.lm_head_B_buffer[target_module][ + # buffer_id, :lora_rank, : org_vocab_size + extra_vocab_size + buffer_id, + : (org_vocab_size + self.lora_added_tokens_size), + :lora_rank, + ] + load_lora_weight_tensor(buffer_view, lora_b_weights) + + def get_embedding_tensor( + self, target_module: str, lora_type: LoRAType + ) -> Optional[torch.Tensor]: + """ + Get LoRA tensor for non-layer modules (embed_tokens, lm_head). + + Args: + target_module: Module name, either "embed_tokens" or "lm_head" + lora_type: Either LoRAType.LORA_A or LoRAType.LORA_B + + Returns: + The corresponding buffer tensor, or None if not available + """ + + if target_module == "added_tokens": + if ( + self.lora_added_tokens_size is not None + and self.lora_added_tokens_size > 0 + ): + return self.new_embeddings_buffer["input_embeddings"] + return None + elif target_module == "embed_tokens": + if lora_type == LoRAType.LORA_A: + return self.embedding_A_buffer[target_module] + return self.embedding_B_buffer[target_module] + elif target_module == "lm_head": + if lora_type == LoRAType.LORA_A: + return self.lm_head_A_buffer[target_module] + return self.lm_head_B_buffer[target_module] + + raise ValueError( + f"Invalid target_module '{target_module}'. " + f"Expected 'embed_tokens' or 'lm_head'." + ) + def get_tensor( self, target_module: str, layer_id: int, lora_type: LoRAType ) -> torch.Tensor: + if lora_type == LoRAType.LORA_A: return self.A_buffer[target_module][layer_id] diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py index 74a2e84a2c40..71eb1fea4837 100644 --- a/python/sglang/srt/lora/triton_ops/__init__.py +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -1,5 +1,6 @@ from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward +from .embedding_lora_a import embedding_lora_a_fwd from .gate_up_lora_b import gate_up_lora_b_fwd from .qkv_lora_b import qkv_lora_b_fwd from .sgemm_lora_a import sgemm_lora_a_fwd @@ -12,4 +13,5 @@ "sgemm_lora_b_fwd", "chunked_sgmv_lora_shrink_forward", "chunked_sgmv_lora_expand_forward", + "embedding_lora_a_fwd", ] diff --git a/python/sglang/srt/lora/triton_ops/embedding_lora_a.py b/python/sglang/srt/lora/triton_ops/embedding_lora_a.py new file mode 100644 index 000000000000..1e21be50fd79 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/embedding_lora_a.py @@ -0,0 +1,186 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.utils import LoRABatchInfo + + +@triton.jit +def _embedding_lora_a_kernel( + # Pointers to tensors + input_ids, + weights, + output, + extra_embeddings, + # Dimensions + vocab_size, + rank, + num_loras, + # Strides + w_stride_0, # stride for lora index + w_stride_1, # stride for rank + w_stride_2, # stride for vocab + output_stride_0, + output_stride_1, + extra_emb_stride_0, # stride for lora index + extra_emb_stride_1, # stride for token + extra_emb_stride_2, # stride for hidden dim (= rank for extra embeddings) + # Batch info + seg_lens, + seg_indptr, + weight_indices, + lora_ranks, + # Meta-parameters + BLOCK_RANK: tl.constexpr, + HAS_EXTRA_EMBEDDINGS: tl.constexpr, +): + """ + Embedding lookup for LoRA A weights with support for extra tokens. + + Each program handles one token across a block of rank dimensions. + Grid: (cdiv(max_len, 1), bs) - one program per token in each batch + """ + batch_id = tl.program_id(axis=1) + token_idx = tl.program_id(axis=0) + + w_index = tl.load(weight_indices + batch_id) + rank_val = tl.load(lora_ranks + w_index) + + # If rank is 0, skip + if rank_val == 0: + return + + seg_start = tl.load(seg_indptr + batch_id) + seg_len = tl.load(seg_lens + batch_id) + + # Check if this token is within the segment + if token_idx >= seg_len: + return + + # Load the token ID + token_id = tl.load(input_ids + seg_start + token_idx) + + # Process in chunks of BLOCK_RANK dimensions + num_blocks = tl.cdiv(rank_val, BLOCK_RANK) + + for block_id in range(num_blocks): + rank_offset = tl.arange(0, BLOCK_RANK) + block_id * BLOCK_RANK + rank_mask = rank_offset < rank_val + + # Check if this is an extra token + is_extra_token = token_id >= vocab_size + + if HAS_EXTRA_EMBEDDINGS and is_extra_token: + # Use extra embeddings + extra_token_id = token_id - vocab_size + extra_emb_ptr = ( + extra_embeddings + + w_index * extra_emb_stride_0 + + extra_token_id * extra_emb_stride_1 + + rank_offset * extra_emb_stride_2 + ) + emb_values = tl.load(extra_emb_ptr, mask=rank_mask, other=0.0) + else: + # Use regular LoRA A weights + # weights shape: (num_loras, rank, vocab_size) + # We need to load weights[w_index, rank_offset, token_id] + token_id_clamped = tl.minimum(token_id, vocab_size - 1) + weight_ptr = ( + weights + + w_index * w_stride_0 + + rank_offset * w_stride_1 + + token_id_clamped * w_stride_2 + ) + emb_values = tl.load(weight_ptr, mask=rank_mask, other=0.0) + + # Write to output + output_ptr = ( + output + + (seg_start + token_idx) * output_stride_0 + + rank_offset * output_stride_1 + ) + tl.store(output_ptr, emb_values, mask=rank_mask) + + +def embedding_lora_a_fwd( + input_ids: torch.Tensor, + weights: torch.Tensor, + batch_info: LoRABatchInfo, + vocab_size: int, + extra_embeddings: torch.Tensor = None, +) -> torch.Tensor: + """ + Forward pass for LoRA A embedding lookup. + + Args: + input_ids: (s,) token IDs + weights: (num_loras, rank, vocab_size) LoRA A embedding weights + batch_info: LoRABatchInfo containing batch information + vocab_size: base vocabulary size + extra_embeddings: (num_loras, num_extra_tokens, rank) extra token embeddings + + Returns: + output: (s, rank) embedded features + """ + assert input_ids.is_contiguous() + assert weights.is_contiguous() + assert len(input_ids.shape) == 1 + assert len(weights.shape) == 3 + + S = input_ids.shape[0] + num_loras = weights.shape[0] + rank = weights.shape[1] + vocab_size_weights = weights.shape[2] + + # Block size for rank dimension + BLOCK_RANK = 128 + + has_extra_embeddings = extra_embeddings is not None + + if has_extra_embeddings: + assert extra_embeddings.is_contiguous() + extra_emb_stride = ( + extra_embeddings.stride(0), + extra_embeddings.stride(1), + extra_embeddings.stride(2), + ) + else: + # Create dummy tensor to satisfy Triton + extra_embeddings = torch.empty( + (1, 1, 1), device=input_ids.device, dtype=weights.dtype + ) + extra_emb_stride = (1, 1, 1) + + # Grid: one program per token in each batch segment + grid = ( + batch_info.max_len, + batch_info.bs, + ) + + output = torch.zeros((S, rank), device=input_ids.device, dtype=weights.dtype) + + _embedding_lora_a_kernel[grid]( + input_ids, + weights, + output, + extra_embeddings, + vocab_size, + rank, + num_loras, + weights.stride(0), + weights.stride(1), + weights.stride(2), + output.stride(0), + output.stride(1), + extra_emb_stride[0], + extra_emb_stride[1], + extra_emb_stride[2], + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + batch_info.lora_ranks, + BLOCK_RANK, + has_extra_embeddings, + ) + + return output diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 48a450d9b468..b59c17aa522c 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -46,7 +46,11 @@ class LoRAType(Enum): def get_hidden_dim( - module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int + module_name: str, + config: AutoConfig, + base_model: torch.nn.Module, + layer_idx: int, + lora_added_vocab_size: int = 0, ) -> Tuple[int]: """ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. @@ -78,6 +82,14 @@ def get_hidden_dim( return config.hidden_size, config.intermediate_size * 2 elif module_name == "down_proj": return config.intermediate_size, config.hidden_size + elif module_name == "embed_tokens": + # For embedding: input is vocab_size (as embedding lookup), output is hidden_size + # if contain extra tokens will be added; otherwise is 0. + return config.vocab_size + lora_added_vocab_size, config.hidden_size + elif module_name == "lm_head": + # For lm_head: input is hidden_size, output is vocab_size + # if contain extra tokens will be added; otherwise is 0. + return config.hidden_size, config.vocab_size + lora_added_vocab_size else: raise NotImplementedError() @@ -95,6 +107,12 @@ def get_normalized_target_modules( "v_proj": "qkv_proj", "gate_proj": "gate_up_proj", "up_proj": "gate_up_proj", + "embed_tokens": "embed_tokens", + "vocab_emb": "embed_tokens", + "embeddings": "embed_tokens", + "word_embeddings": "embed_tokens", + "lm_head": "lm_head", + "output": "lm_head", } result = set() @@ -131,4 +149,5 @@ def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> s ) +EMBEDDING_NAMES = ["embed_tokens", "lm_head"] ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"] diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index fd9a61ef2502..3aa7140cb9f1 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -3322,6 +3322,8 @@ def is_gfx95_supported(): "down_proj", "qkv_proj", "gate_up_proj", + "embed_tokens", + "lm_head", ] LORA_TARGET_ALL_MODULES = "all" diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 6d31b5868a47..e9b152ae9614 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -436,6 +436,7 @@ def forward_generation_raw( ) else: model = base_model + if patch_model_do_sample_false: model.generation_config.do_sample = False outputs = model.generate( @@ -455,6 +456,7 @@ def forward_generation_raw( text = tokenizer.decode( outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True ) + # Check if the text is empty or only whitespace. if not text.strip(): raise ValueError( diff --git a/test/run_suite_nightly.py b/test/run_suite_nightly.py index 936c7f5d8a10..14c34b6fe750 100644 --- a/test/run_suite_nightly.py +++ b/test/run_suite_nightly.py @@ -14,6 +14,7 @@ TestFile("test_lora_eviction_policy.py", 200), TestFile("test_lora_openai_api.py", 30), TestFile("test_lora_openai_compatible.py", 150), + TestFile("test_lora_hf_sgl_logprob_diff.py", 300), TestFile("test_batch_invariant_ops.py", 10), TestFile("test_cpp_radix_cache.py", 60), TestFile("test_deepseek_v3_deterministic.py", 240), diff --git a/test/srt/lora/test_lora_eviction.py b/test/srt/lora/test_lora_eviction.py index fc1e00e3d969..78cdd8282fe0 100644 --- a/test/srt/lora/test_lora_eviction.py +++ b/test/srt/lora/test_lora_eviction.py @@ -97,7 +97,17 @@ def _run_test( max_loras_per_batch=1, enable_lora=True, max_lora_rank=256, - lora_target_modules=["all"], + # Need to list all lora modules, or "all" might include lora modules without assigning lora weights + # lora_target_modules=["all"], + lora_target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], ) as srt_runner: adapter_sequence = lora_paths if not reverse else lora_paths[::-1] diff --git a/test/srt/lora/test_lora_hf_sgl_logprob_diff.py b/test/srt/lora/test_lora_hf_sgl_logprob_diff.py new file mode 100644 index 000000000000..b0975fa5d666 --- /dev/null +++ b/test/srt/lora/test_lora_hf_sgl_logprob_diff.py @@ -0,0 +1,559 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Test to compare log probabilities between HuggingFace+LoRA and SGLang+LoRA. + +This test: +1. Runs SGLang with LoRA and collects log probabilities +2. Runs HuggingFace with LoRA and collects log probabilities +3. Compares the differences (max and mean) between the two implementations +4. Uses unittest framework for easy integration with test suites + +Usage: + python test_lora_hf_sgl_logprob_diff.py + or + python -m unittest test_lora_hf_sgl_logprob_diff +""" + +import multiprocessing as mp +import os +import sys +import unittest +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch + +# Add sglang to path if needed +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python")) + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.runners import HFRunner, SRTRunner + +register_cuda_ci(est_time=300, suite="nightly-1-gpu", nightly=True) + +from sglang.test.test_utils import ( + DEFAULT_PORT_FOR_SRT_TEST_RUNNER, + CustomTestCase, + is_in_ci, +) + +# Test configuration constants +LORA_BACKEND = "triton" +DISABLE_CUDA_GRAPH = False +LORA_TARGET_MODULES = None +LOGPROB_THRESHOLD = 1e-01 + +# Default test prompts +DEFAULT_TEST_PROMPTS = [ + "SGL is a", + "AI is a field of computer science focused on", + "Computer science is the study of", + "Write a short story.", + "What are the main components of a computer?", +] + +# Formatting constants +DIVIDER_WIDTH = 80 +SECTION_CHAR = "=" +SUBSECTION_CHAR = "-" + + +def print_section_header(title: str): + """Print a major section header.""" + print("\n" + SECTION_CHAR * DIVIDER_WIDTH) + print(title) + print(SECTION_CHAR * DIVIDER_WIDTH) + + +def print_subsection_header(title: str): + """Print a subsection header.""" + print(f"\n{SUBSECTION_CHAR * 40}") + print(f"{title}") + print(SUBSECTION_CHAR * 40) + + +def print_config_info(title: str, config: Dict[str, Any]): + """Print configuration information in a consistent format.""" + print_section_header(title) + for key, value in config.items(): + print(f" {key}: {value}") + + +def compare_logprobs_for_type( + sglang_logprobs: torch.Tensor, hf_logprobs: torch.Tensor, logprob_type: str +) -> Dict[str, Any]: + """ + Compare logprobs for a specific type (prefill or decode). + + Args: + sglang_logprobs: SGLang log probabilities + hf_logprobs: HuggingFace log probabilities + logprob_type: Type of logprobs ("prefill" or "decode") + + Returns: + Dictionary containing comparison statistics + """ + diff = torch.abs(sglang_logprobs - hf_logprobs) + max_diff = torch.max(diff).item() + mean_diff = torch.mean(diff).item() + shape = list(sglang_logprobs.shape) + matches_threshold = max_diff < LOGPROB_THRESHOLD + + return { + "max_diff": max_diff, + "mean_diff": mean_diff, + "shape": shape, + "matches_threshold": matches_threshold, + "type": logprob_type, + } + + +def print_logprob_comparison(comparison: Dict[str, Any]): + """Print logprob comparison results in a consistent format.""" + logprob_type = comparison["type"].capitalize() + print(f"\n{logprob_type} logprobs:") + print(f" Shape: {comparison['shape']}") + print(f" Max difference: {comparison['max_diff']:.6e}") + print(f" Mean difference: {comparison['mean_diff']:.6e}") + + status = "PASS" if comparison["matches_threshold"] else "FAIL" + print(f" Status: {status} (threshold: {LOGPROB_THRESHOLD:.0e})") + + +def compare_output_strings( + sglang_output: str, hf_output: str, max_display_len: int = 200 +) -> Dict[str, Any]: + """ + Compare output strings between SGLang and HuggingFace. + + Args: + sglang_output: SGLang generated text + hf_output: HuggingFace generated text + max_display_len: Maximum length for display + + Returns: + Dictionary containing comparison results + """ + outputs_match = sglang_output.strip() == hf_output.strip() + + # Truncate for display if needed + sglang_display = ( + sglang_output[:max_display_len] + if len(sglang_output) > max_display_len + else sglang_output + ) + hf_display = ( + hf_output[:max_display_len] if len(hf_output) > max_display_len else hf_output + ) + + return { + "match": outputs_match, + "sglang_output": sglang_output, + "hf_output": hf_output, + "sglang_display": sglang_display, + "hf_display": hf_display, + } + + +def print_output_comparison(comparison: Dict[str, Any]): + """Print output string comparison in a consistent format.""" + print(f"\nOutput strings:") + status = "MATCH" if comparison["match"] else "DIFFER" + print(f" Status: {status}") + print(f" SGLang: {comparison['sglang_display']}") + print(f" HuggingFace: {comparison['hf_display']}") + + +def prepare_lora_paths_per_prompt( + lora_paths: List[str], num_prompts: int +) -> List[Optional[str]]: + """ + Prepare LoRA paths for each prompt by cycling through available LoRAs. + + Args: + lora_paths: List of available LoRA adapter paths + num_prompts: Number of prompts to generate LoRA paths for + + Returns: + List of LoRA paths (one per prompt), or None values if no LoRAs + """ + if not lora_paths: + return [None] * num_prompts + + return [lora_paths[i % len(lora_paths)] for i in range(num_prompts)] + + +def run_sglang_with_lora( + model_path: str, + lora_paths: List[str], + prompts: List[str], + max_new_tokens: int, + torch_dtype: torch.dtype, + lora_backend: str, + port: int, + disable_cuda_graph: bool, + lora_target_modules: Optional[List[str]], + tp_size: int, +) -> Dict[str, Any]: + """Run SGLang with LoRA and return log probabilities.""" + config = { + "Model": model_path, + "LoRA paths": lora_paths, + "LoRA backend": lora_backend, + "Disable CUDA graph": disable_cuda_graph, + "Port": port, + "Number of prompts": len(prompts), + "Tensor parallel size": tp_size, + } + print_config_info("Running SGLang with LoRA", config) + + lora_paths_per_prompt = prepare_lora_paths_per_prompt(lora_paths, len(prompts)) + + with SRTRunner( + model_path, + torch_dtype=torch_dtype, + model_type="generation", + tp_size=tp_size, + lora_paths=lora_paths, + max_loras_per_batch=len(lora_paths) if lora_paths else 1, + lora_backend=lora_backend, + disable_cuda_graph=disable_cuda_graph, + disable_radix_cache=True, + port=port, + mem_fraction_static=0.88, + lora_target_modules=lora_target_modules, + ) as srt_runner: + srt_outputs = srt_runner.forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=lora_paths_per_prompt, + ) + + return { + "top_input_logprobs": srt_outputs.top_input_logprobs, + "top_output_logprobs": srt_outputs.top_output_logprobs, + "output_strs": srt_outputs.output_strs, + "lora_paths": lora_paths_per_prompt, + } + + +def run_hf_with_lora( + model_path: str, + lora_paths: List[str], + prompts: List[str], + max_new_tokens: int, + torch_dtype: torch.dtype, +) -> Dict[str, Any]: + """Run HuggingFace with LoRA and return log probabilities.""" + config = { + "Model": model_path, + "LoRA paths": lora_paths, + "Number of prompts": len(prompts), + } + print_config_info("Running HuggingFace with LoRA", config) + + lora_paths_per_prompt = prepare_lora_paths_per_prompt(lora_paths, len(prompts)) + + with HFRunner( + model_path, + torch_dtype=torch_dtype, + model_type="generation", + patch_model_do_sample_false=True, + ) as hf_runner: + hf_outputs = hf_runner.forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=lora_paths_per_prompt, + ) + + return { + "top_input_logprobs": hf_outputs.top_input_logprobs, + "top_output_logprobs": hf_outputs.top_output_logprobs, + "output_strs": hf_outputs.output_strs, + "lora_paths": lora_paths_per_prompt, + } + + +def compare_single_prompt( + prompt_idx: int, + sglang_data: Dict[str, Any], + hf_data: Dict[str, Any], +) -> Dict[str, Any]: + """ + Compare logprobs and outputs for a single prompt. + + Args: + prompt_idx: Index of the prompt being compared + sglang_data: SGLang results data + hf_data: HuggingFace results data + + Returns: + Dictionary containing all comparison results + """ + print_subsection_header(f"Prompt {prompt_idx + 1}") + print(f"LoRA adapter: {sglang_data['lora_paths'][prompt_idx]}") + + result = { + "prompt_idx": prompt_idx, + "lora_path": sglang_data["lora_paths"][prompt_idx], + } + + # Compare prefill (input) logprobs + sglang_prefill = torch.tensor(sglang_data["top_input_logprobs"][prompt_idx]) + hf_prefill = torch.tensor(hf_data["top_input_logprobs"][prompt_idx]) + prefill_comparison = compare_logprobs_for_type( + sglang_prefill, hf_prefill, "prefill" + ) + print_logprob_comparison(prefill_comparison) + + # Store prefill results + result["prefill_max_diff"] = prefill_comparison["max_diff"] + result["prefill_mean_diff"] = prefill_comparison["mean_diff"] + result["prefill_shape"] = prefill_comparison["shape"] + result["prefill_logprob_match"] = prefill_comparison["matches_threshold"] + + # Compare decode (output) logprobs + sglang_decode = torch.tensor(sglang_data["top_output_logprobs"][prompt_idx]) + hf_decode = torch.tensor(hf_data["top_output_logprobs"][prompt_idx]) + decode_comparison = compare_logprobs_for_type(sglang_decode, hf_decode, "decode") + print_logprob_comparison(decode_comparison) + + # Store decode results + result["decode_max_diff"] = decode_comparison["max_diff"] + result["decode_mean_diff"] = decode_comparison["mean_diff"] + result["decode_shape"] = decode_comparison["shape"] + result["decode_logprob_match"] = decode_comparison["matches_threshold"] + + # Overall logprob match + result["overall_logprob_match"] = ( + prefill_comparison["matches_threshold"] + and decode_comparison["matches_threshold"] + ) + + # Compare output strings + sglang_output = sglang_data["output_strs"][prompt_idx] + hf_output = hf_data["output_strs"][prompt_idx] + output_comparison = compare_output_strings(sglang_output, hf_output) + print_output_comparison(output_comparison) + + # Store output results + result["outputs_match"] = output_comparison["match"] + result["sglang_output"] = output_comparison["sglang_output"] + result["hf_output"] = output_comparison["hf_output"] + + return result + + +def print_overall_statistics(results: List[Dict[str, Any]]): + """Print overall statistics across all prompts.""" + print_section_header("Overall Statistics") + + # Gather statistics + prefill_max_diffs = [r["prefill_max_diff"] for r in results] + prefill_mean_diffs = [r["prefill_mean_diff"] for r in results] + decode_max_diffs = [r["decode_max_diff"] for r in results] + decode_mean_diffs = [r["decode_mean_diff"] for r in results] + + # Print logprob statistics + print("\nLogprob Differences:") + print(f" Prefill:") + print(f" Max of max: {max(prefill_max_diffs):.6e}") + print(f" Mean of max: {np.mean(prefill_max_diffs):.6e}") + print(f" Mean of mean: {np.mean(prefill_mean_diffs):.6e}") + + print(f" Decode:") + print(f" Max of max: {max(decode_max_diffs):.6e}") + print(f" Mean of max: {np.mean(decode_max_diffs):.6e}") + print(f" Mean of mean: {np.mean(decode_mean_diffs):.6e}") + + # Print match statistics + num_prompts = len(results) + logprob_match_count = sum(r["overall_logprob_match"] for r in results) + prefill_match_count = sum(r["prefill_logprob_match"] for r in results) + decode_match_count = sum(r["decode_logprob_match"] for r in results) + outputs_match_count = sum(r["outputs_match"] for r in results) + + print(f"\nLogprob Statistics (threshold: {LOGPROB_THRESHOLD:.0e}):") + overall_status = "PASSED" if logprob_match_count == num_prompts else "FAILED" + print(f" Overall logprob: {logprob_match_count}/{num_prompts} {overall_status}") + print(f" Prefill logprob: {prefill_match_count}/{num_prompts}") + print(f" Decode logprob: {decode_match_count}/{num_prompts}") + + print(f"\nString Statistics:") + print(f" Output strings: {outputs_match_count}/{num_prompts}") + + # Return overall stats for saving + return { + "logprob_differences": { + "prefill": { + "max_of_max_diffs": max(prefill_max_diffs), + "mean_of_max_diffs": float(np.mean(prefill_max_diffs)), + "mean_of_mean_diffs": float(np.mean(prefill_mean_diffs)), + }, + "decode": { + "max_of_max_diffs": max(decode_max_diffs), + "mean_of_max_diffs": float(np.mean(decode_max_diffs)), + "mean_of_mean_diffs": float(np.mean(decode_mean_diffs)), + }, + }, + "match_statistics": { + "overall_logprob_match_rate": logprob_match_count / num_prompts, + "prefill_logprob_match_rate": prefill_match_count / num_prompts, + "decode_logprob_match_rate": decode_match_count / num_prompts, + "outputs_match_rate": outputs_match_count / num_prompts, + }, + } + + +def compare_logprobs( + sglang_logprobs: Dict[str, Any], hf_logprobs: Dict[str, Any] +) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + """Compare log probabilities and compute statistics.""" + print_section_header("Comparing Log Probabilities") + + results = [] + num_prompts = len(sglang_logprobs["top_input_logprobs"]) + + for i in range(num_prompts): + result = compare_single_prompt(i, sglang_logprobs, hf_logprobs) + results.append(result) + + overall_stats = print_overall_statistics(results) + + return results, overall_stats + + +class TestLoRAHFSGLLogprobDifference(CustomTestCase): + """ + Test case to compare log probabilities between HuggingFace+LoRA and SGLang+LoRA. + """ + + def _run_comparison_test( + self, + model_path: str, + lora_paths: List[str], + prompts: List[str], + max_new_tokens: int = 32, + torch_dtype: torch.dtype = torch.float16, + lora_backend: str = LORA_BACKEND, + port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, + disable_cuda_graph: bool = DISABLE_CUDA_GRAPH, + lora_target_modules: Optional[List[str]] = LORA_TARGET_MODULES, + tp_size: int = 1, + ): + """ + Run comparison test between SGLang and HuggingFace with LoRA. + """ + print_section_header(f"Testing {model_path} with LoRA adapters") + + # Step 1: Run SGLang with LoRA + sglang_logprobs = run_sglang_with_lora( + model_path=model_path, + lora_paths=lora_paths, + prompts=prompts, + max_new_tokens=max_new_tokens, + torch_dtype=torch_dtype, + lora_backend=lora_backend, + port=port, + disable_cuda_graph=disable_cuda_graph, + lora_target_modules=lora_target_modules, + tp_size=tp_size, + ) + + # Clear GPU memory + print("\nClearing GPU memory...") + torch.cuda.empty_cache() + + # Step 2: Run HuggingFace with LoRA + hf_logprobs = run_hf_with_lora( + model_path=model_path, + lora_paths=lora_paths, + prompts=prompts, + max_new_tokens=max_new_tokens, + torch_dtype=torch_dtype, + ) + + # Step 3: Compare log probabilities + results, overall_stats = compare_logprobs(sglang_logprobs, hf_logprobs) + + # Assert that all prompts pass the threshold + for result in results: + self.assertTrue( + result["prefill_logprob_match"], + f"Prefill logprob mismatch for prompt {result['prompt_idx']} " + f"(max_diff={result['prefill_max_diff']:.6e}, threshold={LOGPROB_THRESHOLD:.0e})", + ) + self.assertTrue( + result["decode_logprob_match"], + f"Decode logprob mismatch for prompt {result['prompt_idx']} " + f"(max_diff={result['decode_max_diff']:.6e}, threshold={LOGPROB_THRESHOLD:.0e})", + ) + + print_section_header("Test completed successfully!") + + return results, overall_stats + + def test_lora_logprob_comparison_basic(self): + """ + Basic test comparing HF and SGLang LoRA logprobs with small model. + """ + # Use a smaller model and shorter prompts for CI + if is_in_ci(): + self.skipTest("Skipping in CI environment - requires large models") + + model_path = "meta-llama/Llama-2-7b-hf" + lora_paths = ["yushengsu/sglang_lora_logprob_diff_without_tuning"] + prompts = DEFAULT_TEST_PROMPTS[:2] # Use fewer prompts for faster testing + + self._run_comparison_test( + model_path=model_path, + lora_paths=lora_paths, + prompts=prompts, + max_new_tokens=32, + ) + + def test_lora_logprob_comparison_full(self): + """ + Full test comparing HF and SGLang LoRA logprobs with all prompts. + """ + if is_in_ci(): + self.skipTest("Skipping in CI environment - requires large models") + + model_path = "meta-llama/Llama-2-7b-hf" + lora_paths = ["yushengsu/sglang_lora_logprob_diff_without_tuning"] + prompts = DEFAULT_TEST_PROMPTS + + self._run_comparison_test( + model_path=model_path, + lora_paths=lora_paths, + prompts=prompts, + max_new_tokens=32, + ) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + try: + unittest.main(warnings="ignore", verbosity=2) + finally: + # Final cleanup + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() diff --git a/test/srt/lora/test_lora_update.py b/test/srt/lora/test_lora_update.py index 3f11bdd48d7d..9c3f0855033b 100644 --- a/test/srt/lora/test_lora_update.py +++ b/test/srt/lora/test_lora_update.py @@ -218,7 +218,17 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: base="meta-llama/Llama-3.1-8B-Instruct", enable_lora=True, max_lora_rank=256, - lora_target_modules=["all"], + # Need to list all lora modules, or "all" might include lora modules without assigning lora weights + # lora_target_modules=["all"], + lora_target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], max_loras_per_batch=4, all_adapters=[ "philschmid/code-llama-3-1-8b-text-to-sql-lora", @@ -337,6 +347,8 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: description="Test explicitly specified lora-target-modules.", base="meta-llama/Llama-3.1-8B-Instruct", max_loras_per_batch=3, + # Need to list all lora modules, or "all" might include lora modules without assigning lora weights + # lora_target_modules=["all"], lora_target_modules=[ "q_proj", "k_proj", @@ -751,7 +763,17 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: ], enable_lora=True, max_lora_rank=256, - lora_target_modules=["all"], + # Need to list all lora modules, or "all" might include lora modules without assigning lora weights + # lora_target_modules=["all"], + lora_target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], op_sequence=[ Operation( type=OperationType.LOAD, @@ -1503,7 +1525,17 @@ def test_v1_models_endpoint_with_lora(self): lora_paths=[], max_loras_per_batch=2, max_lora_rank=256, - lora_target_modules=["all"], + # Need to list all lora modules, or "all" might include lora modules without assigning lora weights + # lora_target_modules=["all"], + lora_target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], enable_lora=True, ) as session: # Test with no adapters loaded diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 278135c7c4b3..591e3ca604a0 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -201,6 +201,7 @@ # Nightly test suites have been moved to test/run_suite_nightly.py "__not_in_ci__": [ TestFile("test_release_memory_occupation.py", 200), # Temporarily disabled + TestFile("lora/test_lora_hf_sgl_logprob_diff.py"), # Nightly test TestFile("models/test_dummy_grok_models.py"), TestFile( "rl/test_update_weights_from_disk.py"