Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2c9353d
overlap lora weight loading with compute
glenliu21 Dec 17, 2025
1aebbc7
move logic to LoRAPrefetcher
glenliu21 Dec 23, 2025
bff918e
add lora_prefetcher.py
glenliu21 Dec 23, 2025
bc9cf43
merge main
glenliu21 Dec 23, 2025
b442c8b
fix weight sync issue
glenliu21 Dec 25, 2025
8a674b5
fix
glenliu21 Dec 28, 2025
ee707d6
Merge branch 'main' into lora_pipeline
glenliu21 Dec 29, 2025
071465f
Merge branch 'main' into lora_pipeline
glenliu21 Dec 29, 2025
28f8fe9
fix
glenliu21 Dec 29, 2025
58edc3a
precommit
glenliu21 Dec 29, 2025
5351cb2
add server arg and test
glenliu21 Dec 31, 2025
8209420
register test for ci
glenliu21 Dec 31, 2025
81072e6
adjust test
glenliu21 Dec 31, 2025
ffeeacb
Merge branch 'main' into lora_pipeline
glenliu21 Dec 31, 2025
044d789
improve test
glenliu21 Jan 2, 2026
fb31cf7
Merge branch 'main' into lora_pipeline
glenliu21 Jan 2, 2026
db1a67e
add tp test
glenliu21 Jan 2, 2026
daa8403
Merge branch 'main' into lora_pipeline
glenliu21 Jan 2, 2026
ea43f6d
rename lora_prefetcher to lora_overlap_loader; reorganize and reforma…
glenliu21 Jan 5, 2026
30930f5
Merge branch 'main' into lora_pipeline
glenliu21 Jan 5, 2026
2892f63
Merge branch 'main' into lora_pipeline
glenliu21 Jan 5, 2026
1500be6
Merge branch 'main' into lora_pipeline
Fridge003 Jan 10, 2026
3293996
Merge branch 'main' into lora_pipeline
glenliu21 Jan 10, 2026
f87bdf6
max_loaded_loras fix
glenliu21 Jan 10, 2026
354ff9c
Merge branch 'main' into lora_pipeline
glenliu21 Jan 10, 2026
77609c0
Merge branch 'main' into lora_pipeline
glenliu21 Jan 10, 2026
8b69562
fix server arg description
glenliu21 Jan 10, 2026
0aa9226
max_loaded_loras arg
glenliu21 Jan 12, 2026
0409f43
Merge branch 'main' into lora_pipeline
glenliu21 Jan 12, 2026
ec6cecf
Merge branch 'main' into lora_pipeline
glenliu21 Jan 12, 2026
ae79420
Merge branch 'main' into lora_pipeline
glenliu21 Jan 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/advanced_features/lora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
"\n",
"* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n",
"\n",
"* `enable_lora_overlap_loading`: Enable asynchronous LoRA weight loading in order to overlap H2D transfers with GPU compute. This should be enabled if you find that your LoRA workloads are bottlenecked by adapter weight loading, for example when frequently loading large LoRA adapters.\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an example for lora overlap loading in below section (can be updated in a following PR)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see #17464.

