diff --git a/python/sglang/srt/entrypoints/EngineBase.py b/python/sglang/srt/entrypoints/EngineBase.py index 9ac68faa7a2..42ecb12aa8d 100644 --- a/python/sglang/srt/entrypoints/EngineBase.py +++ b/python/sglang/srt/entrypoints/EngineBase.py @@ -48,6 +48,14 @@ def update_weights_from_tensor( """Update model weights with in-memory tensor data.""" pass + def load_lora_adapter(self, lora_name: str, lora_path: str): + """Load a new LoRA adapter without re-launching the engine.""" + pass + + def unload_lora_adapter(self, lora_name: str): + """Unload a LoRA adapter without re-launching the engine.""" + pass + @abstractmethod def release_memory_occupation(self): """Release GPU memory occupation temporarily.""" diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 0f75b238050..db788b11b5c 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -48,10 +48,12 @@ GetWeightsByNameReqInput, ImageDataItem, InitWeightsUpdateGroupReqInput, + LoadLoRAAdapterReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, RpcReqInput, RpcReqOutput, + UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, @@ -478,6 +480,29 @@ def get_weights_by_name(self, name: str, truncate_size: int = 100): self.tokenizer_manager.get_weights_by_name(obj, None) ) + def load_lora_adapter(self, lora_name: str, lora_path: str): + """Load a new LoRA adapter without re-launching the engine.""" + + obj = LoadLoRAAdapterReqInput( + lora_name=lora_name, + lora_path=lora_path, + ) + + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.load_lora_adapter(obj, None) + ) + + def unload_lora_adapter(self, lora_name: str): + """Unload a LoRA adapter without re-launching the engine.""" + + obj = UnloadLoRAAdapterReqInput(lora_name=lora_name) + + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.unload_lora_adapter(obj, None) + ) + def release_memory_occupation(self, tags: Optional[List[str]] = None): obj = ReleaseMemoryOccupationReqInput(tags=tags) loop = asyncio.get_event_loop() diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index daa8999b76e..812dc63cbd9 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -72,6 +72,7 @@ GenerateReqInput, GetWeightsByNameReqInput, InitWeightsUpdateGroupReqInput, + LoadLoRAAdapterReqInput, OpenSessionReqInput, ParseFunctionCallReq, ProfileReqInput, @@ -80,6 +81,7 @@ SeparateReasoningReqInput, SetInternalStateReq, SlowDownReqInput, + UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, @@ -595,6 +597,40 @@ async def slow_down(obj: SlowDownReqInput, request: Request): return _create_error_response(e) +@app.api_route("/load_lora_adapter", methods=["POST"]) +async def load_lora_adapter(obj: LoadLoRAAdapterReqInput, request: Request): + """Load a new LoRA adapter without re-launching the server.""" + result = await _global_state.tokenizer_manager.load_lora_adapter(obj, request) + + if result.success: + return ORJSONResponse( + result, + status_code=HTTPStatus.OK, + ) + else: + return ORJSONResponse( + result, + status_code=HTTPStatus.BAD_REQUEST, + ) + + +@app.api_route("/unload_lora_adapter", methods=["POST"]) +async def unload_lora_adapter(obj: UnloadLoRAAdapterReqInput, request: Request): + """Load a new LoRA adapter without re-launching the server.""" + result = await _global_state.tokenizer_manager.unload_lora_adapter(obj, request) + + if result.success: + return ORJSONResponse( + result, + status_code=HTTPStatus.OK, + ) + else: + return ORJSONResponse( + result, + status_code=HTTPStatus.BAD_REQUEST, + ) + + @app.api_route("/open_session", methods=["GET", "POST"]) async def open_session(obj: OpenSessionReqInput, request: Request): """Open a session, and return its unique session id.""" diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index c1ebe2dcdec..2a3d2acfdff 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -65,7 +65,7 @@ def __init__( self.layers: List[LoRALayer] = nn.ModuleList( [ LoRALayer(config, base_hf_config) - for i in range(base_hf_config.num_hidden_layers) + for _ in range(base_hf_config.num_hidden_layers) ] ) @@ -88,10 +88,9 @@ def initialize_weights(self): else: self.weights[name] = loaded_weight.cpu() - # stack kv_proj and gate_up_proj - for i in range(self.base_hf_config.num_hidden_layers): - layer = self.layers[i] - weight_names = [name for name, _ in layer.weights.items()] + # normalize kv_proj and gate_up_proj + for layer in self.layers: + weight_names = list(layer.weights.keys()) self.normalize_qkv_proj(weight_names, layer.weights) self.normalize_gate_up_proj(weight_names, layer.weights) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index afba645a9d7..ca0b62c5575 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -35,6 +35,7 @@ get_normalized_lora_weight_names, get_weight_name, ) +from sglang.srt.managers.io_struct import LoRAUpdateResult from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import replace_submodule @@ -98,44 +99,96 @@ def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int): ], ) - def load_lora_adapters(self, lora_paths: Dict[str, str]): + def create_lora_update_result( + self, success: bool, error_message: str = "" + ) -> LoRAUpdateResult: + return LoRAUpdateResult( + success=success, + error_message=error_message, + loaded_adapters={ + name: config.path for name, config in self.configs.items() + }, + ) + + def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult: """ Load LoRA adapters from the specified paths. - TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading. Args: lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths. If a LoRA adapter is already loaded, it will be skipped with a warning. """ + results = [] for lora_name, lora_path in lora_paths.items(): - if lora_name in self.loras: - logger.warning( - f"LoRA adapter {lora_name} is already loaded." - "If you want to reload it, please unload it first." - ) - continue + result = self.load_lora_adapter(lora_name, lora_path, update_state=False) + results.append(result) + + self.update_state_from_configs() + + return self.create_lora_update_result( + success=all(result.success for result in results), + error_message="\n".join( + result.error_message for result in results if not result.success + ), + ) + + def load_lora_adapter( + self, lora_name: str, lora_path: str, update_state: bool = True + ) -> LoRAUpdateResult: + """ + Load a single LoRA adapter from the specified path. + + Args: + lora_name (str): The name of the LoRA adapter. + lora_path (str): The file path to the LoRA adapter. + update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading. + """ + success = True + error_message = "" + + if lora_name in self.loras: + success = False + error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first." + + try: self.configs[lora_name] = LoRAConfig(lora_path) + except Exception as e: + success = False + error_message = ( + f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}" + ) - self.update_state_from_configs() + if update_state: + self.update_state_from_configs() + + return self.create_lora_update_result( + success=success, + error_message=error_message, + ) - def unload_lora_adapters(self, lora_names: Set[str]): + def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult: """ Unload LoRA adapters by their names. This will remove the adapters from the memory pool and delete the corresponding LoRA modules. - - Args: - lora_names (Set[str]): A set of LoRA adapter names to unload. """ - for lora_name in lora_names: - if lora_name in self.loras: - del self.configs[lora_name] - else: - logger.warning(f"LoRA adapter {lora_name} is not loaded.") + + success = True + error_message = "" + if lora_name in self.loras: + del self.configs[lora_name] + else: + error_message = f"LoRA adapter {lora_name} is not loaded." + success = False self.update_state_from_configs() + return self.create_lora_update_result( + success=success, + error_message=error_message, + ) + def prepare_lora_batch(self, forward_batch: ForwardBatch): # load active loras into lora memory pool cur_uids = set(forward_batch.lora_paths) @@ -372,8 +425,8 @@ def update_lora_adapters(self): lora_adapter.initialize_weights() self.loras[name] = lora_adapter - # Clean up unused LoRA adapters - for name in self.loras: + # Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration. + for name in list(self.loras): if name not in self.configs: logger.info(f"Unloading LoRA adapter {name}") del self.loras[name] diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index cd11967e862..aebd820ab16 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -20,7 +20,7 @@ import uuid from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union from sglang.srt.multimodal.mm_utils import has_valid_data @@ -1002,3 +1002,27 @@ class RpcReqInput: class RpcReqOutput: success: bool message: str + + +@dataclass +class LoadLoRAAdapterReqInput: + # The name of the lora module to newly loaded. + lora_name: str + # The path of loading. + lora_path: str + + +@dataclass +class UnloadLoRAAdapterReqInput: + # The name of lora module to unload. + lora_name: str + + +@dataclass +class LoRAUpdateResult: + success: bool + error_message: Optional[str] = None + loaded_adapters: Dict[str, str] = field(default_factory=dict) + + +LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 692d4673d6b..b8364632f12 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -82,6 +82,8 @@ HealthCheckOutput, InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqOutput, + LoadLoRAAdapterReqInput, + LoadLoRAAdapterReqOutput, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, @@ -99,6 +101,8 @@ SlowDownReqOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + UnloadLoRAAdapterReqInput, + UnloadLoRAAdapterReqOutput, UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqOutput, UpdateWeightsFromDistributedReqInput, @@ -519,6 +523,8 @@ def __init__( (SetInternalStateReq, self.set_internal_state), (RpcReqInput, self.handle_rpc_request), (ExpertDistributionReq, self.expert_distribution_handle), + (LoadLoRAAdapterReqInput, self.load_lora_adapter), + (UnloadLoRAAdapterReqInput, self.unload_lora_adapter), ] ) @@ -2241,6 +2247,36 @@ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): logger.error(message) return UpdateWeightFromDiskReqOutput(success, message, 0) + def load_lora_adapter( + self, recv_req: LoadLoRAAdapterReqInput + ) -> LoadLoRAAdapterReqOutput: + """In-place loading a new lora adapter from disk or huggingface.""" + + result = self.tp_worker.load_lora_adapter(recv_req) + + if result.success: + flush_cache_success = self.flush_cache() + assert flush_cache_success, "Cache flush failed after loading lora adapter." + else: + logger.error(result.error_message) + return result + + def unload_lora_adapter( + self, recv_req: UnloadLoRAAdapterReqInput + ) -> UnloadLoRAAdapterReqOutput: + """Unload the lora adapter.""" + + result = self.tp_worker.unload_lora_adapter(recv_req) + + if result.success: + flush_cache_success = self.flush_cache() + assert ( + flush_cache_success + ), "Cache flush failed after unloading LoRA weights" + else: + logger.error(result.error_message) + return result + def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): """Initialize the online model parameter update group.""" success, message = self.tp_worker.init_weights_update_group(recv_req) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index ceb2fa52715..c4ec8646bdf 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -83,6 +83,9 @@ HealthCheckOutput, InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqOutput, + LoadLoRAAdapterReqInput, + LoadLoRAAdapterReqOutput, + LoRAUpdateResult, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, @@ -99,6 +102,8 @@ SlowDownReqOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + UnloadLoRAAdapterReqInput, + UnloadLoRAAdapterReqOutput, UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqOutput, UpdateWeightsFromDistributedReqInput, @@ -311,6 +316,9 @@ def __init__( self.expert_distribution_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.update_lora_adapter_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self._result_dispatcher = TypeBasedDispatcher( [ @@ -377,6 +385,10 @@ def __init__( ExpertDistributionReqOutput, self.expert_distribution_communicator.handle_recv, ), + ( + LoRAUpdateResult, + self.update_lora_adapter_communicator.handle_recv, + ), (HealthCheckOutput, lambda x: None), ] ) @@ -960,6 +972,49 @@ async def update_weights_from_tensor( result = (await self.update_weights_from_tensor_communicator(obj))[0] return result.success, result.message + async def load_lora_adapter( + self, + obj: LoadLoRAAdapterReqInput, + _: Optional[fastapi.Request] = None, + ) -> LoadLoRAAdapterReqOutput: + self.auto_create_handle_loop() + + # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works + # with dp_size > 1. + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for dynamic lora loading" + logger.info( + "Start load Lora adapter. Lora name=%s, path=%s", + obj.lora_name, + obj.lora_path, + ) + + async with self.model_update_lock.writer_lock: + result = (await self.update_lora_adapter_communicator(obj))[0] + return result + + async def unload_lora_adapter( + self, + obj: UnloadLoRAAdapterReqInput, + _: Optional[fastapi.Request] = None, + ) -> UnloadLoRAAdapterReqOutput: + self.auto_create_handle_loop() + + # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works + # with dp_size > 1. + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for dynamic lora loading" + logger.info( + "Start unload Lora adapter. Lora name=%s", + obj.lora_name, + ) + + async with self.model_update_lock.writer_lock: + result = (await self.update_lora_adapter_communicator(obj))[0] + return result + async def get_weights_by_name( self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None ): diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 73a12e2850d..afd9541aa4e 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -30,6 +30,8 @@ from sglang.srt.managers.io_struct import ( GetWeightsByNameReqInput, InitWeightsUpdateGroupReqInput, + LoadLoRAAdapterReqInput, + UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, @@ -275,3 +277,13 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): recv_req.name, recv_req.truncate_size ) return parameter + + def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput): + result = self.model_runner.load_lora_adapter( + recv_req.lora_name, recv_req.lora_path + ) + return result + + def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput): + result = self.model_runner.unload_lora_adapter(recv_req.lora_name) + return result diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 45f220db62a..3bd69997690 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -26,6 +26,8 @@ from sglang.srt.managers.io_struct import ( GetWeightsByNameReqInput, InitWeightsUpdateGroupReqInput, + LoadLoRAAdapterReqInput, + UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, @@ -268,6 +270,12 @@ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): return self.worker.get_weights_by_name(recv_req) + def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput): + return self.worker.load_lora_adapter(recv_req) + + def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput): + return self.worker.unload_lora_adapter(recv_req) + def __delete__(self): self.input_queue.put((None, None)) self.copy_queue.put((None, None, None)) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9ac26810e96..277cab8dfc0 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -26,7 +26,6 @@ import torch import torch.distributed as dist -from sglang.srt import debug_utils from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig @@ -819,8 +818,47 @@ def init_lora_manager(self): tp_size=self.tp_size, tp_rank=self.tp_rank, ) - self.lora_manager.load_lora_adapters(self.server_args.lora_paths) - logger.info("LoRA manager ready.") + result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths) + if result.success: + logger.info( + f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}" + ) + else: + raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}") + + def load_lora_adapter(self, lora_name: str, lora_path: str): + """Load a new lora adapter from disk or huggingface.""" + + logger.info( + f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. " + f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + ) + + result = self.lora_manager.load_lora_adapter(lora_name, lora_path) + + logger.info( + f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. " + f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + ) + + return result + + def unload_lora_adapter(self, lora_name: str): + """Unload a lora adapter that was previously loaded during initialization or dynamic loading.""" + + logger.info( + f"LoRA adapter unloading starts: name={lora_name}. " + f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + ) + + result = self.lora_manager.unload_lora_adapter(lora_name) + + logger.info( + f"LoRA adapter unloading completes: name={lora_name}. " + f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + ) + + return result def profile_max_num_token(self, total_gpu_memory: int): available_gpu_memory = get_available_gpu_memory( diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index c2d20b99458..481bf682d1c 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -503,6 +503,7 @@ def __init__( disable_overlap_schedule: bool = False, disable_custom_all_reduce: bool = False, torchao_config: Optional[str] = None, + cuda_graph_max_bs: int = 4, sleep_on_idle=False, ): self.model_type = model_type @@ -539,7 +540,7 @@ def __init__( tokenizer_path=tokenizer_path, enable_ep_moe=enable_ep_moe, disable_overlap_schedule=disable_overlap_schedule, - cuda_graph_max_bs=4, + cuda_graph_max_bs=cuda_graph_max_bs, disable_custom_all_reduce=disable_custom_all_reduce, sleep_on_idle=sleep_on_idle, **spec_kwargs, @@ -552,6 +553,12 @@ def __init__( else: self.tokenizer = None + def load_lora_adapter(self, lora_name: str, lora_path: str): + return self.engine.load_lora_adapter(lora_name, lora_path) + + def unload_lora_adapter(self, lora_name: str): + return self.engine.unload_lora_adapter(lora_name) + def forward( self, prompts: Union[ diff --git a/test/srt/models/lora/test_lora_update.py b/test/srt/models/lora/test_lora_update.py new file mode 100644 index 00000000000..587789cf1b3 --- /dev/null +++ b/test/srt/models/lora/test_lora_update.py @@ -0,0 +1,616 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import multiprocessing as mp +import unittest +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional, Union + +import requests +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.runners import SRTRunner +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +PROMPTS = [ + "SGL is a", + "AI is a field of computer science focused on", + "Computer science is the study of", + "Write a short story.", + "What are the main components of a computer?", +] + + +class OperationType(Enum): + LOAD = "load" + UNLOAD = "unload" + NOOP = "noop" + FORWARD = "forward" + + +@dataclass +class Operation: + type: OperationType + data: Optional[str] + + +@dataclass +class TestCase: + base: str + max_loras_per_batch: int + all_adapters: List[str] + initial_adapters: List[str] + op_sequence: List[Operation] + max_new_tokens: int = 32 + + +def create_batch_data(adapters: Union[str, list]) -> dict: + if not isinstance(adapters, list): + adapters = [adapters] + return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters] + + +TEST_CASES = [ + # basic test, no eviction + TestCase( + base="meta-llama/Llama-3.1-8B-Instruct", + max_loras_per_batch=3, + all_adapters=[ + "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + ], + initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"], + op_sequence=[ + Operation( + type=OperationType.LOAD, + data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + ), + Operation( + type=OperationType.LOAD, + data="pbevan11/llama-3.1-8b-ocr-correction", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + ] + ), + ), + Operation( + type=OperationType.UNLOAD, + data="philschmid/code-llama-3-1-8b-text-to-sql-lora", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + ] + ), + ), + Operation( + type=OperationType.UNLOAD, + data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), + ), + Operation( + type=OperationType.LOAD, + data="philschmid/code-llama-3-1-8b-text-to-sql-lora", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "pbevan11/llama-3.1-8b-ocr-correction", + ] + ), + ), + ], + ), + # Eviction + TestCase( + base="meta-llama/Llama-3.1-8B-Instruct", + max_loras_per_batch=1, + all_adapters=[ + "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + ], + initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"], + op_sequence=[ + Operation( + type=OperationType.FORWARD, + data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), + ), + Operation( + type=OperationType.LOAD, + data="pbevan11/llama-3.1-8b-ocr-correction", + ), + Operation( + type=OperationType.UNLOAD, + data="philschmid/code-llama-3-1-8b-text-to-sql-lora", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), + ), + Operation( + type=OperationType.LOAD, + data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + ), + Operation( + type=OperationType.LOAD, + data="philschmid/code-llama-3-1-8b-text-to-sql-lora", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" + ), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" + ), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), + ), + ], + ), +] + + +class LoRAUpdateTestSessionMode(Enum): + ENGINE = "engine" + SERVER = "server" + + +class LoRAUpdateTestSessionBase: + """ + Base context manager for testing LoRA adapters. + """ + + def __init__( + self, + *, + testcase: Optional[TestCase], + model_path: str, + lora_paths: list[str], + max_loras_per_batch: int = 1, + lora_backend: str = "triton", + disable_cuda_graph: bool = False, + cuda_graph_max_bs: int = 4, + ): + self.testcase = testcase + self.model_path = model_path + self.lora_paths = lora_paths + self.max_loras_per_batch = max_loras_per_batch + self.lora_backend = lora_backend + self.disable_cuda_graph = disable_cuda_graph + self.cuda_graph_max_bs = cuda_graph_max_bs + + self.expected_adapters = set(lora_paths) + self.handle = None # Will be set in __enter__ + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Don't suppress exceptions by default + return False + + def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None): + """ + Load a LoRA adapter by name and path. + """ + raise NotImplementedError("Subclasses must implement load_lora_adapter") + + def unload_lora_adapter(self, lora_name: str): + """ + Unload a LoRA adapter by name. + """ + raise NotImplementedError("Subclasses must implement unload_lora_adapter") + + def forward( + self, + prompts: List[str], + lora_paths: List[str], + max_new_tokens: int = 32, + ): + """ + Perform a batch forward pass with the current set of loaded LoRA adapters. + """ + raise NotImplementedError("Subclasses must implement forward") + + +class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): + """ + Context manager for testing LoRA adapters with in-process engine. + """ + + def __enter__(self): + # in-process runner + self.handle = SRTRunner( + model_path=self.model_path, + model_type="generation", + lora_paths=self.lora_paths, + lora_backend=self.lora_backend, + torch_dtype=torch.float16, + max_loras_per_batch=self.max_loras_per_batch, + disable_cuda_graph=self.disable_cuda_graph, + cuda_graph_max_bs=self.cuda_graph_max_bs, + disable_radix_cache=True, + ) + self.handle.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.handle is not None: + # delegate cleanup to SRTRunner + return self.handle.__exit__(exc_type, exc_val, exc_tb) + # don't suppress exceptions + return False + + def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None): + """ + Load a LoRA adapter by name and path. + """ + if lora_path is None: + lora_path = lora_name + + self.expected_adapters.add(lora_name) + + response = self.handle.load_lora_adapter( + lora_name=lora_name, + lora_path=lora_path, + ) + self.testcase.assertTrue(response.success) + loaded_adapters = set(response.loaded_adapters) + + print(f"loaded_adapters: {loaded_adapters}") + self.testcase.assertEqual(loaded_adapters, self.expected_adapters) + + def unload_lora_adapter(self, lora_name: str): + """ + Unload a LoRA adapter by name. + """ + self.expected_adapters.remove(lora_name) + + response = self.handle.unload_lora_adapter( + lora_name=lora_name, + ) + self.testcase.assertTrue(response.success) + loaded_adapters = set(response.loaded_adapters) + + print(f"loaded_adapters: {loaded_adapters}") + self.testcase.assertEqual(loaded_adapters, self.expected_adapters) + + def forward( + self, + prompts: List[str], + lora_paths: List[str], + max_new_tokens: int = 32, + ): + """ + Perform a batch forward pass with the current set of loaded LoRA adapters. + """ + response = self.handle.batch_forward( + prompts=prompts, + lora_paths=lora_paths, + max_new_tokens=max_new_tokens, + ) + output_strs = response.output_strs + + print(f"output_strs: {output_strs}") + return output_strs + + +class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): + """ + Context manager for testing LoRA adapters with standalone server. + """ + + def __enter__(self): + other_args = [ + "--cuda-graph-max-bs", + str(self.cuda_graph_max_bs), + "--lora-paths", + *self.lora_paths, + "--max-loras-per-batch", + str(self.max_loras_per_batch), + "--lora-backend", + self.lora_backend, + "--disable-radix-cache", + "--random-seed", + "42", + "--max-running-request", + "1", + ] + if self.disable_cuda_graph: + other_args.append("--disable-cuda-graph") + + # launch external server + self.handle = popen_launch_server( + self.model_path, + DEFAULT_URL_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.handle is not None: + kill_process_tree(self.handle.pid) + # don't suppress exceptions + return False + + def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None): + """ + Load a LoRA adapter by name and path. + """ + if lora_path is None: + lora_path = lora_name + + self.expected_adapters.add(lora_name) + + response = requests.post( + DEFAULT_URL_FOR_TEST + "/load_lora_adapter", + json={"lora_name": lora_name, "lora_path": lora_path}, + ) + self.testcase.assertTrue(response.ok) + loaded_adapters = set(response.json()["loaded_adapters"]) + + print(f"loaded_adapters: {loaded_adapters}") + self.testcase.assertEqual(loaded_adapters, self.expected_adapters) + + def unload_lora_adapter(self, lora_name: str): + """ + Unload a LoRA adapter by name. + """ + self.expected_adapters.remove(lora_name) + + response = requests.post( + DEFAULT_URL_FOR_TEST + "/unload_lora_adapter", + json={"lora_name": lora_name}, + ) + self.testcase.assertTrue(response.ok) + loaded_adapters = set(response.json()["loaded_adapters"]) + + print(f"loaded_adapters: {loaded_adapters}") + self.testcase.assertEqual(loaded_adapters, self.expected_adapters) + + def forward( + self, + prompts: List[str], + lora_paths: List[str], + max_new_tokens: int = 32, + ): + """ + Perform a batch forward pass with the current set of loaded LoRA adapters. + """ + response = requests.post( + DEFAULT_URL_FOR_TEST + "/generate", + json={ + "text": prompts, + "lora_path": lora_paths, + "sampling_params": { + "temperature": 0, + "top_k": 1, + "max_new_tokens": max_new_tokens, + }, + }, + ) + self.testcase.assertTrue(response.ok) + output_strs = [r["text"] for r in response.json()] + + print(f"output_strs: {output_strs}") + return output_strs + + +# Factory function to create the appropriate LoRA test session based on mode +def LoRAUpdateTestSession( + *, + testcase: Optional[TestCase], + mode: LoRAUpdateTestSessionMode, + model_path: str, + lora_paths: list[str], + max_loras_per_batch: int = 1, + lora_backend: str = "triton", + disable_cuda_graph: bool = False, + cuda_graph_max_bs: int = 4, +): + common_kwargs = { + "testcase": testcase, + "model_path": model_path, + "lora_paths": lora_paths, + "max_loras_per_batch": max_loras_per_batch, + "lora_backend": lora_backend, + "disable_cuda_graph": disable_cuda_graph, + "cuda_graph_max_bs": cuda_graph_max_bs, + } + + if mode == LoRAUpdateTestSessionMode.ENGINE: + return LoRAUpdateEngineTestSession(**common_kwargs) + elif mode == LoRAUpdateTestSessionMode.SERVER: + return LoRAUpdateServerTestSession(**common_kwargs) + else: + raise ValueError(f"Unrecognized mode: {mode!r}") + + +class TestLoRADynamicUpdate(CustomTestCase): + """ + This test case verifies that the SRT runner can dynamically load and unload LoRA adapters + during a sequence of operations, and that the outputs of forward passes with dynamically loaded + adapters match the outputs of forward passes with statically loaded adapters. + """ + + def _repeat_each(lst, n): + return [x for x in lst for _ in range(n)] + + def _run_operation_sequence( + self, + mode: LoRAUpdateTestSessionMode, + base: str, + initial_adapters: List[str], + max_loras_per_batch: int, + op_sequence: List[Operation], + max_new_tokens: int = 32, + ) -> List[tuple]: + """ + Runs a sequence of operations on the SRT runner, including loading and unloading LoRA adapters, + and performing forward passes with the current set of loaded adapters. + """ + + forward_outputs = [] + with LoRAUpdateTestSession( + testcase=self, + mode=mode, + model_path=base, + lora_paths=initial_adapters, + max_loras_per_batch=max_loras_per_batch, + ) as session: + for op in op_sequence: + op_type = op.type + data = op.data + print("-" * 100) + print( + f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---" + ) + if op_type == OperationType.LOAD: + result = session.load_lora_adapter( + lora_name=data, + lora_path=data, + ) + elif op_type == OperationType.UNLOAD: + result = session.unload_lora_adapter( + lora_name=data, + ) + elif op_type == OperationType.FORWARD: + prompts, adapters = zip(*data) + result = session.forward( + prompts=list(prompts), + lora_paths=list(adapters), + max_new_tokens=max_new_tokens, + ) + forward_outputs.append(result) + + return forward_outputs + + def test_dynamic_adapter_updates(self): + for case_idx, test_case in enumerate(TEST_CASES, start=1): + for mode in [ + LoRAUpdateTestSessionMode.SERVER, + LoRAUpdateTestSessionMode.ENGINE, + ]: + print("=" * 100) + print(f"Starting test case {case_idx} in {mode.value} mode.") + print("=" * 100) + + print( + f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---" + ) + # Test dynamic loading of adapters + # TODO (lifuhuang): currently at least one LoRA path is required during initialization to enable lora, + # we should fix this in the future https://github.com/sgl-project/sglang/issues/7463. + dynamic_output = self._run_operation_sequence( + mode=mode, + initial_adapters=test_case.initial_adapters, + base=test_case.base, + max_loras_per_batch=test_case.max_loras_per_batch, + op_sequence=test_case.op_sequence, + max_new_tokens=test_case.max_new_tokens, + ) + + # static loading + forward_ops = [ + x for x in test_case.op_sequence if x.type == OperationType.FORWARD + ] + + print("=" * 100) + print( + f"\n--- Running static pass with {len(forward_ops)} operations ---" + ) + static_output = self._run_operation_sequence( + mode=mode, + initial_adapters=test_case.all_adapters, + base=test_case.base, + max_loras_per_batch=test_case.max_loras_per_batch, + op_sequence=forward_ops, + max_new_tokens=test_case.max_new_tokens, + ) + + print(f"Dynamic output: {dynamic_output}") + print(f"Static output: {static_output}") + print("=" * 100) + self.assertEqual( + len(dynamic_output), + len(static_output), + f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}", + ) + for i, (dynamic, static) in enumerate( + zip(dynamic_output, static_output), start=1 + ): + self.assertEqual( + len(dynamic), + len(static), + f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}", + ) + for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1): + d_out = d_out.strip() + s_out = s_out.strip() + self.assertEqual( + d_out, + s_out, + f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'", + ) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main(warnings="ignore") diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 8c023f8d2c4..00adfa318ce 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -17,6 +17,7 @@ class TestFile: TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_lora_cuda_graph.py", 250), + TestFile("models/lora/test_lora_update.py", 400), TestFile("models/test_embedding_models.py", 73), # TestFile("models/test_clip_models.py", 52), TestFile("models/test_encoder_embedding_models.py", 100),