Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 79 additions & 34 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,25 @@ def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
seg_indptr=torch.zeros(
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
),
max_len=0,
max_len=1,
weight_indices=torch.zeros(
self.max_bs_in_cuda_graph, dtype=torch.int32
),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
)

# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
# across batches.
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
torch.cumsum(
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
dim=0,
out=self.cuda_graph_batch_info.seg_indptr[
1 : self.max_bs_in_cuda_graph + 1
],
)

def init_loras(self):
# Config of each LoRA adapter
self.configs: Dict[str, LoRAConfig] = {}
Expand Down Expand Up @@ -159,58 +170,92 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
# set up batch info shared by all lora modules
bs = forward_batch.batch_size

def transfer_adapter_info(
weight_indices_out: torch.Tensor,
lora_ranks_out: torch.Tensor,
scalings_out: torch.Tensor,
):
"""
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
to device (CUDA) asynchronously.
"""
weight_indices = [0] * len(forward_batch.lora_paths)
lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
if lora_path is not None:
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
scalings[weight_indices[i]] = lora.scaling

# Use pinned memory to avoid synchronizations during host-to-device transfer
weight_indices_tensor = torch.tensor(
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
)
lora_ranks_tensor = torch.tensor(
lora_ranks, dtype=torch.int64, pin_memory=True, device="cpu"
)
scalings_tensor = torch.tensor(
scalings, dtype=torch.float, pin_memory=True, device="cpu"
)

# Copy to device tensors asynchronously
weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
lora_ranks_out[: self.max_loras_per_batch].copy_(
lora_ranks_tensor, non_blocking=True
)
scalings_out[: self.max_loras_per_batch].copy_(
scalings_tensor, non_blocking=True
)

if (
hasattr(self, "max_bs_in_cuda_graph")
and bs <= self.max_bs_in_cuda_graph
and forward_batch.forward_mode.is_cuda_graph()
):
# Do in-place updates when CUDA graph is enabled and the batch forward mode
# could use CUDA graph.
self.cuda_graph_batch_info.bs = bs
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
torch.cumsum(
self.cuda_graph_batch_info.seg_lens[:bs],
dim=0,
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],

transfer_adapter_info(
self.cuda_graph_batch_info.weight_indices,
self.cuda_graph_batch_info.lora_ranks,
self.cuda_graph_batch_info.scalings,
)
self.cuda_graph_batch_info.max_len = 1

for i, lora_path in enumerate(forward_batch.lora_paths):
self.cuda_graph_batch_info.weight_indices[i] = (
self.memory_pool.get_buffer_id(lora_path)
)
if lora_path is not None:
lora = self.loras[lora_path]
self.cuda_graph_batch_info.lora_ranks[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.config.hf_config["r"]
self.cuda_graph_batch_info.scalings[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.scaling
self.cuda_graph_batch_info.bs = bs
self.cuda_graph_batch_info.max_len = 1
batch_info = self.cuda_graph_batch_info
else:
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
lora_ranks = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
)
scalings = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
)
transfer_adapter_info(
weight_indices,
lora_ranks,
scalings,
)

seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device=self.device)
)

max_len = (
# Calculate max_len from the CPU copy to avoid D2H transfer.
max(forward_batch.extend_seq_lens_cpu)
if forward_batch.forward_mode.is_extend()
else 1
)

seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)

lora_ranks = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
)
scalings = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
)
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
if lora_path is not None:
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
scalings[weight_indices[i]] = lora.scaling
batch_info = LoRABatchInfo(
bs=bs,
seg_lens=seg_lens,
Expand Down
9 changes: 4 additions & 5 deletions python/sglang/srt/lora/mem_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,21 @@ def get_available_buffer_slot():
for buffer_id in range(self.max_loras_per_batch):
# Prioritize empty slots
if self.buffer_id_to_uid[buffer_id] == "":
return buffer_id, ""
return buffer_id

for buffer_id in range(self.max_loras_per_batch):
# Evict unneeded lora
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
return buffer_id, self.buffer_id_to_uid[buffer_id]
self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id])
return buffer_id

raise ValueError(
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
)

for uid in cur_uids:
if uid not in self.uid_to_buffer_id:
buffer_id, evicted_lora_uid = get_available_buffer_slot()
if evicted_lora_uid != "":
self.uid_to_buffer_id.pop(evicted_lora_uid)
buffer_id = get_available_buffer_slot()
self.load_lora_weight_to_buffer(
uid, buffer_id, lora_adapters.get(uid, None)
)
Expand Down
Loading