Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions python/sglang/srt/lora/backend/ascend_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

import torch

from sglang.srt.lora.backend.base_backend import BaseLoRABackend
Expand Down Expand Up @@ -204,24 +202,41 @@ def run_gate_up_lora(
return output_tensor

def init_cuda_graph_batch_info(
self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
self,
max_bs_in_cuda_graph: int,
num_tokens_per_bs: int,
):
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
# across batches.
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1)
torch.cumsum(
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
dim=0,
out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
)
with torch.device("npu"):
self.npu_graph_batch_info = LoRABatchInfo(
bs=max_bs_in_cuda_graph,
use_cuda_graph=True,
num_segments=None,
seg_lens=torch.full(
(max_bs_in_cuda_graph,), num_tokens_per_bs, dtype=torch.int32
),
seg_indptr=torch.empty(max_bs_in_cuda_graph + 1, dtype=torch.int32),
max_len=num_tokens_per_bs,
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
permutation=None,
)

# Initialize seg_indptr for NPU graph as they remain constant
# across batches.
torch.cumsum(
self.npu_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
dim=0,
out=self.npu_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
)

def prepare_lora_batch(
self,
forward_batch: ForwardBatch,
weight_indices: list[int],
lora_ranks: list[int],
scalings: list[float],
batch_info: Optional[LoRABatchInfo] = None,
use_cuda_graph: bool,
):
# Use pinned memory to avoid synchronizations during host-to-device transfer
weight_indices_tensor = torch.tensor(
Expand All @@ -236,10 +251,11 @@ def prepare_lora_batch(

bs = forward_batch.batch_size

if batch_info is not None:
if use_cuda_graph:
assert (
batch_info.use_cuda_graph
), "batch_info.use_cuda_graph must be True when batch_info is provided"
self.npu_graph_batch_info is not None
), "NPU Graph batch info is not initialized."
batch_info = self.npu_graph_batch_info
batch_info.bs = forward_batch.batch_size
batch_info.num_segments = forward_batch.batch_size
else:
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/lora/backend/lora_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def create_ascend_backend():
return AscendLoRABackend


@register_lora_backend("torch_native")
def create_torch_native_backend():
from sglang.srt.lora.backend.torch_backend import TorchNativeLoRABackend

return TorchNativeLoRABackend


@register_lora_backend("flashinfer")
def create_flashinfer_backend():
raise ValueError(
Expand Down
297 changes: 297 additions & 0 deletions python/sglang/srt/lora/backend/torch_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
import torch

from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.torch_ops import sgmv_expand, sgmv_expand_slice, sgmv_shrink
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.model_executor.forward_batch_info import ForwardBatch


class TorchNativeLoRABackend(BaseLoRABackend):
name = "torch_native"

def __init__(
self,
max_loras_per_batch: int,
device: torch.device,
**kwargs,
):
super().__init__(max_loras_per_batch, device)

def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:

total_seq_len, _ = x.shape
_, weight_out_dim, _ = weights.shape

output_tensor = torch.zeros(
(total_seq_len, weight_out_dim), dtype=x.dtype, device=x.device
)
sgmv_shrink(
x,
weights,
output_tensor,
self.batch_info.seg_lens,
self.batch_info.weight_indices,
1.0,
)
scaling = torch.repeat_interleave(
self.batch_info.scalings[self.batch_info.weight_indices],
self.batch_info.seg_lens,
output_size=total_seq_len,
).unsqueeze(-1)
output_tensor = output_tensor * scaling

return output_tensor

def run_lora_b_sgemm(
self,
x: torch.Tensor,
weights: torch.Tensor,
base_output: torch.Tensor = None,
*args,
**kwargs,
) -> torch.Tensor:
total_seq_len, _ = x.shape
_, weight_out_dim, _ = weights.shape

if base_output is None:
output_tensor = torch.zeros(
(total_seq_len, weight_out_dim), device=x.device, dtype=x.dtype
)
else:
output_tensor = base_output

sgmv_expand(
x,
weights,
output_tensor,
self.batch_info.seg_lens,
self.batch_info.weight_indices,
True,
)

return output_tensor

def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: torch.Tensor,
output_offset: torch.Tensor,
output_offset_cpu: torch.Tensor,
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
*args,
**kwargs,
) -> torch.Tensor:
num_slices = 3
assert isinstance(qkv_lora_b, torch.Tensor)

total_seq_len, _ = x.shape
_, weight_intermediate_dim, _ = qkv_lora_a.shape
_, weight_out_dim, _ = qkv_lora_b.shape
max_rank = weight_intermediate_dim // num_slices

if base_output is None:
output_tensor = torch.zeros(
(total_seq_len, weight_out_dim), device=x.device, dtype=x.dtype
)
else:
output_tensor = base_output

lora_a_output = torch.zeros(
total_seq_len, weight_intermediate_dim, dtype=x.dtype, device=x.device
)
sgmv_shrink(
x,
qkv_lora_a,
lora_a_output,
self.batch_info.seg_lens,
self.batch_info.weight_indices,
1.0,
)
scaling = torch.repeat_interleave(
self.batch_info.scalings[self.batch_info.weight_indices],
self.batch_info.seg_lens,
output_size=total_seq_len,
).unsqueeze(-1)
lora_a_output = lora_a_output * scaling

for slice_id in range(num_slices):
slice_offset = output_offset_cpu[slice_id]
slice_offset_next = output_offset_cpu[slice_id + 1]
slice_size = slice_offset_next - slice_offset
sgmv_expand_slice(
lora_a_output[:, (max_rank * slice_id) : (max_rank * (slice_id + 1))],
qkv_lora_b[:, slice_offset:slice_offset_next],
output_tensor,
self.batch_info.seg_lens,
self.batch_info.weight_indices,
slice_offset,
slice_size,
True,
)

return output_tensor

def run_gate_up_lora(
self,
x: torch.Tensor,
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: torch.Tensor,
base_output: torch.Tensor = None,
*args,
**kwargs,
) -> torch.Tensor:

num_slices = 2
assert isinstance(gate_up_lora_b, torch.Tensor)

total_seq_len, _ = x.shape
_, weight_intermediate_dim, _ = gate_up_lora_a.shape
_, weight_out_dim, _ = gate_up_lora_b.shape
slice_size = weight_out_dim // num_slices
max_rank = weight_intermediate_dim // num_slices

if base_output is None:
output_tensor = torch.zeros(
(total_seq_len, weight_out_dim), device=x.device, dtype=x.dtype
)
else:
output_tensor = base_output

lora_a_output = torch.zeros(
total_seq_len, weight_intermediate_dim, dtype=x.dtype, device=x.device
)
sgmv_shrink(
x,
gate_up_lora_a,
lora_a_output,
self.batch_info.seg_lens,
self.batch_info.weight_indices,
1.0,
)
scaling = torch.repeat_interleave(
self.batch_info.scalings[self.batch_info.weight_indices],
self.batch_info.seg_lens,
output_size=total_seq_len,
).unsqueeze(-1)
lora_a_output = lora_a_output * scaling

slice_offset = 0
for slice_id in range(num_slices):
sgmv_expand_slice(
lora_a_output[:, (max_rank * slice_id) : (max_rank * (slice_id + 1))],
gate_up_lora_b[:, slice_offset : slice_offset + slice_size],
output_tensor,
self.batch_info.seg_lens,
self.batch_info.weight_indices,
slice_offset,
slice_size,
True,
)
slice_offset += slice_size

return output_tensor

def init_cuda_graph_batch_info(
self,
max_bs_in_cuda_graph: int,
num_tokens_per_bs: int,
):
with torch.device("cuda"):
self.cuda_graph_batch_info = LoRABatchInfo(
bs=max_bs_in_cuda_graph,
use_cuda_graph=True,
num_segments=None,
seg_lens=torch.full(
(max_bs_in_cuda_graph,), num_tokens_per_bs, dtype=torch.int32
),
seg_indptr=torch.empty(max_bs_in_cuda_graph + 1, dtype=torch.int32),
max_len=num_tokens_per_bs,
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
permutation=None,
)

# Initialize seg_indptr for CUDA graph as they remain constant
# across batches.
torch.cumsum(
self.cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
dim=0,
out=self.cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
)

def prepare_lora_batch(
self,
forward_batch: ForwardBatch,
weight_indices: list[int],
lora_ranks: list[int],
scalings: list[float],
use_cuda_graph: bool,
):
# Use pinned memory to avoid synchronizations during host-to-device transfer
weight_indices_tensor = torch.tensor(
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
)
lora_ranks_tensor = torch.tensor(
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
)
scalings_tensor = torch.tensor(
scalings, dtype=torch.float, pin_memory=True, device="cpu"
)

bs = forward_batch.batch_size

if use_cuda_graph:
assert (
self.cuda_graph_batch_info is not None
), "CUDA Graph batch info is not initialized."
batch_info = self.cuda_graph_batch_info
batch_info.bs = forward_batch.batch_size
batch_info.num_segments = forward_batch.batch_size
else:
max_len = (
# Calculate max_len from the CPU copy to avoid D2H transfer.
max(forward_batch.extend_seq_lens_cpu)
if forward_batch.forward_mode.is_extend()
else 1
)
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, dtype=torch.int32, device=self.device)
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)

batch_info = LoRABatchInfo(
bs=forward_batch.batch_size,
num_segments=forward_batch.batch_size,
max_len=max_len,
use_cuda_graph=False,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
weight_indices=torch.empty(
(bs,), dtype=torch.int32, device=self.device
),
lora_ranks=torch.empty(
(self.max_loras_per_batch,), dtype=torch.int32, device=self.device
),
Comment on lines +280 to +282
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The lora_ranks tensor is being initialized with dtype=torch.int64, while the source lora_ranks_tensor is torch.int32. Other backends like ascend_backend and csgmv consistently use torch.int32. For consistency and to potentially save memory, it would be better to use torch.int32 here as well.

Suggested change
lora_ranks=torch.empty(
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
),
lora_ranks=torch.empty(
(self.max_loras_per_batch,), dtype=torch.int32, device=self.device
),

scalings=torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
),
permutation=None,
)

# Copy to device asynchronously
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
lora_ranks_tensor, non_blocking=True
)
batch_info.scalings[: self.max_loras_per_batch].copy_(
scalings_tensor, non_blocking=True
)
batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
self.batch_info = batch_info
7 changes: 7 additions & 0 deletions python/sglang/srt/lora/torch_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .lora_ops import sgmv_expand, sgmv_expand_slice, sgmv_shrink

__all__ = [
"sgmv_expand",
"sgmv_expand_slice",
"sgmv_shrink",
]
Loading
Loading