Skip to content

[perf][refactor] Refactor and optimize sfa_v1.py for dsv3.2/glm5#6874

Merged
wangxiyuan merged 4 commits intovllm-project:mainfrom
rjg-lyh:pr-sfa-multistream
Mar 5, 2026
Merged

[perf][refactor] Refactor and optimize sfa_v1.py for dsv3.2/glm5#6874
wangxiyuan merged 4 commits intovllm-project:mainfrom
rjg-lyh:pr-sfa-multistream

Conversation

@rjg-lyh
Copy link
Copy Markdown
Collaborator

@rjg-lyh rjg-lyh commented Feb 28, 2026

What this PR does / why we need it?

This PR refactors sfa_v1.py to improve code readability and usability, fixes a code bug, and enhances performance through the replacement of certain operators.

changes

  • improve code readability: Optimizes parts of the code structure in sfa_v1.py, supplementary comments for key code blocks, removes some unused variables, and improves the naming of certain functions and variables.

  • resolved a duplicated double write to k_cache: Fixed redundant double writes of k_cache in the indexer_select module (in both the forward function and indexer_select_post_process), improving performance to some extent.

  • replace scatter ops with reshape_and_cache: This optimization replaces two separate cache storage operations on k_nope and k_pe with a single call to the reshape_and_cache operator, improving performance. The original scatter operator involves reordering slot_mapping for generality, introducing significant scalar computations. In contrast, the reshape_and_cache operator eliminates this redundant reordering step, thus reducing unnecessary computation time and enhancing the operator's performance.

performance comparison

4*A3, 1P1D, P dp2tp16, D dp8tp4, input/output: 64K/3K
origin:
TTFT: 28s, TPOT: 26ms, TPS: 820 token/s

fixed redundant double writes of k_cache:
TTFT: 24s, TPOT: 26ms, TPS: 840 token/s

replace scatter ops with reshape_and_cache:
TTFT: 24s, TPOT: 26ms, TPS: 850 token/s

Does this PR introduce any user-facing change?

No.

How was this patch tested?

CI passed with new added/existing test.

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces multistream overlap support for SFA models, specifically targeting dsv3.2/glm5. It refactors the attention mechanism implementation by clarifying layer sharding configurations and streamlining data preprocessing, specifically by removing conditional all-gather operations to optimize performance and simplify the codebase.

Highlights

  • Refactored Layer Sharding Logic: The flag enable_dsa_cp_prefill_only was renamed to enable_layer_sharding_in_dsa_cp and its usage updated across the SFA implementation, improving clarity and scope of the layer sharding feature within DSA-CP.
  • Simplified Data Preprocessing: The need_gather_q_kv parameter and all associated torch.ops.vllm.maybe_all_gather_and_maybe_unpad calls were removed, streamlining the data flow for Q/K/V projections and weights within the attention mechanism.
  • Updated MLAPO Comments: Comments related to the MLAPO operator were clarified, detailing its function in fusing Q/K/V pre-processing steps and noting its limitations regarding input tokens and conflicts with FlashComm.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • vllm_ascend/attention/sfa_v1.py
    • Updated comments for the enable_mlapo configuration.
    • Renamed the enable_dsa_cp_prefill_only flag to enable_layer_sharding_in_dsa_cp and adjusted its usage throughout the class.
    • Removed the need_gather_q_kv parameter from several function signatures, including _sfa_preprocess_decode, indexer_select_pre_process, and indexer_select_post_process.
    • Eliminated conditional torch.ops.vllm.maybe_all_gather_and_maybe_unpad calls that previously depended on the need_gather_q_kv flag.
    • Adjusted the conditional logic for layer sharding setup and weight processing to utilize the newly named enable_layer_sharding_in_dsa_cp flag.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the AscendSFAImpl in vllm_ascend/attention/sfa_v1.py. The changes primarily involve renaming enable_dsa_cp_prefill_only to enable_layer_sharding_in_dsa_cp for improved clarity, and removing the need_gather_q_kv parameter along with its associated logic from several methods. These modifications enhance code readability and maintainability. My review includes one suggestion to complete the refactoring by removing a now-unused parameter from a method signature.

