From eadfae21f379416c0527950fa1cbb9cc9b0967f3 Mon Sep 17 00:00:00 2001 From: glenliu21 Date: Wed, 4 Feb 2026 23:27:15 -0500 Subject: [PATCH] cleanup lora load logic --- python/sglang/srt/lora/lora_manager.py | 82 +++++++++++--------------- 1 file changed, 36 insertions(+), 46 deletions(-) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index b5d38dcd08d0..bea2668fe017 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -16,6 +16,7 @@ # and "Punica: Multi-Tenant LoRA Serving" import logging +from dataclasses import dataclass from typing import Dict, Iterable, List, Optional import torch @@ -47,6 +48,13 @@ logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class LoRATensorPayload: + tensors: Dict[str, torch.Tensor] + config_dict: Dict + added_tokens_config: Optional[Dict] + + class LoRAManager: def __init__( self, @@ -116,12 +124,17 @@ def create_lora_update_result( }, ) - def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: + def load_lora_adapter( + self, + lora_ref: LoRARef, + tensor_payload: Optional[LoRATensorPayload] = None, + ) -> LoRAUpdateOutput: """ - Load a single LoRA adapter from the specified path. + Load a single LoRA adapter from either the specified path or the tensors and config. Args: lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID. + tensor_payload (Optional[LoRATensorPayload]): Object containing the tensors, config, and optionally, added tokens config for this adapter. """ assert ( lora_ref.lora_name is not None and lora_ref.lora_path is not None @@ -132,12 +145,20 @@ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: try: # load configs - new_adapter = LoRAConfig(lora_ref.lora_path) + if tensor_payload: + new_adapter = LoRAConfig.from_dict( + tensor_payload.config_dict, tensor_payload.added_tokens_config + ) + else: + new_adapter = LoRAConfig(lora_ref.lora_path) + self.validate_new_adapter(new_adapter, lora_ref) self.configs[lora_ref.lora_id] = new_adapter # load weights - self.load_lora_weights(lora_ref) + self.load_lora_weights( + lora_ref, tensor_payload.tensors if tensor_payload is not None else None + ) # keep metadata for displayed messages self.lora_refs[lora_ref.lora_id] = lora_ref @@ -154,7 +175,6 @@ def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef): """ Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible. """ - # Check if this LoRA adapter is already loaded for existing_lora_ref in self.lora_refs.values(): if lora_ref.lora_name == existing_lora_ref.lora_name: @@ -436,7 +456,9 @@ def init_lora_shapes( ) self.lora_added_tokens_size = inferred_extra_vocab_size - def load_lora_weights(self, lora_ref: LoRARef): + def load_lora_weights( + self, lora_ref: LoRARef, tensors: Optional[Dict[str, torch.Tensor]] = None + ): """ Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation. """ @@ -447,7 +469,11 @@ def load_lora_weights(self, lora_ref: LoRARef): self.load_config, self.lora_backend, ) - lora_adapter.initialize_weights() + + if tensors is None: + lora_adapter.initialize_weights() + else: + lora_adapter.initialize_weights_from_tensors(tensors) # If we want to overlap loading LoRA adapters with compute, they must be pinned in CPU memory if self.enable_lora_overlap_loading: @@ -455,22 +481,6 @@ def load_lora_weights(self, lora_ref: LoRARef): self.loras[lora_ref.lora_id] = lora_adapter - def load_lora_weights_from_tensors( - self, lora_ref: LoRARef, tensors: Dict[str, torch.Tensor] - ): - """ - Load the weights of a LoRA adapter from tensors to CPU memory. - """ - lora_adapter = LoRAAdapter( - lora_ref.lora_id, - self.configs[lora_ref.lora_id], - self.base_hf_config, - self.load_config, - self.lora_backend, - ) - lora_adapter.initialize_weights_from_tensors(tensors) - self.loras[lora_ref.lora_id] = lora_adapter - def load_lora_adapter_from_tensors( self, lora_ref: LoRARef, @@ -481,29 +491,9 @@ def load_lora_adapter_from_tensors( """ Load a single LoRA adapter from tensors and config dict. """ - assert ( - lora_ref.lora_name is not None and lora_ref.lora_path is not None - ), "LoRARef must have both lora_name and lora_path set for loading." - assert ( - lora_ref.lora_id not in self.loras - ), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend." - - try: - new_adapter = LoRAConfig.from_dict(config_dict, added_tokens_config) - self.validate_new_adapter(new_adapter, lora_ref) - self.configs[lora_ref.lora_id] = new_adapter - - self.load_lora_weights_from_tensors(lora_ref, tensors) - - self.lora_refs[lora_ref.lora_id] = lora_ref - self.num_pinned_loras += int(lora_ref.pinned) - except Exception as e: - return self.create_lora_update_result( - success=False, - error_message=str(e), - ) - - return self.create_lora_update_result(success=True) + return self.load_lora_adapter( + lora_ref, LoRATensorPayload(tensors, config_dict, added_tokens_config) + ) def init_memory_pool(self): """(Re)initialize the LoRA memory pool based on the current configurations."""