Skip to content
247 changes: 173 additions & 74 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# and "Punica: Multi-Tenant LoRA Serving"

import logging
from typing import Dict, List, Set, Tuple
from typing import Dict, Set, Tuple

import torch

Expand Down Expand Up @@ -45,7 +45,6 @@ class LoRAManager:
def __init__(
self,
base_model: torch.nn.Module,
lora_paths: Dict[str, str],
base_hf_config: AutoConfig,
max_loras_per_batch: int,
load_config: LoadConfig,
Expand All @@ -55,7 +54,6 @@ def __init__(
tp_rank: int = 0,
):
self.base_model: torch.nn.Module = base_model
self.lora_paths: Dict[str, str] = lora_paths
self.base_hf_config: AutoConfig = base_hf_config
self.max_loras_per_batch: int = max_loras_per_batch
self.load_config: LoadConfig = load_config
Expand All @@ -69,8 +67,8 @@ def __init__(
backend_type = get_backend_from_name(lora_backend)
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)

self.init_loras()
self.init_lora_memory_pool()
# Initialize mutable internal state of the LoRAManager.
self.init_state()

def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
Expand Down Expand Up @@ -100,72 +98,49 @@ def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
],
)

def init_loras(self):
# Config of each LoRA adapter
self.configs: Dict[str, LoRAConfig] = {}
def load_lora_adapters(self, lora_paths: Dict[str, str]):
"""
Load LoRA adapters from the specified paths.
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.

Args:
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
If a LoRA adapter is already loaded, it will be skipped with a warning.
"""

for lora_name, lora_path in lora_paths.items():
if lora_name in self.loras:
logger.warning(
f"LoRA adapter {lora_name} is already loaded."
"If you want to reload it, please unload it first."
)
continue

# Target module names in huggingface lora configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
self.hf_target_names: Set[str] = set()
for name, path in self.lora_paths.items():
self.configs[name] = LoRAConfig(path)
self.hf_target_names.update(self.configs[name].target_modules)
self.configs[lora_name] = LoRAConfig(lora_path)

# Target lora weight names for lora_a and lora_b modules respectively.
weights_A: List[str] = []
weights_B: List[str] = []
for module in self.hf_target_names:
lora_A, lora_B = get_normalized_lora_weight_names(module)
weights_A += lora_A
weights_B += lora_B
self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B)
self.update_state_from_configs()

# load all weights to cpu
self.loras: Dict[str, LoRAAdapter] = {}
for name in self.lora_paths.keys():
lora_adapter = LoRAAdapter(
name,
self.configs[name],
self.base_hf_config,
self.load_config,
self.lora_backend,
)
lora_adapter.initialize_weights()
self.loras[name] = lora_adapter

# misc lora configs
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
def unload_lora_adapters(self, lora_names: Set[str]):
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.

if self.lora_backend == "flashinfer":
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
scaling = list(self.loras.values())[0].scaling
assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
assert all(x.scaling == scaling for x in self.loras.values())
Args:
lora_names (Set[str]): A set of LoRA adapter names to unload.
"""
for lora_name in lora_names:
if lora_name in self.loras:
del self.configs[lora_name]
else:
logger.warning(f"LoRA adapter {lora_name} is not loaded.")

# Convert original model layers to layers with LoRA
self.convert_to_lora_layers()

def init_lora_memory_pool(self):
# Initialize memory pool
self.memory_pool = LoRAMemoryPool(
self.base_hf_config,
self.max_loras_per_batch,
self.max_lora_dim,
self.dtype,
self.tp_size,
self.tp_rank,
self.lora_modules,
)

# Initialize target lora modules in memory pool
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
self.update_state_from_configs()

def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)

# set up batch info shared by all lora modules
bs = forward_batch.batch_size
Expand Down Expand Up @@ -267,9 +242,16 @@ def transfer_adapter_info(
)
self.lora_backend.set_batch_info(batch_info)

# call set_lora_info for each lora modules
for layer_id, modules in self.lora_modules.items():
for module_name, module in modules:
# TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
# this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
self.update_lora_info()

def update_lora_info(self):
"""
Update all LoRA modules to associate them with the latest memory buffer.
"""
for layer_id, layer_modules in self.lora_modules.items():
for module_name, module in layer_modules.items():
if "qkv_proj" in module_name:
module.set_lora_info(
self.memory_pool.get_tensor(
Expand All @@ -295,23 +277,139 @@ def transfer_adapter_info(
),
)

def init_state(self):
"""
Initialize the internal (mutable) state of the LoRAManager.