"\n",
"* `lora_paths`: The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {\"lora_name\":str,\"lora_path\":str,\"pinned\":bool}.\n",
"\n",
"* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n",
Expand Down
1 change: 1 addition & 0 deletions docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Argument | Description | Defaults | Options |
| --- | --- | --- | --- |
| `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to `True` if `--lora-paths` is provided for backward compatibility. | `False` | Bool flag (set to enable) |
| `--enable-lora-overlap-loading` | Enable asynchronous LoRA weight loading in order to overlap H2D transfers with GPU compute. This should be enabled if you find that your LoRA workloads are bottlenecked by adapter weight loading, for example when frequently loading large LoRA adapters. | `False` | Bool flag (set to enable)
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | `None` | Type: int |
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. You can also set it to `all` to enable LoRA for all supported modules; note this may introduce minor performance overhead. | `None` | `q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`, `qkv_proj`, `gate_up_proj`, `all` |
| `--lora-paths` | The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: `<PATH>` \| `<NAME>=<PATH>` \| JSON with schema `{"lora_name": str, "lora_path": str, "pinned": bool}`. | `None` | Type: List[str] / JSON objects |
Expand Down
11 changes: 11 additions & 0 deletions python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,14 @@ def normalize_gate_up_proj(
if "lora_A" in weight_name:
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
# else: no-op as LoRA B weight is already stacked.

def pin_weights_in_cpu(self):
for layer in self.layers:
for name, weight in layer.weights.items():
layer.weights[name] = weight.pin_memory()

for name, weight in self.embedding_layers.items():
self.embedding_layers[name] = weight.pin_memory()

for name, weight in self.added_tokens_embeddings.items():
self.added_tokens_embeddings[name] = weight.pin_memory()
20 changes: 17 additions & 3 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def __init__(
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.lora_added_tokens_size: Optional[int] = None
self.enable_lora_overlap_loading: Optional[bool] = (
server_args.enable_lora_overlap_loading
)

# Store eviction policy from server args
self.eviction_policy = server_args.lora_eviction_policy
Expand Down Expand Up @@ -208,7 +211,7 @@ def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:

return self.create_lora_update_result(success=True)

def validate_lora_batch(self, lora_ids: set[str]) -> bool:
def validate_lora_batch(self, lora_ids: set[Optional[str]]) -> bool:
"""
Validate if the LoRA IDs in the batch can be loaded into the current LoRA memory pool.
"""
Expand Down Expand Up @@ -239,9 +242,11 @@ def validate_lora_batch(self, lora_ids: set[str]) -> bool:

return required_slots <= mem_pool_vacancy

def prepare_lora_batch(self, forward_batch: ForwardBatch):
def fetch_new_loras(
self, new_loras: set[Optional[str]], running_loras: set[Optional[str]] = set()
):
# Load active loras into lora memory pool
cur_uids = set(forward_batch.lora_ids)
cur_uids = new_loras | running_loras

assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(
Expand All @@ -253,6 +258,7 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
lora_lm_head_module=self.lm_head_module, # merge into embedding or lora module
)

def prepare_lora_batch(self, forward_batch: ForwardBatch):
# set up batch info shared by all lora modules
bs = forward_batch.batch_size

Expand Down Expand Up @@ -442,6 +448,11 @@ def load_lora_weights(self, lora_ref: LoRARef):
self.lora_backend,
)
lora_adapter.initialize_weights()

# If we want to overlap loading LoRA adapters with compute, they must be pinned in CPU memory
if self.enable_lora_overlap_loading:
lora_adapter.pin_weights_in_cpu()

self.loras[lora_ref.lora_id] = lora_adapter

def load_lora_weights_from_tensors(
Expand Down Expand Up @@ -509,6 +520,9 @@ def init_memory_pool(self):
lora_added_tokens_size=self.lora_added_tokens_size,
)

# Initializing memory pool with base model
self.fetch_new_loras({None})

def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(module, self.lora_backend)
replace_submodule(self.base_model, module_name, lora_module)
Expand Down
82 changes: 82 additions & 0 deletions python/sglang/srt/lora/lora_overlap_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import logging
from enum import Enum, auto
from typing import Dict, Optional

import torch
from torch.cuda import Event as CudaEvent
from torch.cuda import Stream as CudaStream
from torch.cuda import StreamContext as CudaStreamContext

from sglang.srt.lora.lora_manager import LoRAManager

logger = logging.getLogger(__name__)


class LoRAOverlapLoadStatus(Enum):
LOADED = auto()
LOADING = auto()
NOT_LOADED = auto()


class LoRAOverlapLoader:
def __init__(self, lora_manager):
self.lora_manager: LoRAManager = lora_manager
self.device_module = torch.get_device_module(self.lora_manager.device)
self.load_stream: CudaStream = self.device_module.Stream()
self.load_stream_context: CudaStreamContext = self.device_module.stream(
self.load_stream
)
self.lora_to_overlap_load_event: Dict[Optional[str], CudaEvent] = {}

def try_overlap_load_lora(
self, lora_id: Optional[str], running_loras: set[Optional[str]]
) -> bool:
"""
Check a LoRA adapter's asynchronous load status, and try to load it if there's capacity
in the memory pool. Returns whether or not the adapter has been loaded.
"""
lora_pipeline_load_status = self._check_overlap_load_status(lora_id)
if lora_pipeline_load_status == LoRAOverlapLoadStatus.LOADING:
return False
elif lora_pipeline_load_status == LoRAOverlapLoadStatus.NOT_LOADED:
res = self._try_start_overlap_load(lora_id, running_loras)
if res:
logger.debug(f"Loading LoRA adapter {lora_id} asynchronously")

return False
else:
assert lora_pipeline_load_status == LoRAOverlapLoadStatus.LOADED
return True

def _check_overlap_load_status(
self, lora_id: Optional[str]
) -> LoRAOverlapLoadStatus:
if lora_id not in self.lora_to_overlap_load_event:
return LoRAOverlapLoadStatus.NOT_LOADED

event = self.lora_to_overlap_load_event[lora_id]

if not event.query():
return LoRAOverlapLoadStatus.LOADING

torch.cuda.current_stream().wait_event(event)
del self.lora_to_overlap_load_event[lora_id]

return LoRAOverlapLoadStatus.LOADED

def _try_start_overlap_load(
self, lora_id: Optional[str], running_loras: set[Optional[str]]
) -> bool:
loras_to_be_loaded = running_loras | self.lora_to_overlap_load_event.keys()

new_lora_set = {lora_id} | loras_to_be_loaded
if not self.lora_manager.validate_lora_batch(new_lora_set):
return False

with self.load_stream_context:
self.lora_manager.fetch_new_loras({lora_id}, loras_to_be_loaded)
event = self.device_module.Event()
event.record(self.load_stream)

self.lora_to_overlap_load_event[lora_id] = event
return True
2 changes: 1 addition & 1 deletion python/sglang/srt/lora/mem_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def load_lora_weight_tensor(
assert (
buffer_view.shape == weight.shape
), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
buffer_view.copy_(weight)
buffer_view.copy_(weight, non_blocking=True)

if uid is None:
for i in range(self.num_layer):
Expand Down
41 changes: 27 additions & 14 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
)
from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config
from sglang.srt.lora.lora_overlap_loader import LoRAOverlapLoader
from sglang.srt.managers.io_struct import (
AbortReq,
BaseBatchReq,
Expand Down Expand Up @@ -284,6 +285,7 @@ def __init__(
server_args.priority_scheduling_preemption_threshold
)
self.enable_lora = server_args.enable_lora
self.enable_lora_overlap_loading = server_args.enable_lora_overlap_loading
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule
self.enable_pdmux = server_args.enable_pdmux
Expand Down Expand Up @@ -371,6 +373,12 @@ def __init__(
# Init request dispatcher
self.init_request_dispatcher()

# Init LoRA overlap loader
if self.enable_lora_overlap_loading:
self.lora_overlap_loader = LoRAOverlapLoader(
self.tp_worker.model_runner.lora_manager
)

# Init the grammar backend for constrained generation
self.grammar_manager = GrammarManager(self)

Expand Down Expand Up @@ -1905,23 +1913,25 @@ def _get_new_batch_prefill_raw(
self.chunked_req = adder.add_chunked_req(self.chunked_req)

if self.enable_lora:
lora_set = set([req.lora_id for req in self.running_batch.reqs])
running_loras = {req.lora_id for req in self.running_batch.reqs}

# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:

if self.enable_lora:
new_lora_set = (
lora_set
| set([req.lora_id for req in adder.can_run_list])
| set([req.lora_id])
)
if not self.tp_worker.can_run_lora_batch(new_lora_set):
# Batch would exceed the LoRA slot limit.
# Skip this request and try scheduling it in a future iteration.
# Note: When eviction is needed, the eviction policy prefers to
# evict LoRA adapters over base model (None) - see mem_pool.py.
continue
if self.enable_lora and req.lora_id not in running_loras:
if self.enable_lora_overlap_loading:
# For overlapping loading of LoRA weights with computation, we will load each adapter one at a time,
# as opposed to loading them in one batch
res = self.lora_overlap_loader.try_overlap_load_lora(
req.lora_id, running_loras
)
if not res:
continue
else:
new_lora_set = {req.lora_id} | running_loras
if not self.tp_worker.model_runner.lora_manager.validate_lora_batch(
new_lora_set
):
continue

running_bs = len(self.running_batch.reqs)
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Expand Down Expand Up @@ -1951,6 +1961,9 @@ def _get_new_batch_prefill_raw(
truncation_align_size=self.truncation_align_size,
)

if self.enable_lora:
running_loras.add(req.lora_id)

if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN:
if self.enable_hierarchical_cache:
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,11 @@ def init_new(

# Init lora information
if model_runner.server_args.enable_lora:
# In the non-LoRA overlap loading case, we fetch LoRA adapters into the memory pool
# as a batch, right before running the batch
if not model_runner.server_args.enable_lora_overlap_loading:
model_runner.lora_manager.fetch_new_loras(set(ret.lora_ids))

model_runner.lora_manager.prepare_lora_batch(ret)

return ret
Expand Down
21 changes: 21 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ class ServerArgs:

# LoRA
enable_lora: Optional[bool] = None
enable_lora_overlap_loading: Optional[bool] = None
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[Union[set[str], List[str]]] = None
lora_paths: Optional[
Expand Down Expand Up @@ -3371,6 +3372,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.",
)
parser.add_argument(
"--enable-lora-overlap-loading",
default=ServerArgs.enable_lora_overlap_loading,
action="store_true",
help="Enable asynchronous LoRA weight loading in order to overlap H2D transfers with GPU compute. This should be enabled if you find that your LoRA workloads are bottlenecked by adapter weight loading, for example when frequently loading large LoRA adapters.",
)
parser.add_argument(
"--max-lora-rank",
default=ServerArgs.max_lora_rank,
Expand Down Expand Up @@ -4900,6 +4907,20 @@ def check_lora_server_args(self):
)

if self.enable_lora:
if self.enable_lora_overlap_loading is None:
self.enable_lora_overlap_loading = False

if self.enable_lora_overlap_loading:
# TODO (glenliu21): use some sort of buffer with eviction instead of enforcing a limit
max_loaded_loras_limit = self.max_loras_per_batch * 2
assert (
self.max_loaded_loras is not None
and self.max_loaded_loras <= max_loaded_loras_limit
), (
"Enabling LoRA overlap loading requires pinning LoRA adapter weights in CPU memory, "
f"so --max-loaded-loras must be less than or equal to double --max-loras-per-batch: {max_loaded_loras_limit}"
)

# Validate compatibility with speculative decoding
if self.speculative_algorithm not in ["NGRAM", None]:
raise ValueError(
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/test/lora_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __post_init__(self):
),
],
max_loras_per_batch=2,
max_loaded_loras=4,
),
]

Expand Down Expand Up @@ -285,6 +286,7 @@ def run_lora_test_one_by_one(
torch_dtype: torch.dtype,
max_new_tokens: int,
backend: str = "csgmv",
enable_lora_overlap_loading: Optional[bool] = None,
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
mem_fraction_static: float = 0.88,
Expand Down Expand Up @@ -331,6 +333,7 @@ def run_lora_test_one_by_one(
lora_paths=[
adaptor.name for adaptor in model_case.adaptors if adaptor.name is not None
],
enable_lora_overlap_loading=enable_lora_overlap_loading,
max_loras_per_batch=model_case.max_loras_per_batch,
max_loaded_loras=model_case.max_loaded_loras,
lora_backend=backend,
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/test/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ def __init__(
max_lora_rank: Optional[int] = None,
lora_target_modules: Optional[List[str]] = None,
enable_lora: Optional[bool] = None,
enable_lora_overlap_loading: Optional[bool] = None,
max_loaded_loras: Optional[int] = None,
json_model_override_args: Optional[dict[str, Any]] = None,
lora_eviction_policy: str = "lru",
Expand Down Expand Up @@ -612,6 +613,7 @@ def __init__(
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
enable_lora=enable_lora,
enable_lora_overlap_loading=enable_lora_overlap_loading,
max_loaded_loras=max_loaded_loras,
json_model_override_args=(
json.dumps(json_model_override_args)
Expand Down
Loading
Loading