Skip to content

[DSR1] Added MLA test#2100

Merged
nvmbreughe merged 4 commits intoflashinfer-ai:mainfrom
nvmbreughe:mbreughe/dsr1_mla_test
Nov 19, 2025
Merged

[DSR1] Added MLA test#2100
nvmbreughe merged 4 commits intoflashinfer-ai:mainfrom
nvmbreughe:mbreughe/dsr1_mla_test

Conversation

@nvmbreughe
Copy link
Copy Markdown
Contributor

@nvmbreughe nvmbreughe commented Nov 17, 2025

📌 Description

Added DSR1 MLA test, and split up the trtllm_batch_decode_mla function.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Tests
    • Improved test suite for batch decoding by making maximum sequence length configurable, adding parameterized runs across short and long lengths, and introducing a compatibility wrapper to preserve legacy behavior. This enhances coverage and validation across varied sequence-length scenarios.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Nov 17, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

The test in tests/attention/test_trtllm_gen_mla.py was refactored: the core test now accepts a MAX_SEQ_LEN parameter, the original test became a wrapper calling the core with MAX_SEQ_LEN=1024, and a new parameterized test runs the core for MAX_SEQ_LEN values 1024 and 8960.

Changes

Cohort / File(s) Summary
Test function refactor & parameterization
tests/attention/test_trtllm_gen_mla.py
Renamed core test to trtllm_batch_decode_mla and added MAX_SEQ_LEN: int parameter; added wrapper test_trtllm_batch_decode_mla() that calls core with MAX_SEQ_LEN=1024; added test_dsr1_trtllm_mla which parameterizes MAX_SEQ_LEN over [1024, 8960].

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

  • Single test file changed with a consistent refactor and added parameterization.
  • Check call sites in the same file and any test discovery/CI implications.

Possibly related PRs

  • feat: add xqa mla backend #2053: Modifies the same test file and related test function signatures, indicating related changes to test_trtllm_batch_decode_mla/trtllm_batch_decode_mla.

Suggested reviewers

  • cyx-6
  • wenscarl
  • bkryu

Poem

🐰 Hop, hop, the tests now bend and flex,
With MAX\_SEQ\_LEN joining the specs,
One wraps at 1024, another roams far,
Parametrized hops reach 8960 stars!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title '[DSR1] Added MLA test' is specific and directly describes the main changes—adding a DSR1 MLA test and refactoring the trtllm_batch_decode_mla function.
Description check ✅ Passed The description uses the required template and includes a clear summary of changes. The Description section explains what was added and why. Pre-commit and test checklists are mostly completed with relevant items checked.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 936fc69 and 3615d50.

📒 Files selected for processing (1)
  • tests/attention/test_trtllm_gen_mla.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_mla.py (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
🔇 Additional comments (3)
tests/attention/test_trtllm_gen_mla.py (3)

12-22: Refactoring approach looks good.

The extraction of the core test logic into a reusable helper function is a clean approach that enables testing with different MAX_SEQ_LEN values while maintaining backward compatibility through the wrapper.

Note: A previous review comment correctly identified that MAX_SEQ_LEN should follow PEP 8 naming conventions (max_seq_len). Since that feedback is already provided, I won't duplicate it here.


229-262: LGTM: Wrapper preserves original test coverage.

The test wrapper correctly delegates to the helper function with MAX_SEQ_LEN=1024, maintaining the original test behavior while enabling the refactored structure.


265-298: LGTM: DSR1 test adds appropriate coverage.

The new test correctly exercises the MLA decode path with larger sequence lengths (8960) relevant to DSR1 scenarios. The reduced parameter space (smaller batch_size range, single backend) is reasonable for keeping test execution time manageable.

Note: A previous review comment correctly identified the PEP 8 naming convention issue with MAX_SEQ_LEN. Since that feedback is already provided, I won't duplicate it here.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @nvmbreughe, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the testing infrastructure for TRT-LLM batch decode MLA by refactoring an existing test into a more modular and reusable helper function. This change facilitates the addition of a new, dedicated test for DSR1 MLA, enabling comprehensive validation across a wider range of sequence lengths and batch configurations. The primary goal is to ensure the robustness and correctness of the DSR1 MLA implementation within the TRT-LLM framework.

Highlights

  • Test Refactoring: The original test_trtllm_batch_decode_mla function has been refactored into a reusable helper function named trtllm_batch_decode_mla, removing its direct pytest.mark.parametrize decorators.
  • New DSR1 MLA Test: A new test function, test_dsr1_trtllm_mla, has been introduced to specifically validate DSR1 MLA (Multi-Layer Attention) functionality across different MAX_SEQ_LEN values (1024 and 8960) and varying batch sizes.
  • Parameterization Update: The MAX_SEQ_LEN parameter is now explicitly passed to the core trtllm_batch_decode_mla function, allowing for more flexible testing scenarios.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the test_trtllm_batch_decode_mla test by splitting it into a helper function and a test wrapper. It also introduces a new test, test_dsr1_trtllm_mla, to validate Multi-head Latent Attention (MLA) with different sequence lengths. The changes are well-structured and improve the test suite's coverage and maintainability. My feedback focuses on adhering to Python's naming conventions for consistency and readability.

dynamic_scale: bool,
enable_pdl: bool,
backend: str,
MAX_SEQ_LEN: int,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Per PEP 8, function parameter names should be in lowercase_with_underscores. Please rename MAX_SEQ_LEN to max_seq_len for consistency. You will also need to update its usage within the function body.

Suggested change
MAX_SEQ_LEN: int,
max_seq_len: int,

Comment on lines +295 to +317
@pytest.mark.parametrize("MAX_SEQ_LEN", [1024, 8960])
def test_dsr1_trtllm_mla(
batch_size: int,
scale: float,
dtype: torch.dtype,
page_size: int,
q_len_per_request: int,
dynamic_scale: bool,
enable_pdl: bool,
backend: str,
MAX_SEQ_LEN: int,
):
trtllm_batch_decode_mla(
batch_size,
scale,
dtype,
page_size,
q_len_per_request,
dynamic_scale,
enable_pdl,
backend,
MAX_SEQ_LEN,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Following PEP 8 style guidelines, parameter and variable names should be in lowercase_with_underscores. Please rename MAX_SEQ_LEN to max_seq_len in the parametrize decorator, the test function signature, and the call to the helper function.

@pytest.mark.parametrize("max_seq_len", [1024, 8960])
def test_dsr1_trtllm_mla(
    batch_size: int,
    scale: float,
    dtype: torch.dtype,
    page_size: int,
    q_len_per_request: int,
    dynamic_scale: bool,
    enable_pdl: bool,
    backend: str,
    max_seq_len: int,
):
    trtllm_batch_decode_mla(
        batch_size,
        scale,
        dtype,
        page_size,
        q_len_per_request,
        dynamic_scale,
        enable_pdl,
        backend,
        max_seq_len,
    )

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, cc @qsang-nv for viz.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Nov 18, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !145 has been created, and the CI pipeline #38686503 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #38686503: 6/18 passed

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Nov 18, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !145 has been updated with latest changes, and the CI pipeline #38743391 is currently running. I'll report back once the pipeline job completes.

@nvmbreughe nvmbreughe enabled auto-merge (squash) November 18, 2025 23:12
@nvmbreughe nvmbreughe merged commit 219592b into flashinfer-ai:main Nov 19, 2025
4 checks passed
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

Added DSR1 MLA test, and split up the trtllm_batch_decode_mla function.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Tests**
* Improved test suite for batch decoding by making maximum sequence
length configurable, adding parameterized runs across short and long
lengths, and introducing a compatibility wrapper to preserve legacy
behavior. This enhances coverage and validation across varied
sequence-length scenarios.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Zihao Ye <expye@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants