Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions python/sglang/srt/lora/backend/ascend_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ class AscendLoRABackend(BaseLoRABackend):

def __init__(
self,
max_loras_per_batch: int,
max_loras_total: int,
device: torch.device,
**kwargs,
):
super().__init__(max_loras_per_batch, device)
super().__init__(max_loras_total, device)

def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
Expand Down Expand Up @@ -268,19 +268,19 @@ def prepare_lora_batch(
(bs,), dtype=torch.int32, device=self.device
),
lora_ranks=torch.empty(
(self.max_loras_per_batch,), dtype=torch.int32, device=self.device
(self.max_loras_total,), dtype=torch.int32, device=self.device
),
scalings=torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
(self.max_loras_total,), dtype=torch.float, device=self.device
),
permutation=None,
)

# Copy to device asynchronously
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
batch_info.lora_ranks[: self.max_loras_total].copy_(
lora_ranks_tensor, non_blocking=True
)
batch_info.scalings[: self.max_loras_per_batch].copy_(
batch_info.scalings[: self.max_loras_total].copy_(
scalings_tensor, non_blocking=True
)
batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/lora/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ class BaseLoRABackend:
Args:
max_loras_per_batch: maximum number of different lora weights
that can be applied in a single forward batch.
max_loras_prefetch: maximum number of lora weights in a prefetch batch.
device: the device where the backend runs.
"""

def __init__(self, max_loras_per_batch: int, device: torch.device):
self.max_loras_per_batch = max_loras_per_batch
def __init__(self, max_loras_total: int, device: torch.device):
self.max_loras_total = max_loras_total
self.device = device

def run_lora_a_sgemm(
Expand Down
16 changes: 8 additions & 8 deletions python/sglang/srt/lora/backend/chunked_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend):

def __init__(
self,
max_loras_per_batch: int,
max_loras_total: int,
device: torch.device,
server_args: ServerArgs,
):
super().__init__(max_loras_per_batch, device)
super().__init__(max_loras_total, device)
self.max_chunk_size = server_args.max_lora_chunk_size

def run_lora_a_sgemm(
Expand Down Expand Up @@ -175,8 +175,8 @@ def init_cuda_graph_batch_info(
seg_indptr=torch.zeros(max_num_segments + 1, dtype=torch.int32),
weight_indices=torch.zeros(max_num_segments, dtype=torch.int32),
permutation=torch.zeros(max_num_tokens, dtype=torch.int32),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
lora_ranks=torch.zeros(self.max_loras_total, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_total, dtype=torch.float),
num_segments=None, # Set per batch
max_len=None, # Not used in CSGMV backend
)
Expand Down Expand Up @@ -222,10 +222,10 @@ def prepare_lora_batch(
(num_segments,), dtype=torch.int32, device=self.device
),
lora_ranks=torch.empty(
(self.max_loras_per_batch,), dtype=torch.int32, device=self.device
(self.max_loras_total,), dtype=torch.int32, device=self.device
),
scalings=torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
(self.max_loras_total,), dtype=torch.float, device=self.device
),
permutation=torch.empty(
(len(permutation),), dtype=torch.int32, device=self.device
Expand All @@ -240,10 +240,10 @@ def prepare_lora_batch(
batch_info.max_len = chunk_size

# Copy to device asynchronously
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
batch_info.lora_ranks[: self.max_loras_total].copy_(
lora_ranks_tensor, non_blocking=True
)
batch_info.scalings[: self.max_loras_per_batch].copy_(
batch_info.scalings[: self.max_loras_total].copy_(
scalings_tensor, non_blocking=True
)
batch_info.weight_indices[:num_segments].copy_(
Expand Down
16 changes: 8 additions & 8 deletions python/sglang/srt/lora/backend/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ class TritonLoRABackend(BaseLoRABackend):

def __init__(
self,
max_loras_per_batch: int,
max_loras_total: int,
device: torch.device,
**kwargs,
):
super().__init__(max_loras_per_batch, device)
super().__init__(max_loras_total, device)

def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
Expand Down Expand Up @@ -110,8 +110,8 @@ def init_cuda_graph_batch_info(
seg_indptr=torch.empty(max_bs_in_cuda_graph + 1, dtype=torch.int32),
max_len=num_tokens_per_bs,
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
lora_ranks=torch.zeros(self.max_loras_total, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_total, dtype=torch.float),
permutation=None,
)

Expand Down Expand Up @@ -177,19 +177,19 @@ def prepare_lora_batch(
(bs,), dtype=torch.int32, device=self.device
),
lora_ranks=torch.empty(
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
(self.max_loras_total,), dtype=torch.int64, device=self.device
),
scalings=torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
(self.max_loras_total,), dtype=torch.float, device=self.device
),
permutation=None,
)

# Copy to device asynchronously
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
batch_info.lora_ranks[: self.max_loras_total].copy_(
lora_ranks_tensor, non_blocking=True
)
batch_info.scalings[: self.max_loras_per_batch].copy_(
batch_info.scalings[: self.max_loras_total].copy_(
scalings_tensor, non_blocking=True
)
batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
Expand Down
16 changes: 12 additions & 4 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
base_model: torch.nn.Module,
base_hf_config: AutoConfig,
max_loras_per_batch: int,
max_loras_prefetch: int,
load_config: LoadConfig,
dtype: torch.dtype,
lora_backend: str = "triton",
Expand All @@ -69,6 +70,7 @@ def __init__(
self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config
self.max_loras_per_batch: int = max_loras_per_batch
self.max_loras_prefetch: int = max_loras_prefetch
self.load_config: LoadConfig = load_config
self.dtype: torch.dtype = dtype
self.device: torch.device = next(self.base_model.parameters()).device
Expand All @@ -81,8 +83,9 @@ def __init__(
# LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
backend_type = get_backend_from_name(lora_backend)
self.max_loras_total = max_loras_per_batch + max_loras_prefetch
self.lora_backend: BaseLoRABackend = backend_type(
max_loras_per_batch=max_loras_per_batch,
max_loras_total=self.max_loras_total,
device=self.device,
server_args=server_args,
)
Expand Down Expand Up @@ -241,7 +244,7 @@ def validate_lora_batch(self, lora_ids: set[str]) -> bool:

return required_slots <= mem_pool_vacancy

def prepare_lora_batch(self, forward_batch: ForwardBatch):
def prepare_lora_batch(self, forward_batch: ForwardBatch, prefetch=False):
# Load active loras into lora memory pool
cur_uids = set(forward_batch.lora_ids)

Expand All @@ -251,8 +254,12 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
lora_adapters=self.loras,
lora_modules=self.lora_modules,
lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation.
prefetch=prefetch,
)

if prefetch:
return

# set up batch info shared by all lora modules
bs = forward_batch.batch_size

Expand All @@ -263,8 +270,8 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
)

weight_indices = [0] * len(forward_batch.lora_ids)
lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch
lora_ranks = [0] * self.max_loras_total
scalings = [0] * self.max_loras_total
for i, uid in enumerate(forward_batch.lora_ids):
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
if uid is not None:
Expand Down Expand Up @@ -416,6 +423,7 @@ def init_memory_pool(self):
self.memory_pool = LoRAMemoryPool(
base_hf_config=self.base_hf_config,
max_loras_per_batch=self.max_loras_per_batch,
max_loras_prefetch=self.max_loras_prefetch,
dtype=self.dtype,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
Expand Down
Loading