-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[Feature] Ascend support enable-mixed-chunk #12490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7efd761
4160ef3
aa37f8f
c312466
a3b6d7f
ea75e75
299e639
11c839b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,37 +41,6 @@ class ForwardMetadata: | |
|
|
||
| class AscendAttnBackend(AttentionBackend): | ||
|
|
||
| def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16): | ||
| mask_flag = torch.tril( | ||
| torch.ones((max_seq_len, max_seq_len), dtype=torch.bool) | ||
| ).view(max_seq_len, max_seq_len) | ||
| mask_flag = ~mask_flag | ||
| if dtype == torch.float16: | ||
| mask_value = torch.finfo(torch.float32).min | ||
| else: | ||
| mask_value = 1 | ||
| self.mask = ( | ||
| torch.masked_fill( | ||
| torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value | ||
| ) | ||
| .to(dtype) | ||
| .to(self.device) | ||
| ) | ||
| self.mask_len = max_seq_len | ||
|
|
||
| def get_verify_buffers_to_fill_after_draft(self): | ||
| """ | ||
| Return buffers for verify attention kernels that needs to be filled after draft. | ||
|
|
||
| Typically, these are tree mask and position buffers. | ||
| """ | ||
| return [None, None] | ||
|
|
||
| def update_verify_buffers_to_fill_after_draft( | ||
| self, spec_info: SpecInput, cuda_graph_bs: Optional[int] | ||
| ): | ||
| pass | ||
|
|
||
| def __init__(self, model_runner: ModelRunner): | ||
| super().__init__() | ||
| self.forward_metadata = None | ||
|
|
@@ -106,34 +75,100 @@ def __init__(self, model_runner: ModelRunner): | |
| self.mtp_mask = torch.tril(torch.ones(2048, 2048, dtype=torch.bool)).npu() | ||
| self.mtp_mask = ~self.mtp_mask | ||
|
|
||
| def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16): | ||
| mask_flag = torch.tril( | ||
| torch.ones((max_seq_len, max_seq_len), dtype=torch.bool) | ||
| ).view(max_seq_len, max_seq_len) | ||
| mask_flag = ~mask_flag | ||
| if dtype == torch.float16: | ||
| mask_value = torch.finfo(torch.float32).min | ||
| else: | ||
| mask_value = 1 | ||
| self.mask = ( | ||
| torch.masked_fill( | ||
| torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value | ||
| ) | ||
| .to(dtype) | ||
| .to(self.device) | ||
| ) | ||
| self.mask_len = max_seq_len | ||
|
|
||
| def get_verify_buffers_to_fill_after_draft(self): | ||
| """ | ||
| Return buffers for verify attention kernels that needs to be filled after draft. | ||
|
|
||
| Typically, these are tree mask and position buffers. | ||
| """ | ||
| return [None, None] | ||
|
|
||
| def update_verify_buffers_to_fill_after_draft( | ||
| self, spec_info: SpecInput, cuda_graph_bs: Optional[int] | ||
| ): | ||
| pass | ||
|
|
||
| def init_forward_metadata(self, forward_batch: ForwardBatch): | ||
| """Init the metadata for a forward pass.""" | ||
| tp_size = get_attention_tp_size() | ||
| self.forward_metadata = ForwardMetadata() | ||
| seq_lens_max = forward_batch.seq_lens.max() | ||
| if forward_batch.forward_mode.is_target_verify(): | ||
| seq_lens_max += self.speculative_num_draft_tokens | ||
| self.forward_metadata.block_tables = ( | ||
| forward_batch.req_to_token_pool.req_to_token[ | ||
| forward_batch.req_pool_indices, :seq_lens_max | ||
| ][:, :: self.page_size] | ||
| // self.page_size | ||
| ) | ||
| if forward_batch.extend_seq_lens is not None: | ||
| self.forward_metadata.extend_seq_lens_cpu_int = ( | ||
| forward_batch.extend_seq_lens.cpu().int() | ||
| if forward_batch.forward_mode.is_mixed(): | ||
| bs_prefill = forward_batch.batch_size - forward_batch.running_decode_bs | ||
| seq_lens_max_mix = ( | ||
| forward_batch.seq_lens[:bs_prefill].max(), | ||
| forward_batch.seq_lens[bs_prefill:].max(), | ||
| ) | ||
| self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() | ||
| if ( | ||
| not forward_batch.forward_mode.is_draft_extend_v2() | ||
| and not forward_batch.forward_mode.is_draft_extend() | ||
| and not forward_batch.forward_mode.is_target_verify() | ||
| ): | ||
| seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu) | ||
| self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum | ||
|
|
||
| if forward_batch.forward_mode.is_target_verify(): | ||
| self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens | ||
| req_pool_indices_mix = ( | ||
| forward_batch.req_pool_indices[:bs_prefill], | ||
| forward_batch.req_pool_indices[bs_prefill:], | ||
| ) | ||
| self.forward_metadata.block_tables_mix = ( | ||
| forward_batch.req_to_token_pool.req_to_token[ | ||
| req_pool_indices_mix[0], : seq_lens_max_mix[0] | ||
| ][:, :: self.page_size] | ||
| // self.page_size, | ||
| forward_batch.req_to_token_pool.req_to_token[ | ||
| req_pool_indices_mix[1], : seq_lens_max_mix[1] | ||
| ][:, :: self.page_size] | ||
| // self.page_size, | ||
| ) | ||
| if forward_batch.extend_seq_lens is not None: | ||
| self.forward_metadata.extend_seq_lens_cpu_int_mix = ( | ||
| forward_batch.extend_seq_lens.cpu().int()[:bs_prefill], | ||
| forward_batch.extend_seq_lens.cpu().int()[bs_prefill:], | ||
| ) | ||
| self.forward_metadata.seq_lens_cpu_int_mix = ( | ||
| forward_batch.seq_lens_cpu.int()[:bs_prefill], | ||
| forward_batch.seq_lens_cpu.int()[bs_prefill:], | ||
| ) | ||
| self.forward_metadata.seq_lens_list_cumsum_mix = np.cumsum( | ||
| forward_batch.extend_seq_lens_cpu[:bs_prefill] | ||
| ), np.cumsum(forward_batch.extend_seq_lens_cpu[bs_prefill:]) | ||
| else: | ||
| seq_lens_max = forward_batch.seq_lens.max() | ||
| if forward_batch.forward_mode.is_target_verify(): | ||
| seq_lens_max += self.speculative_num_draft_tokens | ||
| self.forward_metadata.block_tables = ( | ||
| forward_batch.req_to_token_pool.req_to_token[ | ||
| forward_batch.req_pool_indices, :seq_lens_max | ||
| ][:, :: self.page_size] | ||
| // self.page_size | ||
| ) | ||
| if forward_batch.extend_seq_lens is not None: | ||
| self.forward_metadata.extend_seq_lens_cpu_int = ( | ||
| forward_batch.extend_seq_lens.cpu().int() | ||
| ) | ||
| self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() | ||
| if ( | ||
| not forward_batch.forward_mode.is_draft_extend_v2() | ||
| and not forward_batch.forward_mode.is_draft_extend() | ||
| and not forward_batch.forward_mode.is_target_verify() | ||
| ): | ||
| seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu) | ||
| self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum | ||
|
|
||
| if forward_batch.forward_mode.is_target_verify(): | ||
| self.forward_metadata.seq_lens_cpu_int += ( | ||
| self.speculative_num_draft_tokens | ||
| ) | ||
|
|
||
| self.graph_mode = False | ||
|
|
||
|
|
@@ -851,6 +886,95 @@ def forward_decode( | |
| ) | ||
| return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank) | ||
|
|
||
| def forward_mixed( | ||
| self, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| layer: RadixAttention, | ||
| forward_batch: ForwardBatch, | ||
| save_kv_cache: bool = True, | ||
| **kwargs, | ||
| ): | ||
| out_cache_loc = forward_batch.out_cache_loc | ||
|
|
||
| # Calculates the batch sizes for prefill and decode stages | ||
| out_cache_len_prefill = len(out_cache_loc) - forward_batch.running_decode_bs | ||
| bs_prefill = forward_batch.batch_size - forward_batch.running_decode_bs | ||
|
|
||
| # Splits the input tensors into two parts (prefill and decode stages) based on batch sizes | ||
| q_prefill, q_decode = q[:out_cache_len_prefill], q[out_cache_len_prefill:] | ||
| k_prefill, k_decode = k[:out_cache_len_prefill], k[out_cache_len_prefill:] | ||
| v_prefill, v_decode = v[:out_cache_len_prefill], v[out_cache_len_prefill:] | ||
| loc_prefill, loc_decode = ( | ||
| out_cache_loc[:out_cache_len_prefill], | ||
| out_cache_loc[out_cache_len_prefill:], | ||
| ) | ||
|
|
||
| forward_batch.out_cache_loc = loc_prefill | ||
| if not self.use_mla: | ||
| if self.use_fia: | ||
| extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu | ||
| forward_batch.extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu[ | ||
| :bs_prefill | ||
| ] | ||
| else: | ||
| self.forward_metadata.extend_seq_lens_cpu_int = ( | ||
| self.forward_metadata.extend_seq_lens_cpu_int_mix[0] | ||
| ) | ||
| self.forward_metadata.block_tables = ( | ||
| self.forward_metadata.block_tables_mix[0] | ||
| ) | ||
| self.forward_metadata.seq_lens_cpu_int = ( | ||
| self.forward_metadata.seq_lens_cpu_int_mix[0] | ||
| ) | ||
| else: | ||
| forward_batch.num_token_non_padded_cpu = forward_batch.prefill_input_ids | ||
| self.forward_metadata.seq_lens_list_cumsum = ( | ||
| self.forward_metadata.seq_lens_list_cumsum_mix[0] | ||
| ) | ||
|
|
||
| # Performs the forward pass for the prefill stage | ||
| output_prefill = self.forward_extend( | ||
| q_prefill, | ||
| k_prefill, | ||
| v_prefill, | ||
| layer, | ||
| forward_batch, | ||
| save_kv_cache=save_kv_cache, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| forward_batch.out_cache_loc = loc_decode | ||
| self.forward_metadata.block_tables = self.forward_metadata.block_tables_mix[1] | ||
| self.forward_metadata.seq_lens_cpu_int = ( | ||
| self.forward_metadata.seq_lens_cpu_int_mix[1] | ||
| ) | ||
|
|
||
| # Performs the forward pass for the decode stage | ||
| output_decode = self.forward_decode( | ||
| q_decode, | ||
| k_decode, | ||
| v_decode, | ||
| layer, | ||
| forward_batch, | ||
| save_kv_cache=save_kv_cache, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| # Resets forward_metadata and forward_batch properties after processing | ||
| forward_batch.out_cache_loc = out_cache_loc | ||
| forward_batch.num_token_non_padded_cpu = None | ||
| self.forward_metadata.extend_seq_lens_cpu_int = None | ||
| self.forward_metadata.seq_lens_list_cumsum = None | ||
| self.forward_metadata.block_tables = None | ||
| self.forward_metadata.seq_lens_cpu_int = None | ||
| if not self.use_mla and self.use_fia: | ||
| forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu | ||
|
|
||
| # Concatenates and returns the outputs from both parts | ||
| return torch.cat([output_prefill, output_decode], dim=0) | ||
|
Comment on lines
+899
to
+976
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function out_cache_loc = forward_batch.out_cache_loc
extend_seq_lens_cpu_bak = None
if not self.use_mla and self.use_fia:
extend_seq_lens_cpu_bak = forward_batch.extend_seq_lens_cpu
try:
# Calculates the batch sizes for prefill and decode stages
out_cache_len_prefill = len(out_cache_loc) - forward_batch.running_decode_bs
bs_prefill = forward_batch.batch_size - forward_batch.running_decode_bs
# Splits the input tensors into two parts (prefill and decode stages) based on batch sizes
q_prefill, q_decode = q[:out_cache_len_prefill], q[out_cache_len_prefill:]
k_prefill, k_decode = k[:out_cache_len_prefill], k[out_cache_len_prefill:]
v_prefill, v_decode = v[:out_cache_len_prefill], v[out_cache_len_prefill:]
loc_prefill, loc_decode = (
out_cache_loc[:out_cache_len_prefill],
out_cache_loc[out_cache_len_prefill:],
)
forward_batch.out_cache_loc = loc_prefill
if not self.use_mla:
if self.use_fia:
forward_batch.extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu[
:bs_prefill
]
else:
self.forward_metadata.extend_seq_lens_cpu_int = (
self.forward_metadata.extend_seq_lens_cpu_int_mix[0]
)
self.forward_metadata.block_tables = (
self.forward_metadata.block_tables_mix[0]
)
self.forward_metadata.seq_lens_cpu_int = (
self.forward_metadata.seq_lens_cpu_int_mix[0]
)
else:
forward_batch.num_token_non_padded_cpu = forward_batch.prefill_input_ids
self.forward_metadata.seq_lens_list_cumsum = (
self.forward_metadata.seq_lens_list_cumsum_mix[0]
)
# Performs the forward pass for the prefill stage
output_prefill = self.forward_extend(
q_prefill,
k_prefill,
v_prefill,
layer,
forward_batch,
save_kv_cache=save_kv_cache,
**kwargs,
)
forward_batch.out_cache_loc = loc_decode
self.forward_metadata.block_tables = self.forward_metadata.block_tables_mix[1]
self.forward_metadata.seq_lens_cpu_int = (
self.forward_metadata.seq_lens_cpu_int_mix[1]
)
# Performs the forward pass for the decode stage
output_decode = self.forward_decode(
q_decode,
k_decode,
v_decode,
layer,
forward_batch,
save_kv_cache=save_kv_cache,
**kwargs,
)
# Concatenates and returns the outputs from both parts
return torch.cat([output_prefill, output_decode], dim=0)
finally:
# Resets forward_metadata and forward_batch properties after processing
forward_batch.out_cache_loc = out_cache_loc
forward_batch.num_token_non_padded_cpu = None
self.forward_metadata.extend_seq_lens_cpu_int = None
self.forward_metadata.seq_lens_list_cumsum = None
self.forward_metadata.block_tables = None
self.forward_metadata.seq_lens_cpu_int = None
if not self.use_mla and self.use_fia:
forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu_bak |
||
|
|
||
|
|
||
| class AscendAttnMultiStepDraftBackend: | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -558,6 +558,7 @@ def free(self, free_index: torch.Tensor): | |||||
| self.release_pages = torch.cat((free_page_indices, self.release_pages)) | ||||||
| else: | ||||||
| self.free_pages = torch.cat((free_page_indices, self.free_pages)) | ||||||
| self.free_pages = torch.unique(torch.sort(self.free_pages)[0]) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The line
Suggested change
|
||||||
| else: | ||||||
| self.free_group.append(free_index) | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve code clarity and avoid redundant computations, you can store the results of
forward_batch.extend_seq_lens.cpu().int(),forward_batch.seq_lens_cpu.int(), andforward_batch.extend_seq_lens_cpuin temporary variables before slicing them for the mixed-mode metadata.