Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 142 additions & 32 deletions python/sglang/srt/layers/attention/ascend_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,37 +41,6 @@ class ForwardMetadata:

class AscendAttnBackend(AttentionBackend):

def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
mask_flag = torch.tril(
torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
).view(max_seq_len, max_seq_len)
mask_flag = ~mask_flag
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
self.mask = (
torch.masked_fill(
torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
)
.to(dtype)
.to(self.device)
)
self.mask_len = max_seq_len

def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers for verify attention kernels that needs to be filled after draft.

Typically, these are tree mask and position buffers.
"""
return [None, None]

def update_verify_buffers_to_fill_after_draft(
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
):
pass

def __init__(self, model_runner: ModelRunner):
super().__init__()
self.forward_metadata = None
Expand Down Expand Up @@ -106,6 +75,42 @@ def __init__(self, model_runner: ModelRunner):
self.mtp_mask = torch.tril(torch.ones(2048, 2048, dtype=torch.bool)).npu()
self.mtp_mask = ~self.mtp_mask

# enable-mixed-chunk
attn_mask = self._generate_attn_mask(8192)
self._seq_len_cached = attn_mask.shape[0]
self.attn_mask_cache = attn_mask

def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
mask_flag = torch.tril(
torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
).view(max_seq_len, max_seq_len)
mask_flag = ~mask_flag
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
self.mask = (
torch.masked_fill(
torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
)
.to(dtype)
.to(self.device)
)
self.mask_len = max_seq_len

def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers for verify attention kernels that needs to be filled after draft.

Typically, these are tree mask and position buffers.
"""
return [None, None]

def update_verify_buffers_to_fill_after_draft(
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
):
pass

def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
tp_size = get_attention_tp_size()
Expand Down Expand Up @@ -134,9 +139,73 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):

if forward_batch.forward_mode.is_target_verify():
self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens

if forward_batch.forward_mode.is_mixed():
attn_mask_id = self.get_attention_mask_id(
self.forward_metadata.seq_lens_cpu_int,
self.forward_metadata.extend_seq_lens_cpu_int,
)
self.attn_mask = self.get_splitfuse_attn_mask( # type: ignore
seq_lens=self.forward_metadata.seq_lens_cpu_int,
position=attn_mask_id,
dtype=torch.float16,
device="npu",
).to(torch.bfloat16)
self.graph_mode = False

@staticmethod
def _generate_attn_mask(max_seq_len, dtype=torch.float16):
# Construct lower triangle matrix.
mask_flag = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool).tril_()
# Create upper triangle matrix used to mark mask positions.
mask_flag = ~mask_flag
# Currently for fp16 dtype, the mask value should be set to -inf.
# TODO: Eliminate this part in the future.
mask_value = float("-inf") if dtype == torch.float16 else 1
# mask_value = -10000 if dtype == torch.float16 else 1
attn_mask = torch.zeros(
size=(max_seq_len, max_seq_len), dtype=dtype
).masked_fill_(mask_flag, mask_value)
return attn_mask

def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
if seqlen > self._seq_len_cached:
self._seq_len_cached = seqlen
self.attn_mask_cache = self._generate_attn_mask(seqlen, dtype)
if self.attn_mask_cache.dtype != dtype:
self.attn_mask_cache = self.attn_mask_cache.to(dtype)

def get_splitfuse_attn_mask(
self,
seq_lens: torch.Tensor = None,
position: torch.Tensor = None,
dtype: torch.dtype = None,
device: torch.device = None,
) -> torch.Tensor:
if dtype not in [torch.float16, torch.bfloat16]:
raise ValueError("splitfuse_attn_mask now only supports bf16 and fp16")
max_seq_len = max(seq_lens, default=0)
self._update_attn_cache(max_seq_len, dtype)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
# mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(
# dtype)
attn_mask = torch.index_select(self.attn_mask_cache, dim=0, index=position)[
:, :max_seq_len
]
# attn_mask *= mask_scale_factor
return attn_mask.contiguous().to(device, non_blocking=True)

def get_attention_mask_id(self, seq_lens, extend_lens):
starts = seq_lens - extend_lens
ends = seq_lens

# Use torch.stack to stack the start and end indices together
ranges = torch.stack((starts, ends), dim=-1)

# Use list comprehension to generate tensors for each range and concatenate them
attn_mask_id = torch.cat([torch.arange(start, end) for start, end in ranges])
return attn_mask_id

def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.graph_metadata = {
"block_tables": torch.empty(
Expand Down Expand Up @@ -851,6 +920,47 @@ def forward_decode(
)
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)

def forward_mixed(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)

query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
attn_output = torch.empty(
(query.shape[0], layer.tp_q_head_num, layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)

torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=k_cache,
value_cache=v_cache,
block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int,
mask=self.attn_mask,
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
scale_value=layer.scaling,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
out=attn_output,
)
return attn_output.view(
attn_output.shape[0], layer.tp_q_head_num * layer.v_head_dim
)


class AscendAttnMultiStepDraftBackend:
"""
Expand Down
22 changes: 22 additions & 0 deletions python/sglang/srt/layers/attention/base_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ def forward(
save_kv_cache=save_kv_cache,
**kwargs,
)
elif forward_batch.forward_mode.is_mixed():
return self.forward_mixed(
q,
k,
v,
layer,
forward_batch,
save_kv_cache=save_kv_cache,
**kwargs,
)
else:
return self.forward_extend(
q,
Expand Down Expand Up @@ -132,6 +142,18 @@ def forward_extend(
"""Run a forward for extend."""
raise NotImplementedError()

def forward_mixed(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
"""Run a forward for mix."""
raise NotImplementedError()

def support_triton(self):
"""Check if the current backend supports triton."""
return True
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def add_req_state(r, insert_sort=False):
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
)
else:
if self.rem_chunk_tokens == 0:
if self.rem_chunk_tokens <= 0:
return AddReqResult.OTHER

# Chunked prefill
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/mem_cache/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def free(self, free_index: torch.Tensor):
self.release_pages = torch.cat((free_page_indices, self.release_pages))
else:
self.free_pages = torch.cat((free_page_indices, self.free_pages))
self.free_pages = torch.unique(self.free_pages)
else:
self.free_group.append(free_index)

Expand Down