Skip to content

Conversation

@MichelleWu351
Copy link

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @MichelleWu351, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enables a 'mixed chunk' processing mode for Ascend devices, allowing for more efficient handling of concurrent prefill and decode operations within a single batch. It involves significant refactoring of the attention backend to support this new mode, alongside updates to batch metadata structures and scheduling policies to correctly manage and process these mixed workloads on NPU hardware.

Highlights

  • Ascend Mixed Chunk Support: Introduced a new forward_mixed method in AscendAttnBackend to handle mixed prefill and decode stages within a single batch, specifically for Ascend devices. This method orchestrates the splitting of input tensors and metadata, calling forward_extend for prefill and forward_decode for decode, then concatenating their outputs.
  • Batch Metadata and Scheduling Updates: Added running_decode_bs and prefill_input_ids to ScheduleBatch, ModelWorkerBatch, and ForwardBatch classes to properly manage information for mixed chunk processing. The mix_with_running method in ScheduleBatch was updated to populate these new fields.
  • NPU-Specific Policy Adjustments: Modified the PolicyConfig initialization in schedule_policy.py to conditionally adjust token offsets (rem_input_tokens, rem_total_token_offset, cur_rem_token_offset) when running on Ascend (NPU) devices, ensuring correct scheduling behavior for mixed batches.
  • Memory Management Improvement: Enhanced the free method in allocator.py to ensure that free_pages are unique and sorted after memory deallocation, which can improve memory management efficiency.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@ping1jing2 ping1jing2 marked this pull request as draft November 1, 2025 08:58
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for mixed-chunk processing on Ascend hardware. The changes primarily involve updating attention backends and scheduling logic to handle mixed batches of prefill and decode requests. My review focuses on ensuring the correctness and robustness of these changes. I've identified a potential issue with state management in the new forward_mixed function that could lead to inconsistencies if not handled carefully, and I've also suggested some minor improvements for code clarity and efficiency.

Comment on lines +899 to +976
out_cache_loc = forward_batch.out_cache_loc

# Calculates the batch sizes for prefill and decode stages
out_cache_len_prefill = len(out_cache_loc) - forward_batch.running_decode_bs
bs_prefill = forward_batch.batch_size - forward_batch.running_decode_bs

# Splits the input tensors into two parts (prefill and decode stages) based on batch sizes
q_prefill, q_decode = q[:out_cache_len_prefill], q[out_cache_len_prefill:]
k_prefill, k_decode = k[:out_cache_len_prefill], k[out_cache_len_prefill:]
v_prefill, v_decode = v[:out_cache_len_prefill], v[out_cache_len_prefill:]
loc_prefill, loc_decode = (
out_cache_loc[:out_cache_len_prefill],
out_cache_loc[out_cache_len_prefill:],
)

forward_batch.out_cache_loc = loc_prefill
if not self.use_mla:
if self.use_fia:
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
forward_batch.extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu[
:bs_prefill
]
else:
self.forward_metadata.extend_seq_lens_cpu_int = (
self.forward_metadata.extend_seq_lens_cpu_int_mix[0]
)
self.forward_metadata.block_tables = (
self.forward_metadata.block_tables_mix[0]
)
self.forward_metadata.seq_lens_cpu_int = (
self.forward_metadata.seq_lens_cpu_int_mix[0]
)
else:
forward_batch.num_token_non_padded_cpu = forward_batch.prefill_input_ids
self.forward_metadata.seq_lens_list_cumsum = (
self.forward_metadata.seq_lens_list_cumsum_mix[0]
)

# Performs the forward pass for the prefill stage
output_prefill = self.forward_extend(
q_prefill,
k_prefill,
v_prefill,
layer,
forward_batch,
save_kv_cache=save_kv_cache,
**kwargs,
)

