diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index 06e4e8ba5a95..fcbc66c05515 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -144,6 +144,7 @@ def init_cuda_graph_batch_info( self, max_bs_in_cuda_graph: int, num_tokens_per_bs: int, + has_embedding_layers: bool = False, ): """Initialize the batch info for CUDA Graph mode. @@ -151,9 +152,9 @@ def init_cuda_graph_batch_info( logic for CUDA Graph mode. Args: - cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode num_tokens_per_bs: number of tokens per sequence (1 for decoding, >1 for target_verify) + has_embedding_layers: whether target_modules includes embedding layers (embed_tokens/lm_head) """ pass diff --git a/python/sglang/srt/lora/backend/chunked_backend.py b/python/sglang/srt/lora/backend/chunked_backend.py index f17f473cbdfd..7e1d18a08413 100644 --- a/python/sglang/srt/lora/backend/chunked_backend.py +++ b/python/sglang/srt/lora/backend/chunked_backend.py @@ -4,6 +4,7 @@ from sglang.srt.lora.triton_ops import ( chunked_sgmv_lora_expand_forward, chunked_sgmv_lora_shrink_forward, + embedding_lora_a_fwd, ) from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -32,6 +33,7 @@ def __init__( ): super().__init__(max_loras_per_batch, device) self.max_chunk_size = server_args.max_lora_chunk_size + self.has_embedding_layers = False # Will be set by manager def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs @@ -64,6 +66,28 @@ def run_lora_b_sgemm( base_output=base_output, ) + def run_lora_a_embedding( + self, + input_ids: torch.Tensor, + weights: torch.Tensor, + vocab_size: int, + extra_embeddings: torch.Tensor = None, + *args, + **kwargs, + ) -> torch.Tensor: + """Run LoRA A embedding lookup using Triton kernel. + + Uses embedding_batch_info which maintains original sequence structure + (not the chunked/reordered structure used for linear layers). + """ + return embedding_lora_a_fwd( + input_ids=input_ids, + weights=weights, + batch_info=self.embedding_batch_info, + vocab_size=vocab_size, + extra_embeddings=extra_embeddings, + ) + def run_qkv_lora( self, x: torch.Tensor, @@ -162,7 +186,9 @@ def init_cuda_graph_batch_info( self, max_bs_in_cuda_graph: int, num_tokens_per_bs: int, + has_embedding_layers: bool = False, ): + self.has_embedding_layers = has_embedding_layers max_num_segments = ( (num_tokens_per_bs + MIN_CHUNK_SIZE - 1) // MIN_CHUNK_SIZE ) * max_bs_in_cuda_graph @@ -181,6 +207,36 @@ def init_cuda_graph_batch_info( max_len=None, # Not used in CSGMV backend ) + # TODO: The embedding_batch_info will be removed after the chunked kernel + # for embedding has been implemented. This is currently a workaround to + # make embedding run with the non-chunked Triton kernel. + if has_embedding_layers: + # Create embedding-specific batch info (uses original sequence structure) + self.cuda_graph_embedding_batch_info = LoRABatchInfo( + bs=max_bs_in_cuda_graph, + use_cuda_graph=True, + num_segments=max_bs_in_cuda_graph, + seg_lens=torch.full( + (max_bs_in_cuda_graph,), num_tokens_per_bs, dtype=torch.int32 + ), + seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32), + max_len=num_tokens_per_bs, + weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32), + lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32), + scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), + permutation=None, + ) + # Initialize seg_indptr for embedding CUDA graph + torch.cumsum( + self.cuda_graph_embedding_batch_info.seg_lens[ + :max_bs_in_cuda_graph + ], + dim=0, + out=self.cuda_graph_embedding_batch_info.seg_indptr[ + 1 : max_bs_in_cuda_graph + 1 + ], + ) + def prepare_lora_batch( self, forward_batch: ForwardBatch, @@ -254,6 +310,69 @@ def prepare_lora_batch( self.batch_info = batch_info + # Setup embedding_batch_info (uses original sequence structure, not chunked) + # Only needed when target_modules includes embedding layers (embed_tokens/lm_head) + if self.has_embedding_layers: + bs = forward_batch.batch_size + if use_cuda_graph: + embedding_batch_info = self.cuda_graph_embedding_batch_info + embedding_batch_info.bs = bs + embedding_batch_info.num_segments = bs + else: + emb_max_len = ( + max(forward_batch.extend_seq_lens_cpu) + if forward_batch.forward_mode.is_extend() + else 1 + ) + emb_seg_lens = ( + forward_batch.extend_seq_lens + if forward_batch.forward_mode.is_extend() + else torch.ones(bs, dtype=torch.int32, device=self.device) + ) + emb_seg_indptr = torch.zeros( + (bs + 1,), dtype=torch.int32, device=self.device + ) + emb_seg_indptr[1:] = torch.cumsum(emb_seg_lens, dim=0) + + embedding_batch_info = LoRABatchInfo( + bs=bs, + num_segments=bs, + max_len=emb_max_len, + use_cuda_graph=False, + seg_lens=emb_seg_lens, + seg_indptr=emb_seg_indptr, + weight_indices=torch.empty( + (bs,), dtype=torch.int32, device=self.device + ), + lora_ranks=torch.empty( + (self.max_loras_per_batch,), + dtype=torch.int32, + device=self.device, + ), + scalings=torch.empty( + (self.max_loras_per_batch,), + dtype=torch.float, + device=self.device, + ), + permutation=None, + ) + + # Copy common data to embedding_batch_info (reuse already-created tensors) + weight_indices_for_embedding = torch.tensor( + weight_indices, dtype=torch.int32, pin_memory=True, device="cpu" + ) + embedding_batch_info.lora_ranks[: self.max_loras_per_batch].copy_( + lora_ranks_tensor, non_blocking=True + ) + embedding_batch_info.scalings[: self.max_loras_per_batch].copy_( + scalings_tensor, non_blocking=True + ) + embedding_batch_info.weight_indices[:bs].copy_( + weight_indices_for_embedding, non_blocking=True + ) + + self.embedding_batch_info = embedding_batch_info + @staticmethod def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch): """ diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 6bd05dee3db1..22e3b648b713 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -96,9 +96,16 @@ def init_cuda_graph_batch_info( self, max_bs_in_cuda_graph: int, num_tokens_per_bs: int ): self.max_bs_in_cuda_graph = max_bs_in_cuda_graph + + # Check if target_modules includes embedding layers + has_embedding_layers = ( + "embed_tokens" in self.target_modules or "lm_head" in self.target_modules + ) + self.lora_backend.init_cuda_graph_batch_info( max_bs_in_cuda_graph=max_bs_in_cuda_graph, num_tokens_per_bs=num_tokens_per_bs, + has_embedding_layers=has_embedding_layers, ) def create_lora_update_result( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index bc17db6f89e0..ae6731988521 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -4462,17 +4462,6 @@ def check_lora_server_args(self): ), "If 'all' is specified in --lora-target-modules, it should be the only module specified." self.lora_target_modules = set(SUPPORTED_LORA_TARGET_MODULES) - # When using the chunked SGMV backend, skip embedding / lm_head layers for now, - # since it does not support these yet (TODO: implement embedding / lm_head support) - if self.lora_backend == "csgmv": - logger.warning( - "LoRA backend 'csgmv' does not yet support embedding or lm_head layers; " - "dropping 'embed_tokens' and 'lm_head' from --lora-target-modules=all. " - "To apply LoRA to these, use --lora-backend triton." - ) - self.lora_target_modules.discard("embed_tokens") - self.lora_target_modules.discard("lm_head") - # Ensure sufficient information is provided for LoRA initialization. assert self.lora_paths or ( self.max_lora_rank and self.lora_target_modules diff --git a/scripts/playground/lora/train_embedding_lora_adapter.py b/scripts/playground/lora/train_embedding_lora_adapter.py new file mode 100644 index 000000000000..27f19868c570 --- /dev/null +++ b/scripts/playground/lora/train_embedding_lora_adapter.py @@ -0,0 +1,221 @@ +import argparse +import os + +import torch +from datasets import load_dataset +from peft import LoraConfig, get_peft_model +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import SFTConfig, SFTTrainer + +TARGET_MODULES = [ + "embed_tokens", + "lm_head", + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] + + +def format_alpaca_prompt(example): + """Format alpaca dataset example into a prompt.""" + if example.get("input") and example["input"].strip(): + return f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:\n{example['output']}" + return f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['output']}" + + +def train_embedding_lora_adapter( + base_model: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + output_dir: str = "./sglang_embedding_lora_test_adapter", + num_train_steps: int = 500, + rank: int = 8, + lora_alpha: int = 16, +): + print(f"Training embedding LoRA adapter") + print(f" Base model: {base_model}") + print(f" Output dir: {output_dir}") + print(f" Training steps: {num_train_steps}") + print(f" Rank: {rank}, Alpha: {lora_alpha}") + print(f" Target modules: {TARGET_MODULES}") + print() + + print("Loading base model...") + model = AutoModelForCausalLM.from_pretrained( + base_model, + torch_dtype=torch.float16, + device_map="auto", + ) + tokenizer = AutoTokenizer.from_pretrained(base_model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + + # embed_tokens must be in target_modules (not modules_to_save) for SGLang's + # run_lora_a_embedding() to work correctly + lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=TARGET_MODULES, + lora_dropout=0.0, + bias="none", + task_type="CAUSAL_LM", + ) + + print("Applying LoRA...") + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + + print("\nLoading alpaca dataset...") + dataset = load_dataset("tatsu-lab/alpaca", split="train[:2000]") + dataset = dataset.map( + lambda x: {"text": format_alpaca_prompt(x)}, + remove_columns=dataset.column_names, + ) + + print(f"Dataset size: {len(dataset)}") + print(f"Example:\n{dataset[0]['text'][:200]}...") + + sft_config = SFTConfig( + output_dir=os.path.join(output_dir, "checkpoints"), + max_steps=num_train_steps, + per_device_train_batch_size=4, + learning_rate=2e-4, + lr_scheduler_type="cosine", + warmup_steps=50, + logging_steps=25, + save_steps=num_train_steps, + save_total_limit=1, + fp16=True, + report_to="none", + max_length=512, + dataset_text_field="text", + ) + + print("\nStarting training...") + trainer = SFTTrainer( + model=model, + train_dataset=dataset, + args=sft_config, + ) + trainer.train() + + print(f"\nSaving adapter to {output_dir}...") + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + readme = generate_readme(base_model, rank, lora_alpha, num_train_steps, model) + with open(os.path.join(output_dir, "README.md"), "w") as f: + f.write(readme) + + print(f"\nDone! Adapter saved to: {output_dir}") + print("\nTo upload to HuggingFace:") + print( + f" huggingface-cli upload YOUR_USERNAME/sglang_embedding_lora_test_adapter {output_dir}" + ) + + print("\nTesting generation...") + test_generation(model, tokenizer) + + +def test_generation(model, tokenizer): + """Test that the model produces coherent outputs.""" + prompts = [ + "### Instruction:\nWhat is the capital of France?\n\n### Response:\n", + "### Instruction:\nWrite a short greeting.\n\n### Response:\n", + ] + + model.eval() + for prompt in prompts: + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=50, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + ) + response = tokenizer.decode(outputs[0], skip_special_tokens=True) + print(f"\nPrompt: {prompt[:50]}...") + print(f"Response: {response[len(prompt):][:100]}...") + + +def generate_readme(base_model, rank, lora_alpha, num_train_steps, peft_model): + """Generate README for the adapter.""" + weight_shapes = [] + for name, param in peft_model.named_parameters(): + if "lora_" in name and param.requires_grad: + parts = name.split(".") + for i, part in enumerate(parts): + if part in TARGET_MODULES: + layer_name = part + lora_type = parts[i + 1] + weight_shapes.append( + f"{layer_name}.{lora_type}: {tuple(param.shape)}" + ) + break + + weight_shapes = sorted(set(weight_shapes)) + + return f"""# Trained LoRA Adapter for SGLang Embedding LoRA Testing + +This is a fine-tuned LoRA adapter for testing SGLang's embedding LoRA implementation. + +## Configuration + +- **Base model:** `{base_model}` +- **LoRA rank (r):** {rank} +- **LoRA alpha:** {lora_alpha} +- **Target modules:** {", ".join(TARGET_MODULES)} +- **Training steps:** {num_train_steps} +- **Training data:** alpaca dataset + +## Weight Shapes + +``` +{chr(10).join(weight_shapes)} +``` + +## Purpose + +This adapter tests that SGLang's `ChunkedSgmvLoRABackend.run_lora_a_embedding()` correctly +handles embedding LoRA layers (`embed_tokens`, `lm_head`). + +**Key:** `embed_tokens` is in `target_modules` (LoRA decomposition), NOT `modules_to_save` (full weights). + +## Usage with SGLang + +```python +# Used by: test/srt/lora/test_lora_hf_sgl_logprob_diff.py +``` + +## Created with + +```bash +python scripts/playground/lora/train_embedding_lora_adapter.py --num_train_steps {num_train_steps} +``` +""" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--base_model", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0" + ) + parser.add_argument( + "--output_dir", type=str, default="./sglang_embedding_lora_test_adapter" + ) + parser.add_argument("--num_train_steps", type=int, default=500) + parser.add_argument("--rank", type=int, default=8) + parser.add_argument("--lora_alpha", type=int, default=16) + + args = parser.parse_args() + train_embedding_lora_adapter( + base_model=args.base_model, + output_dir=args.output_dir, + num_train_steps=args.num_train_steps, + rank=args.rank, + lora_alpha=args.lora_alpha, + ) diff --git a/test/nightly/test_chunked_lora_embedding.py b/test/nightly/test_chunked_lora_embedding.py new file mode 100644 index 000000000000..57d6aef2b5d5 --- /dev/null +++ b/test/nightly/test_chunked_lora_embedding.py @@ -0,0 +1,96 @@ +"""Test ChunkedSgmvLoRABackend.run_lora_a_embedding() method.""" + +import unittest +from unittest.mock import MagicMock + +import torch + +from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend +from sglang.srt.lora.triton_ops import embedding_lora_a_fwd +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.test.test_utils import CustomTestCase + + +class TestChunkedLoRAEmbedding(CustomTestCase): + """Test embedding LoRA for ChunkedSgmvLoRABackend (requires CUDA).""" + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_run_lora_a_embedding(self): + """Test that run_lora_a_embedding uses embedding_batch_info correctly.""" + device = torch.device("cuda") + vocab_size = 1024 + rank = 16 + num_loras = 2 + batch_size = 2 + seq_len = 4 + + # Create mock server_args + server_args = MagicMock() + server_args.max_lora_chunk_size = 128 + + # Create backend + backend = ChunkedSgmvLoRABackend( + max_loras_per_batch=num_loras, + device=device, + server_args=server_args, + ) + + # Create mock forward_batch + forward_batch = ForwardBatch( + forward_mode=ForwardMode.EXTEND, + batch_size=batch_size, + input_ids=torch.randint( + 0, vocab_size, (batch_size, 3), dtype=torch.int32, device=device + ), + req_pool_indices=None, + seq_lens=None, + out_cache_loc=None, + seq_lens_sum=seq_len, + extend_seq_lens=torch.tensor([2, 2], dtype=torch.int32, device=device), + extend_seq_lens_cpu=[2, 2], + ) + + # Prepare batch (this sets up embedding_batch_info) + backend.prepare_lora_batch( + forward_batch=forward_batch, + weight_indices=[0, 1], + lora_ranks=[rank, rank], + scalings=[1.0, 0.5], + use_cuda_graph=False, + ) + + # Verify embedding_batch_info was created + self.assertIsNotNone(backend.embedding_batch_info) + self.assertEqual(backend.embedding_batch_info.num_segments, batch_size) + self.assertIsNone(backend.embedding_batch_info.permutation) # Original order + + # Create input and weights + input_ids = torch.randint( + 0, vocab_size, (seq_len,), dtype=torch.int32, device=device + ) + weights = torch.randn( + num_loras, rank, vocab_size, dtype=torch.float16, device=device + ) + + # Run the method + output = backend.run_lora_a_embedding( + input_ids=input_ids, + weights=weights, + vocab_size=vocab_size, + ) + + # Verify output shape + self.assertEqual(output.shape, (seq_len, rank)) + + # Verify it matches direct call to kernel with embedding_batch_info + expected = embedding_lora_a_fwd( + input_ids=input_ids, + weights=weights, + batch_info=backend.embedding_batch_info, + vocab_size=vocab_size, + ) + self.assertTrue(torch.allclose(output, expected)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/lora/test_lora_eviction.py b/test/srt/lora/test_lora_eviction.py index 78cdd8282fe0..63987f1ca2bb 100644 --- a/test/srt/lora/test_lora_eviction.py +++ b/test/srt/lora/test_lora_eviction.py @@ -39,6 +39,12 @@ BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct" +# Embedding LoRA test configuration (TinyLlama with lora_target_modules=["all"]) +EMBEDDING_LORA_BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +EMBEDDING_LORA_ADAPTER = ( + "ash256/sglang_embedding_lora_test_adapter" # includes embed_tokens and lm_head +) + @contextlib.contextmanager def dynamically_loaded_adapter(runner, lora_path: str, lora_name: str): @@ -73,6 +79,31 @@ def test_lora_eviction_with_reused_lora_name(self): self._run_test(ADAPTERS, output_history, reuse_lora_name=True, repeat=1) self._run_test(ADAPTERS, output_history, reuse_lora_name=False, repeat=1) + def test_lora_eviction_with_embedding_lora_all_target_modules(self): + """ + Test LoRA eviction with lora_target_modules=["all"] using an embedding LoRA adapter. + + This test verifies that the csgmv backend properly handles eviction when using + lora_target_modules=["all"] which includes embed_tokens and lm_head layers. + Uses TinyLlama base model with the ash256/sglang_embedding_lora_test_adapter. + """ + output_history = {} + self._run_test( + [EMBEDDING_LORA_ADAPTER], + output_history, + base_model=EMBEDDING_LORA_BASE_MODEL, + lora_target_modules=["all"], + max_lora_rank=16, + ) + self._run_test( + [EMBEDDING_LORA_ADAPTER], + output_history, + reverse=True, + base_model=EMBEDDING_LORA_BASE_MODEL, + lora_target_modules=["all"], + max_lora_rank=16, + ) + def _run_test( self, lora_paths: List[str], @@ -80,34 +111,36 @@ def _run_test( reverse: bool = False, repeat: int = 2, reuse_lora_name: bool = False, + base_model: str = BASE_MODEL, + lora_target_modules: List[str] = None, + max_lora_rank: int = 256, ): REUSED_LORA_NAME = "lora" max_new_tokens = 256 torch_dtype = torch.float16 - base_path = BASE_MODEL - assert len(lora_paths) >= 2 + + if lora_target_modules is None: + lora_target_modules = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] initial_lora_paths = lora_paths if not reuse_lora_name else None # Initialize runners with SRTRunner( - base_path, + base_model, torch_dtype=torch_dtype, model_type="generation", lora_paths=initial_lora_paths, max_loras_per_batch=1, enable_lora=True, - max_lora_rank=256, - # Need to list all lora modules, or "all" might include lora modules without assigning lora weights - # lora_target_modules=["all"], - lora_target_modules=[ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], + max_lora_rank=max_lora_rank, + lora_target_modules=lora_target_modules, ) as srt_runner: adapter_sequence = lora_paths if not reverse else lora_paths[::-1] diff --git a/test/srt/lora/test_lora_hf_sgl_logprob_diff.py b/test/srt/lora/test_lora_hf_sgl_logprob_diff.py index b0975fa5d666..782f839ab3ee 100644 --- a/test/srt/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/srt/lora/test_lora_hf_sgl_logprob_diff.py @@ -543,6 +543,63 @@ def test_lora_logprob_comparison_full(self): max_new_tokens=32, ) + def test_lora_embedding_logprob_comparison(self): + """ + Test embedding LoRA (embed_tokens/lm_head) with ChunkedSgmvLoRABackend. + + Adapter must have embed_tokens in target_modules (not modules_to_save). + """ + if is_in_ci(): + self.skipTest("Skipping in CI environment - requires large models") + + model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + lora_paths = ["ash256/sglang_embedding_lora_test_adapter"] + prompts = DEFAULT_TEST_PROMPTS + + # Embedding LoRA has higher numerical variance due to vocab-sized operations, + # but outputs should still match + global LOGPROB_THRESHOLD + original_threshold = LOGPROB_THRESHOLD + LOGPROB_THRESHOLD = 1.0 # Relaxed for embedding LoRA + try: + self._run_comparison_test( + model_path=model_path, + lora_paths=lora_paths, + prompts=prompts, + max_new_tokens=32, + lora_backend="csgmv", + ) + finally: + LOGPROB_THRESHOLD = original_threshold + + def test_lora_embedding_logprob_comparison_triton(self): + """ + Test embedding LoRA (embed_tokens/lm_head) with TritonLoRABackend. + + This serves as baseline to validate chunked backend against. + Adapter must have embed_tokens in target_modules (not modules_to_save). + """ + if is_in_ci(): + self.skipTest("Skipping in CI environment - requires large models") + + model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + lora_paths = ["ash256/sglang_embedding_lora_test_adapter"] + prompts = DEFAULT_TEST_PROMPTS + + global LOGPROB_THRESHOLD + original_threshold = LOGPROB_THRESHOLD + LOGPROB_THRESHOLD = 1.0 # Relaxed for embedding LoRA + try: + self._run_comparison_test( + model_path=model_path, + lora_paths=lora_paths, + prompts=prompts, + max_new_tokens=32, + lora_backend="triton", + ) + finally: + LOGPROB_THRESHOLD = original_threshold + if __name__ == "__main__": try: diff --git a/test/srt/lora/test_lora_update.py b/test/srt/lora/test_lora_update.py index 9c3f0855033b..8c7bc23dc52a 100644 --- a/test/srt/lora/test_lora_update.py +++ b/test/srt/lora/test_lora_update.py @@ -43,6 +43,12 @@ MEM_FRACTION_STATIC = 0.8 +# Embedding LoRA test configuration (TinyLlama with lora_target_modules=["all"]) +EMBEDDING_LORA_BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +EMBEDDING_LORA_ADAPTER = ( + "ash256/sglang_embedding_lora_test_adapter" # includes embed_tokens and lm_head +) + class OperationType(Enum): LOAD = "load" @@ -218,17 +224,7 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: base="meta-llama/Llama-3.1-8B-Instruct", enable_lora=True, max_lora_rank=256, - # Need to list all lora modules, or "all" might include lora modules without assigning lora weights - # lora_target_modules=["all"], - lora_target_modules=[ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], + lora_target_modules=["all"], max_loras_per_batch=4, all_adapters=[ "philschmid/code-llama-3-1-8b-text-to-sql-lora", @@ -763,17 +759,7 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: ], enable_lora=True, max_lora_rank=256, - # Need to list all lora modules, or "all" might include lora modules without assigning lora weights - # lora_target_modules=["all"], - lora_target_modules=[ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], + lora_target_modules=["all"], op_sequence=[ Operation( type=OperationType.LOAD, @@ -884,12 +870,84 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: ), ] +# Embedding LoRA tests with lora_target_modules=["all"] (tests csgmv backend with embed_tokens/lm_head) +EMBEDDING_LORA_TESTS = [ + TestCase( + description="dynamic lora update with embedding LoRA and lora_target_modules=all", + base=EMBEDDING_LORA_BASE_MODEL, + max_loras_per_batch=1, + all_adapters=[EMBEDDING_LORA_ADAPTER], + initial_adapters=[EMBEDDING_LORA_ADAPTER], + enable_lora=True, + max_lora_rank=16, + lora_target_modules=["all"], # This includes embed_tokens and lm_head + op_sequence=[ + Operation( + type=OperationType.FORWARD, + data=create_batch_data(EMBEDDING_LORA_ADAPTER), + ), + Operation( + type=OperationType.UNLOAD, + data=EMBEDDING_LORA_ADAPTER, + ), + Operation( + type=OperationType.LOAD, + data=EMBEDDING_LORA_ADAPTER, + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data(EMBEDDING_LORA_ADAPTER), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data(None), # Test base model inference + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data(EMBEDDING_LORA_ADAPTER), + ), + ], + ), + TestCase( + description="dynamic lora update without initial paths with embedding LoRA and lora_target_modules=all", + base=EMBEDDING_LORA_BASE_MODEL, + max_loras_per_batch=1, + all_adapters=[EMBEDDING_LORA_ADAPTER], + enable_lora=True, + max_lora_rank=16, + lora_target_modules=["all"], # This includes embed_tokens and lm_head + op_sequence=[ + Operation( + type=OperationType.LOAD, + data=EMBEDDING_LORA_ADAPTER, + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data(EMBEDDING_LORA_ADAPTER), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data(None), # Test base model inference + ), + Operation( + type=OperationType.UNLOAD, + data=EMBEDDING_LORA_ADAPTER, + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data(EMBEDDING_LORA_ADAPTER), # Implicit reload + ), + ], + ), +] + ALL_TESTS = ( BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS + MAX_LOADED_LORAS_TESTS + EVICTION_TESTS + + EMBEDDING_LORA_TESTS ) @@ -1525,17 +1583,7 @@ def test_v1_models_endpoint_with_lora(self): lora_paths=[], max_loras_per_batch=2, max_lora_rank=256, - # Need to list all lora modules, or "all" might include lora modules without assigning lora weights - # lora_target_modules=["all"], - lora_target_modules=[ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], + lora_target_modules=["all"], enable_lora=True, ) as session: # Test with no adapters loaded