diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 8ccb674f9195..737c82d49edd 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -128,8 +128,8 @@ def _process_weight(self, name: str, loaded_weight: torch.Tensor): # added/extra token emb self.added_tokens_embeddings[name] = loaded_weight.cpu() assert loaded_weight.shape[0] == self.config.lora_added_tokens_size, ( - f"LoRA adapter {self.uid} has extra_vocab_size {self.config.extra_vocab_size} specified in the config, " - f"but the loaded weight has {loaded_weight.shape[0]} extra vocab size" + f"LoRA adapter {self.uid} has lora_added_tokens_size {self.config.lora_added_tokens_size} specified in the config, " + f"but the loaded weight '{name}' has shape {loaded_weight.shape[0]} in first dimension" ) def _normalize_weights(self): diff --git a/python/sglang/srt/lora/lora_config.py b/python/sglang/srt/lora/lora_config.py index 939a9331111b..917feef155cb 100644 --- a/python/sglang/srt/lora/lora_config.py +++ b/python/sglang/srt/lora/lora_config.py @@ -13,11 +13,14 @@ # ============================================================================== import json +import logging import os from typing import Dict, Optional from huggingface_hub import snapshot_download +logger = logging.getLogger(__name__) + class LoRAConfig: def __init__( @@ -25,6 +28,7 @@ def __init__( path: Optional[str] = None, config_dict: Optional[Dict] = None, added_tokens_config: Optional[Dict] = None, + base_vocab_size: Optional[int] = None, ) -> None: self.path = path @@ -38,17 +42,41 @@ def __init__( self.target_modules = self.hf_config["target_modules"] self.r = self.hf_config["r"] self.lora_alpha = self.hf_config["lora_alpha"] + + # Filter fake added tokens: tokens with ID < base_vocab_size are already + # part of the base vocabulary and should not be treated as added tokens. + # This commonly happens when added_tokens.json is copied from the base + # model's tokenizer. + if self.added_tokens_config and base_vocab_size is not None: + self.added_tokens_config = { + token: token_id + for token, token_id in self.added_tokens_config.items() + if token_id >= base_vocab_size + } + self.lora_added_tokens_size = ( len(self.added_tokens_config) if self.added_tokens_config is not None else 0 ) + if self.lora_added_tokens_size > 0: + raise ValueError( + f"LoRA adapter has {self.lora_added_tokens_size} added tokens, " + f"but added tokens are not supported yet. " + f"Added tokens: {self.added_tokens_config}" + ) + @classmethod def from_dict( cls, config_dict: Dict, added_tokens_config: Optional[Dict] = None, + base_vocab_size: Optional[int] = None, ) -> "LoRAConfig": - return cls(config_dict=config_dict, added_tokens_config=added_tokens_config) + return cls( + config_dict=config_dict, + added_tokens_config=added_tokens_config, + base_vocab_size=base_vocab_size, + ) def get_lora_config(self, dummy=False): if dummy: @@ -82,9 +110,5 @@ def get_added_tokens_config(self): with open(added_tokens_path, "r") as f: return json.load(f) except json.JSONDecodeError as e: - # Log warning but don't crash if JSON is malformed - import logging - - logger = logging.getLogger(__name__) logger.warning(f"Failed to parse added_tokens.json: {e}") return None diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 7b222161f669..ead840827d5d 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -134,7 +134,10 @@ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: try: # load configs - new_adapter = LoRAConfig(lora_ref.lora_path) + new_adapter = LoRAConfig( + lora_ref.lora_path, + base_vocab_size=self.base_hf_config.vocab_size, + ) self.validate_new_adapter(new_adapter, lora_ref) self.configs[lora_ref.lora_id] = new_adapter @@ -560,7 +563,11 @@ def load_lora_adapter_from_tensors( ), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend." try: - new_adapter = LoRAConfig.from_dict(config_dict, added_tokens_config) + new_adapter = LoRAConfig.from_dict( + config_dict, + added_tokens_config, + base_vocab_size=self.base_hf_config.vocab_size, + ) self.validate_new_adapter(new_adapter, lora_ref) self.configs[lora_ref.lora_id] = new_adapter