Skip to content
Merged
13 changes: 13 additions & 0 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,12 @@ def process_input_logprobs_by_chunk(
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, total_size)

# Notify lm_head LoRA about the current chunk so it can swap
# to the precomputed per-chunk batch_info. This is a no-op
# for non-LoRA lm_head modules.
if hasattr(lm_head, "set_lm_head_pass"):
lm_head.set_lm_head_pass(i)

# Get indices for this chunk
chunk_mask = (input_logprob_indices >= start_idx) & (
input_logprob_indices < end_idx
Expand Down Expand Up @@ -792,6 +798,13 @@ def process_input_logprobs_by_chunk(
]
input_token_logprobs.append(chunk_input_token_logprobs)

# Restore the full-pruned lm_head batch_info after chunk iteration.
if hasattr(lm_head, "reset_lm_head_pass"):
assert hasattr(
lm_head, "set_lm_head_pass"
), "lm_head must have set_lm_head_pass method and reset_lm_head_pass method at the same time"
lm_head.reset_lm_head_pass()

# Concatenate the results
input_token_logprobs = torch.cat(input_token_logprobs, dim=0)

Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/lora/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import torch

from sglang.srt.lora.backend.lmhead_mixing import LoRABackendLmHeadMixing
from sglang.srt.model_executor.forward_batch_info import ForwardBatch


class BaseLoRABackend:
class BaseLoRABackend(LoRABackendLmHeadMixing):
"""Base class for different Lora backends.
Each backend has its own implementation of Lora kernels.

Expand All @@ -18,6 +19,7 @@ class BaseLoRABackend:
def __init__(self, max_loras_per_batch: int, device: torch.device):
self.max_loras_per_batch = max_loras_per_batch
self.device = device
self.init_lm_head_config()

def run_lora_a_embedding(
self,
Expand Down
138 changes: 130 additions & 8 deletions python/sglang/srt/lora/backend/chunked_backend.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import dataclasses
from typing import List, Optional, Tuple

import torch

from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.triton_ops import (
chunked_embedding_lora_a_forward,
chunked_sgmv_lora_expand_forward,
chunked_sgmv_lora_shrink_forward,
)
from sglang.srt.lora.utils import LoRABatchInfo, generate_sequence_lengths
from sglang.srt.lora.utils import (
LoRABatchInfo,
generate_sequence_lengths,
get_lm_head_pruned_lens,
merge_and_chunk_segments,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs

Expand Down Expand Up @@ -33,13 +42,40 @@ def __init__(
super().__init__(max_loras_per_batch, device)
self.max_chunk_size = server_args.max_lora_chunk_size

def run_lora_a_embedding(
self,
input_ids: torch.Tensor,
weights: torch.Tensor,
vocab_size: int,
extra_embeddings: torch.Tensor = None,
*args,
**kwargs,
) -> torch.Tensor:
assert (
extra_embeddings is None
), "Extra embeddings for lora a is not supported yet in chunked backend"
return chunked_embedding_lora_a_forward(
input_ids=input_ids,
weights=weights,
batch_info=self.batch_info,
vocab_size=vocab_size,
)

def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
self,
x: torch.Tensor,
weights: torch.Tensor,
pruned_batch_info: LoRABatchInfo = None,
*args,
**kwargs,
) -> torch.Tensor:
batch_info = (
pruned_batch_info if pruned_batch_info is not None else self.batch_info
)
return chunked_sgmv_lora_shrink_forward(
x=x,
weights=weights,
batch_info=self.batch_info,
batch_info=batch_info,
num_slices=1,
)

Expand All @@ -49,16 +85,20 @@ def run_lora_b_sgemm(
weights: torch.Tensor,
output_offset: torch.Tensor,
base_output: torch.Tensor = None,
pruned_batch_info: LoRABatchInfo = None,
*args,
**kwargs,
) -> torch.Tensor:
# For simple lora B, we use slice offsets [0, output_dim]
output_dim = weights.shape[-2]
max_slice_size = output_dim
batch_info = (
pruned_batch_info if pruned_batch_info is not None else self.batch_info
)
return chunked_sgmv_lora_expand_forward(
x=x,
weights=weights,
batch_info=self.batch_info,
batch_info=batch_info,
slice_offsets=output_offset,
max_slice_size=max_slice_size,
base_output=base_output,
Expand Down Expand Up @@ -141,15 +181,18 @@ def _determine_chunk_size(self, forward_batch: ForwardBatch) -> int:
Returns:
The determined chunk size
"""

if self.max_chunk_size <= MIN_CHUNK_SIZE:
return MIN_CHUNK_SIZE

num_tokens = (
forward_batch.extend_num_tokens
if forward_batch.forward_mode.is_extend()
else forward_batch.batch_size
)
return self._determine_chunk_size_for_tokens(num_tokens)

def _determine_chunk_size_for_tokens(self, num_tokens: int) -> int:
"""Determine chunk size given a token count directly."""
if self.max_chunk_size <= MIN_CHUNK_SIZE:
return MIN_CHUNK_SIZE

if num_tokens >= 256:
chunk_size = 128
elif num_tokens >= 64:
Expand Down Expand Up @@ -253,6 +296,85 @@ def prepare_lora_batch(
batch_info.permutation[: len(permutation)].copy_(permutation, non_blocking=True)

self.batch_info = batch_info
self.lm_head_batch_info, self.lm_head_pass_batch_infos = (
self._prepare_lm_head_batch_info(forward_batch, weight_indices, batch_info)
)

def _prepare_lm_head_batch_info(
self,
forward_batch: ForwardBatch,
weight_indices: list[int],
batch_info: LoRABatchInfo,
) -> Tuple[Optional[LoRABatchInfo], Optional[List[LoRABatchInfo]]]:

# Precompute lm_head_batch_info for pruned lm_head LoRA
pruned_lens = get_lm_head_pruned_lens(forward_batch)
lm_head_batch_info = None
lm_head_pass_batch_infos = None

if pruned_lens is not None:
pruned_total = sum(pruned_lens)
chunk_size = self._determine_chunk_size_for_tokens(pruned_total)
lm_head_segments = merge_and_chunk_segments(
weight_indices, pruned_lens, chunk_size=chunk_size
)
lm_head_batch_info = self._build_lm_head_batch_info(
lm_head_segments, batch_info, chunk_size, pruned_total
)

# Precompute per-pass batch_infos for logprobs chunking
pass_segments = self._get_lm_head_pass_segments(weight_indices, pruned_lens)
if pass_segments is not None:
lm_head_pass_batch_infos = []
for seg_wi, seg_lens_list in pass_segments:
pass_total = sum(seg_lens_list)
pass_chunk_size = self._determine_chunk_size_for_tokens(pass_total)
chunked_segments = merge_and_chunk_segments(
seg_wi, seg_lens_list, chunk_size=pass_chunk_size
)
lm_head_pass_batch_infos.append(
self._build_lm_head_batch_info(
chunked_segments,
batch_info,
pass_chunk_size,
pass_total,
)
)

return lm_head_batch_info, lm_head_pass_batch_infos

def _build_lm_head_batch_info(
self,
lm_head_segments: Tuple[List[int], List[int]],
batch_info: LoRABatchInfo,
chunk_size: int,
expected_tokens: int,
) -> LoRABatchInfo:
seg_weight_indices_cpu, seg_lens_cpu = lm_head_segments
pruned_total = sum(seg_lens_cpu)
num_segments = len(seg_weight_indices_cpu)

weight_indices = torch.tensor(
seg_weight_indices_cpu, dtype=torch.int32, device=self.device
)
seg_lens = torch.tensor(seg_lens_cpu, dtype=torch.int32, device=self.device)
seg_indptr = torch.zeros(
(num_segments + 1,), dtype=torch.int32, device=self.device
)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)

# Identity permutation (lm_head tokens are in original order)
permutation = torch.arange(pruned_total, dtype=torch.int32, device=self.device)

return dataclasses.replace(
batch_info,
num_segments=num_segments,
max_len=chunk_size,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
permutation=permutation,
expected_tokens=expected_tokens,
)

@staticmethod
def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch):
Expand Down
64 changes: 64 additions & 0 deletions python/sglang/srt/lora/backend/lmhead_mixing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import List, Optional, Tuple