forward_batch.out_cache_loc = loc_decode
self.forward_metadata.block_tables = self.forward_metadata.block_tables_mix[1]
self.forward_metadata.seq_lens_cpu_int = (
self.forward_metadata.seq_lens_cpu_int_mix[1]
)

# Performs the forward pass for the decode stage
output_decode = self.forward_decode(
q_decode,
k_decode,
v_decode,
layer,
forward_batch,
save_kv_cache=save_kv_cache,
**kwargs,
)

# Resets forward_metadata and forward_batch properties after processing
forward_batch.out_cache_loc = out_cache_loc
forward_batch.num_token_non_padded_cpu = None
self.forward_metadata.extend_seq_lens_cpu_int = None
self.forward_metadata.seq_lens_list_cumsum = None
self.forward_metadata.block_tables = None
self.forward_metadata.seq_lens_cpu_int = None
if not self.use_mla and self.use_fia:
forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu

# Concatenates and returns the outputs from both parts
return torch.cat([output_prefill, output_decode], dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function forward_mixed modifies shared state on forward_batch and self.forward_metadata. If an exception occurs during the forward calls, the cleanup code will not be executed, leaving objects in an inconsistent state. This could lead to hard-to-debug errors in subsequent operations. Additionally, there is a potential NameError because extend_seq_lens_cpu is defined within a conditional block but used in the cleanup section. Using a try...finally block and defining the backup variable before the try block will ensure the state is always restored correctly and fix the potential NameError.

        out_cache_loc = forward_batch.out_cache_loc
        extend_seq_lens_cpu_bak = None
        if not self.use_mla and self.use_fia:
            extend_seq_lens_cpu_bak = forward_batch.extend_seq_lens_cpu

        try:
            # Calculates the batch sizes for prefill and decode stages
            out_cache_len_prefill = len(out_cache_loc) - forward_batch.running_decode_bs
            bs_prefill = forward_batch.batch_size - forward_batch.running_decode_bs

            # Splits the input tensors into two parts (prefill and decode stages) based on batch sizes
            q_prefill, q_decode = q[:out_cache_len_prefill], q[out_cache_len_prefill:]
            k_prefill, k_decode = k[:out_cache_len_prefill], k[out_cache_len_prefill:]
            v_prefill, v_decode = v[:out_cache_len_prefill], v[out_cache_len_prefill:]
            loc_prefill, loc_decode = (
                out_cache_loc[:out_cache_len_prefill],
                out_cache_loc[out_cache_len_prefill:],
            )

            forward_batch.out_cache_loc = loc_prefill
            if not self.use_mla:
                if self.use_fia:
                    forward_batch.extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu[
                        :bs_prefill
                    ]
                else:
                    self.forward_metadata.extend_seq_lens_cpu_int = (
                        self.forward_metadata.extend_seq_lens_cpu_int_mix[0]
                    )
                    self.forward_metadata.block_tables = (
                        self.forward_metadata.block_tables_mix[0]
                    )
                    self.forward_metadata.seq_lens_cpu_int = (
                        self.forward_metadata.seq_lens_cpu_int_mix[0]
                    )
            else:
                forward_batch.num_token_non_padded_cpu = forward_batch.prefill_input_ids
                self.forward_metadata.seq_lens_list_cumsum = (
                    self.forward_metadata.seq_lens_list_cumsum_mix[0]
                )

            # Performs the forward pass for the prefill stage
            output_prefill = self.forward_extend(
                q_prefill,
                k_prefill,
                v_prefill,
                layer,
                forward_batch,
                save_kv_cache=save_kv_cache,
                **kwargs,
            )

            forward_batch.out_cache_loc = loc_decode
            self.forward_metadata.block_tables = self.forward_metadata.block_tables_mix[1]
            self.forward_metadata.seq_lens_cpu_int = (
                self.forward_metadata.seq_lens_cpu_int_mix[1]
            )

            # Performs the forward pass for the decode stage
            output_decode = self.forward_decode(
                q_decode,
                k_decode,
                v_decode,
                layer,
                forward_batch,
                save_kv_cache=save_kv_cache,
                **kwargs,
            )

            # Concatenates and returns the outputs from both parts
            return torch.cat([output_prefill, output_decode], dim=0)
        finally:
            # Resets forward_metadata and forward_batch properties after processing
            forward_batch.out_cache_loc = out_cache_loc
            forward_batch.num_token_non_padded_cpu = None
            self.forward_metadata.extend_seq_lens_cpu_int = None
            self.forward_metadata.seq_lens_list_cumsum = None
            self.forward_metadata.block_tables = None
            self.forward_metadata.seq_lens_cpu_int = None
            if not self.use_mla and self.use_fia:
                forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu_bak

Comment on lines +133 to +144
if forward_batch.extend_seq_lens is not None:
self.forward_metadata.extend_seq_lens_cpu_int_mix = (
forward_batch.extend_seq_lens.cpu().int()[:bs_prefill],
forward_batch.extend_seq_lens.cpu().int()[bs_prefill:],
)
self.forward_metadata.seq_lens_cpu_int_mix = (
forward_batch.seq_lens_cpu.int()[:bs_prefill],
forward_batch.seq_lens_cpu.int()[bs_prefill:],
)
self.forward_metadata.seq_lens_list_cumsum_mix = np.cumsum(
forward_batch.extend_seq_lens_cpu[:bs_prefill]
), np.cumsum(forward_batch.extend_seq_lens_cpu[bs_prefill:])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve code clarity and avoid redundant computations, you can store the results of forward_batch.extend_seq_lens.cpu().int(), forward_batch.seq_lens_cpu.int(), and forward_batch.extend_seq_lens_cpu in temporary variables before slicing them for the mixed-mode metadata.

Suggested change
if forward_batch.extend_seq_lens is not None:
self.forward_metadata.extend_seq_lens_cpu_int_mix = (
forward_batch.extend_seq_lens.cpu().int()[:bs_prefill],
forward_batch.extend_seq_lens.cpu().int()[bs_prefill:],
)
self.forward_metadata.seq_lens_cpu_int_mix = (
forward_batch.seq_lens_cpu.int()[:bs_prefill],
forward_batch.seq_lens_cpu.int()[bs_prefill:],
)
self.forward_metadata.seq_lens_list_cumsum_mix = np.cumsum(
forward_batch.extend_seq_lens_cpu[:bs_prefill]
), np.cumsum(forward_batch.extend_seq_lens_cpu[bs_prefill:])
if forward_batch.extend_seq_lens is not None:
extend_seq_lens_cpu_int = forward_batch.extend_seq_lens.cpu().int()
self.forward_metadata.extend_seq_lens_cpu_int_mix = (
extend_seq_lens_cpu_int[:bs_prefill],
extend_seq_lens_cpu_int[bs_prefill:],
)
seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
self.forward_metadata.seq_lens_cpu_int_mix = (
seq_lens_cpu_int[:bs_prefill],
seq_lens_cpu_int[bs_prefill:],
)
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
self.forward_metadata.seq_lens_list_cumsum_mix = (
np.cumsum(extend_seq_lens_cpu[:bs_prefill]),
np.cumsum(extend_seq_lens_cpu[bs_prefill:]),
)

self.release_pages = torch.cat((free_page_indices, self.release_pages))
else:
self.free_pages = torch.cat((free_page_indices, self.free_pages))
self.free_pages = torch.unique(torch.sort(self.free_pages)[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The line self.free_pages = torch.unique(torch.sort(self.free_pages)[0]) is a bit redundant. torch.unique by default returns sorted unique elements. You can simplify this to self.free_pages = torch.unique(self.free_pages).

Suggested change
self.free_pages = torch.unique(torch.sort(self.free_pages)[0])
self.free_pages = torch.unique(self.free_pages)

@ping1jing2 ping1jing2 closed this Nov 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants