Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
GenerateReqInput,
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
OpenSessionReqInput,
ParseFunctionCallReq,
ProfileReqInput,
Expand All @@ -79,6 +80,7 @@
SeparateReasoningReqInput,
SetInternalStateReq,
SlowDownReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
Expand Down Expand Up @@ -583,6 +585,43 @@ async def slow_down(obj: SlowDownReqInput, request: Request):
return _create_error_response(e)


@app.api_route("/load_lora_adapter", methods=["POST"])
async def load_lora_adapter(obj: LoadLoRAAdapterReqInput, request: Request):
"""Load a new LoRA adapter without re-launching the server."""
success, message = await _global_state.tokenizer_manager.load_lora_adapter(
obj, request
)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(
content,
status_code=HTTPStatus.OK,
)
else:
return ORJSONResponse(
content,
status_code=HTTPStatus.BAD_REQUEST,
)

@app.api_route("/unload_lora_adapter", methods=["POST"])
async def unload_lora_adapter(obj: UnloadLoRAAdapterReqInput, request: Request):
"""Load a new LoRA adapter without re-launching the server."""
success, message = await _global_state.tokenizer_manager.unload_lora_adapter(
obj, request
)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(
content,
status_code=HTTPStatus.OK,
)
else:
return ORJSONResponse(
content,
status_code=HTTPStatus.BAD_REQUEST,
)


@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
Expand Down
9 changes: 4 additions & 5 deletions python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
self.layers: List[LoRALayer] = nn.ModuleList(
[
LoRALayer(config, base_hf_config)
for i in range(base_hf_config.num_hidden_layers)
for _ in range(base_hf_config.num_hidden_layers)
]
)

Expand All @@ -88,10 +88,9 @@ def initialize_weights(self):
else:
self.weights[name] = loaded_weight.cpu()

# stack kv_proj and gate_up_proj
for i in range(self.base_hf_config.num_hidden_layers):
layer = self.layers[i]
weight_names = [name for name, _ in layer.weights.items()]
# normalize kv_proj and gate_up_proj
for layer in self.layers:
weight_names = list(layer.weights.keys())
self.normalize_qkv_proj(weight_names, layer.weights)
self.normalize_gate_up_proj(weight_names, layer.weights)

Expand Down
34 changes: 25 additions & 9 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# and "Punica: Multi-Tenant LoRA Serving"

import logging
from typing import Dict, Set, Tuple
from typing import Dict, Iterable, Set, Tuple

import torch

Expand Down Expand Up @@ -98,35 +98,51 @@ def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
],
)

def load_lora_adapters(self, lora_paths: Dict[str, str]):
def load_lora_adapters(
self, lora_paths: Dict[str, str]
) -> Dict[str, Tuple[bool, str]]:
"""
Load LoRA adapters from the specified paths.
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.

Args:
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
If a LoRA adapter is already loaded, it will be skipped with a warning.

Returns:
Dict[str, Tuple[bool, str]]: A dictionary mapping LoRA adapter names to a tuple of
(success, message). If loading is successful, success is True and message is an empty string.
If loading fails, success is False and message contains the error message.
"""