from sglang.srt.environ import envs
from sglang.srt.lora.utils import LoRABatchInfo, build_lm_head_pass_segments
from sglang.srt.model_executor.forward_batch_info import ForwardBatch


class LoRABackendLmHeadMixing:
def init_lm_head_config(self):
self.lm_head_batch_info = None
# Precomputed per-pass lm_head batch_infos. When the logits processor
# calls lm_head in multiple passes (chunked logprobs), each pass gets
# its own batch_info from this list.
self.lm_head_pass_batch_infos = None
# Current pass index. When set, apply_lora uses
# lm_head_pass_batch_infos[idx] instead of lm_head_batch_info.
self._lm_head_pass_idx = None

def _get_lm_head_pass_segments(
self,
weight_indices: list[int],
pruned_lens: List[int],
) -> Optional[List[Tuple[List[int], List[int]]]]:
"""Compute per-pass segment info for lm_head LoRA logprobs chunking.

When LogitsProcessor splits pruned states into fixed-size passes,
each pass needs its own segmentation so that lm_head LoRA operates
on the correct adapter assignments. This method returns the generic
per-pass (seg_weight_indices, seg_lens) tuples; each backend is
responsible for converting them into backend-specific LoRABatchInfo.

Returns None if logprobs chunking is disabled or the pruned token
count does not exceed the logprobs chunk size.
"""
logprobs_chunk_size = envs.SGLANG_LOGITS_PROCESSER_CHUNK_SIZE.get()
enable_logprobs_chunk = envs.SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK.get()
pruned_total = sum(pruned_lens)

if not enable_logprobs_chunk or pruned_total <= logprobs_chunk_size:
return None

return build_lm_head_pass_segments(
weight_indices, pruned_lens, logprobs_chunk_size
)

def _prepare_lm_head_batch_info(
self,
forward_batch: ForwardBatch,
weight_indices: list[int],
batch_info: LoRABatchInfo,
) -> Tuple[Optional[LoRABatchInfo], Optional[List[LoRABatchInfo]]]:
"""Prepare the lm_head batch info for the current forward batch."""
"""It returns a tuple of (lm_head_batch_info, lm_head_pass_batch_infos)."""
pass

def _build_lm_head_batch_info(
self,
lm_head_segments: Tuple[List[int], List[int]],
batch_info: LoRABatchInfo,
chunk_size: int,
expected_tokens: int,
) -> LoRABatchInfo:
"""Build a LoRABatchInfo for pruned lm_head input."""
pass
Loading
Loading