diff --git a/python/sglang/srt/lora/backend/ascend_backend.py b/python/sglang/srt/lora/backend/ascend_backend.py index 4278b340e489..2cffea189730 100644 --- a/python/sglang/srt/lora/backend/ascend_backend.py +++ b/python/sglang/srt/lora/backend/ascend_backend.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from sglang.srt.lora.backend.base_backend import BaseLoRABackend @@ -204,16 +202,33 @@ def run_gate_up_lora( return output_tensor def init_cuda_graph_batch_info( - self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int + self, + max_bs_in_cuda_graph: int, + num_tokens_per_bs: int, ): - # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant - # across batches. - cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1) - torch.cumsum( - cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph], - dim=0, - out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1], - ) + with torch.device("npu"): + self.npu_graph_batch_info = LoRABatchInfo( + bs=max_bs_in_cuda_graph, + use_cuda_graph=True, + num_segments=None, + seg_lens=torch.full( + (max_bs_in_cuda_graph,), num_tokens_per_bs, dtype=torch.int32 + ), + seg_indptr=torch.empty(max_bs_in_cuda_graph + 1, dtype=torch.int32), + max_len=num_tokens_per_bs, + weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32), + lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32), + scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), + permutation=None, + ) + + # Initialize seg_indptr for NPU graph as they remain constant + # across batches. + torch.cumsum( + self.npu_graph_batch_info.seg_lens[:max_bs_in_cuda_graph], + dim=0, + out=self.npu_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1], + ) def prepare_lora_batch( self, @@ -221,7 +236,7 @@ def prepare_lora_batch( weight_indices: list[int], lora_ranks: list[int], scalings: list[float], - batch_info: Optional[LoRABatchInfo] = None, + use_cuda_graph: bool, ): # Use pinned memory to avoid synchronizations during host-to-device transfer weight_indices_tensor = torch.tensor( @@ -236,10 +251,11 @@ def prepare_lora_batch( bs = forward_batch.batch_size - if batch_info is not None: + if use_cuda_graph: assert ( - batch_info.use_cuda_graph - ), "batch_info.use_cuda_graph must be True when batch_info is provided" + self.npu_graph_batch_info is not None + ), "NPU Graph batch info is not initialized." + batch_info = self.npu_graph_batch_info batch_info.bs = forward_batch.batch_size batch_info.num_segments = forward_batch.batch_size else: diff --git a/python/sglang/srt/lora/backend/lora_registry.py b/python/sglang/srt/lora/backend/lora_registry.py index c3dd77888616..160ca8e3ee3d 100644 --- a/python/sglang/srt/lora/backend/lora_registry.py +++ b/python/sglang/srt/lora/backend/lora_registry.py @@ -36,6 +36,13 @@ def create_ascend_backend(): return AscendLoRABackend +@register_lora_backend("torch_native") +def create_torch_native_backend(): + from sglang.srt.lora.backend.torch_backend import TorchNativeLoRABackend + + return TorchNativeLoRABackend + + @register_lora_backend("flashinfer") def create_flashinfer_backend(): raise ValueError( diff --git a/python/sglang/srt/lora/backend/torch_backend.py b/python/sglang/srt/lora/backend/torch_backend.py new file mode 100644 index 000000000000..af467bc81f68 --- /dev/null +++ b/python/sglang/srt/lora/backend/torch_backend.py @@ -0,0 +1,297 @@ +import torch + +from sglang.srt.lora.backend.base_backend import BaseLoRABackend +from sglang.srt.lora.torch_ops import sgmv_expand, sgmv_expand_slice, sgmv_shrink +from sglang.srt.lora.utils import LoRABatchInfo +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + + +class TorchNativeLoRABackend(BaseLoRABackend): + name = "torch_native" + + def __init__( + self, + max_loras_per_batch: int, + device: torch.device, + **kwargs, + ): + super().__init__(max_loras_per_batch, device) + + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + + total_seq_len, _ = x.shape + _, weight_out_dim, _ = weights.shape + + output_tensor = torch.zeros( + (total_seq_len, weight_out_dim), dtype=x.dtype, device=x.device + ) + sgmv_shrink( + x, + weights, + output_tensor, + self.batch_info.seg_lens, + self.batch_info.weight_indices, + 1.0, + ) + scaling = torch.repeat_interleave( + self.batch_info.scalings[self.batch_info.weight_indices], + self.batch_info.seg_lens, + output_size=total_seq_len, + ).unsqueeze(-1) + output_tensor = output_tensor * scaling + + return output_tensor + + def run_lora_b_sgemm( + self, + x: torch.Tensor, + weights: torch.Tensor, + base_output: torch.Tensor = None, + *args, + **kwargs, + ) -> torch.Tensor: + total_seq_len, _ = x.shape + _, weight_out_dim, _ = weights.shape + + if base_output is None: + output_tensor = torch.zeros( + (total_seq_len, weight_out_dim), device=x.device, dtype=x.dtype + ) + else: + output_tensor = base_output + + sgmv_expand( + x, + weights, + output_tensor, + self.batch_info.seg_lens, + self.batch_info.weight_indices, + True, + ) + + return output_tensor + + def run_qkv_lora( + self, + x: torch.Tensor, + qkv_lora_a: torch.Tensor, + qkv_lora_b: torch.Tensor, + output_offset: torch.Tensor, + output_offset_cpu: torch.Tensor, + max_qkv_out_dim: int, + base_output: torch.Tensor = None, + *args, + **kwargs, + ) -> torch.Tensor: + num_slices = 3 + assert isinstance(qkv_lora_b, torch.Tensor) + + total_seq_len, _ = x.shape + _, weight_intermediate_dim, _ = qkv_lora_a.shape + _, weight_out_dim, _ = qkv_lora_b.shape + max_rank = weight_intermediate_dim // num_slices + + if base_output is None: + output_tensor = torch.zeros( + (total_seq_len, weight_out_dim), device=x.device, dtype=x.dtype + ) + else: + output_tensor = base_output + + lora_a_output = torch.zeros( + total_seq_len, weight_intermediate_dim, dtype=x.dtype, device=x.device + ) + sgmv_shrink( + x, + qkv_lora_a, + lora_a_output, + self.batch_info.seg_lens, + self.batch_info.weight_indices, + 1.0, + ) + scaling = torch.repeat_interleave( + self.batch_info.scalings[self.batch_info.weight_indices], + self.batch_info.seg_lens, + output_size=total_seq_len, + ).unsqueeze(-1) + lora_a_output = lora_a_output * scaling + + for slice_id in range(num_slices): + slice_offset = output_offset_cpu[slice_id] + slice_offset_next = output_offset_cpu[slice_id + 1] + slice_size = slice_offset_next - slice_offset + sgmv_expand_slice( + lora_a_output[:, (max_rank * slice_id) : (max_rank * (slice_id + 1))], + qkv_lora_b[:, slice_offset:slice_offset_next], + output_tensor, + self.batch_info.seg_lens, + self.batch_info.weight_indices, + slice_offset, + slice_size, + True, + ) + + return output_tensor + + def run_gate_up_lora( + self, + x: torch.Tensor, + gate_up_lora_a: torch.Tensor, + gate_up_lora_b: torch.Tensor, + base_output: torch.Tensor = None, + *args, + **kwargs, + ) -> torch.Tensor: + + num_slices = 2 + assert isinstance(gate_up_lora_b, torch.Tensor) + + total_seq_len, _ = x.shape + _, weight_intermediate_dim, _ = gate_up_lora_a.shape + _, weight_out_dim, _ = gate_up_lora_b.shape + slice_size = weight_out_dim // num_slices + max_rank = weight_intermediate_dim // num_slices + + if base_output is None: + output_tensor = torch.zeros( + (total_seq_len, weight_out_dim), device=x.device, dtype=x.dtype + ) + else: + output_tensor = base_output + + lora_a_output = torch.zeros( + total_seq_len, weight_intermediate_dim, dtype=x.dtype, device=x.device + ) + sgmv_shrink( + x, + gate_up_lora_a, + lora_a_output, + self.batch_info.seg_lens, + self.batch_info.weight_indices, + 1.0, + ) + scaling = torch.repeat_interleave( + self.batch_info.scalings[self.batch_info.weight_indices], + self.batch_info.seg_lens, + output_size=total_seq_len, + ).unsqueeze(-1) + lora_a_output = lora_a_output * scaling + + slice_offset = 0 + for slice_id in range(num_slices): + sgmv_expand_slice( + lora_a_output[:, (max_rank * slice_id) : (max_rank * (slice_id + 1))], + gate_up_lora_b[:, slice_offset : slice_offset + slice_size], + output_tensor, + self.batch_info.seg_lens, + self.batch_info.weight_indices, + slice_offset, + slice_size, + True, + ) + slice_offset += slice_size + + return output_tensor + + def init_cuda_graph_batch_info( + self, + max_bs_in_cuda_graph: int, + num_tokens_per_bs: int, + ): + with torch.device("cuda"): + self.cuda_graph_batch_info = LoRABatchInfo( + bs=max_bs_in_cuda_graph, + use_cuda_graph=True, + num_segments=None, + seg_lens=torch.full( + (max_bs_in_cuda_graph,), num_tokens_per_bs, dtype=torch.int32 + ), + seg_indptr=torch.empty(max_bs_in_cuda_graph + 1, dtype=torch.int32), + max_len=num_tokens_per_bs, + weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32), + lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32), + scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), + permutation=None, + ) + + # Initialize seg_indptr for CUDA graph as they remain constant + # across batches. + torch.cumsum( + self.cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph], + dim=0, + out=self.cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1], + ) + + def prepare_lora_batch( + self, + forward_batch: ForwardBatch, + weight_indices: list[int], + lora_ranks: list[int], + scalings: list[float], + use_cuda_graph: bool, + ): + # Use pinned memory to avoid synchronizations during host-to-device transfer + weight_indices_tensor = torch.tensor( + weight_indices, dtype=torch.int32, pin_memory=True, device="cpu" + ) + lora_ranks_tensor = torch.tensor( + lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu" + ) + scalings_tensor = torch.tensor( + scalings, dtype=torch.float, pin_memory=True, device="cpu" + ) + + bs = forward_batch.batch_size + + if use_cuda_graph: + assert ( + self.cuda_graph_batch_info is not None + ), "CUDA Graph batch info is not initialized." + batch_info = self.cuda_graph_batch_info + batch_info.bs = forward_batch.batch_size + batch_info.num_segments = forward_batch.batch_size + else: + max_len = ( + # Calculate max_len from the CPU copy to avoid D2H transfer. + max(forward_batch.extend_seq_lens_cpu) + if forward_batch.forward_mode.is_extend() + else 1 + ) + seg_lens = ( + forward_batch.extend_seq_lens + if forward_batch.forward_mode.is_extend() + else torch.ones(bs, dtype=torch.int32, device=self.device) + ) + seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device) + seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) + + batch_info = LoRABatchInfo( + bs=forward_batch.batch_size, + num_segments=forward_batch.batch_size, + max_len=max_len, + use_cuda_graph=False, + seg_lens=seg_lens, + seg_indptr=seg_indptr, + weight_indices=torch.empty( + (bs,), dtype=torch.int32, device=self.device + ), + lora_ranks=torch.empty( + (self.max_loras_per_batch,), dtype=torch.int32, device=self.device + ), + scalings=torch.empty( + (self.max_loras_per_batch,), dtype=torch.float, device=self.device + ), + permutation=None, + ) + + # Copy to device asynchronously + batch_info.lora_ranks[: self.max_loras_per_batch].copy_( + lora_ranks_tensor, non_blocking=True + ) + batch_info.scalings[: self.max_loras_per_batch].copy_( + scalings_tensor, non_blocking=True + ) + batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True) + self.batch_info = batch_info diff --git a/python/sglang/srt/lora/torch_ops/__init__.py b/python/sglang/srt/lora/torch_ops/__init__.py new file mode 100644 index 000000000000..8cc3826f6fd0 --- /dev/null +++ b/python/sglang/srt/lora/torch_ops/__init__.py @@ -0,0 +1,7 @@ +from .lora_ops import sgmv_expand, sgmv_expand_slice, sgmv_shrink + +__all__ = [ + "sgmv_expand", + "sgmv_expand_slice", + "sgmv_shrink", +] diff --git a/python/sglang/srt/lora/torch_ops/lora_ops.py b/python/sglang/srt/lora/torch_ops/lora_ops.py new file mode 100644 index 000000000000..98c5848a8993 --- /dev/null +++ b/python/sglang/srt/lora/torch_ops/lora_ops.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + + +def sgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = False, +): + total_seq_len, _ = inputs.shape + exploded_indices = torch.repeat_interleave( + lora_indices_tensor, seq_len_tensor, output_size=total_seq_len + ) + + bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, add_inputs) + + +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +): + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + # LoRA adapter and model may add different amounts of padding to output + common_len = min(outputs.shape[1], output_tensor.shape[1]) + + if add_inputs: + output_tensor[:, :common_len] += outputs[:limit, :common_len] + else: + output_tensor[:, :common_len] = outputs[:limit, :common_len] + + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float, +): + total_seq_len, _ = inputs.shape + exploded_indices = torch.repeat_interleave( + lora_indices_tensor, seq_len_tensor, output_size=total_seq_len + ) + + bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, scaling) + + +def bgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +): + selected_loras = lora_a_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + output_tensor[:, : outputs.shape[1]] = scaling * outputs[:] + + +def sgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +): + total_seq_len, _ = inputs.shape + exploded_indices = torch.repeat_interleave( + lora_indices_tensor, seq_len_tensor, output_size=total_seq_len + ) + + bgmv_expand_slice( + inputs, + lora_b_weights, + output_tensor, + exploded_indices, + slice_offset, + slice_size, + add_inputs, + ) + + +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +): + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) + inputs = inputs.to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + if add_inputs: + output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:] + else: + output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e845288ced1d..578817ac83ef 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -131,7 +131,7 @@ "intel_xpu", ] -LORA_BACKEND_CHOICES = ["triton", "csgmv", "ascend"] +LORA_BACKEND_CHOICES = ["triton", "csgmv", "ascend", "torch_native"] DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] diff --git a/test/manual/test_lora_ops.py b/test/manual/test_lora_ops.py new file mode 100644 index 000000000000..e5018b54f27b --- /dev/null +++ b/test/manual/test_lora_ops.py @@ -0,0 +1,287 @@ +import unittest + +import torch + +from sglang.srt.lora.torch_ops.lora_ops import ( + bgmv_expand, + bgmv_expand_slice, + bgmv_shrink, + sgmv_expand, + sgmv_expand_slice, + sgmv_shrink, +) +from sglang.test.test_utils import CustomTestCase + + +class TestLoraOps(CustomTestCase): + def test_sgmv_expand(self): + batch_size = 2 + input_dim = 4 + output_dim = 6 + num_loras = 3 + dtype = torch.float32 + + inputs = torch.randn(batch_size, input_dim, dtype=dtype) + lora_b_weights = torch.randn(num_loras, output_dim, input_dim, dtype=dtype) + seq_len_tensor = torch.ones(batch_size, dtype=torch.int32) + lora_indices_tensor = torch.randint(0, num_loras, (batch_size,)) + add_inputs = True + + total_seq_len, _ = inputs.shape + exploded_indices = torch.repeat_interleave( + lora_indices_tensor, seq_len_tensor, output_size=total_seq_len + ) + expect_output = torch.zeros(batch_size, output_dim, dtype=dtype) + bgmv_expand(inputs, lora_b_weights, expect_output, exploded_indices, add_inputs) + + actual_output = torch.zeros(batch_size, output_dim, dtype=dtype) + sgmv_expand( + inputs, + lora_b_weights, + actual_output, + seq_len_tensor, + lora_indices_tensor, + add_inputs, + ) + + self.assertTrue(torch.allclose(actual_output, expect_output)) + + def test_bgmv_expand(self): + batch_size = 2 + input_dim = 4 + output_dim = 6 + num_loras = 3 + dtype = torch.float32 + + inputs = torch.randn(batch_size, input_dim, dtype=dtype) + lora_b_weights = torch.randn(num_loras, output_dim, input_dim, dtype=dtype) + lora_indices_tensor = torch.randint(0, num_loras, (batch_size,)) + + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=dtype) + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + limit = batch_size + common_len = min(outputs.shape[1], output_dim) + expect_output = torch.zeros(batch_size, output_dim, dtype=dtype) + expect_output[:, :common_len] = outputs[:limit, :common_len] + + actual_output = torch.zeros(batch_size, output_dim, dtype=dtype) + bgmv_expand( + inputs, + lora_b_weights, + actual_output, + lora_indices_tensor, + add_inputs=False, + ) + + self.assertTrue(torch.allclose(actual_output, expect_output)) + + def test_bgmv_expand_add_residual(self): + batch_size = 2 + input_dim = 4 + output_dim = 6 + num_loras = 3 + dtype = torch.float32 + + inputs = torch.randn(batch_size, input_dim, dtype=dtype) + lora_b_weights = torch.randn(num_loras, output_dim, input_dim, dtype=dtype) + lora_indices_tensor = torch.randint(0, num_loras, (batch_size,)) + + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=dtype) + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + limit = batch_size + common_len = min(outputs.shape[1], output_dim) + expect_output = torch.randn(batch_size, output_dim, dtype=dtype) + actual_output = expect_output.clone() + + expect_output[:, :common_len] += outputs[:limit, :common_len] + + bgmv_expand( + inputs, + lora_b_weights, + actual_output, + lora_indices_tensor, + add_inputs=True, + ) + + self.assertTrue(torch.allclose(actual_output, expect_output)) + + def test_sgmv_shrink(self): + batch_size = 2 + input_dim = 4 + output_dim = 6 + num_loras = 3 + dtype = torch.float32 + + inputs = torch.randn(batch_size, input_dim, dtype=dtype) + lora_a_weights = torch.randn(num_loras, output_dim, input_dim, dtype=dtype) + seq_len_tensor = torch.ones(batch_size, dtype=torch.int32) + lora_indices_tensor = torch.randint(0, num_loras, (batch_size,)) + scaling = 0.9 + + total_seq_len, _ = inputs.shape + exploded_indices = torch.repeat_interleave( + lora_indices_tensor, seq_len_tensor, output_size=total_seq_len + ) + expect_output = torch.zeros(batch_size, output_dim, dtype=dtype) + bgmv_shrink(inputs, lora_a_weights, expect_output, exploded_indices, scaling) + + actual_output = torch.zeros(batch_size, output_dim, dtype=dtype) + sgmv_shrink( + inputs, + lora_a_weights, + actual_output, + seq_len_tensor, + lora_indices_tensor, + scaling, + ) + + self.assertTrue(torch.allclose(actual_output, expect_output)) + + def test_bgmv_shrink(self): + batch_size = 2 + input_dim = 4 + output_dim = 6 + num_loras = 3 + dtype = torch.float32 + + inputs = torch.randn(batch_size, input_dim, dtype=dtype) + lora_a_weights = torch.randn(num_loras, output_dim, input_dim, dtype=dtype) + lora_indices_tensor = torch.randint(0, num_loras, (batch_size,)) + scaling = 0.9 + + selected_loras = lora_a_weights[lora_indices_tensor].to(dtype=dtype) + inputs = inputs.to(dtype=dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + expect_output = torch.zeros(batch_size, output_dim, dtype=dtype) + expect_output[:, : outputs.shape[1]] = scaling * outputs[:] + + actual_output = torch.zeros(batch_size, output_dim, dtype=dtype) + bgmv_shrink( + inputs, + lora_a_weights, + actual_output, + lora_indices_tensor, + scaling=scaling, + ) + + self.assertTrue(torch.allclose(actual_output, expect_output)) + + def test_sgmv_expand_slice(self): + batch_size = 2 + input_dim = 4 + output_dim = 6 + output_dim_slice = 12 + num_loras = 3 + dtype = torch.float32 + + inputs = torch.randn(batch_size, input_dim, dtype=dtype) + lora_b_weights = torch.randn(num_loras, output_dim, input_dim, dtype=dtype) + seq_len_tensor = torch.ones(batch_size, dtype=torch.int32) + lora_indices_tensor = torch.randint(0, num_loras, (batch_size,)) + slice_offset = 2 + slice_size = 6 + add_inputs = False + + total_seq_len, _ = inputs.shape + exploded_indices = torch.repeat_interleave( + lora_indices_tensor, seq_len_tensor, output_size=total_seq_len + ) + expect_output = torch.randn(batch_size, output_dim_slice, dtype=dtype) + actual_output = expect_output.clone() + bgmv_expand_slice( + inputs, + lora_b_weights, + expect_output, + exploded_indices, + slice_offset, + slice_size, + add_inputs, + ) + + sgmv_expand_slice( + inputs, + lora_b_weights, + actual_output, + seq_len_tensor, + lora_indices_tensor, + slice_offset, + slice_size, + add_inputs, + ) + + self.assertTrue(torch.allclose(actual_output, expect_output)) + + def test_bgmv_expand_slice(self): + batch_size = 2 + input_dim = 4 + output_dim = 6 + output_dim_slice = 12 + num_loras = 3 + dtype = torch.float32 + + inputs = torch.randn(batch_size, input_dim, dtype=dtype) + lora_b_weights = torch.randn(num_loras, output_dim, input_dim, dtype=dtype) + lora_indices_tensor = torch.randint(0, num_loras, (batch_size,)) + slice_offset = 2 + slice_size = 6 + + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=dtype) + inputs = inputs.to(dtype=dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + expect_output = torch.zeros(batch_size, output_dim_slice, dtype=dtype) + expect_output[:, slice_offset : slice_offset + slice_size] = outputs[:] + + actual_output = torch.zeros(batch_size, output_dim_slice, dtype=dtype) + bgmv_expand_slice( + inputs, + lora_b_weights, + actual_output, + lora_indices_tensor, + slice_offset, + slice_size, + add_inputs=False, + ) + + self.assertTrue(torch.allclose(actual_output, expect_output)) + + def test_bgmv_expand_slice_add_residual(self): + batch_size = 2 + input_dim = 4 + output_dim = 6 + output_dim_slice = 12 + num_loras = 3 + dtype = torch.float32 + + inputs = torch.randn(batch_size, input_dim, dtype=dtype) + lora_b_weights = torch.randn(num_loras, output_dim, input_dim, dtype=dtype) + lora_indices_tensor = torch.randint(0, num_loras, (batch_size,)) + slice_offset = 2 + slice_size = 6 + + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=dtype) + inputs = inputs.to(dtype=dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + expect_output = torch.randn(batch_size, output_dim_slice, dtype=dtype) + actual_output = expect_output.clone() + expect_output[:, slice_offset : slice_offset + slice_size] += outputs[:] + + bgmv_expand_slice( + inputs, + lora_b_weights, + actual_output, + lora_indices_tensor, + slice_offset, + slice_size, + add_inputs=True, + ) + + self.assertTrue(torch.allclose(actual_output, expect_output)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/test_torch_backend.py b/test/manual/test_torch_backend.py new file mode 100644 index 000000000000..6ca4ee54c9bb --- /dev/null +++ b/test/manual/test_torch_backend.py @@ -0,0 +1,224 @@ +import unittest + +import torch + +from sglang.srt.lora.backend.torch_backend import TorchNativeLoRABackend +from sglang.srt.lora.torch_ops.lora_ops import ( + sgmv_expand, + sgmv_expand_slice, + sgmv_shrink, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.test.test_utils import CustomTestCase + + +class TestTorchNativeLoRABackend(CustomTestCase): + + device = "cpu" + forward_batch = ForwardBatch( + forward_mode=ForwardMode.EXTEND, + batch_size=2, + input_ids=torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32), + req_pool_indices=None, + seq_lens=None, + out_cache_loc=None, + seq_lens_sum=6, + extend_seq_lens=torch.tensor([1, 1], dtype=torch.int32), + extend_seq_lens_cpu=[1, 1], + ) + weight_indices = [0, 1] + lora_ranks = [1, 1] + scalings = [1.0, 0.5] + use_cuda_graph = False + + @classmethod + def setUpClass(cls): + cls.backend = TorchNativeLoRABackend(max_loras_per_batch=2, device=cls.device) + cls.backend.prepare_lora_batch( + forward_batch=cls.forward_batch, + weight_indices=cls.weight_indices, + lora_ranks=cls.lora_ranks, + scalings=cls.scalings, + use_cuda_graph=cls.use_cuda_graph, + ) + + def test_run_lora_a_sgemm(self): + batch_size = 2 + input_dim = 4 + output_dim = 6 + num_loras = 3 + dtype = torch.float32 + + x = torch.randn(batch_size, input_dim, dtype=dtype) + weights = torch.randn(num_loras, output_dim, input_dim, dtype=dtype) + + total_seq_len, _ = x.shape + _, weight_output_dim, _ = weights.shape + output_tensor = torch.zeros( + (total_seq_len, weight_output_dim), dtype=dtype, device=self.device + ) + sgmv_shrink( + x, + weights, + output_tensor, + self.backend.batch_info.seg_lens, + self.backend.batch_info.weight_indices, + 1.0, + ) + scaling = torch.repeat_interleave( + self.backend.batch_info.scalings[self.backend.batch_info.weight_indices], + self.backend.batch_info.seg_lens, + output_size=total_seq_len, + ).unsqueeze(-1) + expect_output = output_tensor * scaling + + actual_output = self.backend.run_lora_a_sgemm(x, weights) + + self.assertTrue(torch.allclose(actual_output, expect_output)) + + def test_run_lora_b_sgemm(self): + batch_size = 2 + input_dim = 6 + output_dim = 4 + num_loras = 3 + dtype = torch.float32 + + x = torch.randn(batch_size, input_dim, dtype=dtype) + weights = torch.randn(num_loras, output_dim, input_dim, dtype=dtype) + + total_seq_len, _ = x.shape + _, weight_output_dim, _ = weights.shape + output_tensor = torch.zeros( + (total_seq_len, weight_output_dim), dtype=dtype, device=self.device + ) + sgmv_expand( + x, + weights, + output_tensor, + self.backend.batch_info.seg_lens, + self.backend.batch_info.weight_indices, + True, + ) + expect_output = output_tensor + + actual_output = self.backend.run_lora_b_sgemm(x, weights) + + self.assertTrue(torch.allclose(actual_output, expect_output)) + + def test_run_qkv_lora(self): + batch_size = 2 + input_dim = 6 + output_dim = 4 + num_loras = 3 + dtype = torch.float32 + + x = torch.randn(batch_size, input_dim, dtype=dtype) + qkv_lora_a = torch.randn(num_loras, output_dim, input_dim, dtype=dtype) + qkv_lora_b = torch.randn(num_loras, input_dim, output_dim, dtype=dtype) + output_offset_cpu = torch.tensor([0, 3, 6, 9, 12], dtype=torch.int32) + + num_slices = 3 + total_seq_len, _ = x.shape + _, weight_intermediate_dim, _ = qkv_lora_a.shape + _, weight_out_dim, _ = qkv_lora_b.shape + max_rank = weight_intermediate_dim // num_slices + output_tensor = torch.zeros( + (total_seq_len, weight_out_dim), device=x.device, dtype=x.dtype + ) + lora_a_output = torch.zeros( + total_seq_len, weight_intermediate_dim, dtype=x.dtype, device=x.device + ) + sgmv_shrink( + x, + qkv_lora_a, + lora_a_output, + self.backend.batch_info.seg_lens, + self.backend.batch_info.weight_indices, + 1.0, + ) + scaling = torch.repeat_interleave( + self.backend.batch_info.scalings[self.backend.batch_info.weight_indices], + self.backend.batch_info.seg_lens, + output_size=total_seq_len, + ).unsqueeze(-1) + lora_a_output = lora_a_output * scaling + for slice_id in range(num_slices): + slice_offset = output_offset_cpu[slice_id] + slice_offset_next = output_offset_cpu[slice_id + 1] + slice_size = slice_offset_next - slice_offset + sgmv_expand_slice( + lora_a_output[:, (max_rank * slice_id) : (max_rank * (slice_id + 1))], + qkv_lora_b[:, slice_offset:slice_offset_next], + output_tensor, + self.backend.batch_info.seg_lens, + self.backend.batch_info.weight_indices, + slice_offset, + slice_size, + True, + ) + expect_output = output_tensor + actual_output = self.backend.run_qkv_lora( + x, qkv_lora_a, qkv_lora_b, None, output_offset_cpu, 0 + ) + self.assertTrue(torch.allclose(actual_output, expect_output)) + + def test_run_gate_up_lora(self): + batch_size = 2 + input_dim = 6 + output_dim = 4 + num_loras = 3 + dtype = torch.float32 + + num_slices = 2 + + x = torch.randn(batch_size, input_dim, dtype=dtype) + gate_up_lora_a = torch.randn(num_loras, output_dim, input_dim, dtype=dtype) + gate_up_lora_b = torch.randn( + num_loras, output_dim, output_dim // num_slices, dtype=dtype + ) + + total_seq_len, _ = x.shape + _, weight_intermediate_dim, _ = gate_up_lora_a.shape + _, weight_out_dim, _ = gate_up_lora_b.shape + slice_size = weight_out_dim // num_slices + max_rank = weight_intermediate_dim // num_slices + output_tensor = torch.zeros( + (total_seq_len, weight_out_dim), device=x.device, dtype=x.dtype + ) + lora_a_output = torch.zeros( + total_seq_len, weight_intermediate_dim, dtype=x.dtype, device=x.device + ) + sgmv_shrink( + x, + gate_up_lora_a, + lora_a_output, + self.backend.batch_info.seg_lens, + self.backend.batch_info.weight_indices, + 1.0, + ) + scaling = torch.repeat_interleave( + self.backend.batch_info.scalings[self.backend.batch_info.weight_indices], + self.backend.batch_info.seg_lens, + output_size=total_seq_len, + ).unsqueeze(-1) + lora_a_output = lora_a_output * scaling + slice_offset = 0 + for slice_id in range(num_slices): + sgmv_expand_slice( + lora_a_output[:, (max_rank * slice_id) : (max_rank * (slice_id + 1))], + gate_up_lora_b[:, slice_offset : slice_offset + slice_size], + output_tensor, + self.backend.batch_info.seg_lens, + self.backend.batch_info.weight_indices, + slice_offset, + slice_size, + True, + ) + slice_offset += slice_size + expect_output = output_tensor + actual_output = self.backend.run_gate_up_lora(x, gate_up_lora_a, gate_up_lora_b) + self.assertTrue(torch.allclose(actual_output, expect_output)) + + +if __name__ == "__main__": + unittest.main()