Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 178 additions & 54 deletions python/sglang/srt/layers/attention/ascend_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,37 +41,6 @@ class ForwardMetadata:

class AscendAttnBackend(AttentionBackend):

def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
mask_flag = torch.tril(
torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
).view(max_seq_len, max_seq_len)
mask_flag = ~mask_flag
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
self.mask = (
torch.masked_fill(
torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
)
.to(dtype)
.to(self.device)
)
self.mask_len = max_seq_len

def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers for verify attention kernels that needs to be filled after draft.

Typically, these are tree mask and position buffers.
"""
return [None, None]

def update_verify_buffers_to_fill_after_draft(
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
):
pass

def __init__(self, model_runner: ModelRunner):
super().__init__()
self.forward_metadata = None
Expand Down Expand Up @@ -106,34 +75,100 @@ def __init__(self, model_runner: ModelRunner):
self.mtp_mask = torch.tril(torch.ones(2048, 2048, dtype=torch.bool)).npu()
self.mtp_mask = ~self.mtp_mask

def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
mask_flag = torch.tril(
torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
).view(max_seq_len, max_seq_len)
mask_flag = ~mask_flag
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
self.mask = (
torch.masked_fill(
torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
)
.to(dtype)
.to(self.device)
)
self.mask_len = max_seq_len

def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers for verify attention kernels that needs to be filled after draft.

Typically, these are tree mask and position buffers.
"""
return [None, None]

def update_verify_buffers_to_fill_after_draft(
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
):
pass

def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
tp_size = get_attention_tp_size()
self.forward_metadata = ForwardMetadata()
seq_lens_max = forward_batch.seq_lens.max()
if forward_batch.forward_mode.is_target_verify():
seq_lens_max += self.speculative_num_draft_tokens
self.forward_metadata.block_tables = (
forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :seq_lens_max
][:, :: self.page_size]
// self.page_size
)
if forward_batch.extend_seq_lens is not None:
self.forward_metadata.extend_seq_lens_cpu_int = (
forward_batch.extend_seq_lens.cpu().int()
if forward_batch.forward_mode.is_mixed():
bs_prefill = forward_batch.batch_size - forward_batch.running_decode_bs
seq_lens_max_mix = (
forward_batch.seq_lens[:bs_prefill].max(),
forward_batch.seq_lens[bs_prefill:].max(),
)
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
if (
not forward_batch.forward_mode.is_draft_extend_v2()
and not forward_batch.forward_mode.is_draft_extend()
and not forward_batch.forward_mode.is_target_verify()
):
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum

if forward_batch.forward_mode.is_target_verify():
self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens
req_pool_indices_mix = (
forward_batch.req_pool_indices[:bs_prefill],
forward_batch.req_pool_indices[bs_prefill:],
)
self.forward_metadata.block_tables_mix = (
forward_batch.req_to_token_pool.req_to_token[
req_pool_indices_mix[0], : seq_lens_max_mix[0]
][:, :: self.page_size]
// self.page_size,
forward_batch.req_to_token_pool.req_to_token[
req_pool_indices_mix[1], : seq_lens_max_mix[1]
][:, :: self.page_size]
// self.page_size,
)
if forward_batch.extend_seq_lens is not None:
self.forward_metadata.extend_seq_lens_cpu_int_mix = (
forward_batch.extend_seq_lens.cpu().int()[:bs_prefill],
forward_batch.extend_seq_lens.cpu().int()[bs_prefill:],
)
self.forward_metadata.seq_lens_cpu_int_mix = (
forward_batch.seq_lens_cpu.int()[:bs_prefill],
forward_batch.seq_lens_cpu.int()[bs_prefill:],
)
self.forward_metadata.seq_lens_list_cumsum_mix = np.cumsum(
forward_batch.extend_seq_lens_cpu[:bs_prefill]
), np.cumsum(forward_batch.extend_seq_lens_cpu[bs_prefill:])
Comment on lines +133 to +144
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve code clarity and avoid redundant computations, you can store the results of forward_batch.extend_seq_lens.cpu().int(), forward_batch.seq_lens_cpu.int(), and forward_batch.extend_seq_lens_cpu in temporary variables before slicing them for the mixed-mode metadata.

Suggested change
if forward_batch.extend_seq_lens is not None:
self.forward_metadata.extend_seq_lens_cpu_int_mix = (
forward_batch.extend_seq_lens.cpu().int()[:bs_prefill],
forward_batch.extend_seq_lens.cpu().int()[bs_prefill:],
)
self.forward_metadata.seq_lens_cpu_int_mix = (
forward_batch.seq_lens_cpu.int()[:bs_prefill],
forward_batch.seq_lens_cpu.int()[bs_prefill:],
)
self.forward_metadata.seq_lens_list_cumsum_mix = np.cumsum(
forward_batch.extend_seq_lens_cpu[:bs_prefill]
), np.cumsum(forward_batch.extend_seq_lens_cpu[bs_prefill:])
if forward_batch.extend_seq_lens is not None:
extend_seq_lens_cpu_int = forward_batch.extend_seq_lens.cpu().int()
self.forward_metadata.extend_seq_lens_cpu_int_mix = (
extend_seq_lens_cpu_int[:bs_prefill],
extend_seq_lens_cpu_int[bs_prefill:],
)
seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
self.forward_metadata.seq_lens_cpu_int_mix = (
seq_lens_cpu_int[:bs_prefill],
seq_lens_cpu_int[bs_prefill:],
)
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
self.forward_metadata.seq_lens_list_cumsum_mix = (
np.cumsum(extend_seq_lens_cpu[:bs_prefill]),
np.cumsum(extend_seq_lens_cpu[bs_prefill:]),
)

else:
seq_lens_max = forward_batch.seq_lens.max()
if forward_batch.forward_mode.is_target_verify():
seq_lens_max += self.speculative_num_draft_tokens
self.forward_metadata.block_tables = (
forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :seq_lens_max
][:, :: self.page_size]
// self.page_size
)
if forward_batch.extend_seq_lens is not None:
self.forward_metadata.extend_seq_lens_cpu_int = (
forward_batch.extend_seq_lens.cpu().int()
)
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
if (
not forward_batch.forward_mode.is_draft_extend_v2()
and not forward_batch.forward_mode.is_draft_extend()
and not forward_batch.forward_mode.is_target_verify()
):
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum

if forward_batch.forward_mode.is_target_verify():
self.forward_metadata.seq_lens_cpu_int += (
self.speculative_num_draft_tokens
)

self.graph_mode = False

Expand Down Expand Up @@ -851,6 +886,95 @@ def forward_decode(
)
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)

def forward_mixed(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
out_cache_loc = forward_batch.out_cache_loc

# Calculates the batch sizes for prefill and decode stages
out_cache_len_prefill = len(out_cache_loc) - forward_batch.running_decode_bs
bs_prefill = forward_batch.batch_size - forward_batch.running_decode_bs

# Splits the input tensors into two parts (prefill and decode stages) based on batch sizes
q_prefill, q_decode = q[:out_cache_len_prefill], q[out_cache_len_prefill:]
k_prefill, k_decode = k[:out_cache_len_prefill], k[out_cache_len_prefill:]
v_prefill, v_decode = v[:out_cache_len_prefill], v[out_cache_len_prefill:]
loc_prefill, loc_decode = (
out_cache_loc[:out_cache_len_prefill],
out_cache_loc[out_cache_len_prefill:],
)

forward_batch.out_cache_loc = loc_prefill
if not self.use_mla:
if self.use_fia:
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
forward_batch.extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu[
:bs_prefill
]
else:
self.forward_metadata.extend_seq_lens_cpu_int = (
self.forward_metadata.extend_seq_lens_cpu_int_mix[0]
)
self.forward_metadata.block_tables = (
self.forward_metadata.block_tables_mix[0]
)
self.forward_metadata.seq_lens_cpu_int = (
self.forward_metadata.seq_lens_cpu_int_mix[0]
)
else:
forward_batch.num_token_non_padded_cpu = forward_batch.prefill_input_ids
self.forward_metadata.seq_lens_list_cumsum = (
self.forward_metadata.seq_lens_list_cumsum_mix[0]
)

# Performs the forward pass for the prefill stage
output_prefill = self.forward_extend(
q_prefill,
k_prefill,
v_prefill,
layer,
forward_batch,
save_kv_cache=save_kv_cache,
**kwargs,
)

forward_batch.out_cache_loc = loc_decode
self.forward_metadata.block_tables = self.forward_metadata.block_tables_mix[1]
self.forward_metadata.seq_lens_cpu_int = (
self.forward_metadata.seq_lens_cpu_int_mix[1]
)

# Performs the forward pass for the decode stage
output_decode = self.forward_decode(
q_decode,
k_decode,
v_decode,
layer,
forward_batch,
save_kv_cache=save_kv_cache,
**kwargs,
)

# Resets forward_metadata and forward_batch properties after processing
forward_batch.out_cache_loc = out_cache_loc
forward_batch.num_token_non_padded_cpu = None
self.forward_metadata.extend_seq_lens_cpu_int = None
self.forward_metadata.seq_lens_list_cumsum = None
self.forward_metadata.block_tables = None
self.forward_metadata.seq_lens_cpu_int = None
if not self.use_mla and self.use_fia:
forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu

# Concatenates and returns the outputs from both parts
return torch.cat([output_prefill, output_decode], dim=0)
Comment on lines +899 to +976
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function forward_mixed modifies shared state on forward_batch and self.forward_metadata. If an exception occurs during the forward calls, the cleanup code will not be executed, leaving objects in an inconsistent state. This could lead to hard-to-debug errors in subsequent operations. Additionally, there is a potential NameError because extend_seq_lens_cpu is defined within a conditional block but used in the cleanup section. Using a try...finally block and defining the backup variable before the try block will ensure the state is always restored correctly and fix the potential NameError.

        out_cache_loc = forward_batch.out_cache_loc
        extend_seq_lens_cpu_bak = None
        if not self.use_mla and self.use_fia:
            extend_seq_lens_cpu_bak = forward_batch.extend_seq_lens_cpu

        try:
            # Calculates the batch sizes for prefill and decode stages
            out_cache_len_prefill = len(out_cache_loc) - forward_batch.running_decode_bs
            bs_prefill = forward_batch.batch_size - forward_batch.running_decode_bs

            # Splits the input tensors into two parts (prefill and decode stages) based on batch sizes
            q_prefill, q_decode = q[:out_cache_len_prefill], q[out_cache_len_prefill:]
            k_prefill, k_decode = k[:out_cache_len_prefill], k[out_cache_len_prefill:]
            v_prefill, v_decode = v[:out_cache_len_prefill], v[out_cache_len_prefill:]
            loc_prefill, loc_decode = (
                out_cache_loc[:out_cache_len_prefill],
                out_cache_loc[out_cache_len_prefill:],
            )

            forward_batch.out_cache_loc = loc_prefill
            if not self.use_mla:
                if self.use_fia:
                    forward_batch.extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu[
                        :bs_prefill
                    ]
                else:
                    self.forward_metadata.extend_seq_lens_cpu_int = (
                        self.forward_metadata.extend_seq_lens_cpu_int_mix[0]
                    )
                    self.forward_metadata.block_tables = (
                        self.forward_metadata.block_tables_mix[0]
                    )
                    self.forward_metadata.seq_lens_cpu_int = (
                        self.forward_metadata.seq_lens_cpu_int_mix[0]
                    )
            else:
                forward_batch.num_token_non_padded_cpu = forward_batch.prefill_input_ids
                self.forward_metadata.seq_lens_list_cumsum = (
                    self.forward_metadata.seq_lens_list_cumsum_mix[0]
                )

            # Performs the forward pass for the prefill stage
            output_prefill = self.forward_extend(
                q_prefill,
                k_prefill,
                v_prefill,
                layer,
                forward_batch,
                save_kv_cache=save_kv_cache,
                **kwargs,
            )

            forward_batch.out_cache_loc = loc_decode
            self.forward_metadata.block_tables = self.forward_metadata.block_tables_mix[1]
            self.forward_metadata.seq_lens_cpu_int = (
                self.forward_metadata.seq_lens_cpu_int_mix[1]
            )

            # Performs the forward pass for the decode stage
            output_decode = self.forward_decode(
                q_decode,
                k_decode,
                v_decode,
                layer,
                forward_batch,
                save_kv_cache=save_kv_cache,
                **kwargs,
            )

            # Concatenates and returns the outputs from both parts
            return torch.cat([output_prefill, output_decode], dim=0)
        finally:
            # Resets forward_metadata and forward_batch properties after processing
            forward_batch.out_cache_loc = out_cache_loc
            forward_batch.num_token_non_padded_cpu = None
            self.forward_metadata.extend_seq_lens_cpu_int = None
            self.forward_metadata.seq_lens_list_cumsum = None
            self.forward_metadata.block_tables = None
            self.forward_metadata.seq_lens_cpu_int = None
            if not self.use_mla and self.use_fia:
                forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu_bak



class AscendAttnMultiStepDraftBackend:
"""
Expand Down
22 changes: 22 additions & 0 deletions python/sglang/srt/layers/attention/base_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ def forward(
save_kv_cache=save_kv_cache,
**kwargs,
)
elif forward_batch.forward_mode.is_mixed():
return self.forward_mixed(
q,
k,
v,
layer,
forward_batch,
save_kv_cache=save_kv_cache,
**kwargs,
)
else:
return self.forward_extend(
q,
Expand Down Expand Up @@ -132,6 +142,18 @@ def forward_extend(
"""Run a forward for extend."""
raise NotImplementedError()

def forward_mixed(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
"""Run a forward for mix."""
raise NotImplementedError()

def support_triton(self):
"""Check if the current backend supports triton."""
return True
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index: int = -1

# mix chunk
running_decode_bs: int = 0
prefill_input_ids: int = 0

@classmethod
def init_new(
cls,
Expand Down Expand Up @@ -1393,12 +1397,14 @@ def prepare_for_split_prefill(self):
def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED
running_bs = running_batch.batch_size()
self.running_decode_bs = running_bs

for req in running_batch.reqs:
req.fill_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = 1

input_ids = torch.cat([self.input_ids, running_batch.input_ids])
self.prefill_input_ids = self.input_ids
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])

self.merge_batch(running_batch)
Expand Down Expand Up @@ -1791,6 +1797,8 @@ def get_model_worker_batch(
encoder_out_cache_loc=self.encoder_out_cache_loc,
lora_ids=[req.lora_id for req in self.reqs],
sampling_info=self.sampling_info,
running_decode_bs=self.running_decode_bs,
prefill_input_ids=self.prefill_input_ids,
input_embeds=self.input_embeds,
token_type_ids=self.token_type_ids,
spec_algorithm=self.spec_algorithm,
Expand Down Expand Up @@ -1900,6 +1908,10 @@ class ModelWorkerBatch:
# Sampling info
sampling_info: SamplingBatchInfo

# For mixed chunk
running_decode_bs: int
prefill_input_ids: int

# The original sequence lengths, Qwen-1M related
orig_seq_lens: Optional[torch.Tensor] = None

Expand Down
18 changes: 12 additions & 6 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from python.sglang.srt.utils.common import is_npu

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -331,14 +333,18 @@ def __init__(
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.running_batch = running_batch
self.new_token_ratio = new_token_ratio
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= mixed_with_decode_tokens

self.rem_total_token_offset = mixed_with_decode_tokens
self.cur_rem_token_offset = mixed_with_decode_tokens
if is_npu():
self.rem_input_tokens = rem_input_tokens
self.rem_total_token_offset = 0
self.cur_rem_token_offset = 0
else:
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= mixed_with_decode_tokens

self.rem_total_token_offset = mixed_with_decode_tokens
self.cur_rem_token_offset = mixed_with_decode_tokens
self.req_states = None
self.can_run_list = []
self.preempt_list = []
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/mem_cache/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def free(self, free_index: torch.Tensor):
self.release_pages = torch.cat((free_page_indices, self.release_pages))
else:
self.free_pages = torch.cat((free_page_indices, self.free_pages))
self.free_pages = torch.unique(torch.sort(self.free_pages)[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The line self.free_pages = torch.unique(torch.sort(self.free_pages)[0]) is a bit redundant. torch.unique by default returns sorted unique elements. You can simplify this to self.free_pages = torch.unique(self.free_pages).

Suggested change
self.free_pages = torch.unique(torch.sort(self.free_pages)[0])
self.free_pages = torch.unique(self.free_pages)

else:
self.free_group.append(free_index)

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,8 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
num_token_non_padded=self.num_token_non_padded,
global_forward_mode=self.capture_forward_mode,
lora_ids=lora_ids,
running_decode_bs=None,
prefill_input_ids=None,
)
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)

Expand Down
Loading