Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def extend(reqs, model_runner):
batch = ScheduleBatch.init_new(
reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
tree_cache=None,
model_config=model_runner.model_config,
enable_overlap=False,
Expand Down Expand Up @@ -326,7 +326,7 @@ def latency_test_run_once(

# Clear the pools.
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool.clear()
model_runner.token_to_kv_pool_allocator.clear()

measurement_results = {
"run_name": run_name,
Expand Down
139 changes: 81 additions & 58 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@

from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available

if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo

if is_flashinfer_available():
from flashinfer import (
Expand All @@ -36,6 +37,7 @@
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import PosEncodingMode


class WrapperDispatch(Enum):
Expand Down Expand Up @@ -113,6 +115,7 @@ def __init__(
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer

max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
self.kv_indptr = [
Expand All @@ -133,10 +136,13 @@ def __init__(
assert self.num_wrappers == 1
self.kv_last_page_len = kv_last_page_len_buf

self.qo_indptr = [
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
for _ in range(self.num_wrappers)
]
if not self.skip_prefill:
self.qo_indptr = [
torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
for _ in range(self.num_wrappers)
]

self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD"
Expand Down Expand Up @@ -276,7 +282,7 @@ def init_forward_metadata_capture_cuda_graph(
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
Expand Down Expand Up @@ -346,7 +352,7 @@ def init_forward_metadata_replay_cuda_graph(
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
Expand Down Expand Up @@ -526,7 +532,7 @@ def update(
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
Expand All @@ -538,7 +544,7 @@ def update_single_wrapper(
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward(
Expand All @@ -558,7 +564,7 @@ def update_sliding_window(
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
for wrapper_id in range(2):
if wrapper_id == 0:
Expand Down Expand Up @@ -592,7 +598,7 @@ def update_cross_attention(
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
for wrapper_id in range(2):
if wrapper_id == 0:
Expand Down Expand Up @@ -623,7 +629,7 @@ def call_begin_forward(
paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
if spec_info is None:
bs = len(req_pool_indices)
Expand All @@ -642,9 +648,9 @@ def call_begin_forward(
self.req_to_token.shape[1],
)
else:
assert isinstance(spec_info, EagleDraftInput)
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1

wrapper.begin_forward(
kv_indptr,
kv_indices,
Expand Down Expand Up @@ -699,7 +705,7 @@ def update(
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
Expand All @@ -713,7 +719,7 @@ def update_single_wrapper(
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
if use_ragged:
paged_kernel_lens = prefix_lens
Expand Down Expand Up @@ -746,7 +752,7 @@ def update_sliding_window(
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
for wrapper_id in range(2):
if wrapper_id == 0:
Expand Down Expand Up @@ -787,7 +793,7 @@ def update_cross_attention(
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
for wrapper_id in range(2):
if wrapper_id == 0:
Expand Down Expand Up @@ -829,10 +835,11 @@ def call_begin_forward(
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
bs = len(req_pool_indices)
bs = len(seq_lens)
if spec_info is None:
assert len(seq_lens) == len(req_pool_indices)
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
Expand All @@ -855,10 +862,14 @@ def call_begin_forward(
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
assert isinstance(spec_info, EagleDraftInput) or isinstance(
spec_info, EagleVerifyInput
)
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
self.req_to_token,
)
)
Expand Down Expand Up @@ -890,6 +901,11 @@ def call_begin_forward(
)


# Use as a fast path to override the indptr in flashinfer's plan function
# This is used to remove some host-to-device copy overhead.
global global_override_indptr_cpu


class FlashInferMultiStepDraftBackend:
"""
Wrap multiple flashinfer attention backends as one for multiple consecutive
Expand All @@ -907,6 +923,7 @@ def __init__(
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices

max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
Expand All @@ -929,7 +946,9 @@ def __init__(
kv_last_page_len_buf=self.kv_last_page_len,
)
)

self.max_context_len = self.attn_backends[0].max_context_len

# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]

Expand Down Expand Up @@ -959,13 +978,23 @@ def common_template(
triton.next_power_of_2(bs),
)

assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)

# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
global global_override_indptr_cpu

for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1)
]
global_override_indptr_cpu = indptr_cpu_whole[i]
call_fn(i, forward_batch)

global_override_indptr_cpu = None

def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros(
(
Expand All @@ -977,6 +1006,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
)

def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
Expand All @@ -993,6 +1024,7 @@ def init_cuda_graph_state(self, max_bs: int):
dtype=torch.int32,
device="cuda",
)

for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
Expand Down Expand Up @@ -1031,43 +1063,6 @@ def call_fn(i, forward_batch):
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)


@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_indptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)

req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid)

kv_start = 0
kv_end = 0
if kv_start_idx:
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)

num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < kv_end - kv_start
data = tl.load(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ kv_start
+ offset,
mask=mask,
)
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)


def should_use_tensor_core(
kv_cache_dtype: torch.dtype,
num_attention_heads: int,
Expand All @@ -1089,6 +1084,21 @@ def should_use_tensor_core(
if env_override is not None:
return env_override.lower() == "true"

# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try:
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels

if not _grouped_size_compiled_for_decode_kernels(
num_attention_heads,
num_kv_heads,
):
return True
else:
return False
except (ImportError, AttributeError):
pass

# Calculate GQA group size
gqa_group_size = num_attention_heads // num_kv_heads

Expand Down Expand Up @@ -1118,12 +1128,18 @@ def fast_decode_plan(
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
**kwargs,
non_blocking: bool = True,
) -> None:
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
Modifications:
- Remove unnecessary device-to-device copy for the cuda graph buffers.
- Remove unnecessary host-to-device copy for the metadata buffers.
"""
batch_size = len(last_page_len)
if logits_soft_cap is None:
logits_soft_cap = 0.0

if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size:
raise ValueError(
Expand All @@ -1136,13 +1152,19 @@ def fast_decode_plan(
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
# Skip these copies
# self._paged_kv_indptr_buf.copy_(indptr)
# self._paged_kv_indices_buf[: len(indices)] = indices
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
else:
self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len

# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
if not q_data_type:
q_data_type = data_type

if not hasattr(self, "empty_q_data"):
self.empty_q_data = torch.empty(
0,
Expand All @@ -1159,6 +1181,7 @@ def fast_decode_plan(
),
)
self.last_page_len = torch.ones(32768, dtype=torch.int32)

empty_q_data = self.empty_q_data
empty_kv_cache = self.empty_kv_cache
stream = torch.cuda.current_stream()
Expand Down
Loading
Loading