These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
"""

# Configs of all active LoRA adapters.
self.configs: Dict[str, LoRAConfig] = {}

# LoRA adapter weights cached in CPU memory.
self.loras: Dict[str, LoRAAdapter] = {}

# Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
self.lora_weight_names: Tuple[Set[str]] = (set(), set())

# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
i: {} for i in range(self.base_hf_config.num_hidden_layers)
}

# Initialize memory pool
self.memory_pool = LoRAMemoryPool(
self.base_hf_config,
self.max_loras_per_batch,
self.dtype,
self.tp_size,
self.tp_rank,
)

def update_state_from_configs(self):
"""
Update the internal state of the LoRAManager based on the current `self.configs`. This method
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).

This includes:
- Initializing LoRA adapters if they are not already loaded.
- Collect all LoRA weight names based on the current loaded adapters.
- Lazily monkey-patching the base model to use LoRA layers where applicable.
- Preparing the GPU buffer pool for active LoRA weights.
"""

# Target module names in huggingface lora configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
hf_target_module_names: Set[str] = set()
for config in self.configs.values():
hf_target_module_names.update(config.target_modules)
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])

# Loads / unloads LoRA adapters based on the latest configs.
self.update_lora_adapters()

# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
#
# Please note that the following update operations are "monotonic" by design, meaning that we update
# multiple places to support the new weight names when the first adapter targeting such weight names
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
# list of LoRA weight names is expected to be extremely finite and stable.
Comment on lines +334 to +338
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment is very long and complex. Consider simplifying the explanation or breaking it into smaller, more manageable comments. It might also be helpful to provide a link to a design document or issue that explains the reasoning behind this design choice.

self.update_lora_weight_names(hf_target_module_names)
self.update_lora_modules(hf_target_module_names)
self.update_memory_buffers(max_lora_dim)

def update_lora_weight_names(self, hf_target_names: Set[str]):
"""
Add new LoRA weight names if needed based on the current `self.configs`.
"""

# Target lora weight names for lora_a and lora_b modules respectively.
for module in hf_target_names:
lora_A, lora_B = get_normalized_lora_weight_names(module)
self.lora_weight_names[0].update(lora_A)
self.lora_weight_names[1].update(lora_B)

def update_lora_adapters(self):
"""
Update the LoRA adapters in CPU memory based on the current `self.configs`.
It loads any new adapters that are not already loaded, and unloads any adapters
that are no longer in `self.configs` (e.g., unloaded).
"""

# Load new adapter weights to cpu
for name, config in self.configs.items():
if name not in self.loras:
logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
lora_adapter = LoRAAdapter(
name,
config,
self.base_hf_config,
self.load_config,
self.lora_backend,
)
lora_adapter.initialize_weights()
self.loras[name] = lora_adapter

# Clean up unused LoRA adapters
for name in self.loras:
if name not in self.configs:
logger.info(f"Unloading LoRA adapter {name}")
del self.loras[name]

# Additional checks for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
if self.lora_backend == "flashinfer":
lora_dims = set(x.hf_config["r"] for x in self.configs.values())
scalings = set(x.scaling for x in self.loras.values())
assert (
len(lora_dims) == 1 and len(scalings) == 1
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "

def update_memory_buffers(self, max_lora_dim: int):
"""
Update the LoRA memory pool buffers based on the current LoRA configurations and update
LoRA modules to use the new buffers. This method should be called after the LoRA configurations
are set or updated.
"""

self.memory_pool.init_buffers(
self.lora_weight_names, self.base_model, max_lora_dim
)

def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(module, self.lora_backend)
replace_submodule(self.base_model, module_name, lora_module)
return lora_module

def convert_to_lora_layers(self):
def update_lora_modules(self, hf_target_names: Set[str]):
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names(
self.hf_target_names, self.base_model
hf_target_names, self.base_model
)

# Monkey patch to use the LoRA version layers
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
i: [] for i in range(self.base_hf_config.num_hidden_layers)
}

for module_name, module in self.base_model.named_modules():
# TODO (lifuhuang): in the future, we should consider generalizing the
# should_apply_lora function to support mapping by full module name instead
Expand All @@ -326,6 +424,7 @@ def convert_to_lora_layers(self):
# The module should be converted if it is included in target_names
if module_name.split(".")[-1] in customized_target_names:
layer_id = get_layer_id(module_name)
self.lora_modules[layer_id].append(
(module_name, self.set_lora_module(module_name, module))
)
if module_name not in self.lora_modules[layer_id]:
self.lora_modules[layer_id][module_name] = self.set_lora_module(
module_name, module
)
Loading
Loading