-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Refactor LoRAManager and LoRAMemoryPool state management logic for dynamic LoRA loading support #7412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Refactor LoRAManager and LoRAMemoryPool state management logic for dynamic LoRA loading support #7412
Changes from 8 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
a861dfb
Refactor LoRAManager and LoRAMemoryPool for dynamic LoRA loading supp…
lifuhuang 25c3844
Refactor LoRAManager and LoRAMemoryPool for dynamic LoRA loading supp…
lifuhuang 6012d21
Merge remote-tracking branch 'origin/lifuhuang/dynamic-lora' into lif…
lifuhuang ee25dd6
Checkpoint.
lifuhuang fd46ee1
Merge remote-tracking branch 'origin/main' into lifuhuang/dynamic-lora
lifuhuang 09d6b29
Reset files.
lifuhuang 66fbf1d
Update.
lifuhuang dc16f49
Merge branch 'main' into lifuhuang/dynamic-lora
lifuhuang 8469e3a
Merge branch 'main' into lifuhuang/dynamic-lora
lifuhuang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| # and "Punica: Multi-Tenant LoRA Serving" | ||
|
|
||
| import logging | ||
| from typing import Dict, List, Set, Tuple | ||
| from typing import Dict, Set, Tuple | ||
|
|
||
| import torch | ||
|
|
||
|
|
@@ -45,7 +45,6 @@ class LoRAManager: | |
| def __init__( | ||
| self, | ||
| base_model: torch.nn.Module, | ||
| lora_paths: Dict[str, str], | ||
| base_hf_config: AutoConfig, | ||
| max_loras_per_batch: int, | ||
| load_config: LoadConfig, | ||
|
|
@@ -55,7 +54,6 @@ def __init__( | |
| tp_rank: int = 0, | ||
| ): | ||
| self.base_model: torch.nn.Module = base_model | ||
| self.lora_paths: Dict[str, str] = lora_paths | ||
| self.base_hf_config: AutoConfig = base_hf_config | ||
| self.max_loras_per_batch: int = max_loras_per_batch | ||
| self.load_config: LoadConfig = load_config | ||
|
|
@@ -69,8 +67,8 @@ def __init__( | |
| backend_type = get_backend_from_name(lora_backend) | ||
| self.lora_backend: BaseLoRABackend = backend_type(lora_backend) | ||
|
|
||
| self.init_loras() | ||
| self.init_lora_memory_pool() | ||
| # Initialize mutable internal state of the LoRAManager. | ||
| self.init_state() | ||
|
|
||
| def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int): | ||
| self.max_bs_in_cuda_graph = max_bs_in_cuda_graph | ||
|
|
@@ -100,72 +98,49 @@ def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int): | |
| ], | ||
| ) | ||
|
|
||
| def init_loras(self): | ||
| # Config of each LoRA adapter | ||
| self.configs: Dict[str, LoRAConfig] = {} | ||
| def load_lora_adapters(self, lora_paths: Dict[str, str]): | ||
| """ | ||
| Load LoRA adapters from the specified paths. | ||
| TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading. | ||
|
|
||
| Args: | ||
| lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths. | ||
| If a LoRA adapter is already loaded, it will be skipped with a warning. | ||
| """ | ||
|
|
||
| for lora_name, lora_path in lora_paths.items(): | ||
| if lora_name in self.loras: | ||
| logger.warning( | ||
| f"LoRA adapter {lora_name} is already loaded." | ||
| "If you want to reload it, please unload it first." | ||
| ) | ||
| continue | ||
|
|
||
| # Target module names in huggingface lora configs. | ||
| # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"} | ||
| self.hf_target_names: Set[str] = set() | ||
| for name, path in self.lora_paths.items(): | ||
| self.configs[name] = LoRAConfig(path) | ||
| self.hf_target_names.update(self.configs[name].target_modules) | ||
| self.configs[lora_name] = LoRAConfig(lora_path) | ||
|
|
||
| # Target lora weight names for lora_a and lora_b modules respectively. | ||
| weights_A: List[str] = [] | ||
| weights_B: List[str] = [] | ||
| for module in self.hf_target_names: | ||
| lora_A, lora_B = get_normalized_lora_weight_names(module) | ||
| weights_A += lora_A | ||
| weights_B += lora_B | ||
| self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B) | ||
| self.update_state_from_configs() | ||
|
|
||
| # load all weights to cpu | ||
| self.loras: Dict[str, LoRAAdapter] = {} | ||
| for name in self.lora_paths.keys(): | ||
| lora_adapter = LoRAAdapter( | ||
| name, | ||
| self.configs[name], | ||
| self.base_hf_config, | ||
| self.load_config, | ||
| self.lora_backend, | ||
| ) | ||
| lora_adapter.initialize_weights() | ||
| self.loras[name] = lora_adapter | ||
|
|
||
| # misc lora configs | ||
| self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()]) | ||
| def unload_lora_adapters(self, lora_names: Set[str]): | ||
| """ | ||
| Unload LoRA adapters by their names. This will remove the adapters from the memory pool and | ||
| delete the corresponding LoRA modules. | ||
|
|
||
| if self.lora_backend == "flashinfer": | ||
| # FIXME remove the restrictions after supporting multi-rank for flashinfer backend | ||
| max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()]) | ||
| scaling = list(self.loras.values())[0].scaling | ||
| assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values()) | ||
| assert all(x.scaling == scaling for x in self.loras.values()) | ||
| Args: | ||
| lora_names (Set[str]): A set of LoRA adapter names to unload. | ||
| """ | ||
| for lora_name in lora_names: | ||
| if lora_name in self.loras: | ||
| del self.configs[lora_name] | ||
| else: | ||
| logger.warning(f"LoRA adapter {lora_name} is not loaded.") | ||
|
|
||
| # Convert original model layers to layers with LoRA | ||
| self.convert_to_lora_layers() | ||
|
|
||
| def init_lora_memory_pool(self): | ||
| # Initialize memory pool | ||
| self.memory_pool = LoRAMemoryPool( | ||
| self.base_hf_config, | ||
| self.max_loras_per_batch, | ||
| self.max_lora_dim, | ||
| self.dtype, | ||
| self.tp_size, | ||
| self.tp_rank, | ||
| self.lora_modules, | ||
| ) | ||
|
|
||
| # Initialize target lora modules in memory pool | ||
| self.memory_pool.init_buffers(self.lora_weight_names, self.base_model) | ||
| self.update_state_from_configs() | ||
|
|
||
| def prepare_lora_batch(self, forward_batch: ForwardBatch): | ||
| # load active loras into lora memory pool | ||
| cur_uids = set(forward_batch.lora_paths) | ||
| assert len(cur_uids) <= self.max_loras_per_batch | ||
| self.memory_pool.prepare_lora_batch(cur_uids, self.loras) | ||
| self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules) | ||
|
|
||
| # set up batch info shared by all lora modules | ||
| bs = forward_batch.batch_size | ||
|
|
@@ -267,9 +242,16 @@ def transfer_adapter_info( | |
| ) | ||
| self.lora_backend.set_batch_info(batch_info) | ||
|
|
||
| # call set_lora_info for each lora modules | ||
| for layer_id, modules in self.lora_modules.items(): | ||
| for module_name, module in modules: | ||
| # TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call | ||
| # this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch. | ||
| self.update_lora_info() | ||
lifuhuang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def update_lora_info(self): | ||
| """ | ||
| Update all LoRA modules to associate them with the latest memory buffer. | ||
| """ | ||
| for layer_id, layer_modules in self.lora_modules.items(): | ||
| for module_name, module in layer_modules.items(): | ||
| if "qkv_proj" in module_name: | ||
| module.set_lora_info( | ||
| self.memory_pool.get_tensor( | ||
|
|
@@ -295,23 +277,139 @@ def transfer_adapter_info( | |
| ), | ||
| ) | ||
|
|
||
| def init_state(self): | ||
| """ | ||
| Initialize the internal (mutable) state of the LoRAManager. | ||
|
|
||
| These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically. | ||
| """ | ||
|
|
||
| # Configs of all active LoRA adapters. | ||
| self.configs: Dict[str, LoRAConfig] = {} | ||
|
|
||
| # LoRA adapter weights cached in CPU memory. | ||
| self.loras: Dict[str, LoRAAdapter] = {} | ||
|
|
||
| # Supported weight names (e.g., qkv_proj) for LoRA A and B respectively. | ||
| self.lora_weight_names: Tuple[Set[str]] = (set(), set()) | ||
|
|
||
| # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module. | ||
| self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = { | ||
| i: {} for i in range(self.base_hf_config.num_hidden_layers) | ||
| } | ||
|
|
||
| # Initialize memory pool | ||
| self.memory_pool = LoRAMemoryPool( | ||
| self.base_hf_config, | ||
| self.max_loras_per_batch, | ||
| self.dtype, | ||
| self.tp_size, | ||
| self.tp_rank, | ||
| ) | ||
|
|
||
| def update_state_from_configs(self): | ||
| """ | ||
| Update the internal state of the LoRAManager based on the current `self.configs`. This method | ||
| should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded). | ||
|
|
||
| This includes: | ||
| - Initializing LoRA adapters if they are not already loaded. | ||
| - Collect all LoRA weight names based on the current loaded adapters. | ||
| - Lazily monkey-patching the base model to use LoRA layers where applicable. | ||
| - Preparing the GPU buffer pool for active LoRA weights. | ||
| """ | ||
|
|
||
| # Target module names in huggingface lora configs. | ||
| # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"} | ||
| hf_target_module_names: Set[str] = set() | ||
| for config in self.configs.values(): | ||
| hf_target_module_names.update(config.target_modules) | ||
| max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()]) | ||
|
|
||
| # Loads / unloads LoRA adapters based on the latest configs. | ||
| self.update_lora_adapters() | ||
|
|
||
| # Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed. | ||
| # | ||
| # Please note that the following update operations are "monotonic" by design, meaning that we update | ||
| # multiple places to support the new weight names when the first adapter targeting such weight names | ||
| # is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer) | ||
| # even if the associated adapters are unloaded later for both simplicity and practicality reasons: the | ||
| # list of LoRA weight names is expected to be extremely finite and stable. | ||
|
Comment on lines
+334
to
+338
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| self.update_lora_weight_names(hf_target_module_names) | ||
| self.update_lora_modules(hf_target_module_names) | ||
| self.update_memory_buffers(max_lora_dim) | ||
|
|
||
| def update_lora_weight_names(self, hf_target_names: Set[str]): | ||
| """ | ||
| Add new LoRA weight names if needed based on the current `self.configs`. | ||
| """ | ||
|
|
||
| # Target lora weight names for lora_a and lora_b modules respectively. | ||
| for module in hf_target_names: | ||
| lora_A, lora_B = get_normalized_lora_weight_names(module) | ||
| self.lora_weight_names[0].update(lora_A) | ||
| self.lora_weight_names[1].update(lora_B) | ||
|
|
||
| def update_lora_adapters(self): | ||
| """ | ||
| Update the LoRA adapters in CPU memory based on the current `self.configs`. | ||
| It loads any new adapters that are not already loaded, and unloads any adapters | ||
| that are no longer in `self.configs` (e.g., unloaded). | ||
| """ | ||
|
|
||
| # Load new adapter weights to cpu | ||
| for name, config in self.configs.items(): | ||
| if name not in self.loras: | ||
| logger.info(f"Loading weight of LoRA adapter {name} from {config.path}") | ||
| lora_adapter = LoRAAdapter( | ||
| name, | ||
| config, | ||
| self.base_hf_config, | ||
| self.load_config, | ||
| self.lora_backend, | ||
| ) | ||
| lora_adapter.initialize_weights() | ||
| self.loras[name] = lora_adapter | ||
|
|
||
| # Clean up unused LoRA adapters | ||
| for name in self.loras: | ||
| if name not in self.configs: | ||
| logger.info(f"Unloading LoRA adapter {name}") | ||
| del self.loras[name] | ||
|
|
||
| # Additional checks for flashinfer backend | ||
| # FIXME remove the restrictions after supporting multi-rank for flashinfer backend | ||
| if self.lora_backend == "flashinfer": | ||
| lora_dims = set(x.hf_config["r"] for x in self.configs.values()) | ||
| scalings = set(x.scaling for x in self.loras.values()) | ||
| assert ( | ||
| len(lora_dims) == 1 and len(scalings) == 1 | ||
| ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. " | ||
|
|
||
| def update_memory_buffers(self, max_lora_dim: int): | ||
| """ | ||
| Update the LoRA memory pool buffers based on the current LoRA configurations and update | ||
| LoRA modules to use the new buffers. This method should be called after the LoRA configurations | ||
| are set or updated. | ||
| """ | ||
|
|
||
| self.memory_pool.init_buffers( | ||
| self.lora_weight_names, self.base_model, max_lora_dim | ||
| ) | ||
|
|
||
| def set_lora_module(self, module_name, module): | ||
| lora_module = get_lora_layer(module, self.lora_backend) | ||
| replace_submodule(self.base_model, module_name, lora_module) | ||
| return lora_module | ||
|
|
||
| def convert_to_lora_layers(self): | ||
| def update_lora_modules(self, hf_target_names: Set[str]): | ||
| # Target module names of customized layers defined in python/sglang/srt/layers | ||
| # e.g., {"qkv_proj", "o_proj"} | ||
| customized_target_names = get_customized_names_from_hf_names( | ||
| self.hf_target_names, self.base_model | ||
| hf_target_names, self.base_model | ||
| ) | ||
|
|
||
| # Monkey patch to use the LoRA version layers | ||
| self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = { | ||
| i: [] for i in range(self.base_hf_config.num_hidden_layers) | ||
| } | ||
|
|
||
| for module_name, module in self.base_model.named_modules(): | ||
| # TODO (lifuhuang): in the future, we should consider generalizing the | ||
| # should_apply_lora function to support mapping by full module name instead | ||
|
|
@@ -326,6 +424,7 @@ def convert_to_lora_layers(self): | |
| # The module should be converted if it is included in target_names | ||
| if module_name.split(".")[-1] in customized_target_names: | ||
| layer_id = get_layer_id(module_name) | ||
| self.lora_modules[layer_id].append( | ||
| (module_name, self.set_lora_module(module_name, module)) | ||
| ) | ||
| if module_name not in self.lora_modules[layer_id]: | ||
| self.lora_modules[layer_id][module_name] = self.set_lora_module( | ||
| module_name, module | ||
| ) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.