Skip to content
76 changes: 56 additions & 20 deletions python/sglang/srt/lora/lora_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import asyncio
from collections import OrderedDict
from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union
from uuid import uuid4
Expand Down Expand Up @@ -71,8 +72,11 @@ def __init__(self, lora_paths: Optional[List[LoRARef]] = None):
# Please note that the counter increment/decrement operations are not synchronized through this
# lock, as they are designed to be non-blocking and can be performed concurrently.
self._registry_lock = RWLock()
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
self._registry: Dict[str, LoRARef] = {}
# An ordered dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
# The LoRARefs are stored in LRU order, such that LoRA adapters that have been
# most recently used are stored at the end. Note that lookups count for accesses.
# Ties are broken arbitrarily.
self._registry: OrderedDict[str, LoRARef] = OrderedDict()
# Counters for ongoing requests, mapping from LoRA ID to ConcurrentCounter.
self._counters: Dict[str, ConcurrentCounter] = {}

Expand Down Expand Up @@ -124,29 +128,30 @@ def _lookup(name: str) -> str:
f"The following requested LoRA adapters are not loaded: {name}\n"
f"Loaded adapters: {self._registry.keys()}."
)
self._registry.move_to_end(name)
return lora_ref.lora_id

async with self._registry_lock.reader_lock:
if isinstance(lora_name, str):
if isinstance(lora_name, str):
async with self._registry_lock.writer_lock:
lora_id = _lookup(lora_name)
await self._counters[lora_id].increment(notify_all=False)
return lora_id
elif isinstance(lora_name, list):

await self._counters[lora_id].increment(notify_all=False)
return lora_id
elif isinstance(lora_name, list):
async with self._registry_lock.writer_lock:
lora_ids = [_lookup(name) for name in lora_name]

# Increment the counters only after all IDs are looked up.
await asyncio.gather(
*[
self._counters[id].increment(notify_all=False)
for id in lora_ids
if id is not None
]
)
return lora_ids
else:
raise TypeError(
"lora_name must be either a string or a list of strings."
)
# Increment the counters only after all IDs are looked up.
await asyncio.gather(
*[
self._counters[id].increment(notify_all=False)
for id in lora_ids
if id is not None
]
)
return lora_ids
else:
raise TypeError("lora_name must be either a string or a list of strings.")

