Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions python/sglang/srt/lora/backend/ascend_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ class AscendLoRABackend(BaseLoRABackend):
def __init__(
self,
max_loras_per_batch: int,
max_loras_prefetch: int,
device: torch.device,
**kwargs,
):
super().__init__(max_loras_per_batch, device)
super().__init__(max_loras_per_batch, max_loras_prefetch, device)

def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
Expand Down Expand Up @@ -268,19 +269,19 @@ def prepare_lora_batch(
(bs,), dtype=torch.int32, device=self.device
),
lora_ranks=torch.empty(
(self.max_loras_per_batch,), dtype=torch.int32, device=self.device
(self.max_loras_total,), dtype=torch.int32, device=self.device
),
scalings=torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
(self.max_loras_total,), dtype=torch.float, device=self.device
),
permutation=None,
)

# Copy to device asynchronously
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
batch_info.lora_ranks[: self.max_loras_total].copy_(
lora_ranks_tensor, non_blocking=True
)
batch_info.scalings[: self.max_loras_per_batch].copy_(
batch_info.scalings[: self.max_loras_total].copy_(
scalings_tensor, non_blocking=True
)
batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/lora/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@ class BaseLoRABackend:
Args:
max_loras_per_batch: maximum number of different lora weights
that can be applied in a single forward batch.
max_loras_prefetch: maximum number of lora weights in a prefetch batch.
device: the device where the backend runs.
"""

def __init__(self, max_loras_per_batch: int, device: torch.device):
def __init__(
self, max_loras_per_batch: int, max_loras_prefetch: int, device: torch.device
):
self.max_loras_per_batch = max_loras_per_batch
self.max_loras_prefetch = max_loras_prefetch
self.max_loras_total = max_loras_per_batch + max_loras_prefetch
self.device = device

def run_lora_a_sgemm(
Expand Down
15 changes: 8 additions & 7 deletions python/sglang/srt/lora/backend/chunked_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend):
def __init__(
self,
max_loras_per_batch: int,
max_loras_prefetch: int,
device: torch.device,
server_args: ServerArgs,
):
super().__init__(max_loras_per_batch, device)
super().__init__(max_loras_per_batch, max_loras_prefetch, device)
self.max_chunk_size = server_args.max_lora_chunk_size

def run_lora_a_sgemm(
Expand Down Expand Up @@ -175,8 +176,8 @@ def init_cuda_graph_batch_info(
seg_indptr=torch.zeros(max_num_segments + 1, dtype=torch.int32),
weight_indices=torch.zeros(max_num_segments, dtype=torch.int32),
permutation=torch.zeros(max_num_tokens, dtype=torch.int32),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
lora_ranks=torch.zeros(self.max_loras_total, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_total, dtype=torch.float),
num_segments=None, # Set per batch
max_len=None, # Not used in CSGMV backend
)
Expand Down Expand Up @@ -222,10 +223,10 @@ def prepare_lora_batch(
(num_segments,), dtype=torch.int32, device=self.device
),
lora_ranks=torch.empty(
(self.max_loras_per_batch,), dtype=torch.int32, device=self.device
(self.max_loras_total,), dtype=torch.int32, device=self.device
),
scalings=torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
(self.max_loras_total,), dtype=torch.float, device=self.device
),
permutation=torch.empty(
(len(permutation),), dtype=torch.int32, device=self.device
Expand All @@ -240,10 +241,10 @@ def prepare_lora_batch(
batch_info.max_len = chunk_size

# Copy to device asynchronously
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
batch_info.lora_ranks[: self.max_loras_total].copy_(
lora_ranks_tensor, non_blocking=True
)
batch_info.scalings[: self.max_loras_per_batch].copy_(
batch_info.scalings[: self.max_loras_total].copy_(
scalings_tensor, non_blocking=True
)
batch_info.weight_indices[:num_segments].copy_(
Expand Down
15 changes: 8 additions & 7 deletions python/sglang/srt/lora/backend/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ class TritonLoRABackend(BaseLoRABackend):
def __init__(
self,
max_loras_per_batch: int,
max_loras_prefetch: int,
device: torch.device,
**kwargs,
):
super().__init__(max_loras_per_batch, device)
super().__init__(max_loras_per_batch, max_loras_prefetch, device)

def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
Expand Down Expand Up @@ -110,8 +111,8 @@ def init_cuda_graph_batch_info(
seg_indptr=torch.empty(max_bs_in_cuda_graph + 1, dtype=torch.int32),
max_len=num_tokens_per_bs,
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
lora_ranks=torch.zeros(self.max_loras_total, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_total, dtype=torch.float),
permutation=None,
)

Expand Down Expand Up @@ -177,19 +178,19 @@ def prepare_lora_batch(
(bs,), dtype=torch.int32, device=self.device
),
lora_ranks=torch.empty(
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
(self.max_loras_total,), dtype=torch.int64, device=self.device
),
scalings=torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
(self.max_loras_total,), dtype=torch.float, device=self.device
),
permutation=None,
)

# Copy to device asynchronously
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
batch_info.lora_ranks[: self.max_loras_total].copy_(
lora_ranks_tensor, non_blocking=True
)
batch_info.scalings[: self.max_loras_per_batch].copy_(
batch_info.scalings[: self.max_loras_total].copy_(
scalings_tensor, non_blocking=True
)
batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
Expand Down
14 changes: 11 additions & 3 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
base_model: torch.nn.Module,
base_hf_config: AutoConfig,
max_loras_per_batch: int,
max_loras_prefetch: int,
load_config: LoadConfig,
dtype: torch.dtype,
lora_backend: str = "triton",
Expand All @@ -69,6 +70,7 @@ def __init__(
self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config
self.max_loras_per_batch: int = max_loras_per_batch
self.max_loras_prefetch: int = max_loras_prefetch
self.load_config: LoadConfig = load_config
self.dtype: torch.dtype = dtype
self.device: torch.device = next(self.base_model.parameters()).device
Expand All @@ -83,6 +85,7 @@ def __init__(
backend_type = get_backend_from_name(lora_backend)
self.lora_backend: BaseLoRABackend = backend_type(
max_loras_per_batch=max_loras_per_batch,
max_loras_prefetch=max_loras_prefetch,
device=self.device,
server_args=server_args,
)
Expand Down Expand Up @@ -241,7 +244,7 @@ def validate_lora_batch(self, lora_ids: set[str]) -> bool:

return required_slots <= mem_pool_vacancy

def prepare_lora_batch(self, forward_batch: ForwardBatch):
def prepare_lora_batch(self, forward_batch: ForwardBatch, prefetch=False):
# Load active loras into lora memory pool
cur_uids = set(forward_batch.lora_ids)

Expand All @@ -251,8 +254,12 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
lora_adapters=self.loras,
lora_modules=self.lora_modules,
lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation.
prefetch=prefetch,
)

if prefetch:
return

# set up batch info shared by all lora modules
bs = forward_batch.batch_size

Expand All @@ -263,8 +270,8 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
)

weight_indices = [0] * len(forward_batch.lora_ids)
lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch
lora_ranks = [0] * (self.max_loras_per_batch + self.max_loras_prefetch)
scalings = [0] * (self.max_loras_per_batch + self.max_loras_prefetch)
for i, uid in enumerate(forward_batch.lora_ids):
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
if uid is not None:
Expand Down Expand Up @@ -416,6 +423,7 @@ def init_memory_pool(self):
self.memory_pool = LoRAMemoryPool(
base_hf_config=self.base_hf_config,
max_loras_per_batch=self.max_loras_per_batch,
max_loras_prefetch=self.max_loras_prefetch,
dtype=self.dtype,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
Expand Down
72 changes: 51 additions & 21 deletions python/sglang/srt/lora/mem_pool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import logging
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union

Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(
self,
base_hf_config: AutoConfig,
max_loras_per_batch: int,
max_loras_prefetch: int,
dtype: torch.dtype,
tp_size: int,
tp_rank: int,
Expand All @@ -60,6 +62,7 @@ def __init__(
self.base_hf_config: AutoConfig = base_hf_config
self.num_layer: int = base_hf_config.num_hidden_layers
self.max_loras_per_batch: int = max_loras_per_batch
self.max_loras_prefetch: int = max_loras_prefetch
self.dtype: torch.dtype = dtype
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
Expand All @@ -71,9 +74,9 @@ def __init__(

# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
# (max_loras_per_batch + max_loras_prefetch, stacked_num * max_lora_dim, input_dim)
# B_buffer contains num_layer number of column-major tensors with shape
# (stacked_num, max_loras_per_batch, output_dim, max_lora_dim)
# (stacked_num, max_loras_per_batch + max_loras_prefetch, output_dim, max_lora_dim)
self.A_buffer: Dict[str, List[torch.Tensor]] = {}
self.B_buffer: Dict[str, List[torch.Tensor]] = {}

Expand All @@ -83,9 +86,15 @@ def __init__(
# Buffer idx -> lora uid in memory pool
# All uids are initialized as `EmptySlot` for empty buffer slots
# Here we don't initialize to None since None is a valid uid
self.buffer_id_to_uid: List[Union[str, None, EmptySlot]] = [
EMPTY_SLOT
] * self.max_loras_per_batch
self.buffer_id_to_uid: List[Union[str, None, EmptySlot]] = [EMPTY_SLOT] * (
self.max_loras_per_batch + self.max_loras_prefetch
)

self.device = next(base_model.parameters()).device
if self.device.type == "cuda":
self.prefetch_stream = torch.cuda.Stream(device=self.device)
else:
self.prefetch_stream = None

self.init_buffers(base_model)

Expand Down Expand Up @@ -125,7 +134,7 @@ def get_lora_A_shape(
if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
input_dim = divide(input_dim, self.tp_size)
return (
self.max_loras_per_batch,
self.max_loras_per_batch + self.max_loras_prefetch,
max_lora_dim * c,
input_dim,
)
Expand All @@ -146,7 +155,7 @@ def get_lora_B_shape(
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
output_dim = divide(output_dim, self.tp_size)
return (
self.max_loras_per_batch,
self.max_loras_per_batch + self.max_loras_prefetch,
output_dim,
max_lora_dim,
)
Expand Down Expand Up @@ -192,10 +201,30 @@ def prepare_lora_batch(
lora_adapters: Dict[str, LoRAAdapter],
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
lora_refs: Dict[str, LoRARef],
prefetch: bool,
):
stream_ctx = (
torch.cuda.stream(self.prefetch_stream)
if prefetch and self.prefetch_stream is not None
else (
torch.cuda.stream(torch.cuda.current_stream(self.device))
if self.device.type == "cuda"
else contextlib.nullcontext()
)
)

def get_available_buffer_slot():
# 1. Prioritize empty slots
for buffer_id in range(self.max_loras_per_batch):
start_slot, stop_slot = (
(0, self.max_loras_per_batch)
if not prefetch
else (
self.max_loras_per_batch,
self.max_loras_per_batch + self.max_loras_prefetch,
)
)

for buffer_id in range(start_slot, stop_slot):
if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
return buffer_id

Expand Down Expand Up @@ -235,19 +264,20 @@ def get_available_buffer_slot():
)
return victim_buffer_id

# Mark all adapters in current batch as used (for LRU tracking)
for uid in cur_uids:
self.eviction_policy.mark_used(uid)

for uid in cur_uids:
if uid not in self.uid_to_buffer_id:
buffer_id = get_available_buffer_slot()
lora_adapter = lora_adapters.get(uid, None)
self.load_lora_weight_to_buffer(
uid, buffer_id, lora_adapter, lora_modules
)
self.uid_to_buffer_id[uid] = buffer_id
self.buffer_id_to_uid[buffer_id] = uid
with stream_ctx:
# Mark all adapters in current batch as used (for LRU tracking)
for uid in cur_uids:
self.eviction_policy.mark_used(uid)

for uid in cur_uids:
if uid not in self.uid_to_buffer_id:
buffer_id = get_available_buffer_slot()
lora_adapter = lora_adapters.get(uid, None)
self.load_lora_weight_to_buffer(
uid, buffer_id, lora_adapter, lora_modules
)
self.uid_to_buffer_id[uid] = buffer_id
self.buffer_id_to_uid[buffer_id] = uid

def load_lora_weight_to_buffer(
self,
Expand Down
Loading
Loading