diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 9d02958080c..afba645a9d7 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -16,7 +16,7 @@ # and "Punica: Multi-Tenant LoRA Serving" import logging -from typing import Dict, List, Set, Tuple +from typing import Dict, Set, Tuple import torch @@ -45,7 +45,6 @@ class LoRAManager: def __init__( self, base_model: torch.nn.Module, - lora_paths: Dict[str, str], base_hf_config: AutoConfig, max_loras_per_batch: int, load_config: LoadConfig, @@ -55,7 +54,6 @@ def __init__( tp_rank: int = 0, ): self.base_model: torch.nn.Module = base_model - self.lora_paths: Dict[str, str] = lora_paths self.base_hf_config: AutoConfig = base_hf_config self.max_loras_per_batch: int = max_loras_per_batch self.load_config: LoadConfig = load_config @@ -69,8 +67,8 @@ def __init__( backend_type = get_backend_from_name(lora_backend) self.lora_backend: BaseLoRABackend = backend_type(lora_backend) - self.init_loras() - self.init_lora_memory_pool() + # Initialize mutable internal state of the LoRAManager. + self.init_state() def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int): self.max_bs_in_cuda_graph = max_bs_in_cuda_graph @@ -100,72 +98,49 @@ def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int): ], ) - def init_loras(self): - # Config of each LoRA adapter - self.configs: Dict[str, LoRAConfig] = {} + def load_lora_adapters(self, lora_paths: Dict[str, str]): + """ + Load LoRA adapters from the specified paths. + TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading. + + Args: + lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths. + If a LoRA adapter is already loaded, it will be skipped with a warning. + """ + + for lora_name, lora_path in lora_paths.items(): + if lora_name in self.loras: + logger.warning( + f"LoRA adapter {lora_name} is already loaded." + "If you want to reload it, please unload it first." + ) + continue - # Target module names in huggingface lora configs. - # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"} - self.hf_target_names: Set[str] = set() - for name, path in self.lora_paths.items(): - self.configs[name] = LoRAConfig(path) - self.hf_target_names.update(self.configs[name].target_modules) + self.configs[lora_name] = LoRAConfig(lora_path) - # Target lora weight names for lora_a and lora_b modules respectively. - weights_A: List[str] = [] - weights_B: List[str] = [] - for module in self.hf_target_names: - lora_A, lora_B = get_normalized_lora_weight_names(module) - weights_A += lora_A - weights_B += lora_B - self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B) + self.update_state_from_configs() - # load all weights to cpu - self.loras: Dict[str, LoRAAdapter] = {} - for name in self.lora_paths.keys(): - lora_adapter = LoRAAdapter( - name, - self.configs[name], - self.base_hf_config, - self.load_config, - self.lora_backend, - ) - lora_adapter.initialize_weights() - self.loras[name] = lora_adapter - - # misc lora configs - self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()]) + def unload_lora_adapters(self, lora_names: Set[str]): + """ + Unload LoRA adapters by their names. This will remove the adapters from the memory pool and + delete the corresponding LoRA modules. - if self.lora_backend == "flashinfer": - # FIXME remove the restrictions after supporting multi-rank for flashinfer backend - max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()]) - scaling = list(self.loras.values())[0].scaling - assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values()) - assert all(x.scaling == scaling for x in self.loras.values()) + Args: + lora_names (Set[str]): A set of LoRA adapter names to unload. + """ + for lora_name in lora_names: + if lora_name in self.loras: + del self.configs[lora_name] + else: + logger.warning(f"LoRA adapter {lora_name} is not loaded.") - # Convert original model layers to layers with LoRA - self.convert_to_lora_layers() - - def init_lora_memory_pool(self): - # Initialize memory pool - self.memory_pool = LoRAMemoryPool( - self.base_hf_config, - self.max_loras_per_batch, - self.max_lora_dim, - self.dtype, - self.tp_size, - self.tp_rank, - self.lora_modules, - ) - - # Initialize target lora modules in memory pool - self.memory_pool.init_buffers(self.lora_weight_names, self.base_model) + self.update_state_from_configs() def prepare_lora_batch(self, forward_batch: ForwardBatch): # load active loras into lora memory pool cur_uids = set(forward_batch.lora_paths) assert len(cur_uids) <= self.max_loras_per_batch - self.memory_pool.prepare_lora_batch(cur_uids, self.loras) + self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules) # set up batch info shared by all lora modules bs = forward_batch.batch_size @@ -267,9 +242,16 @@ def transfer_adapter_info( ) self.lora_backend.set_batch_info(batch_info) - # call set_lora_info for each lora modules - for layer_id, modules in self.lora_modules.items(): - for module_name, module in modules: + # TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call + # this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch. + self.update_lora_info() + + def update_lora_info(self): + """ + Update all LoRA modules to associate them with the latest memory buffer. + """ + for layer_id, layer_modules in self.lora_modules.items(): + for module_name, module in layer_modules.items(): if "qkv_proj" in module_name: module.set_lora_info( self.memory_pool.get_tensor( @@ -295,23 +277,139 @@ def transfer_adapter_info( ), ) + def init_state(self): + """ + Initialize the internal (mutable) state of the LoRAManager. + + These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically. + """ + + # Configs of all active LoRA adapters. + self.configs: Dict[str, LoRAConfig] = {} + + # LoRA adapter weights cached in CPU memory. + self.loras: Dict[str, LoRAAdapter] = {} + + # Supported weight names (e.g., qkv_proj) for LoRA A and B respectively. + self.lora_weight_names: Tuple[Set[str]] = (set(), set()) + + # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module. + self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = { + i: {} for i in range(self.base_hf_config.num_hidden_layers) + } + + # Initialize memory pool + self.memory_pool = LoRAMemoryPool( + self.base_hf_config, + self.max_loras_per_batch, + self.dtype, + self.tp_size, + self.tp_rank, + ) + + def update_state_from_configs(self): + """ + Update the internal state of the LoRAManager based on the current `self.configs`. This method + should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded). + + This includes: + - Initializing LoRA adapters if they are not already loaded. + - Collect all LoRA weight names based on the current loaded adapters. + - Lazily monkey-patching the base model to use LoRA layers where applicable. + - Preparing the GPU buffer pool for active LoRA weights. + """ + + # Target module names in huggingface lora configs. + # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"} + hf_target_module_names: Set[str] = set() + for config in self.configs.values(): + hf_target_module_names.update(config.target_modules) + max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()]) + + # Loads / unloads LoRA adapters based on the latest configs. + self.update_lora_adapters() + + # Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed. + # + # Please note that the following update operations are "monotonic" by design, meaning that we update + # multiple places to support the new weight names when the first adapter targeting such weight names + # is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer) + # even if the associated adapters are unloaded later for both simplicity and practicality reasons: the + # list of LoRA weight names is expected to be extremely finite and stable. + self.update_lora_weight_names(hf_target_module_names) + self.update_lora_modules(hf_target_module_names) + self.update_memory_buffers(max_lora_dim) + + def update_lora_weight_names(self, hf_target_names: Set[str]): + """ + Add new LoRA weight names if needed based on the current `self.configs`. + """ + + # Target lora weight names for lora_a and lora_b modules respectively. + for module in hf_target_names: + lora_A, lora_B = get_normalized_lora_weight_names(module) + self.lora_weight_names[0].update(lora_A) + self.lora_weight_names[1].update(lora_B) + + def update_lora_adapters(self): + """ + Update the LoRA adapters in CPU memory based on the current `self.configs`. + It loads any new adapters that are not already loaded, and unloads any adapters + that are no longer in `self.configs` (e.g., unloaded). + """ + + # Load new adapter weights to cpu + for name, config in self.configs.items(): + if name not in self.loras: + logger.info(f"Loading weight of LoRA adapter {name} from {config.path}") + lora_adapter = LoRAAdapter( + name, + config, + self.base_hf_config, + self.load_config, + self.lora_backend, + ) + lora_adapter.initialize_weights() + self.loras[name] = lora_adapter + + # Clean up unused LoRA adapters + for name in self.loras: + if name not in self.configs: + logger.info(f"Unloading LoRA adapter {name}") + del self.loras[name] + + # Additional checks for flashinfer backend + # FIXME remove the restrictions after supporting multi-rank for flashinfer backend + if self.lora_backend == "flashinfer": + lora_dims = set(x.hf_config["r"] for x in self.configs.values()) + scalings = set(x.scaling for x in self.loras.values()) + assert ( + len(lora_dims) == 1 and len(scalings) == 1 + ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. " + + def update_memory_buffers(self, max_lora_dim: int): + """ + Update the LoRA memory pool buffers based on the current LoRA configurations and update + LoRA modules to use the new buffers. This method should be called after the LoRA configurations + are set or updated. + """ + + self.memory_pool.init_buffers( + self.lora_weight_names, self.base_model, max_lora_dim + ) + def set_lora_module(self, module_name, module): lora_module = get_lora_layer(module, self.lora_backend) replace_submodule(self.base_model, module_name, lora_module) return lora_module - def convert_to_lora_layers(self): + def update_lora_modules(self, hf_target_names: Set[str]): # Target module names of customized layers defined in python/sglang/srt/layers # e.g., {"qkv_proj", "o_proj"} customized_target_names = get_customized_names_from_hf_names( - self.hf_target_names, self.base_model + hf_target_names, self.base_model ) - # Monkey patch to use the LoRA version layers - self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = { - i: [] for i in range(self.base_hf_config.num_hidden_layers) - } - for module_name, module in self.base_model.named_modules(): # TODO (lifuhuang): in the future, we should consider generalizing the # should_apply_lora function to support mapping by full module name instead @@ -326,6 +424,7 @@ def convert_to_lora_layers(self): # The module should be converted if it is included in target_names if module_name.split(".")[-1] in customized_target_names: layer_id = get_layer_id(module_name) - self.lora_modules[layer_id].append( - (module_name, self.set_lora_module(module_name, module)) - ) + if module_name not in self.lora_modules[layer_id]: + self.lora_modules[layer_id][module_name] = self.set_lora_module( + module_name, module + ) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 7e69c4aabd0..27122ccc42c 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Set, Tuple +from typing import Callable, Dict, List, Optional, Set, Tuple import torch @@ -22,21 +22,16 @@ def __init__( self, base_hf_config: AutoConfig, max_loras_per_batch: int, - max_lora_dim: int, dtype: torch.dtype, tp_size: int, tp_rank: int, - lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]], ): - self.base_hf_config: AutoConfig = base_hf_config self.num_layer: int = base_hf_config.num_hidden_layers self.max_loras_per_batch: int = max_loras_per_batch - self.max_lora_dim: int = max_lora_dim self.dtype: torch.dtype = dtype self.tp_size: int = tp_size self.tp_rank: int = tp_rank - self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules # Both A_buffer and B_buffer maps lora weight names to its buffer space. # A_buffer contains num_layer number of row-major tensors with shape @@ -55,79 +50,84 @@ def __init__( self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch def get_lora_A_shape( - self, module_name: str, base_model: torch.nn.Module + self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int ) -> Tuple[int]: """ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. """ input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model) c = get_stacked_multiply(module_name) - if self.tp_size > 1: - if module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES: - input_dim = divide(input_dim, self.tp_size) + if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES: + input_dim = divide(input_dim, self.tp_size) return ( self.max_loras_per_batch, - self.max_lora_dim * c, + max_lora_dim * c, input_dim, ) def get_lora_B_shape( - self, module_name: str, base_model: torch.nn.Module + self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int ) -> Tuple[int]: """ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. """ _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model) c = get_stacked_multiply(module_name) - if self.tp_size > 1: - if module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES: - output_dim = divide(output_dim, self.tp_size) + if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES: + output_dim = divide(output_dim, self.tp_size) return ( c, self.max_loras_per_batch, output_dim, - self.max_lora_dim, + max_lora_dim, ) def init_buffers( self, lora_weight_names: Tuple[Set[str]], base_model: torch.nn.Module, + max_lora_dim: int, ): - # lora_weight_names is a set of name pairs indicating each pair of lora modules to load # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")} self.lora_weight_names: Tuple[Set[str]] = lora_weight_names device = next(base_model.parameters()).device - # Init A tensor, column_major=False - for module_A in lora_weight_names[0]: - lora_A_shape = self.get_lora_A_shape(module_A, base_model) - self.A_buffer[module_A] = [ - torch.empty( - lora_A_shape, - dtype=self.dtype, - device=device, - ) - for _ in range(self.num_layer) - ] - # Init B tensor, column_major=True - for module_B in lora_weight_names[1]: - lora_B_shape = self.get_lora_B_shape(module_B, base_model) - self.B_buffer[module_B] = [ - torch.empty( - lora_B_shape, - dtype=self.dtype, - device=device, - ) - for _ in range(self.num_layer) - ] + + def update_buffer( + buffer: Dict[str, List[torch.Tensor]], + lora_weight_names: Set[str], + get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]], + ): + new_weight_names = lora_weight_names - buffer.keys() + for module_name in new_weight_names: + lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim) + buffer[module_name] = [ + torch.empty( + lora_shape, + dtype=self.dtype, + device=device, + ) + for _ in range(self.num_layer) + ] + + update_buffer( + self.A_buffer, + lora_weight_names[0], + self.get_lora_A_shape, + ) + + update_buffer( + self.B_buffer, + lora_weight_names[1], + self.get_lora_B_shape, + ) def prepare_lora_batch( self, cur_uids: Set[Optional[str]], lora_adapters: Dict[str, LoRAAdapter], + lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]], ): - def get_available_buffer_slot(): for buffer_id in range(self.max_loras_per_batch): # Prioritize empty slots @@ -147,14 +147,19 @@ def get_available_buffer_slot(): for uid in cur_uids: if uid not in self.uid_to_buffer_id: buffer_id = get_available_buffer_slot() + lora_adapter = lora_adapters.get(uid, None) self.load_lora_weight_to_buffer( - uid, buffer_id, lora_adapters.get(uid, None) + uid, buffer_id, lora_adapter, lora_modules ) self.uid_to_buffer_id[uid] = buffer_id self.buffer_id_to_uid[buffer_id] = uid def load_lora_weight_to_buffer( - self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None + self, + uid: str, + buffer_id: int, + lora_adapter: LoRAAdapter, + lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]], ): def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor): assert ( @@ -186,8 +191,8 @@ def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor): temp_B_buffer[lora_weight_name] = weights if self.tp_size > 1: - cur_layer_modules = self.lora_modules[layer_id] - for module_name, module in cur_layer_modules: + cur_layer_modules = lora_modules[layer_id] + for module_name, module in cur_layer_modules.items(): if "qkv_proj" in module_name: temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights( temp_A_buffer["qkv_proj"], self.tp_rank @@ -236,7 +241,6 @@ def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor): def get_tensor( self, weight_name: str, layer_id: int, lora_type: LoRAType ) -> torch.Tensor: - if lora_type == LoRAType.LORA_A: return self.A_buffer[weight_name][layer_id] diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 445b896046e..2df4a8c14f2 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -108,7 +108,7 @@ def get_hidden_dim( def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]: """ - Mapping a target module name to names of the normized LoRA weights. + Mapping a target module name to names of the normalized LoRA weights. Returned tuple contains (name for Lora A, name for Lora B) """ params_mapping = { diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e1335074aa5..32d1e18da8c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -278,6 +278,10 @@ def initialize(self, min_per_gpu_memory: float): self.apply_torch_tp() # Init lora + # TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add + # a new server arg `enable_lora` to control whether to init LoRA manager to be more + # explicit, as it is perfectly valid to start a server with an empty lora_paths and + # load LoRA adapters dynamically later. if server_args.lora_paths is not None: self.init_lora_manager() @@ -796,7 +800,6 @@ def get_weights_by_name( def init_lora_manager(self): self.lora_manager = LoRAManager( base_model=self.model, - lora_paths=self.server_args.lora_paths, base_hf_config=self.model_config.hf_config, max_loras_per_batch=self.server_args.max_loras_per_batch, load_config=self.load_config, @@ -805,6 +808,7 @@ def init_lora_manager(self): tp_size=self.tp_size, tp_rank=self.tp_rank, ) + self.lora_manager.load_lora_adapters(self.server_args.lora_paths) logger.info("LoRA manager ready.") def profile_max_num_token(self, total_gpu_memory: int):