-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[Feature] Ascend support enable-mixed-chunk #12490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Added forward_mixed method for mixed forward processing.
Ensure free_pages contains unique sorted values.
Summary of ChangesHello @MichelleWu351, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enables a 'mixed chunk' processing mode for Ascend devices, allowing for more efficient handling of concurrent prefill and decode operations within a single batch. It involves significant refactoring of the attention backend to support this new mode, alongside updates to batch metadata structures and scheduling policies to correctly manage and process these mixed workloads on NPU hardware. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for mixed-chunk processing on Ascend hardware. The changes primarily involve updating attention backends and scheduling logic to handle mixed batches of prefill and decode requests. My review focuses on ensuring the correctness and robustness of these changes. I've identified a potential issue with state management in the new forward_mixed function that could lead to inconsistencies if not handled carefully, and I've also suggested some minor improvements for code clarity and efficiency.
| out_cache_loc = forward_batch.out_cache_loc | ||
|
|
||
| # Calculates the batch sizes for prefill and decode stages | ||
| out_cache_len_prefill = len(out_cache_loc) - forward_batch.running_decode_bs | ||
| bs_prefill = forward_batch.batch_size - forward_batch.running_decode_bs | ||
|
|
||
| # Splits the input tensors into two parts (prefill and decode stages) based on batch sizes | ||
| q_prefill, q_decode = q[:out_cache_len_prefill], q[out_cache_len_prefill:] | ||
| k_prefill, k_decode = k[:out_cache_len_prefill], k[out_cache_len_prefill:] | ||
| v_prefill, v_decode = v[:out_cache_len_prefill], v[out_cache_len_prefill:] | ||
| loc_prefill, loc_decode = ( | ||
| out_cache_loc[:out_cache_len_prefill], | ||
| out_cache_loc[out_cache_len_prefill:], | ||
| ) | ||
|
|
||
| forward_batch.out_cache_loc = loc_prefill | ||
| if not self.use_mla: | ||
| if self.use_fia: | ||
| extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu | ||
| forward_batch.extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu[ | ||
| :bs_prefill | ||
| ] | ||
| else: | ||
| self.forward_metadata.extend_seq_lens_cpu_int = ( | ||
| self.forward_metadata.extend_seq_lens_cpu_int_mix[0] | ||
| ) | ||
| self.forward_metadata.block_tables = ( | ||
| self.forward_metadata.block_tables_mix[0] | ||
| ) | ||
| self.forward_metadata.seq_lens_cpu_int = ( | ||
| self.forward_metadata.seq_lens_cpu_int_mix[0] | ||
| ) | ||
| else: | ||
| forward_batch.num_token_non_padded_cpu = forward_batch.prefill_input_ids | ||
| self.forward_metadata.seq_lens_list_cumsum = ( | ||
| self.forward_metadata.seq_lens_list_cumsum_mix[0] | ||
| ) | ||
|
|
||
| # Performs the forward pass for the prefill stage | ||
| output_prefill = self.forward_extend( | ||
| q_prefill, | ||
| k_prefill, | ||
| v_prefill, | ||
| layer, | ||
| forward_batch, | ||
| save_kv_cache=save_kv_cache, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| forward_batch.out_cache_loc = loc_decode | ||
| self.forward_metadata.block_tables = self.forward_metadata.block_tables_mix[1] | ||
| self.forward_metadata.seq_lens_cpu_int = ( | ||
| self.forward_metadata.seq_lens_cpu_int_mix[1] | ||
| ) | ||
|
|
||
| # Performs the forward pass for the decode stage | ||
| output_decode = self.forward_decode( | ||
| q_decode, | ||
| k_decode, | ||
| v_decode, | ||
| layer, | ||
| forward_batch, | ||
| save_kv_cache=save_kv_cache, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| # Resets forward_metadata and forward_batch properties after processing | ||
| forward_batch.out_cache_loc = out_cache_loc | ||
| forward_batch.num_token_non_padded_cpu = None | ||
| self.forward_metadata.extend_seq_lens_cpu_int = None | ||
| self.forward_metadata.seq_lens_list_cumsum = None | ||
| self.forward_metadata.block_tables = None | ||
| self.forward_metadata.seq_lens_cpu_int = None | ||
| if not self.use_mla and self.use_fia: | ||
| forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu | ||
|
|
||
| # Concatenates and returns the outputs from both parts | ||
| return torch.cat([output_prefill, output_decode], dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function forward_mixed modifies shared state on forward_batch and self.forward_metadata. If an exception occurs during the forward calls, the cleanup code will not be executed, leaving objects in an inconsistent state. This could lead to hard-to-debug errors in subsequent operations. Additionally, there is a potential NameError because extend_seq_lens_cpu is defined within a conditional block but used in the cleanup section. Using a try...finally block and defining the backup variable before the try block will ensure the state is always restored correctly and fix the potential NameError.
out_cache_loc = forward_batch.out_cache_loc
extend_seq_lens_cpu_bak = None
if not self.use_mla and self.use_fia:
extend_seq_lens_cpu_bak = forward_batch.extend_seq_lens_cpu
try:
# Calculates the batch sizes for prefill and decode stages
out_cache_len_prefill = len(out_cache_loc) - forward_batch.running_decode_bs
bs_prefill = forward_batch.batch_size - forward_batch.running_decode_bs
# Splits the input tensors into two parts (prefill and decode stages) based on batch sizes
q_prefill, q_decode = q[:out_cache_len_prefill], q[out_cache_len_prefill:]
k_prefill, k_decode = k[:out_cache_len_prefill], k[out_cache_len_prefill:]
v_prefill, v_decode = v[:out_cache_len_prefill], v[out_cache_len_prefill:]
loc_prefill, loc_decode = (
out_cache_loc[:out_cache_len_prefill],
out_cache_loc[out_cache_len_prefill:],
)
forward_batch.out_cache_loc = loc_prefill
if not self.use_mla:
if self.use_fia:
forward_batch.extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu[
:bs_prefill
]
else:
self.forward_metadata.extend_seq_lens_cpu_int = (
self.forward_metadata.extend_seq_lens_cpu_int_mix[0]
)
self.forward_metadata.block_tables = (
self.forward_metadata.block_tables_mix[0]
)
self.forward_metadata.seq_lens_cpu_int = (
self.forward_metadata.seq_lens_cpu_int_mix[0]
)
else:
forward_batch.num_token_non_padded_cpu = forward_batch.prefill_input_ids
self.forward_metadata.seq_lens_list_cumsum = (
self.forward_metadata.seq_lens_list_cumsum_mix[0]
)
# Performs the forward pass for the prefill stage
output_prefill = self.forward_extend(
q_prefill,
k_prefill,
v_prefill,
layer,
forward_batch,
save_kv_cache=save_kv_cache,
**kwargs,
)
forward_batch.out_cache_loc = loc_decode
self.forward_metadata.block_tables = self.forward_metadata.block_tables_mix[1]
self.forward_metadata.seq_lens_cpu_int = (
self.forward_metadata.seq_lens_cpu_int_mix[1]
)
# Performs the forward pass for the decode stage
output_decode = self.forward_decode(
q_decode,
k_decode,
v_decode,
layer,
forward_batch,
save_kv_cache=save_kv_cache,
**kwargs,
)
# Concatenates and returns the outputs from both parts
return torch.cat([output_prefill, output_decode], dim=0)
finally:
# Resets forward_metadata and forward_batch properties after processing
forward_batch.out_cache_loc = out_cache_loc
forward_batch.num_token_non_padded_cpu = None
self.forward_metadata.extend_seq_lens_cpu_int = None
self.forward_metadata.seq_lens_list_cumsum = None
self.forward_metadata.block_tables = None
self.forward_metadata.seq_lens_cpu_int = None
if not self.use_mla and self.use_fia:
forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu_bak| if forward_batch.extend_seq_lens is not None: | ||
| self.forward_metadata.extend_seq_lens_cpu_int_mix = ( | ||
| forward_batch.extend_seq_lens.cpu().int()[:bs_prefill], | ||
| forward_batch.extend_seq_lens.cpu().int()[bs_prefill:], | ||
| ) | ||
| self.forward_metadata.seq_lens_cpu_int_mix = ( | ||
| forward_batch.seq_lens_cpu.int()[:bs_prefill], | ||
| forward_batch.seq_lens_cpu.int()[bs_prefill:], | ||
| ) | ||
| self.forward_metadata.seq_lens_list_cumsum_mix = np.cumsum( | ||
| forward_batch.extend_seq_lens_cpu[:bs_prefill] | ||
| ), np.cumsum(forward_batch.extend_seq_lens_cpu[bs_prefill:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve code clarity and avoid redundant computations, you can store the results of forward_batch.extend_seq_lens.cpu().int(), forward_batch.seq_lens_cpu.int(), and forward_batch.extend_seq_lens_cpu in temporary variables before slicing them for the mixed-mode metadata.
| if forward_batch.extend_seq_lens is not None: | |
| self.forward_metadata.extend_seq_lens_cpu_int_mix = ( | |
| forward_batch.extend_seq_lens.cpu().int()[:bs_prefill], | |
| forward_batch.extend_seq_lens.cpu().int()[bs_prefill:], | |
| ) | |
| self.forward_metadata.seq_lens_cpu_int_mix = ( | |
| forward_batch.seq_lens_cpu.int()[:bs_prefill], | |
| forward_batch.seq_lens_cpu.int()[bs_prefill:], | |
| ) | |
| self.forward_metadata.seq_lens_list_cumsum_mix = np.cumsum( | |
| forward_batch.extend_seq_lens_cpu[:bs_prefill] | |
| ), np.cumsum(forward_batch.extend_seq_lens_cpu[bs_prefill:]) | |
| if forward_batch.extend_seq_lens is not None: | |
| extend_seq_lens_cpu_int = forward_batch.extend_seq_lens.cpu().int() | |
| self.forward_metadata.extend_seq_lens_cpu_int_mix = ( | |
| extend_seq_lens_cpu_int[:bs_prefill], | |
| extend_seq_lens_cpu_int[bs_prefill:], | |
| ) | |
| seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() | |
| self.forward_metadata.seq_lens_cpu_int_mix = ( | |
| seq_lens_cpu_int[:bs_prefill], | |
| seq_lens_cpu_int[bs_prefill:], | |
| ) | |
| extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu | |
| self.forward_metadata.seq_lens_list_cumsum_mix = ( | |
| np.cumsum(extend_seq_lens_cpu[:bs_prefill]), | |
| np.cumsum(extend_seq_lens_cpu[bs_prefill:]), | |
| ) |
| self.release_pages = torch.cat((free_page_indices, self.release_pages)) | ||
| else: | ||
| self.free_pages = torch.cat((free_page_indices, self.free_pages)) | ||
| self.free_pages = torch.unique(torch.sort(self.free_pages)[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The line self.free_pages = torch.unique(torch.sort(self.free_pages)[0]) is a bit redundant. torch.unique by default returns sorted unique elements. You can simplify this to self.free_pages = torch.unique(self.free_pages).
| self.free_pages = torch.unique(torch.sort(self.free_pages)[0]) | |
| self.free_pages = torch.unique(self.free_pages) |
No description provided.