Skip to content

unittest: Add SM arch checks to skip unsupported tests on Hopper#1998

Merged
bkryu merged 1 commit intoflashinfer-ai:mainfrom
bkryu:test_script_sm_check
Oct 28, 2025
Merged

unittest: Add SM arch checks to skip unsupported tests on Hopper#1998
bkryu merged 1 commit intoflashinfer-ai:mainfrom
bkryu:test_script_sm_check

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Oct 28, 2025

📌 Description

A number of unit tests fail on Hopper because they either do not have a support-check or fail based on "what is not supported" while missing SM90. Current PR adds checks based on "what is supported" and skips if not in the supported list of SMs.

Special case of mm_fp4 where mm_fp4.is_backend_supported(backend, compute_capability_number) now exists and is used to skip tests if not supported.

Impacted tests:

  • tests/attention/test_trtllm_gen_attention.py
  • tests/attention/test_trtllm_gen_mla.py
  • tests/gemm/test_bmm_fp8.py
  • tests/gemm/test_mm_fp4.py
  • tests/gemm/test_groupwise_scaled_gemm_fp8.py
  • tests/gemm/test_groupwise_scaled_gemm_mxfp4.py
  • tests/moe/test_trtllm_gen_fused_moe.py

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

@bkryu
Copy link
Collaborator Author

bkryu commented Oct 28, 2025

/bot run

@bkryu bkryu marked this pull request as ready for review October 28, 2025 18:35
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 28, 2025

Walkthrough

This pull request restricts GPU compute capability filters across multiple test suites, narrowing test execution from a broader set of architectures to SM100/SM103 GPUs only. Changes include replacing skip conditions and adding new capability guards consistently across attention, GEMM, and MOE test modules.

Changes

Cohort / File(s) Summary
Attention Tests
tests/attention/test_trtllm_gen_attention.py, tests/attention/test_trtllm_gen_mla.py
Replaced GPU compute capability guards to skip tests unless SM100/SM103 are present, narrowing supported architectures from excluding SM110/SM120/SM121 to allowing only SM100/SM103.
GEMM Tests: Core Logic
tests/gemm/test_bmm_fp8.py, tests/gemm/test_mm_fp4.py
Introduced computed capability variables and added skip conditions for unsupported compute capabilities. For cutlass backend in bmm_fp8, skip when SM not in [10, 11, 12]. For mm_fp4, skip when backend unsupported for the given compute capability.
GEMM Tests: Groupwise
tests/gemm/test_groupwise_scaled_gemm_fp8.py, tests/gemm/test_groupwise_scaled_gemm_mxfp4.py
Added runtime GPU capability checks early in test execution; skip tests for unsupported architectures. Blockscale test requires SM100/103, 110, or 120/121; groupwise tests restrict to SM100/103 for cutlass backend; mxfp4 test restricts to SM100/103 only.
MOE Tests
tests/moe/test_trtllm_gen_fused_moe.py
Updated skip condition in test_moe_quantization_classes from excluding SM110/SM120/SM121 to allowing only SM100/SM103 GPUs.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Areas requiring attention:

  • Verify SM compute capability constants are correctly mapped (SM100/SM103 → [10], SM110 → [11], SM120/SM121 → [12])
  • Cross-check that xfail behaviors (particularly in test_bmm_fp8) are preserved as intended for specific SM versions
  • Ensure skip messages are descriptive and accurate across all test files
  • Validate control flow in multi-test files like test_groupwise_scaled_gemm_fp8, where capability checks are positioned correctly relative to backend-specific logic

Suggested reviewers

  • cyx-6
  • yzh119

Poem

🐰 GPUs align, a narrower range,
SM100, SM103—no need to change!
Tests now focus, skip the rest with grace,
Compute caps dance in a tighter space! 🎯

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The PR description is almost entirely empty, consisting only of the template structure with all placeholder comments and no substantive content. The critical sections—Description, Related Issues, and the verification checklist items—are all unfilled or unchecked, providing no information about what the PR does, why the changes are needed, which issues are addressed, or whether pre-commit checks and tests have been verified. This fails to meet the basic requirements of a complete PR description.
Docstring Coverage ⚠️ Warning Docstring coverage is 10.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The PR title "unittest: Add SM arch checks to skip unsupported tests on Hopper" accurately and concisely describes the main objective of the changeset. The raw summary shows that across multiple test files, GPU compute capability checks have been added to skip tests on unsupported GPU architectures, which aligns directly with the title's description of adding SM (streaming multiprocessor) architecture checks. The title is specific, clear, and avoids vague terminology, making it useful for scanning PR history.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

@bkryu bkryu self-assigned this Oct 28, 2025
@flashinfer-bot
Copy link
Collaborator

GitLab MR !95 has been created, and the CI pipeline #37467216 is currently running. I'll report back once the pipeline job completes.

@bkryu bkryu assigned bkryu and unassigned bkryu Oct 28, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_attention.py (1)

635-637: Consider moving the SM check immediately after compute_capability retrieval.

For consistency with test_trtllm_batch_prefill (Lines 351-353) and to fail fast, consider placing the SM architecture check directly after Line 635 without the intervening blank line. This ensures early exit before any subsequent validation logic.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8ad64e2 and eafc9a9.

📒 Files selected for processing (7)
  • tests/attention/test_trtllm_gen_attention.py (2 hunks)
  • tests/attention/test_trtllm_gen_mla.py (1 hunks)
  • tests/gemm/test_bmm_fp8.py (1 hunks)
  • tests/gemm/test_groupwise_scaled_gemm_fp8.py (4 hunks)
  • tests/gemm/test_groupwise_scaled_gemm_mxfp4.py (1 hunks)
  • tests/gemm/test_mm_fp4.py (1 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/gemm/test_bmm_fp8.py (1)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
tests/gemm/test_mm_fp4.py (2)
flashinfer/gemm.py (1)
  • mm_fp4 (1858-2009)
flashinfer/utils.py (1)
  • is_backend_supported (953-964)
tests/gemm/test_groupwise_scaled_gemm_fp8.py (1)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
🔇 Additional comments (12)
tests/attention/test_trtllm_gen_attention.py (2)

352-353: Verify SM architecture restriction is intentional.

The test now only runs on SM100/SM103 GPUs (compute_capability[0] == 10), excluding all other architectures. This is a significant restriction compared to the previous logic that excluded only specific SM versions. Confirm this aligns with the supported hardware for trtllm_batch_context_with_kv_cache.


920-922: LGTM!

The SM architecture check is correctly placed for early exit, consistent with the pattern in test_trtllm_batch_prefill.

tests/gemm/test_mm_fp4.py (1)

29-34: LGTM!

Good addition of backend support validation using the centralized is_backend_supported() API. This provides early gating based on compute capability and backend combination, improving test clarity and preventing unsupported configurations from proceeding to more specific validation logic.

tests/attention/test_trtllm_gen_mla.py (1)

36-37: LGTM!

The SM architecture restriction is correctly implemented and consistent with other trtllm attention tests in this PR. The check appropriately gates execution to SM100/SM103 only.

tests/gemm/test_groupwise_scaled_gemm_mxfp4.py (1)

259-262: LGTM!

The SM architecture restriction correctly limits gemm_mxfp4_nt_groupwise to SM100/SM103 GPUs. This aligns with the broader PR pattern of tightening compute capability requirements for specialized kernels.

tests/moe/test_trtllm_gen_fused_moe.py (1)

2038-2039: LGTM!

The SM architecture check appropriately restricts the MoE test to SM100/SM103 GPUs, consistent with the compute capability gating pattern used throughout this PR.

tests/gemm/test_bmm_fp8.py (2)

20-21: LGTM!

Good refactoring to retrieve compute_capability once and reuse it, improving efficiency and readability.


29-32: LGTM!

The cutlass backend SM architecture gating is correctly implemented, restricting execution to SM100/103, SM110, and SM120/121 GPUs. This complements the existing xfail for known SM120/121 issues and provides clear skip messaging for unsupported architectures.

tests/gemm/test_groupwise_scaled_gemm_fp8.py (4)

46-50: LGTM!

The SM architecture check appropriately gates gemm_fp8_nt_blockscaled to SM100/103, SM110, and SM120/121 GPUs with a clear skip message.


91-91: LGTM!

Good practice to retrieve compute_capability early for reuse in subsequent conditional logic.


101-104: LGTM!

The cutlass backend gating for gemm_fp8_nt_groupwise correctly restricts execution to supported SM architectures (SM100/103, SM110, and SM120/121).


157-167: Verify SM110 exclusion is intentional for group_gemm_fp8_nt_groupwise.

Line 164-167 restricts group_gemm_fp8_nt_groupwise to SM100/103 and SM120/121 only, excluding SM110 unlike the other FP8 tests in this file (Lines 47, 102). If this exclusion is intentional due to specific SM110 limitations for grouped GEMMs, consider adding a brief inline comment explaining the rationale.

@bkryu bkryu requested a review from yzh119 October 28, 2025 18:45
@bkryu bkryu merged commit c857f09 into flashinfer-ai:main Oct 28, 2025
4 checks passed
@bkryu bkryu deleted the test_script_sm_check branch October 28, 2025 21:21
@coderabbitai coderabbitai bot mentioned this pull request Dec 4, 2025
5 tasks
@coderabbitai coderabbitai bot mentioned this pull request Jan 30, 2026
5 tasks
@coderabbitai coderabbitai bot mentioned this pull request Feb 11, 2026
5 tasks
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
…shinfer-ai#1998)

<!-- .github/pull_request_template.md -->

A number of unit tests fail on Hopper because they either do not have a
support-check or fail based on "what is not supported" while missing
SM90. Current PR adds checks based on "what is supported" and skips if
not in the supported list of SMs.

Special case of `mm_fp4` where `mm_fp4.is_backend_supported(backend,
compute_capability_number)` now exists and is used to skip tests if not
supported.

Impacted tests:
* tests/attention/test_trtllm_gen_attention.py
* tests/attention/test_trtllm_gen_mla.py
* tests/gemm/test_bmm_fp8.py
* tests/gemm/test_mm_fp4.py
* tests/gemm/test_groupwise_scaled_gemm_fp8.py
* tests/gemm/test_groupwise_scaled_gemm_mxfp4.py
* tests/moe/test_trtllm_gen_fused_moe.py

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants