Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 135 additions & 32 deletions python/sglang/srt/layers/attention/ascend_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,37 +41,6 @@ class ForwardMetadata:

class AscendAttnBackend(AttentionBackend):

def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
mask_flag = torch.tril(
torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
).view(max_seq_len, max_seq_len)
mask_flag = ~mask_flag
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
self.mask = (
torch.masked_fill(
torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
)
.to(dtype)
.to(self.device)
)
self.mask_len = max_seq_len

def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers for verify attention kernels that needs to be filled after draft.

Typically, these are tree mask and position buffers.
"""
return [None, None]

def update_verify_buffers_to_fill_after_draft(
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
):
pass

def __init__(self, model_runner: ModelRunner):
super().__init__()
self.forward_metadata = None
Expand Down Expand Up @@ -106,6 +75,43 @@ def __init__(self, model_runner: ModelRunner):
self.mtp_mask = torch.tril(torch.ones(2048, 2048, dtype=torch.bool)).npu()
self.mtp_mask = ~self.mtp_mask

# enable-mixed-chunk
_ASCEND_MIXED_CHUNK_CACHE_SIZE = 8192
attn_mask = self._generate_attn_mask(_ASCEND_MIXED_CHUNK_CACHE_SIZE)
self._seq_len_cached = attn_mask.shape[0]
self.attn_mask_cache = attn_mask

def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
mask_flag = torch.tril(
torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
).view(max_seq_len, max_seq_len)
mask_flag = ~mask_flag
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
self.mask = (
torch.masked_fill(
torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
)
.to(dtype)
.to(self.device)
)
self.mask_len = max_seq_len

def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers for verify attention kernels that needs to be filled after draft.

Typically, these are tree mask and position buffers.
"""
return [None, None]

def update_verify_buffers_to_fill_after_draft(
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
):
pass

def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
tp_size = get_attention_tp_size()
Expand Down Expand Up @@ -134,9 +140,65 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):

if forward_batch.forward_mode.is_target_verify():
self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens

if forward_batch.forward_mode.is_mixed():
attn_mask_id = self.get_attention_mask_id(
self.forward_metadata.seq_lens_cpu_int,
self.forward_metadata.extend_seq_lens_cpu_int,
)
self.attn_mask = self.get_splitfuse_attn_mask(
seq_lens=self.forward_metadata.seq_lens_cpu_int,
position=attn_mask_id,
dtype=torch.float16,
device="npu",
).to(torch.bfloat16)
self.graph_mode = False

@staticmethod
def _generate_attn_mask(max_seq_len, dtype=torch.float16):
# Construct lower triangle matrix.
mask_flag = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool).tril_()
# Create upper triangle matrix used to mark mask positions.
mask_flag = ~mask_flag
mask_value = float("-inf") if dtype in [torch.float16, torch.bfloat16] else 1
attn_mask = torch.zeros(
size=(max_seq_len, max_seq_len), dtype=dtype
).masked_fill_(mask_flag, mask_value)
return attn_mask

def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
if seqlen > self._seq_len_cached:
self._seq_len_cached = seqlen
self.attn_mask_cache = self._generate_attn_mask(seqlen, dtype)
if self.attn_mask_cache.dtype != dtype:
self.attn_mask_cache = self.attn_mask_cache.to(dtype)

def get_splitfuse_attn_mask(
self,
seq_lens: torch.Tensor = None,
position: torch.Tensor = None,
dtype: torch.dtype = None,
device: torch.device = None,
) -> torch.Tensor:
if dtype not in [torch.float16, torch.bfloat16]:
raise ValueError("splitfuse_attn_mask now only supports bf16 and fp16")
max_seq_len = max(seq_lens, default=0)
self._update_attn_cache(max_seq_len, dtype)
attn_mask = torch.index_select(self.attn_mask_cache, dim=0, index=position)[
:, :max_seq_len
]
return attn_mask.contiguous().to(device, non_blocking=True)

def get_attention_mask_id(self, seq_lens, extend_lens):
starts = seq_lens - extend_lens
ends = seq_lens

# Use torch.stack to stack the start and end indices together
ranges = torch.stack((starts, ends), dim=-1)

# Use list comprehension to generate tensors for each range and concatenate them
attn_mask_id = torch.cat([torch.arange(start, end) for start, end in ranges])
return attn_mask_id

def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.graph_metadata = {
"block_tables": torch.empty(
Expand Down Expand Up @@ -851,6 +913,47 @@ def forward_decode(
)
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)

def forward_mixed(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)

query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
attn_output = torch.empty(
(query.shape[0], layer.tp_q_head_num, layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)

torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=k_cache,
value_cache=v_cache,
block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int,
mask=self.attn_mask,
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
scale_value=layer.scaling,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
out=attn_output,
)
return attn_output.view(
attn_output.shape[0], layer.tp_q_head_num * layer.v_head_dim
)


class AscendAttnMultiStepDraftBackend:
"""
Expand Down
22 changes: 22 additions & 0 deletions python/sglang/srt/layers/attention/base_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ def forward(
save_kv_cache=save_kv_cache,
**kwargs,
)
elif forward_batch.forward_mode.is_mixed():
return self.forward_mixed(
q,
k,
v,
layer,
forward_batch,
save_kv_cache=save_kv_cache,
**kwargs,
)
else:
return self.forward_extend(
q,
Expand Down Expand Up @@ -132,6 +142,18 @@ def forward_extend(
"""Run a forward for extend."""
raise NotImplementedError()

def forward_mixed(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
"""Run a forward for mix."""
raise NotImplementedError()

def support_triton(self):
"""Check if the current backend supports triton."""
return True
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def add_req_state(r, insert_sort=False):
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
)
else:
if self.rem_chunk_tokens == 0:
if self.rem_chunk_tokens <= 0:
return AddReqResult.OTHER

# Chunked prefill
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/mem_cache/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def free(self, free_index: torch.Tensor):
self.release_pages = torch.cat((free_page_indices, self.release_pages))
else:
self.free_pages = torch.cat((free_page_indices, self.free_pages))
self.free_pages = torch.unique(self.free_pages)
else:
self.free_group.append(free_index)

Expand Down
Loading