Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 36 additions & 46 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# and "Punica: Multi-Tenant LoRA Serving"

import logging
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional

import torch
Expand Down Expand Up @@ -47,6 +48,13 @@
logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class LoRATensorPayload:
tensors: Dict[str, torch.Tensor]
config_dict: Dict
added_tokens_config: Optional[Dict]


class LoRAManager:
def __init__(
self,
Expand Down Expand Up @@ -116,12 +124,17 @@ def create_lora_update_result(
},
)

def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
def load_lora_adapter(
self,
lora_ref: LoRARef,
tensor_payload: Optional[LoRATensorPayload] = None,
) -> LoRAUpdateOutput:
"""
Load a single LoRA adapter from the specified path.
Load a single LoRA adapter from either the specified path or the tensors and config.

Args:
lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
tensor_payload (Optional[LoRATensorPayload]): Object containing the tensors, config, and optionally, added tokens config for this adapter.
"""
assert (
lora_ref.lora_name is not None and lora_ref.lora_path is not None
Expand All @@ -132,12 +145,20 @@ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:

try:
# load configs
new_adapter = LoRAConfig(lora_ref.lora_path)
if tensor_payload:
new_adapter = LoRAConfig.from_dict(
tensor_payload.config_dict, tensor_payload.added_tokens_config
)
else:
new_adapter = LoRAConfig(lora_ref.lora_path)

self.validate_new_adapter(new_adapter, lora_ref)
self.configs[lora_ref.lora_id] = new_adapter

# load weights
self.load_lora_weights(lora_ref)
self.load_lora_weights(
lora_ref, tensor_payload.tensors if tensor_payload is not None else None
)

# keep metadata for displayed messages
self.lora_refs[lora_ref.lora_id] = lora_ref
Expand All @@ -154,7 +175,6 @@ def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef):
"""
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
"""

# Check if this LoRA adapter is already loaded
for existing_lora_ref in self.lora_refs.values():
if lora_ref.lora_name == existing_lora_ref.lora_name:
Expand Down Expand Up @@ -436,7 +456,9 @@ def init_lora_shapes(
)
self.lora_added_tokens_size = inferred_extra_vocab_size

def load_lora_weights(self, lora_ref: LoRARef):
def load_lora_weights(
self, lora_ref: LoRARef, tensors: Optional[Dict[str, torch.Tensor]] = None
):
"""
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
"""
Expand All @@ -447,30 +469,18 @@ def load_lora_weights(self, lora_ref: LoRARef):
self.load_config,
self.lora_backend,
)
lora_adapter.initialize_weights()

if tensors is None:
lora_adapter.initialize_weights()
else:
lora_adapter.initialize_weights_from_tensors(tensors)

# If we want to overlap loading LoRA adapters with compute, they must be pinned in CPU memory
if self.enable_lora_overlap_loading:
lora_adapter.pin_weights_in_cpu()

self.loras[lora_ref.lora_id] = lora_adapter

def load_lora_weights_from_tensors(
self, lora_ref: LoRARef, tensors: Dict[str, torch.Tensor]
):
"""
Load the weights of a LoRA adapter from tensors to CPU memory.
"""
lora_adapter = LoRAAdapter(
lora_ref.lora_id,
self.configs[lora_ref.lora_id],
self.base_hf_config,
self.load_config,
self.lora_backend,
)
lora_adapter.initialize_weights_from_tensors(tensors)
self.loras[lora_ref.lora_id] = lora_adapter

def load_lora_adapter_from_tensors(
self,
lora_ref: LoRARef,
Expand All @@ -481,29 +491,9 @@ def load_lora_adapter_from_tensors(
"""
Load a single LoRA adapter from tensors and config dict.
"""
assert (
lora_ref.lora_name is not None and lora_ref.lora_path is not None
), "LoRARef must have both lora_name and lora_path set for loading."
assert (
lora_ref.lora_id not in self.loras
), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."

try:
new_adapter = LoRAConfig.from_dict(config_dict, added_tokens_config)
self.validate_new_adapter(new_adapter, lora_ref)
self.configs[lora_ref.lora_id] = new_adapter

self.load_lora_weights_from_tensors(lora_ref, tensors)

self.lora_refs[lora_ref.lora_id] = lora_ref
self.num_pinned_loras += int(lora_ref.pinned)
except Exception as e:
return self.create_lora_update_result(
success=False,
error_message=str(e),
)

return self.create_lora_update_result(success=True)
return self.load_lora_adapter(
lora_ref, LoRATensorPayload(tensors, config_dict, added_tokens_config)
)

def init_memory_pool(self):
"""(Re)initialize the LoRA memory pool based on the current configurations."""
Expand Down
Loading