Comment thread vllm_ascend/attention/sfa_v1.py Outdated
Comment on lines -757 to -758
actual_seq_lengths_query = attn_metadata.cum_query_lens
actual_seq_lengths_key = attn_metadata.seq_lens
if self.enable_dsa_cp:
need_gather_q_kv = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

With the removal of this logic, the need_gather_q_kv parameter is no longer used within the forward method. To improve code clarity and maintainability, it should also be removed from the method's signature.

@rjg-lyh rjg-lyh force-pushed the pr-sfa-multistream branch 3 times, most recently from 6222e0d to b4190fe Compare March 2, 2026 09:17
@rjg-lyh rjg-lyh changed the title [main][feature] multistream support for dsv3.2/glm5 [main][refactor] sfa_v1.py refactor for dsv3.2/glm5 Mar 2, 2026
@rjg-lyh rjg-lyh force-pushed the pr-sfa-multistream branch 2 times, most recently from fbf9816 to 68324cd Compare March 2, 2026 10:30
@rjg-lyh rjg-lyh requested a review from wangxiyuan as a code owner March 2, 2026 12:27
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 2, 2026

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@rjg-lyh rjg-lyh force-pushed the pr-sfa-multistream branch from 877a15d to c5870ed Compare March 3, 2026 01:54
Comment thread vllm_ascend/attention/sfa_v1.py Outdated
)

k = self._get_full_kv(k, attn_metadata)
if kv_cache is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be executed in MLAPO case.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok,i know it now.

@rjg-lyh rjg-lyh force-pushed the pr-sfa-multistream branch 2 times, most recently from 77da13a to 0eb68d9 Compare March 3, 2026 02:21
@rjg-lyh rjg-lyh force-pushed the pr-sfa-multistream branch 2 times, most recently from 834f333 to 8cda6b3 Compare March 3, 2026 02:37
@rjg-lyh rjg-lyh changed the title [main][refactor] sfa_v1.py refactor for dsv3.2/glm5 [perf][refactor] sfa_v1.py refactor and optim for dsv3.2/glm5 Mar 3, 2026
@rjg-lyh rjg-lyh changed the title [perf][refactor] sfa_v1.py refactor and optim for dsv3.2/glm5 [perf][refactor] Refactor and optimize sfa_v1.py for dsv3.2/glm5 Mar 3, 2026
@rjg-lyh rjg-lyh force-pushed the pr-sfa-multistream branch 5 times, most recently from b7db6c4 to 388f38a Compare March 3, 2026 08:29
rjg-lyh added 3 commits March 3, 2026 16:29
Signed-off-by: rjg-lyh <1318825571@qq.com>
Signed-off-by: rjg-lyh <1318825571@qq.com>
@rjg-lyh rjg-lyh force-pushed the pr-sfa-multistream branch 6 times, most recently from 0fbc470 to e4c7c99 Compare March 3, 2026 09:50
Signed-off-by: rjg-lyh <1318825571@qq.com>
@rjg-lyh rjg-lyh force-pushed the pr-sfa-multistream branch from e4c7c99 to a4b5285 Compare March 3, 2026 10:06
@rjg-lyh rjg-lyh added ready read for review ready-for-test start test by label for PR labels Mar 5, 2026
@wangxiyuan wangxiyuan merged commit 2bd9c35 into vllm-project:main Mar 5, 2026
62 of 64 checks passed
ZYang6263 pushed a commit to yydyzr/vllm-ascend that referenced this pull request Mar 5, 2026
…m-project#6874)

### What this PR does / why we need it?
This PR refactors sfa_v1.py to improve code readability and usability,
fixes a code bug, and enhances performance through the replacement of
certain operators.

### changes
- **improve code readability**: Optimizes parts of the code structure in
sfa_v1.py, supplementary comments for key code blocks, removes some
unused variables, and improves the naming of certain functions and
variables.

- **resolved a duplicated double write to k_cache**: Fixed redundant
double writes of k_cache in the indexer_select module (in both the
`forward` function and `indexer_select_post_process`), improving
performance to some extent.

- **replace `scatter` ops with `reshape_and_cache`**: This optimization
replaces two separate cache storage operations on `k_nope` and `k_pe`
with a single call to the `reshape_and_cache` operator, improving
performance. The original `scatter` operator involves reordering
slot_mapping for generality, introducing significant scalar
computations. In contrast, the `reshape_and_cache` operator eliminates
this redundant reordering step, thus reducing unnecessary computation
time and enhancing the operator's performance.

