Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/sglang/srt/entrypoints/EngineBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def update_weights_from_tensor(
"""Update model weights with in-memory tensor data."""
pass

def load_lora_adapter(self, lora_name: str, lora_path: str):
"""Load a new LoRA adapter without re-launching the engine."""
pass

def unload_lora_adapter(self, lora_name: str):
"""Unload a LoRA adapter without re-launching the engine."""
pass

@abstractmethod
def release_memory_occupation(self):
"""Release GPU memory occupation temporarily."""
Expand Down
25 changes: 25 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@
GetWeightsByNameReqInput,
ImageDataItem,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
RpcReqInput,
RpcReqOutput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
Expand Down Expand Up @@ -476,6 +478,29 @@ def get_weights_by_name(self, name: str, truncate_size: int = 100):
self.tokenizer_manager.get_weights_by_name(obj, None)
)

def load_lora_adapter(self, lora_name: str, lora_path: str):
"""Load a new LoRA adapter without re-launching the engine."""

obj = LoadLoRAAdapterReqInput(
lora_name=lora_name,
lora_path=lora_path,
)

loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.load_lora_adapter(obj, None)
)

def unload_lora_adapter(self, lora_name: str):
"""Unload a LoRA adapter without re-launching the engine."""

obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)

loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.unload_lora_adapter(obj, None)
)

def release_memory_occupation(self, tags: Optional[List[str]] = None):
obj = ReleaseMemoryOccupationReqInput(tags=tags)
loop = asyncio.get_event_loop()
Expand Down
36 changes: 36 additions & 0 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
GenerateReqInput,
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
OpenSessionReqInput,
ParseFunctionCallReq,
ProfileReqInput,
Expand All @@ -80,6 +81,7 @@
SeparateReasoningReqInput,
SetInternalStateReq,
SlowDownReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
Expand Down Expand Up @@ -595,6 +597,40 @@ async def slow_down(obj: SlowDownReqInput, request: Request):
return _create_error_response(e)


@app.api_route("/load_lora_adapter", methods=["POST"])
async def load_lora_adapter(obj: LoadLoRAAdapterReqInput, request: Request):
"""Load a new LoRA adapter without re-launching the server."""
result = await _global_state.tokenizer_manager.load_lora_adapter(obj, request)

if result.success:
return ORJSONResponse(
result,
status_code=HTTPStatus.OK,
)
else:
return ORJSONResponse(
result,
status_code=HTTPStatus.BAD_REQUEST,
)


@app.api_route("/unload_lora_adapter", methods=["POST"])
async def unload_lora_adapter(obj: UnloadLoRAAdapterReqInput, request: Request):
"""Load a new LoRA adapter without re-launching the server."""
result = await _global_state.tokenizer_manager.unload_lora_adapter(obj, request)

if result.success:
return ORJSONResponse(
result,
status_code=HTTPStatus.OK,
)
else:
return ORJSONResponse(
result,
status_code=HTTPStatus.BAD_REQUEST,
)


@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
Expand Down
9 changes: 4 additions & 5 deletions python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
self.layers: List[LoRALayer] = nn.ModuleList(
[
LoRALayer(config, base_hf_config)
for i in range(base_hf_config.num_hidden_layers)
for _ in range(base_hf_config.num_hidden_layers)
]
)

Expand All @@ -88,10 +88,9 @@ def initialize_weights(self):
else:
self.weights[name] = loaded_weight.cpu()

# stack kv_proj and gate_up_proj
for i in range(self.base_hf_config.num_hidden_layers):
layer = self.layers[i]
weight_names = [name for name, _ in layer.weights.items()]
# normalize kv_proj and gate_up_proj
for layer in self.layers:
weight_names = list(layer.weights.keys())
self.normalize_qkv_proj(weight_names, layer.weights)
self.normalize_gate_up_proj(weight_names, layer.weights)

Expand Down
93 changes: 73 additions & 20 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
get_normalized_lora_weight_names,
get_weight_name,
)
from sglang.srt.managers.io_struct import LoRAUpdateResult
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import replace_submodule

Expand Down Expand Up @@ -98,44 +99,96 @@ def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
],
)

def load_lora_adapters(self, lora_paths: Dict[str, str]):
def create_lora_update_result(
self, success: bool, error_message: str = ""
) -> LoRAUpdateResult:
return LoRAUpdateResult(
success=success,
error_message=error_message,
loaded_adapters={
name: config.path for name, config in self.configs.items()
},
)

def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
"""
Load LoRA adapters from the specified paths.
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.

Args:
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
If a LoRA adapter is already loaded, it will be skipped with a warning.
"""

results = []
for lora_name, lora_path in lora_paths.items():
if lora_name in self.loras:
logger.warning(
f"LoRA adapter {lora_name} is already loaded."
"If you want to reload it, please unload it first."
)
continue
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
results.append(result)

self.update_state_from_configs()

return self.create_lora_update_result(
success=all(result.success for result in results),
error_message="\n".join(
result.error_message for result in results if not result.success
),
)

def load_lora_adapter(
self, lora_name: str, lora_path: str, update_state: bool = True
) -> LoRAUpdateResult:
"""
Load a single LoRA adapter from the specified path.

Args:
lora_name (str): The name of the LoRA adapter.
lora_path (str): The file path to the LoRA adapter.
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
"""

success = True
error_message = ""

if lora_name in self.loras:
success = False
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."

try:
self.configs[lora_name] = LoRAConfig(lora_path)
except Exception as e:
success = False
error_message = (
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
)

self.update_state_from_configs()
if update_state:
self.update_state_from_configs()

return self.create_lora_update_result(
success=success,
error_message=error_message,
)

def unload_lora_adapters(self, lora_names: Set[str]):
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.

Args:
lora_names (Set[str]): A set of LoRA adapter names to unload.
"""
for lora_name in lora_names:
if lora_name in self.loras:
del self.configs[lora_name]
else:
logger.warning(f"LoRA adapter {lora_name} is not loaded.")

success = True
error_message = ""
if lora_name in self.loras:
del self.configs[lora_name]
else:
error_message = f"LoRA adapter {lora_name} is not loaded."
success = False

self.update_state_from_configs()

return self.create_lora_update_result(
success=success,
error_message=error_message,
)

def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
cur_uids = set(forward_batch.lora_paths)
Expand Down Expand Up @@ -372,8 +425,8 @@ def update_lora_adapters(self):
lora_adapter.initialize_weights()
self.loras[name] = lora_adapter

# Clean up unused LoRA adapters
for name in self.loras:
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
for name in list(self.loras):
if name not in self.configs:
logger.info(f"Unloading LoRA adapter {name}")
del self.loras[name]
Expand Down
26 changes: 25 additions & 1 deletion python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union

from sglang.srt.mm_utils import has_valid_data

Expand Down Expand Up @@ -994,3 +994,27 @@ class RpcReqInput:
class RpcReqOutput:
success: bool
message: str


@dataclass
class LoadLoRAAdapterReqInput:
# The name of the lora module to newly loaded.
lora_name: str
# The path of loading.
lora_path: str


@dataclass
class UnloadLoRAAdapterReqInput:
# The name of lora module to unload.
lora_name: str


@dataclass
class LoRAUpdateResult:
success: bool
error_message: Optional[str] = None
loaded_adapters: Dict[str, str] = field(default_factory=dict)


LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
Loading
Loading