async def release(self, lora_id: Union[str, List[str]]):
"""
Expand Down Expand Up @@ -186,6 +191,37 @@ async def wait_for_unload(self, lora_id: str):
await self._counters[lora_id].wait_for_zero()
del self._counters[lora_id]

async def get_unregistered_loras(self, lora_name: set[str]):
"""
Returns all LoRA adapters in lora_name that are not found in self._registry.
"""
async with self._registry_lock.writer_lock:
unregistered_loras = []

for name in lora_name:
if name in self._registry:
# This counts as a lookup, so we want to update the cache
self._registry.move_to_end(name)
else:
unregistered_loras.append(name)

return unregistered_loras

async def lru_lora_name(self, exclude_pinned=False):
"""
Returns the least recently used LoRA adapter.
If exclude_pinned is True, then return the LRU LoRA adapter that isn't pinned.
"""
async with self._registry_lock.reader_lock:
if not exclude_pinned:
return next(iter(self._registry), None)

for lora_name, lora_ref in self._registry.items():
if not lora_ref.pinned:
return lora_name
else:
return None

def _register_adapter(self, lora_ref: LoRARef):
"""
Internal helper method to register a LoRA adapter.
Expand Down
74 changes: 52 additions & 22 deletions python/sglang/srt/managers/tokenizer_communicator_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,26 @@ async def update_weights_from_ipc(

return success, message

async def _unload_lora_adapter_locked(
self: TokenizerManager,
obj: UnloadLoRAAdapterReqInput,
) -> UnloadLoRAAdapterReqOutput:
assert (
self.lora_update_lock.locked()
), "self.lora_update_lock must be locked in order for self._unload_lora_adapter_locked() to be called"

# Unregister the LoRA adapter from the registry to stop new requests for this adapter
# from being started.
lora_id = await self.lora_registry.unregister(obj.lora_name)
obj.lora_id = lora_id

# Initiate the actual unloading operation at the backend processes only after all
# ongoing requests using this LoRA adapter are finished.
await self.lora_registry.wait_for_unload(lora_id)
result = (await self.update_lora_adapter_communicator(obj))[0]

return result

async def load_lora_adapter(
self: TokenizerManager,
obj: LoadLoRAAdapterReqInput,
Expand All @@ -520,17 +540,6 @@ async def load_lora_adapter(
)

async with self.lora_update_lock:
if (
self.server_args.max_loaded_loras is not None
and self.lora_registry.num_registered_loras
>= self.server_args.max_loaded_loras
):
raise ValueError(
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
"Please unload some LoRA adapters before loading new ones."
)

# Generate new uniquely identifiable LoRARef object.
new_adapter = LoRARef(
lora_name=obj.lora_name,
Expand All @@ -545,6 +554,37 @@ async def load_lora_adapter(
# Register the LoRA adapter only after loading is successful.
if result.success:
await self.lora_registry.register(new_adapter)
self.lora_ref_cache[obj.lora_name] = new_adapter

if self.server_args.max_loaded_loras is not None:
while (
self.lora_registry.num_registered_loras
> self.server_args.max_loaded_loras
):
lru_lora_name = await self.lora_registry.lru_lora_name(
exclude_pinned=True
)
if lru_lora_name is None:
raise ValueError(
"Didn't find any LoRA adapters when trying to evict LRU LoRA adapter. "
f"LoRA registry is: {self.lora_registry._registry}"
)

logger.info(
f"Unloading least recently used LoRA adapter '{lru_lora_name}' "
f"(current number of adapters: {self.lora_registry.num_registered_loras}, "
f"max allowed: {self.server_args.max_loaded_loras})"
)

unload_result = await self._unload_lora_adapter_locked(
UnloadLoRAAdapterReqInput(lora_name=lru_lora_name)
)
if not unload_result.success:
raise ValueError(
f"Error while unloading LRU LoRA adapter '{lru_lora_name}': "
f"{unload_result.error_message}"
)
del result.loaded_adapters[lru_lora_name]

return result
except ValueError as e:
Expand Down Expand Up @@ -581,17 +621,7 @@ async def unload_lora_adapter(
)

async with self.lora_update_lock:
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
# from being started.
lora_id = await self.lora_registry.unregister(obj.lora_name)
obj.lora_id = lora_id

# Initiate the actual unloading operation at the backend processes only after all
# ongoing requests using this LoRA adapter are finished.
await self.lora_registry.wait_for_unload(lora_id)
result = (await self.update_lora_adapter_communicator(obj))[0]

return result
return await self._unload_lora_adapter_locked(obj)
except ValueError as e:
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))

Expand Down
55 changes: 54 additions & 1 deletion python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.environ import envs
from sglang.srt.lora.lora_registry import LoRARegistry
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor
from sglang.srt.managers.disagg_service import start_disagg_service
Expand All @@ -60,6 +60,7 @@
GenerateReqInput,
GetLoadReqInput,
HealthCheckOutput,
LoadLoRAAdapterReqInput,
OpenSessionReqOutput,
SessionParams,
TokenizedEmbeddingReqInput,
Expand Down Expand Up @@ -357,6 +358,13 @@ def __init__(
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
# LoRA updates and inference to overlap.
self.lora_update_lock = asyncio.Lock()
# A cache for mapping the lora_name for LoRA adapters that have been loaded at any
# point to their latest LoRARef objects, so that they can be
# dynamically loaded if needed for inference
self.lora_ref_cache: Dict[str, LoRARef] = {}
if self.server_args.lora_paths is not None:
for lora_ref in self.server_args.lora_paths:
self.lora_ref_cache[lora_ref.lora_name] = lora_ref

# Disaggregation
self.disaggregation_mode = DisaggregationMode(
Expand Down Expand Up @@ -448,6 +456,51 @@ async def generate_request(

async with self.model_update_lock.reader_lock:
if self.server_args.enable_lora and obj.lora_path:
if isinstance(obj.lora_path, str):
unique_lora_paths = set([obj.lora_path])
else:
unique_lora_paths = set(obj.lora_path)

if (
self.server_args.max_loaded_loras is not None
and len(unique_lora_paths) > self.server_args.max_loaded_loras
):
raise ValueError(
f"Received request with {len(unique_lora_paths)} unique loras requested "
f"but max loaded loras is {self.server_args.max_loaded_loras}"
)

# Reload all existing LoRA adapters that have been dynamically unloaded
unregistered_loras = await self.lora_registry.get_unregistered_loras(
unique_lora_paths
)
for lora_path in unregistered_loras:
if lora_path is None:
continue

if lora_path not in self.lora_ref_cache:
raise ValueError(
f"Got LoRA adapter that has never been loaded: {lora_path}\n"
f"All loaded adapters: {self.lora_ref_cache.keys()}."
)

logger.info(f"Reloading evicted adapter: {lora_path}")
new_lora_ref = self.lora_ref_cache[lora_path]
load_result = await self.load_lora_adapter(
LoadLoRAAdapterReqInput(
lora_name=new_lora_ref.lora_name,
lora_path=new_lora_ref.lora_path,
pinned=new_lora_ref.pinned,
)
)
if (
not load_result.success
and "already loaded" not in load_result.error_message
):
raise ValueError(
f"Failed to implicitly load LoRA adapter {lora_path}: {load_result.error_message}"
)

# Look up the LoRA ID from the registry and start tracking ongoing LoRA requests.
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)

Expand Down
Loading
Loading