results = {}
for lora_name, lora_path in lora_paths.items():
if lora_name in self.loras:
logger.warning(
f"LoRA adapter {lora_name} is already loaded."
"If you want to reload it, please unload it first."
error_msg = (
f"LoRA adapter {lora_name} is skipped as it is already loaded. "
)
"If you want to reload it, please unload it first."
results[lora_name] = (False, error_msg)
continue

self.configs[lora_name] = LoRAConfig(lora_path)
try:
self.configs[lora_name] = LoRAConfig(lora_path)
results[lora_name] = (True, "")
except Exception as e:
error_msg = f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
results[lora_name] = (False, error_msg)
logger.error(error_msg)
continue

self.update_state_from_configs()
return results

def unload_lora_adapters(self, lora_names: Set[str]):
def unload_lora_adapters(self, lora_names: Iterable[str]):
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.

Args:
lora_names (Set[str]): A set of LoRA adapter names to unload.
lora_names (Iterable[str]): A list of LoRA adapter names to unload.
"""
for lora_name in lora_names:
if lora_name in self.loras:
Expand Down
26 changes: 26 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,3 +994,29 @@ class RpcReqInput:
class RpcReqOutput:
success: bool
message: str


@dataclass
class LoadLoRAAdapterReqInput:
# The name of the lora module to newly loaded.
lora_name: str
# The path of loading.
lora_path: str


@dataclass
class LoadLoRAAdapterReqOutput:
success: bool
message: str


@dataclass
class UnloadLoRAAdapterReqInput:
# The name of lora module to unload.
lora_name: str


@dataclass
class UnloadLoRAAdapterReqOutput:
success: bool
message: str
30 changes: 30 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
Expand Down Expand Up @@ -501,6 +505,8 @@ def __init__(
(SetInternalStateReq, self.set_internal_state),
(RpcReqInput, self.handle_rpc_request),
(ExpertDistributionReq, self.expert_distribution_handle),
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
]
)

Expand Down Expand Up @@ -2205,6 +2211,30 @@ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message, 0)

def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
"""In-place loading a new lora adapater from disk or huggingface."""

success, message = self.tp_worker.load_lora_adapter(recv_req)

if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return LoadLoRAAdapterReqOutput(success, message)

def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
"""Unload the lora adapter."""

success, message = self.tp_worker.unload_lora_adapter(recv_req)

if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UnloadLoRAAdapterReqOutput(success, message)

def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
Expand Down
55 changes: 55 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
Expand All @@ -99,6 +101,8 @@
SlowDownReqOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
Expand Down Expand Up @@ -311,6 +315,12 @@ def __init__(
self.expert_distribution_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.load_lora_adapter_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.unload_lora_adapter_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)

self._result_dispatcher = TypeBasedDispatcher(
[
Expand Down Expand Up @@ -377,6 +387,14 @@ def __init__(
ExpertDistributionReqOutput,
self.expert_distribution_communicator.handle_recv,
),
(
LoadLoRAAdapterReqOutput,
self.load_lora_adapter_communicator.handle_recv,
),
{
UnloadLoRAAdapterReqOutput,
self.unload_lora_adapter_communicator.handle_recv,
}
(HealthCheckOutput, lambda x: None),
]
)
Expand Down Expand Up @@ -960,6 +978,43 @@ async def update_weights_from_tensor(
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message

async def load_lora_adapter(
self,
obj: LoadLoRAAdapterReqInput,
_: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start load Lora adapter. Lora name=%s, path=%s",
obj.lora_name,
obj.lora_path,
)

async with self.model_update_lock.writer_lock:
result = (await self.load_lora_adapter_communicator(obj))[0]
return result.success, result.message

async def unload_lora_adapter(
self,
obj: UnloadLoRAAdapterReqInput,
_: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start unload Lora adapter. Lora name=%s",
obj.lora_name,
)

async with self.model_update_lock.writer_lock:
result = (await self.unload_lora_adapter_communicator(obj))[0]
return result.success, result.message

async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):
Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
Expand Down Expand Up @@ -275,3 +277,15 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
recv_req.name, recv_req.truncate_size
)
return parameter

def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
parameter = self.model_runner.load_lora_adapter(
recv_req.lora_name, recv_req.lora_path
)
return parameter

def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
parameter = self.model_runner.unload_lora_adapter(
recv_req.lora_name
)
return parameter
8 changes: 8 additions & 0 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
Expand Down Expand Up @@ -268,6 +270,12 @@ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req)

def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
return self.worker.load_lora_adapter(recv_req)

def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
return self.worker.unload_lora_adapter(recv_req)

def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))
Loading
Loading