diff --git a/docs/advanced_features/lora.ipynb b/docs/advanced_features/lora.ipynb index c01e9fbfb754..082a5c279240 100644 --- a/docs/advanced_features/lora.ipynb +++ b/docs/advanced_features/lora.ipynb @@ -29,6 +29,8 @@ "\n", "* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n", "\n", + "* `enable_lora_overlap_loading`: Enable asynchronous LoRA weight loading in order to overlap H2D transfers with GPU compute. This should be enabled if you find that your LoRA workloads are bottlenecked by adapter weight loading, for example when frequently loading large LoRA adapters.\n", + "\n", "* `lora_paths`: The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: | = | JSON with schema {\"lora_name\":str,\"lora_path\":str,\"pinned\":bool}.\n", "\n", "* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n", diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 38ad385cfafd..425200dabdc1 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -225,6 +225,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Argument | Description | Defaults | Options | | --- | --- | --- | --- | | `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to `True` if `--lora-paths` is provided for backward compatibility. | `False` | Bool flag (set to enable) | +| `--enable-lora-overlap-loading` | Enable asynchronous LoRA weight loading in order to overlap H2D transfers with GPU compute. This should be enabled if you find that your LoRA workloads are bottlenecked by adapter weight loading, for example when frequently loading large LoRA adapters. | `False` | Bool flag (set to enable) | `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | `None` | Type: int | | `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. You can also set it to `all` to enable LoRA for all supported modules; note this may introduce minor performance overhead. | `None` | `q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`, `qkv_proj`, `gate_up_proj`, `all` | | `--lora-paths` | The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: `` \| `=` \| JSON with schema `{"lora_name": str, "lora_path": str, "pinned": bool}`. | `None` | Type: List[str] / JSON objects | diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 1745c4f9633b..a3478a8f8378 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -208,3 +208,14 @@ def normalize_gate_up_proj( if "lora_A" in weight_name: weights[gate_up_name] = weights[gate_up_name].repeat(2, 1) # else: no-op as LoRA B weight is already stacked. + + def pin_weights_in_cpu(self): + for layer in self.layers: + for name, weight in layer.weights.items(): + layer.weights[name] = weight.pin_memory() + + for name, weight in self.embedding_layers.items(): + self.embedding_layers[name] = weight.pin_memory() + + for name, weight in self.added_tokens_embeddings.items(): + self.added_tokens_embeddings[name] = weight.pin_memory() diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 1d44d7fe929f..6d8769e6971b 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -72,6 +72,9 @@ def __init__( self.tp_size: int = tp_size self.tp_rank: int = tp_rank self.lora_added_tokens_size: Optional[int] = None + self.enable_lora_overlap_loading: Optional[bool] = ( + server_args.enable_lora_overlap_loading + ) # Store eviction policy from server args self.eviction_policy = server_args.lora_eviction_policy @@ -208,7 +211,7 @@ def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: return self.create_lora_update_result(success=True) - def validate_lora_batch(self, lora_ids: set[str]) -> bool: + def validate_lora_batch(self, lora_ids: set[Optional[str]]) -> bool: """ Validate if the LoRA IDs in the batch can be loaded into the current LoRA memory pool. """ @@ -239,9 +242,11 @@ def validate_lora_batch(self, lora_ids: set[str]) -> bool: return required_slots <= mem_pool_vacancy - def prepare_lora_batch(self, forward_batch: ForwardBatch): + def fetch_new_loras( + self, new_loras: set[Optional[str]], running_loras: set[Optional[str]] = set() + ): # Load active loras into lora memory pool - cur_uids = set(forward_batch.lora_ids) + cur_uids = new_loras | running_loras assert len(cur_uids) <= self.max_loras_per_batch self.memory_pool.prepare_lora_batch( @@ -253,6 +258,7 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): lora_lm_head_module=self.lm_head_module, # merge into embedding or lora module ) + def prepare_lora_batch(self, forward_batch: ForwardBatch): # set up batch info shared by all lora modules bs = forward_batch.batch_size @@ -442,6 +448,11 @@ def load_lora_weights(self, lora_ref: LoRARef): self.lora_backend, ) lora_adapter.initialize_weights() + + # If we want to overlap loading LoRA adapters with compute, they must be pinned in CPU memory + if self.enable_lora_overlap_loading: + lora_adapter.pin_weights_in_cpu() + self.loras[lora_ref.lora_id] = lora_adapter def load_lora_weights_from_tensors( @@ -509,6 +520,9 @@ def init_memory_pool(self): lora_added_tokens_size=self.lora_added_tokens_size, ) + # Initializing memory pool with base model + self.fetch_new_loras({None}) + def set_lora_module(self, module_name, module): lora_module = get_lora_layer(module, self.lora_backend) replace_submodule(self.base_model, module_name, lora_module) diff --git a/python/sglang/srt/lora/lora_overlap_loader.py b/python/sglang/srt/lora/lora_overlap_loader.py new file mode 100644 index 000000000000..bc7b3dd71d2e --- /dev/null +++ b/python/sglang/srt/lora/lora_overlap_loader.py @@ -0,0 +1,82 @@ +import logging +from enum import Enum, auto +from typing import Dict, Optional + +import torch +from torch.cuda import Event as CudaEvent +from torch.cuda import Stream as CudaStream +from torch.cuda import StreamContext as CudaStreamContext + +from sglang.srt.lora.lora_manager import LoRAManager + +logger = logging.getLogger(__name__) + + +class LoRAOverlapLoadStatus(Enum): + LOADED = auto() + LOADING = auto() + NOT_LOADED = auto() + + +class LoRAOverlapLoader: + def __init__(self, lora_manager): + self.lora_manager: LoRAManager = lora_manager + self.device_module = torch.get_device_module(self.lora_manager.device) + self.load_stream: CudaStream = self.device_module.Stream() + self.load_stream_context: CudaStreamContext = self.device_module.stream( + self.load_stream + ) + self.lora_to_overlap_load_event: Dict[Optional[str], CudaEvent] = {} + + def try_overlap_load_lora( + self, lora_id: Optional[str], running_loras: set[Optional[str]] + ) -> bool: + """ + Check a LoRA adapter's asynchronous load status, and try to load it if there's capacity + in the memory pool. Returns whether or not the adapter has been loaded. + """ + lora_pipeline_load_status = self._check_overlap_load_status(lora_id) + if lora_pipeline_load_status == LoRAOverlapLoadStatus.LOADING: + return False + elif lora_pipeline_load_status == LoRAOverlapLoadStatus.NOT_LOADED: + res = self._try_start_overlap_load(lora_id, running_loras) + if res: + logger.debug(f"Loading LoRA adapter {lora_id} asynchronously") + + return False + else: + assert lora_pipeline_load_status == LoRAOverlapLoadStatus.LOADED + return True + + def _check_overlap_load_status( + self, lora_id: Optional[str] + ) -> LoRAOverlapLoadStatus: + if lora_id not in self.lora_to_overlap_load_event: + return LoRAOverlapLoadStatus.NOT_LOADED + + event = self.lora_to_overlap_load_event[lora_id] + + if not event.query(): + return LoRAOverlapLoadStatus.LOADING + + torch.cuda.current_stream().wait_event(event) + del self.lora_to_overlap_load_event[lora_id] + + return LoRAOverlapLoadStatus.LOADED + + def _try_start_overlap_load( + self, lora_id: Optional[str], running_loras: set[Optional[str]] + ) -> bool: + loras_to_be_loaded = running_loras | self.lora_to_overlap_load_event.keys() + + new_lora_set = {lora_id} | loras_to_be_loaded + if not self.lora_manager.validate_lora_batch(new_lora_set): + return False + + with self.load_stream_context: + self.lora_manager.fetch_new_loras({lora_id}, loras_to_be_loaded) + event = self.device_module.Event() + event.record(self.load_stream) + + self.lora_to_overlap_load_event[lora_id] = event + return True diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index f57ea39ac430..27c7a664adc2 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -391,7 +391,7 @@ def load_lora_weight_tensor( assert ( buffer_view.shape == weight.shape ), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}." - buffer_view.copy_(weight) + buffer_view.copy_(weight, non_blocking=True) if uid is None: for i in range(self.num_layer): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1bf294973df1..9a6e412ca6ba 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -66,6 +66,7 @@ ) from sglang.srt.layers.moe import initialize_moe_config from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config +from sglang.srt.lora.lora_overlap_loader import LoRAOverlapLoader from sglang.srt.managers.io_struct import ( AbortReq, BaseBatchReq, @@ -284,6 +285,7 @@ def __init__( server_args.priority_scheduling_preemption_threshold ) self.enable_lora = server_args.enable_lora + self.enable_lora_overlap_loading = server_args.enable_lora_overlap_loading self.max_loras_per_batch = server_args.max_loras_per_batch self.enable_overlap = not server_args.disable_overlap_schedule self.enable_pdmux = server_args.enable_pdmux @@ -371,6 +373,12 @@ def __init__( # Init request dispatcher self.init_request_dispatcher() + # Init LoRA overlap loader + if self.enable_lora_overlap_loading: + self.lora_overlap_loader = LoRAOverlapLoader( + self.tp_worker.model_runner.lora_manager + ) + # Init the grammar backend for constrained generation self.grammar_manager = GrammarManager(self) @@ -1905,23 +1913,25 @@ def _get_new_batch_prefill_raw( self.chunked_req = adder.add_chunked_req(self.chunked_req) if self.enable_lora: - lora_set = set([req.lora_id for req in self.running_batch.reqs]) + running_loras = {req.lora_id for req in self.running_batch.reqs} # Get requests from the waiting queue to a new prefill batch for req in self.waiting_queue: - - if self.enable_lora: - new_lora_set = ( - lora_set - | set([req.lora_id for req in adder.can_run_list]) - | set([req.lora_id]) - ) - if not self.tp_worker.can_run_lora_batch(new_lora_set): - # Batch would exceed the LoRA slot limit. - # Skip this request and try scheduling it in a future iteration. - # Note: When eviction is needed, the eviction policy prefers to - # evict LoRA adapters over base model (None) - see mem_pool.py. - continue + if self.enable_lora and req.lora_id not in running_loras: + if self.enable_lora_overlap_loading: + # For overlapping loading of LoRA weights with computation, we will load each adapter one at a time, + # as opposed to loading them in one batch + res = self.lora_overlap_loader.try_overlap_load_lora( + req.lora_id, running_loras + ) + if not res: + continue + else: + new_lora_set = {req.lora_id} | running_loras + if not self.tp_worker.model_runner.lora_manager.validate_lora_batch( + new_lora_set + ): + continue running_bs = len(self.running_batch.reqs) if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs): @@ -1951,6 +1961,9 @@ def _get_new_batch_prefill_raw( truncation_align_size=self.truncation_align_size, ) + if self.enable_lora: + running_loras.add(req.lora_id) + if res != AddReqResult.CONTINUE: if res == AddReqResult.NO_TOKEN: if self.enable_hierarchical_cache: diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 4402d0928ddd..63050e1cde9d 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -543,6 +543,11 @@ def init_new( # Init lora information if model_runner.server_args.enable_lora: + # In the non-LoRA overlap loading case, we fetch LoRA adapters into the memory pool + # as a batch, right before running the batch + if not model_runner.server_args.enable_lora_overlap_loading: + model_runner.lora_manager.fetch_new_loras(set(ret.lora_ids)) + model_runner.lora_manager.prepare_lora_batch(ret) return ret diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6802f201181b..0b5aa54c7f61 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -396,6 +396,7 @@ class ServerArgs: # LoRA enable_lora: Optional[bool] = None + enable_lora_overlap_loading: Optional[bool] = None max_lora_rank: Optional[int] = None lora_target_modules: Optional[Union[set[str], List[str]]] = None lora_paths: Optional[ @@ -3371,6 +3372,12 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.", ) + parser.add_argument( + "--enable-lora-overlap-loading", + default=ServerArgs.enable_lora_overlap_loading, + action="store_true", + help="Enable asynchronous LoRA weight loading in order to overlap H2D transfers with GPU compute. This should be enabled if you find that your LoRA workloads are bottlenecked by adapter weight loading, for example when frequently loading large LoRA adapters.", + ) parser.add_argument( "--max-lora-rank", default=ServerArgs.max_lora_rank, @@ -4900,6 +4907,20 @@ def check_lora_server_args(self): ) if self.enable_lora: + if self.enable_lora_overlap_loading is None: + self.enable_lora_overlap_loading = False + + if self.enable_lora_overlap_loading: + # TODO (glenliu21): use some sort of buffer with eviction instead of enforcing a limit + max_loaded_loras_limit = self.max_loras_per_batch * 2 + assert ( + self.max_loaded_loras is not None + and self.max_loaded_loras <= max_loaded_loras_limit + ), ( + "Enabling LoRA overlap loading requires pinning LoRA adapter weights in CPU memory, " + f"so --max-loaded-loras must be less than or equal to double --max-loras-per-batch: {max_loaded_loras_limit}" + ) + # Validate compatibility with speculative decoding if self.speculative_algorithm not in ["NGRAM", None]: raise ValueError( diff --git a/python/sglang/test/lora_utils.py b/python/sglang/test/lora_utils.py index 171dc9fc1800..1f8d64f73ab6 100644 --- a/python/sglang/test/lora_utils.py +++ b/python/sglang/test/lora_utils.py @@ -96,6 +96,7 @@ def __post_init__(self): ), ], max_loras_per_batch=2, + max_loaded_loras=4, ), ] @@ -285,6 +286,7 @@ def run_lora_test_one_by_one( torch_dtype: torch.dtype, max_new_tokens: int, backend: str = "csgmv", + enable_lora_overlap_loading: Optional[bool] = None, disable_cuda_graph: bool = False, disable_radix_cache: bool = False, mem_fraction_static: float = 0.88, @@ -331,6 +333,7 @@ def run_lora_test_one_by_one( lora_paths=[ adaptor.name for adaptor in model_case.adaptors if adaptor.name is not None ], + enable_lora_overlap_loading=enable_lora_overlap_loading, max_loras_per_batch=model_case.max_loras_per_batch, max_loaded_loras=model_case.max_loaded_loras, lora_backend=backend, diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index d1eea17deb55..ebc9912da41a 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -552,6 +552,7 @@ def __init__( max_lora_rank: Optional[int] = None, lora_target_modules: Optional[List[str]] = None, enable_lora: Optional[bool] = None, + enable_lora_overlap_loading: Optional[bool] = None, max_loaded_loras: Optional[int] = None, json_model_override_args: Optional[dict[str, Any]] = None, lora_eviction_policy: str = "lru", @@ -612,6 +613,7 @@ def __init__( max_lora_rank=max_lora_rank, lora_target_modules=lora_target_modules, enable_lora=enable_lora, + enable_lora_overlap_loading=enable_lora_overlap_loading, max_loaded_loras=max_loaded_loras, json_model_override_args=( json.dumps(json_model_override_args) diff --git a/test/registered/lora/test_lora_overlap_loading.py b/test/registered/lora/test_lora_overlap_loading.py new file mode 100644 index 000000000000..119e789668ec --- /dev/null +++ b/test/registered/lora/test_lora_overlap_loading.py @@ -0,0 +1,116 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +End-to-end tests for the --enable-lora-overlap-loading server argument. +""" + +import multiprocessing as mp +import unittest + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.lora_utils import ( + CI_MULTI_LORA_MODELS, + TEST_MULTIPLE_BATCH_PROMPTS, + TORCH_DTYPES, + LoRAModelCase, + ensure_reproducibility, +) +from sglang.test.runners import SRTRunner +from sglang.test.test_utils import CustomTestCase, calculate_rouge_l + +register_cuda_ci(est_time=300, suite="stage-b-test-small-1-gpu") + + +class TestLoRAPipelineLoading(CustomTestCase): + + def _run_mixed_batch_test( + self, + model_case: LoRAModelCase, + torch_dtype, + ): + base_path = model_case.base + adaptor_paths = [a.name for a in model_case.adaptors] + print( + f"\n========== Testing mixed batch LoRA overlap loading on base '{base_path}' " + f"with dtype={torch_dtype} ==========\n" + ) + ensure_reproducibility() + max_new_tokens = 32 + + prompts = TEST_MULTIPLE_BATCH_PROMPTS[:3] + configs = [ + [None, adaptor_paths[0], adaptor_paths[1]], + [adaptor_paths[0], None, adaptor_paths[1]], + [adaptor_paths[0], adaptor_paths[1], None], + [adaptor_paths[1], adaptor_paths[0], adaptor_paths[1]], + ] + common_args = dict( + torch_dtype=torch_dtype, + model_type="generation", + tp_size=model_case.tp_size, + lora_paths=adaptor_paths, + max_loras_per_batch=model_case.max_loras_per_batch, + max_loaded_loras=model_case.max_loaded_loras, + disable_cuda_graph=True, + disable_radix_cache=True, + mem_fraction_static=0.65, + sleep_on_idle=True, + ) + + results_no_overlap_loading = [] + with SRTRunner( + base_path, enable_lora_overlap_loading=False, **common_args + ) as runner: + for lora_paths in configs: + results_no_overlap_loading.append( + runner.batch_forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=lora_paths + ).output_strs + ) + + results_overlap_loading = [] + with SRTRunner( + base_path, enable_lora_overlap_loading=True, **common_args + ) as runner: + for lora_paths in configs: + results_overlap_loading.append( + runner.batch_forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=lora_paths + ).output_strs + ) + + for i, (res_no_overlap_loading, res_overlap_loading) in enumerate( + zip(results_no_overlap_loading, results_overlap_loading) + ): + scores = calculate_rouge_l(res_overlap_loading, res_no_overlap_loading) + for j, score in enumerate(scores): + assert score >= model_case.rouge_l_tolerance, ( + f"Batch {i} prompt {j} mismatch: {score}\n" + f"Overlap loading: {res_overlap_loading[j]}\n" + f"No overlap loading: {res_no_overlap_loading[j]}" + ) + + def test_mixed_batch(self): + for model_case in CI_MULTI_LORA_MODELS: + for dtype in TORCH_DTYPES: + self._run_mixed_batch_test(model_case, dtype) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main(warnings="ignore") diff --git a/test/registered/lora/test_lora_tp.py b/test/registered/lora/test_lora_tp.py index a23b8300a914..40be1ee19a56 100644 --- a/test/registered/lora/test_lora_tp.py +++ b/test/registered/lora/test_lora_tp.py @@ -15,12 +15,13 @@ import multiprocessing as mp import os import unittest -from typing import List +from typing import List, Optional from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci from sglang.test.lora_utils import ( ALL_OTHER_LORA_MODELS, CI_LORA_MODELS, + CI_MULTI_LORA_MODELS, DEFAULT_PROMPTS, TORCH_DTYPES, LoRAModelCase, @@ -38,7 +39,11 @@ class TestLoRATP(CustomTestCase): - def _run_tp_on_model_cases(self, model_cases: List[LoRAModelCase]): + def _run_tp_on_model_cases( + self, + model_cases: List[LoRAModelCase], + enable_lora_overlap_loading: Optional[bool] = None, + ): tp_list = [2] # Define TP sizes to iterate over for model_case in model_cases: # If skip_long_prompt is True, filter out prompts longer than 1000 characters @@ -55,12 +60,18 @@ def _run_tp_on_model_cases(self, model_cases: List[LoRAModelCase]): model_case, torch_dtype, max_new_tokens=32, - test_tag=f"tp={tp_size}", + enable_lora_overlap_loading=enable_lora_overlap_loading, + test_tag=f"tp={tp_size}, enable_lora_overlap_loading={enable_lora_overlap_loading}", ) def test_ci_lora_models(self): self._run_tp_on_model_cases(CI_LORA_MODELS) + def test_lora_overlap_loading_ci_lora_models(self): + self._run_tp_on_model_cases( + CI_MULTI_LORA_MODELS, enable_lora_overlap_loading=True + ) + def test_all_lora_models(self): if is_in_ci(): return