diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index b5d38dcd08d0..39b2f1ac440c 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -71,7 +71,7 @@ def __init__( self.device: torch.device = next(self.base_model.parameters()).device self.tp_size: int = tp_size self.tp_rank: int = tp_rank - self.lora_added_tokens_size: Optional[int] = None + self.lora_added_tokens_size: int = 0 self.enable_lora_overlap_loading: Optional[bool] = ( server_args.enable_lora_overlap_loading ) @@ -133,6 +133,7 @@ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: try: # load configs new_adapter = LoRAConfig(lora_ref.lora_path) + self._maybe_update_added_tokens_size(new_adapter) self.validate_new_adapter(new_adapter, lora_ref) self.configs[lora_ref.lora_id] = new_adapter @@ -418,24 +419,6 @@ def init_lora_shapes( default=0, ) - # Auto-infer self.lora_added_vocab_size from loaded LoRA configs - # This happens automatically without requiring user input - # if self.lora_added_vocab_size is None: - if self.lora_added_tokens_size is None: - inferred_extra_vocab_size = next( - ( - x.lora_added_tokens_size - for x in self.configs.values() - if x.lora_added_tokens_size > 0 - ), - 0, - ) - if inferred_extra_vocab_size > 0: - logger.info( - f"self.lora_added_tokens_size={inferred_extra_vocab_size} from LoRA adapters." - ) - self.lora_added_tokens_size = inferred_extra_vocab_size - def load_lora_weights(self, lora_ref: LoRARef): """ Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation. @@ -490,6 +473,7 @@ def load_lora_adapter_from_tensors( try: new_adapter = LoRAConfig.from_dict(config_dict, added_tokens_config) + self._maybe_update_added_tokens_size(new_adapter) self.validate_new_adapter(new_adapter, lora_ref) self.configs[lora_ref.lora_id] = new_adapter @@ -523,6 +507,19 @@ def init_memory_pool(self): # Initializing memory pool with base model self.fetch_new_loras({None}) + def _maybe_update_added_tokens_size(self, new_config: LoRAConfig): + if new_config.lora_added_tokens_size == 0 or self.lora_added_tokens_size > 0: + return + + self.lora_added_tokens_size = new_config.lora_added_tokens_size + logger.info( + f"self.lora_added_tokens_size={self.lora_added_tokens_size} from LoRA adapters." + ) + + # Some LoRA adapters are loaded before the memory pool is initialized + if hasattr(self, "memory_pool"): + self.init_memory_pool() + def set_lora_module(self, module_name, module): lora_module = get_lora_layer(module, self.lora_backend) replace_submodule(self.base_model, module_name, lora_module) diff --git a/test/registered/lora/test_lora_update.py b/test/registered/lora/test_lora_update.py index a7ae1aa58650..78cdd1a26924 100644 --- a/test/registered/lora/test_lora_update.py +++ b/test/registered/lora/test_lora_update.py @@ -34,7 +34,7 @@ popen_launch_server, ) -register_cuda_ci(est_time=487, suite="stage-b-test-large-1-gpu") +register_cuda_ci(est_time=587, suite="stage-b-test-large-1-gpu") PROMPTS = [ "SGL is a", @@ -1482,6 +1482,105 @@ def test_dynamic_lora_update_server(self): mode=LoRAUpdateTestSessionMode.SERVER, test_cases=test_cases ) + def test_lora_added_tokens_size(self): + """ + Test that we correctly handle loading a new adapter that adds tokens to the vocabulary + """ + added_tokens_model = "Qwen/Qwen3-0.6B" + no_added_tokens_adapter = "ethicalabs/Flwr-Qwen3-0.6B-Medical-PEFT" + added_tokens_adapter = "YoussefHosni/Qwen3-0.6b-2B-Token-arabic-LoRA-finetuned" + + # Test loading adapter that adds tokens on startup + with LoRAUpdateTestSession( + testcase=self, + mode=LoRAUpdateTestSessionMode.SERVER, + model_path=added_tokens_model, + lora_paths=[ + { + "lora_name": "no_added_tokens_adapter", + "lora_path": no_added_tokens_adapter, + }, + { + "lora_name": "added_tokens_adapter", + "lora_path": added_tokens_adapter, + }, + ], + max_loras_per_batch=2, + max_lora_rank=32, + lora_target_modules=["all"], + enable_lora=True, + ) as session: + # Verify both adapters are correctly loaded + response = requests.get(DEFAULT_URL_FOR_TEST + "/v1/models") + self.assertTrue(response.ok, response.text) + + adapter_models = [m for m in response.json()["data"] if m.get("parent")] + self.assertEqual( + {m["id"] for m in adapter_models}, + {"no_added_tokens_adapter", "added_tokens_adapter"}, + ) + + # Make sure we can run forward with both adapters + result = session.forward( + prompts=[PROMPTS[0], PROMPTS[1]], + lora_paths=["no_added_tokens_adapter", "added_tokens_adapter"], + ) + print(f"Got output from combined forward pass: {result}") + + # Test dynamically loading adapter that adds tokens + with LoRAUpdateTestSession( + testcase=self, + mode=LoRAUpdateTestSessionMode.SERVER, + model_path=added_tokens_model, + lora_paths=[ + { + "lora_name": "no_added_tokens_adapter", + "lora_path": no_added_tokens_adapter, + } + ], + max_loras_per_batch=2, + max_lora_rank=32, + lora_target_modules=["all"], + enable_lora=True, + ) as session: + # Verify adapter is correctly loaded + response = requests.get(DEFAULT_URL_FOR_TEST + "/v1/models") + self.assertTrue(response.ok, response.text) + + adapter_models = [m for m in response.json()["data"] if m.get("parent")] + self.assertEqual( + {m["id"] for m in adapter_models}, {"no_added_tokens_adapter"} + ) + + # Run one forward request so that adapter weights are loaded into GPU memory + result = session.forward( + prompts=[PROMPTS[0]], + lora_paths=["no_added_tokens_adapter"], + ) + print(f"Got output from no_added_tokens_adapter: {result}") + + # Load adapter that adds tokens to vocab + session.load_lora_adapter( + lora_name="added_tokens_adapter", lora_path=added_tokens_adapter + ) + + # Verify both adapters are correctly loaded + response = requests.get(DEFAULT_URL_FOR_TEST + "/v1/models") + self.assertTrue(response.ok, response.text) + + adapter_models = [m for m in response.json()["data"] if m.get("parent")] + self.assertEqual( + {m["id"] for m in adapter_models}, + {"no_added_tokens_adapter", "added_tokens_adapter"}, + ) + + # Make sure we can run forward with both adapters + result = session.forward( + prompts=[PROMPTS[0], PROMPTS[1]], + lora_paths=["no_added_tokens_adapter", "added_tokens_adapter"], + ) + print(f"Got output from combined forward pass: {result}") + def test_v1_models_endpoint_with_lora(self): """ Test that /v1/models endpoint returns base model and loaded LoRA adapters.