### performance comparison
4*A3, 1P1D, P dp2tp16, D dp8tp4, input/output: 64K/3K
origin:
TTFT: **28s**, TPOT: 26ms, TPS: **820 token/s**

fixed redundant double writes of k_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **840 token/s**

replace scatter ops with reshape_and_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **850 token/s**

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@15d76f7

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
ZYang6263 added a commit to yydyzr/vllm-ascend that referenced this pull request Mar 5, 2026
[perf][refactor] Refactor and optimize sfa_v1.py for dsv3.2/glm5 (vllm-project#6874)
liuchenbing2026 pushed a commit to liuchenbing2026/vllm-ascend that referenced this pull request Mar 5, 2026
…m-project#6874)

### What this PR does / why we need it?
This PR refactors sfa_v1.py to improve code readability and usability,
fixes a code bug, and enhances performance through the replacement of
certain operators.

### changes
- **improve code readability**: Optimizes parts of the code structure in
sfa_v1.py, supplementary comments for key code blocks, removes some
unused variables, and improves the naming of certain functions and
variables.

- **resolved a duplicated double write to k_cache**: Fixed redundant
double writes of k_cache in the indexer_select module (in both the
`forward` function and `indexer_select_post_process`), improving
performance to some extent.

- **replace `scatter` ops with `reshape_and_cache`**: This optimization
replaces two separate cache storage operations on `k_nope` and `k_pe`
with a single call to the `reshape_and_cache` operator, improving
performance. The original `scatter` operator involves reordering
slot_mapping for generality, introducing significant scalar
computations. In contrast, the `reshape_and_cache` operator eliminates
this redundant reordering step, thus reducing unnecessary computation
time and enhancing the operator's performance.

### performance comparison
4*A3, 1P1D, P dp2tp16, D dp8tp4, input/output: 64K/3K
origin:
TTFT: **28s**, TPOT: 26ms, TPS: **820 token/s**

fixed redundant double writes of k_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **840 token/s**

replace scatter ops with reshape_and_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **850 token/s**

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@15d76f7

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
liuchenbing2026 pushed a commit to liuchenbing2026/vllm-ascend that referenced this pull request Mar 5, 2026
…m-project#6874)

### What this PR does / why we need it?
This PR refactors sfa_v1.py to improve code readability and usability,
fixes a code bug, and enhances performance through the replacement of
certain operators.

### changes
- **improve code readability**: Optimizes parts of the code structure in
sfa_v1.py, supplementary comments for key code blocks, removes some
unused variables, and improves the naming of certain functions and
variables.

- **resolved a duplicated double write to k_cache**: Fixed redundant
double writes of k_cache in the indexer_select module (in both the
`forward` function and `indexer_select_post_process`), improving
performance to some extent.

- **replace `scatter` ops with `reshape_and_cache`**: This optimization
replaces two separate cache storage operations on `k_nope` and `k_pe`
with a single call to the `reshape_and_cache` operator, improving
performance. The original `scatter` operator involves reordering
slot_mapping for generality, introducing significant scalar
computations. In contrast, the `reshape_and_cache` operator eliminates
this redundant reordering step, thus reducing unnecessary computation
time and enhancing the operator's performance.

### performance comparison
4*A3, 1P1D, P dp2tp16, D dp8tp4, input/output: 64K/3K
origin:
TTFT: **28s**, TPOT: 26ms, TPS: **820 token/s**

fixed redundant double writes of k_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **840 token/s**

replace scatter ops with reshape_and_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **850 token/s**

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@15d76f7

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…m-project#6874)

### What this PR does / why we need it?
This PR refactors sfa_v1.py to improve code readability and usability,
fixes a code bug, and enhances performance through the replacement of
certain operators.

### changes
- **improve code readability**: Optimizes parts of the code structure in
sfa_v1.py, supplementary comments for key code blocks, removes some
unused variables, and improves the naming of certain functions and
variables.

- **resolved a duplicated double write to k_cache**: Fixed redundant
double writes of k_cache in the indexer_select module (in both the
`forward` function and `indexer_select_post_process`), improving
performance to some extent.

- **replace `scatter` ops with `reshape_and_cache`**: This optimization
replaces two separate cache storage operations on `k_nope` and `k_pe`
with a single call to the `reshape_and_cache` operator, improving
performance. The original `scatter` operator involves reordering
slot_mapping for generality, introducing significant scalar
computations. In contrast, the `reshape_and_cache` operator eliminates
this redundant reordering step, thus reducing unnecessary computation
time and enhancing the operator's performance.

### performance comparison
4*A3, 1P1D, P dp2tp16, D dp8tp4, input/output: 64K/3K
origin:
TTFT: **28s**, TPOT: 26ms, TPS: **820 token/s**

fixed redundant double writes of k_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **840 token/s**

replace scatter ops with reshape_and_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **850 token/s**

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@15d76f7

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
yydyzr added a commit to yydyzr/vllm-ascend that referenced this pull request Mar 9, 2026
yydyzr added a commit to yydyzr/vllm-ascend that referenced this pull request Mar 9, 2026
Revert "[perf][refactor] Refactor and optimize sfa_v1.py for dsv3.2/glm5 (vllm-project#6874)"
liuchenbing2026 pushed a commit to liuchenbing2026/vllm-ascend that referenced this pull request Mar 10, 2026
…m-project#6874)

### What this PR does / why we need it?
This PR refactors sfa_v1.py to improve code readability and usability,
fixes a code bug, and enhances performance through the replacement of
certain operators.

### changes
- **improve code readability**: Optimizes parts of the code structure in
sfa_v1.py, supplementary comments for key code blocks, removes some
unused variables, and improves the naming of certain functions and
variables.

- **resolved a duplicated double write to k_cache**: Fixed redundant
double writes of k_cache in the indexer_select module (in both the
`forward` function and `indexer_select_post_process`), improving
performance to some extent.

- **replace `scatter` ops with `reshape_and_cache`**: This optimization
replaces two separate cache storage operations on `k_nope` and `k_pe`
with a single call to the `reshape_and_cache` operator, improving
performance. The original `scatter` operator involves reordering
slot_mapping for generality, introducing significant scalar
computations. In contrast, the `reshape_and_cache` operator eliminates
this redundant reordering step, thus reducing unnecessary computation
time and enhancing the operator's performance.

### performance comparison
4*A3, 1P1D, P dp2tp16, D dp8tp4, input/output: 64K/3K
origin:
TTFT: **28s**, TPOT: 26ms, TPS: **820 token/s**

fixed redundant double writes of k_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **840 token/s**

replace scatter ops with reshape_and_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **850 token/s**

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@15d76f7

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
yydyzr pushed a commit to yydyzr/vllm-ascend that referenced this pull request Mar 10, 2026
…m-project#6874)

### What this PR does / why we need it?
This PR refactors sfa_v1.py to improve code readability and usability,
fixes a code bug, and enhances performance through the replacement of
certain operators.

### changes
- **improve code readability**: Optimizes parts of the code structure in
sfa_v1.py, supplementary comments for key code blocks, removes some
unused variables, and improves the naming of certain functions and
variables.

- **resolved a duplicated double write to k_cache**: Fixed redundant
double writes of k_cache in the indexer_select module (in both the
`forward` function and `indexer_select_post_process`), improving
performance to some extent.

- **replace `scatter` ops with `reshape_and_cache`**: This optimization
replaces two separate cache storage operations on `k_nope` and `k_pe`
with a single call to the `reshape_and_cache` operator, improving
performance. The original `scatter` operator involves reordering
slot_mapping for generality, introducing significant scalar
computations. In contrast, the `reshape_and_cache` operator eliminates
this redundant reordering step, thus reducing unnecessary computation
time and enhancing the operator's performance.

### performance comparison
4*A3, 1P1D, P dp2tp16, D dp8tp4, input/output: 64K/3K
origin:
TTFT: **28s**, TPOT: 26ms, TPS: **820 token/s**

fixed redundant double writes of k_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **840 token/s**

replace scatter ops with reshape_and_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **850 token/s**

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@15d76f7

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
yydyzr added a commit to yydyzr/vllm-ascend that referenced this pull request Mar 10, 2026
[perf][refactor] Refactor and optimize sfa_v1.py for dsv3.2/glm5 (vllm-project#6874)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants