-
Notifications
You must be signed in to change notification settings - Fork 5.4k
[BugFix] fix loading new adapter with added_tokens #17794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
65df2ee
f4b8662
5a0441d
e33fcce
226149a
27f6562
2c499e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,7 +71,7 @@ def __init__( | |
| self.device: torch.device = next(self.base_model.parameters()).device | ||
| self.tp_size: int = tp_size | ||
| self.tp_rank: int = tp_rank | ||
| self.lora_added_tokens_size: Optional[int] = None | ||
| self.lora_added_tokens_size: int = 0 | ||
| self.enable_lora_overlap_loading: Optional[bool] = ( | ||
| server_args.enable_lora_overlap_loading | ||
| ) | ||
|
|
@@ -133,6 +133,7 @@ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: | |
| try: | ||
| # load configs | ||
| new_adapter = LoRAConfig(lora_ref.lora_path) | ||
| self._maybe_update_added_tokens_size(new_adapter) | ||
| self.validate_new_adapter(new_adapter, lora_ref) | ||
| self.configs[lora_ref.lora_id] = new_adapter | ||
|
|
||
|
|
@@ -418,24 +419,6 @@ def init_lora_shapes( | |
| default=0, | ||
| ) | ||
|
|
||
| # Auto-infer self.lora_added_vocab_size from loaded LoRA configs | ||
| # This happens automatically without requiring user input | ||
| # if self.lora_added_vocab_size is None: | ||
| if self.lora_added_tokens_size is None: | ||
| inferred_extra_vocab_size = next( | ||
| ( | ||
| x.lora_added_tokens_size | ||
| for x in self.configs.values() | ||
| if x.lora_added_tokens_size > 0 | ||
| ), | ||
| 0, | ||
| ) | ||
| if inferred_extra_vocab_size > 0: | ||
| logger.info( | ||
| f"self.lora_added_tokens_size={inferred_extra_vocab_size} from LoRA adapters." | ||
| ) | ||
| self.lora_added_tokens_size = inferred_extra_vocab_size | ||
|
|
||
| def load_lora_weights(self, lora_ref: LoRARef): | ||
| """ | ||
| Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation. | ||
|
|
@@ -490,6 +473,7 @@ def load_lora_adapter_from_tensors( | |
|
|
||
| try: | ||
| new_adapter = LoRAConfig.from_dict(config_dict, added_tokens_config) | ||
| self._maybe_update_added_tokens_size(new_adapter) | ||
| self.validate_new_adapter(new_adapter, lora_ref) | ||
| self.configs[lora_ref.lora_id] = new_adapter | ||
|
|
||
|
|
@@ -523,6 +507,19 @@ def init_memory_pool(self): | |
| # Initializing memory pool with base model | ||
| self.fetch_new_loras({None}) | ||
|
|
||
| def _maybe_update_added_tokens_size(self, new_config: LoRAConfig): | ||
| if new_config.lora_added_tokens_size == 0 or self.lora_added_tokens_size > 0: | ||
| return | ||
|
|
||
| self.lora_added_tokens_size = new_config.lora_added_tokens_size | ||
| logger.info( | ||
| f"self.lora_added_tokens_size={self.lora_added_tokens_size} from LoRA adapters." | ||
| ) | ||
|
|
||
| # Some LoRA adapters are loaded before the memory pool is initialized | ||
| if hasattr(self, "memory_pool"): | ||
| self.init_memory_pool() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this operation delete the earlier adaptors in GPU memory? If so I feel it's risky
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think theoretically that should be fine because adapters will just be reloaded on the next forward pass.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extra tokens only affects embedding memory pool. Can we only reinitialize embedding part? |
||
|
|
||
| def set_lora_module(self, module_name, module): | ||
| lora_module = get_lora_layer(module, self.lora_backend) | ||
| replace_submodule(self.base_model, module_name, lora_module) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.