Skip to content

bugfix: fix failed unittest test_green_ctx and test_jit_example on spark (sm_121)#1951

Merged
yzh119 merged 6 commits intoflashinfer-ai:mainfrom
yzh119:fix-spark-unittest
Nov 5, 2025
Merged

bugfix: fix failed unittest test_green_ctx and test_jit_example on spark (sm_121)#1951
yzh119 merged 6 commits intoflashinfer-ai:mainfrom
yzh119:fix-spark-unittest

Conversation

@yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Oct 20, 2025

📌 Description

There are three failed unittests on spark (sm_121):

  • tests/utils/test_green_ctx.py
  • tests/utils/test_jit_example.py
  • tests/utils/test_sampling.py

First one is because spark has small number of SMs (48) and we don't have a guard on green context splitting.
Second one is an unknown issue (logits don't match with reference) and probably related to barriers on sm_121, xfail now and will fix later.

The last one will be fixed by another PR from @bkryu , this PR fixes the first two issues.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Tests

    • Tests now pre-check GPU resources and auto-skip with informative messages including available and requested SM counts to avoid spurious failures.
    • Added a conditional xfail for GPUs with compute capability 12.1 to avoid false negatives on that hardware.
    • Tightened a sampling test by adding a relative tolerance for more robust numerical validation.
  • Bug Fixes

    • Improved runtime error handling to surface clearer guidance when GPU SM resources are insufficient.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses and resolves three specific unittest failures encountered on Spark environments (sm_121). It implements targeted adjustments to test logic, introduces conditional test skipping based on available hardware resources, and temporarily marks one test as an expected failure due to a known numerical accuracy issue. The overall goal is to enhance test suite stability and reliability on Spark without compromising the integrity of the tests.

Highlights

  • Green Context Tests Stability: Introduced checks to skip green_ctx tests on Spark environments if the required number of Streaming Multiprocessors (SMs) exceeds the available SMs, preventing failures on devices with limited SMs.
  • JIT Example Test XFAIL: Marked test_dump_logits as an expected failure (xfail) specifically for SM 121 (Spark) due to an unresolved numerical accuracy issue, allowing the CI to pass while deferring a full fix.
  • Sampling Test Tolerance Adjustment: Increased the numerical tolerance for test_softmax in test_sampling.py by adding a relative tolerance (rtol) to torch.allclose, resolving failures caused by reduction size differences.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 20, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Add CUDA-runtime guards and validation around green-context splitting to detect insufficient SMs and resource errors, raise clearer RuntimeError messages, and have tests skip on those conditions; add an xfail for SM 12.1 in a JIT test and tighten a softmax numeric assertion to include rtol.

Changes

Cohort / File(s) Summary
Green context logic
flashinfer/green_ctx.py
Added try/except handling around split_device_green_ctx and split_device_green_ctx_by_sm_count to catch CUDA RuntimeError scenarios, re-raise clearer RuntimeError messages including available vs requested SMs and remediation hints, validate/round SM counts and use split_resource_by_sm_count. No public API signature changes.
Green context tests
tests/utils/test_green_ctx.py
Wrapped calls to split functions in try/except catching RuntimeError; on CUDA resource errors tests call pytest.skip(...) with device SMs and requested parameters. Applied to multiple tests: creation, kernel execution, split-by-sm-count creation/execution/alignment.
JIT example test
tests/utils/test_jit_example.py
Imported get_compute_capability and added an xfail marker for test_dump_logits when get_compute_capability(cuda:0) == (12, 1) (SM 12.1) due to numerical accuracy differences.
Sampling test
tests/utils/test_sampling.py
Tightened numeric comparison in test_softmax to torch.allclose(..., atol=1e-5, rtol=1e-5) (included relative tolerance).

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test function
    participant GreenSplit as flashinfer.green_ctx
    participant DeviceQuery as runtime/device query
    participant PyTest as pytest

    Test->>GreenSplit: call split_device_green_ctx* (groups/min_count or sm_count)
    GreenSplit->>DeviceQuery: query device SMs / resource info
    DeviceQuery-->>GreenSplit: available_sms
    alt runtime error OR required_sms > available_sms
        GreenSplit-->>Test: raise RuntimeError("insufficient SMs / resource config …")
        Test->>PyTest: catch RuntimeError -> pytest.skip(message with device SMs & params)
    else
        GreenSplit-->>Test: return split contexts
        Test->>Test: run kernels and assertions
    end

    Note right of Test: Separate flow — test_jit_example queries compute capability\nand marks xfail for SM 12.1
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Review correctness of CUDA error parsing and message matching used by tests.
  • Verify rounding/validation logic for SM counts and the mapping to resource splitting.
  • Confirm no behavioral regressions in normal (non-error) split paths and that messages are robust across CUDA driver versions.

Suggested reviewers

  • yongwww
  • cyx-6
  • wenscarl

Poem

🐇 I counted SMs beneath the night,

I hopped where kernels lost their bite.
If cores are few, I skip the test,
Then nibble logs and take a rest.
A tiny thump — all's put to right.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.57% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The pull request description clearly explains the three failing tests on spark (sm_121), their root causes, and how each is addressed in the PR.
Title check ✅ Passed The title accurately describes the main changes: fixing failed unittests on spark (sm_121), which is the primary objective of the PR.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses three failing unit tests on Spark (sm_121) by adding a guard for SM availability in test_green_ctx.py, marking a test as xfail in test_jit_example.py due to numerical issues, and increasing the tolerance in test_sampling.py. The changes are correct and effectively fix the described issues. I've provided a couple of suggestions for test_green_ctx.py to improve code clarity and reduce duplication.

Comment on lines +20 to +24
total = 0
for sm_count in sm_counts:
rounded = round_up(max(sm_count, min_sm), alignment)
total += rounded
return total
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This for-loop can be expressed more concisely using the built-in sum() function with a generator expression. This is a common Python idiom that improves readability.

Suggested change
total = 0
for sm_count in sm_counts:
rounded = round_up(max(sm_count, min_sm), alignment)
total += rounded
return total
return sum(round_up(max(sm_count, min_sm), alignment) for sm_count in sm_counts)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
tests/utils/test_green_ctx.py (2)

20-24: Consider using built-in sum() for improved readability.

As noted in previous reviews, this for-loop can be expressed more concisely using the built-in sum() function with a generator expression, which is a common Python idiom.

Apply this diff to refactor:

-    total = 0
-    for sm_count in sm_counts:
-        rounded = round_up(max(sm_count, min_sm), alignment)
-        total += rounded
-    return total
+    return sum(round_up(max(sm_count, min_sm), alignment) for sm_count in sm_counts)

36-42: Address the pre-commit formatting failure.

The pipeline indicates a formatting issue that needs to be resolved. Please run pre-commit run --all-files to apply the formatting changes.

Additionally, as noted in previous reviews, this pre-check logic is duplicated across multiple tests. Consider either:

  1. Extracting it into a pytest fixture or helper function
  2. Moving the check into the split_device_green_ctx API itself to raise an exception
🧹 Nitpick comments (1)
tests/utils/test_green_ctx.py (1)

43-45: Prefix unused variable with underscore.

The streams variable is unpacked but never used in this test function. Prefix it with an underscore to indicate it's intentionally unused.

Apply this diff:

-    streams, resources = green_ctx.split_device_green_ctx(
+    _streams, resources = green_ctx.split_device_green_ctx(
         dev, num_groups, min_count
     )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9a65c0e and db585e5.

📒 Files selected for processing (1)
  • tests/utils/test_green_ctx.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/utils/test_green_ctx.py (2)
flashinfer/utils.py (2)
  • get_compute_capability (251-254)
  • get_device_sm_count (595-596)
flashinfer/green_ctx.py (2)
  • get_sm_count_constraint (34-44)
  • split_device_green_ctx (126-178)
🪛 GitHub Actions: pre-commit
tests/utils/test_green_ctx.py

[error] 40-40: ruff-format: 1 file reformatted by this hook. The pre-commit hook failed; please re-run with 'pre-commit run --all-files' to apply formatting changes.


[error] 40-40: Code style formatting changed by ruff-format. Updated call should be: streams, resources = green_ctx.split_device_green_ctx(dev, num_groups, min_count).

🪛 Ruff (0.14.1)
tests/utils/test_green_ctx.py

43-43: Unpacked variable streams is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (6)
tests/utils/test_green_ctx.py (6)

5-5: LGTM!

The imports are necessary for the SM calculation helpers and are correctly placed.


8-13: LGTM!

The helper correctly calculates the total SM count required by rounding up the minimum count to meet alignment requirements and multiplying by the number of groups.


61-67: LGTM!

The pre-check logic correctly validates SM availability before running the test.


97-103: LGTM!

The pre-check correctly uses calculate_required_sms_by_counts to validate SM availability for tests with specific SM counts.


130-136: LGTM!

The pre-check correctly validates SM availability before running the kernel execution test.


165-171: LGTM!

The pre-check correctly validates SM availability before running the alignment test.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/utils/test_green_ctx.py (1)

39-46: Consider consistency in device object creation.

Unlike test_green_ctx_creation (line 15), this test passes torch.device(device) directly without creating a dev variable first. While both approaches work, consistent usage across all tests would improve readability.

Apply this diff for consistency:

+    dev = torch.device(device)
     try:
         streams, resources = green_ctx.split_device_green_ctx(
-            torch.device(device), num_groups, min_count
+            dev, num_groups, min_count
         )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between db585e5 and 89eac51.

📒 Files selected for processing (2)
  • flashinfer/green_ctx.py (3 hunks)
  • tests/utils/test_green_ctx.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/green_ctx.py (1)
flashinfer/utils.py (3)
  • get_compute_capability (251-254)
  • get_device_sm_count (595-596)
  • round_up (589-591)
tests/utils/test_green_ctx.py (1)
flashinfer/green_ctx.py (2)
  • split_device_green_ctx (126-190)
  • split_device_green_ctx_by_sm_count (193-281)
🪛 Ruff (0.14.1)
flashinfer/green_ctx.py

180-183: Avoid specifying long messages outside the exception class

(TRY003)


264-264: Avoid specifying long messages outside the exception class

(TRY003)


272-275: Avoid specifying long messages outside the exception class

(TRY003)

tests/utils/test_green_ctx.py

17-17: Unpacked variable streams is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
tests/utils/test_green_ctx.py (1)

15-23: Good error handling pattern for insufficient SMs.

The try-except block properly catches and skips tests when the device lacks sufficient SMs, which addresses the spark (sm_121) test failures mentioned in the PR objectives.

flashinfer/green_ctx.py (4)

31-31: LGTM! Required import for SM count validation.

The get_device_sm_count import is correctly added and used in both validation checks (lines 177 and 269).


173-184: Excellent early validation for SM availability.

The pre-check correctly computes the required SMs and fails fast before any CUDA operations, providing a clear error message that aligns with the test expectations.


261-261: Good optimization: constraint calculation moved outside loop.

Moving get_sm_count_constraint outside the loop avoids redundant calls, as the constraints don't change between iterations.


267-276: Proper SM validation with informative error message.

The validation correctly sums the rounded SM counts and raises a clear error if insufficient. The error message helpfully includes the actual rounded_sm_counts list to aid debugging.

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can confirm that test_jit_example.py now passes or xfails.
test_green_ctx.py still has 7 failures:

================================================================================================================================================= short test summary info =================================================================================================================================================
FAILED tests/utils/test_green_ctx.py::test_green_ctx_creation[16-3-cuda:0] - RuntimeError: CUDA error code=914(b'CUDA_ERROR_INVALID_RESOURCE_TYPE')
FAILED tests/utils/test_green_ctx.py::test_green_ctx_kernel_execution[16-3-cuda:0] - RuntimeError: CUDA error code=914(b'CUDA_ERROR_INVALID_RESOURCE_TYPE')
FAILED tests/utils/test_green_ctx.py::test_split_device_green_ctx_by_sm_count_creation[sm_counts0-cuda:0] - RuntimeError: CUDA error code=914(b'CUDA_ERROR_INVALID_RESOURCE_TYPE')
FAILED tests/utils/test_green_ctx.py::test_split_device_green_ctx_by_sm_count_creation[sm_counts1-cuda:0] - RuntimeError: CUDA error code=914(b'CUDA_ERROR_INVALID_RESOURCE_TYPE')
FAILED tests/utils/test_green_ctx.py::test_split_device_green_ctx_by_sm_count_kernel_execution[sm_counts0-cuda:0] - RuntimeError: CUDA error code=914(b'CUDA_ERROR_INVALID_RESOURCE_TYPE')
FAILED tests/utils/test_green_ctx.py::test_split_device_green_ctx_by_sm_count_kernel_execution[sm_counts1-cuda:0] - RuntimeError: CUDA error code=914(b'CUDA_ERROR_INVALID_RESOURCE_TYPE')
FAILED tests/utils/test_green_ctx.py::test_split_device_green_ctx_by_sm_count_alignment[sm_counts1-cuda:0] - RuntimeError: CUDA error code=914(b'CUDA_ERROR_INVALID_RESOURCE_TYPE')
=================================================================================================================================== 7 failed, 10 passed, 5 skipped, 1 warning in 0.91s ====================================================================================================================================

Please see my other comment for test_sampling.py. There might be nans happening from the kernel, at least in my local env

probs_ref = torch.softmax(logits_scaled, dim=-1)

assert torch.allclose(probs, probs_ref, atol=1e-5)
assert torch.allclose(probs, probs_ref, rtol=1e-5, atol=1e-5)
Copy link
Collaborator

@bkryu bkryu Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot seem to repro the fix in Spark. It also seems like allclose has a default rtol=1e-5 so this may not even effectively make any change.

In fact in my local env (cu130 container), when I change the tolerance and inject print statements as

    probs_ref = torch.softmax(logits_scaled, dim=-1)
    print(f"{torch.isnan(probs).sum().item() = }")
    print(f"{torch.isnan(probs_ref).sum().item() =}")
    assert torch.allclose(probs, probs_ref, rtol=100, atol=100)

I am seeing nans.

(py312) root@c661e6d696f6:/flashinfer# pytest tests/utils/test_sampling.py -x -s
=================================================================================================================================================== test session starts ===================================================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 900 items                                                                                                                                                                                                                                                                                                       

tests/utils/test_sampling.py torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 4873728
torch.isnan(probs_ref).sum().item() =0
F

======================================================================================================================================================== FAILURES =========================================================================================================================================================
____________________________________________________________________________________________________________________________ test_softmax[True-True-1.0-normal_distribution(std=1)-128256-989] ____________________________________________________________________________________________________________________________
...
>       assert torch.allclose(probs, probs_ref, rtol=100, atol=100)
E       AssertionError: assert False
E        +  where False = <built-in method allclose of type object at 0x16bc850>(tensor([[0.0000e+00, 7.8481e-05, 0.0000e+00,  ..., 9.0452e-06, 8.5036e-06,\n         0.0000e+00],\n        [2.4505e-05, ...05],\n        [0.0000e+00, 0.0000e+00, 7.0366e-06,  ..., 0.0000e+00, 7.1824e-06,\n         2.0367e-06]], device='cuda:0'), tensor([[0.0000e+00, 7.8481e-05, 0.0000e+00,  ..., 9.0452e-06, 8.5036e-06,\n         0.0000e+00],\n        [2.4505e-05, ...05],\n        [0.0000e+00, 0.0000e+00, 7.0366e-06,  ..., 0.0000e+00, 7.1824e-06,\n         2.0367e-06]], device='cuda:0'), rtol=100, atol=100)
E        +    where <built-in method allclose of type object at 0x16bc850> = torch.allclose

tests/utils/test_sampling.py:76: AssertionError

...

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tests/utils/test_green_ctx.py (1)

24-32: Extract duplicated error handling to a helper function.

The CUDA error checking pattern is duplicated across all 5 test functions. This was previously flagged in review comments and should be addressed to improve maintainability.

Extract the error handling to a helper function:

def _skip_if_insufficient_sms(e: RuntimeError, device: str, context_msg: str):
    """Helper to skip tests when device has insufficient SMs for green context splitting."""
    if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \
       "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e):
        cu_dev = green_ctx.get_cudevice(torch.device(device))
        device_resource = green_ctx.get_device_resource(cu_dev)
        total_sms = device_resource.sm.smCount
        pytest.skip(f"Insufficient SMs on device. Total SMs available: {total_sms}. {context_msg}")
    raise

Then simplify each test's except block to:

    except RuntimeError as e:
        _skip_if_insufficient_sms(e, device, f"requested: num_groups={num_groups}, min_count={min_count}")

Based on learnings

Also applies to: 57-65, 94-102, 132-140, 170-178

🧹 Nitpick comments (1)
tests/utils/test_green_ctx.py (1)

15-18: Prefix unused variable with underscore.

The streams variable is unpacked but never used in this test. Prefix it with _ to indicate it's intentionally unused.

Apply this diff:

     try:
-        streams, resources = green_ctx.split_device_green_ctx(
+        _streams, resources = green_ctx.split_device_green_ctx(
             torch.device(device), num_groups, min_count
         )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 89eac51 and 9424eef.

📒 Files selected for processing (2)
  • flashinfer/green_ctx.py (2 hunks)
  • tests/utils/test_green_ctx.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/green_ctx.py (2)
flashinfer/utils.py (2)
  • get_compute_capability (251-254)
  • round_up (589-591)
flashinfer/comm/mnnvl.py (1)
  • round_up (55-57)
tests/utils/test_green_ctx.py (2)
flashinfer/green_ctx.py (5)
  • split_device_green_ctx (126-189)
  • get_cudevice (47-53)
  • get_device_resource (56-61)
  • split_device_green_ctx_by_sm_count (192-283)
  • get_sm_count_constraint (34-44)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
🪛 GitHub Actions: pre-commit
flashinfer/green_ctx.py

[error] 1-1: pre-commit: ruff-format reformatted 2 files. Review and commit changes. Command 'pre-commit run --all-files' reported modifications by hook 'ruff-format'.

tests/utils/test_green_ctx.py

[error] 1-1: pre-commit: ruff-format reformatted 2 files. Review and commit changes. Command 'pre-commit run --all-files' reported modifications by hook 'ruff-format'.

🪛 Ruff (0.14.1)
flashinfer/green_ctx.py

177-177: Consider [*results, remaining] instead of concatenation

Replace with [*results, remaining]

(RUF005)


179-179: Consider moving this statement to an else block

(TRY300)


183-188: Avoid specifying long messages outside the exception class

(TRY003)


266-266: Avoid specifying long messages outside the exception class

(TRY003)


271-271: Consider [*results, remaining] instead of concatenation

Replace with [*results, remaining]

(RUF005)


273-273: Consider moving this statement to an else block

(TRY300)


277-282: Avoid specifying long messages outside the exception class

(TRY003)

tests/utils/test_green_ctx.py

16-16: Unpacked variable streams is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🔇 Additional comments (1)
flashinfer/green_ctx.py (1)

255-283: LGTM! Consistent error handling with helpful validation.

The implementation correctly validates input SM counts and provides descriptive error messages for CUDA resource failures. The pattern is consistent with split_device_green_ctx.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
flashinfer/green_ctx.py (3)

64-78: Type annotation is incorrect; results is a list.

split_resource returns a list of CUdevResource and a single remaining CUdevResource. Adjust the return type to avoid misleading type checkers.

-def split_resource(
+def split_resource(
     resource: CUdevResource,
     num_groups: int,
     min_count: int,
-) -> Tuple[CUdevResource, CUdevResource]:
+) -> Tuple[List[CUdevResource], CUdevResource]:

103-106: Parameter type should be CUdevice, not CUdevResource.

create_green_ctx_streams receives cu_dev from get_cudevice (a CUdevice) and passes it to cuGreenCtxCreate. Fix the annotation.

-def create_green_ctx_streams(
-    cu_dev: CUdevResource, resources: List[CUdevResource]
+def create_green_ctx_streams(
+    cu_dev: CUdevice, resources: List[CUdevResource]
 ) -> List[torch.Stream]:

80-101: Green-context handle leak confirmed in two functions; refactor to eliminate unnecessary context creation.

The review is correct. The codebase creates green contexts but never destroys them—no cuGreenCtxDestroy calls exist anywhere. Two functions are affected:

  1. split_resource_by_sm_count() (lines 80–100): Creates a green context solely to extract a resource already returned by split_resource(). The "refresh" operation is unnecessary; the proposed fix (use remaining directly) is valid and eliminates the leak for this function.

  2. create_green_ctx_streams() (lines 103–123): Creates green contexts in a loop to generate streams, but never stores or destroys the contexts. They go out of scope immediately after stream extraction, creating a handle leak.

The proposed fix for split_resource_by_sm_count() is sound:

-        result, remaining = split_resource(resource, 1, sm_count)
-        results.extend(result)
-        # Refresh the remaining resource for the next iteration
-        desc = checkCudaErrors(driver.cuDevResourceGenerateDesc([remaining], 1))
-        green_ctx = checkCudaErrors(
-            driver.cuGreenCtxCreate(
-                desc, cu_dev, driver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM
-            )
-        )
-        resource = checkCudaErrors(
-            driver.cuGreenCtxGetDevResource(
-                green_ctx, driver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM
-            )
-        )
+        result, remaining = split_resource(resource, 1, sm_count)
+        results.extend(result)
+        resource = remaining

Additionally, review create_green_ctx_streams() to determine whether green contexts must remain alive for stream validity. If yes, contexts must be retained and properly destroyed; if no, context creation can be eliminated.

♻️ Duplicate comments (1)
tests/utils/test_green_ctx.py (1)

25-38: Deduplicate skip logic via a fixture/helper.

The same RuntimeError substring checks + SM-count fetch/skip are repeated across tests. Extract once (fixture/helper) to improve maintainability and keep messages consistent. This was raised earlier; repeating here for the new blocks.

Example fixture:

# conftest.py
import pytest
import flashinfer.green_ctx as green_ctx
import torch

CUDA_RES_ERR = (
    "CUDA error code=914",
    "CUDA_ERROR_INVALID_RESOURCE_TYPE",
    "CUDA error code=915",
    "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION",
)

def skip_if_insufficient_sms(device: str, err: Exception, extra: str) -> None:
    s = str(err)
    if any(sig in s for sig in CUDA_RES_ERR):
        cu_dev = green_ctx.get_cudevice(torch.device(device))
        total_sms = green_ctx.get_device_resource(cu_dev).sm.smCount
        pytest.skip(f"Insufficient SMs ({total_sms}). {extra}")
    raise err

Then in tests:

try:
    ...
except RuntimeError as e:
    skip_if_insufficient_sms(device, e, f"requested: num_groups={num_groups}, min_count={min_count}")

Also applies to: 64-77, 107-120, 151-164, 195-208

🧹 Nitpick comments (5)
tests/utils/test_green_ctx.py (2)

57-63: Remove prints from tests; assert instead.

print(...) adds noisy logs. Prefer simple assertions on shape to keep CI output clean.

-                print(z.shape)
+                assert z.shape == (8192, 8192)
-                print(f"Partition {i}: {z.shape}")
+                assert z.shape == (4096, 4096)

Optional: consider smaller matrices (e.g., 2048 or parametrize) to reduce CI time on small GPUs.

Also applies to: 144-150


180-194: Micro: avoid repeated device construction.

Compute dev = torch.device(device) once and reuse; minor readability and overhead win.

-        _, resources = green_ctx.split_device_green_ctx_by_sm_count(
-            torch.device(device), sm_counts
-        )
+        dev = torch.device(device)
+        _, resources = green_ctx.split_device_green_ctx_by_sm_count(dev, sm_counts)
...
-            min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint(
-                *green_ctx.get_compute_capability(torch.device(device))
-            )
+            min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint(
+                *green_ctx.get_compute_capability(dev)
+            )
flashinfer/green_ctx.py (3)

173-193: Style and lints: list-unpack concat; try/else; centralize error checks.

  • Use list-unpack for concat (RUF005).
  • Move return to else of try (TRY300).
  • Optional: centralize error signature checks to a helper constant.
     try:
         cu_dev = get_cudevice(dev)
         resource = get_device_resource(cu_dev)
         results, remaining = split_resource(resource, num_groups, min_count)
-        resources = results + [remaining]
+        resources = [*results, remaining]
         streams = create_green_ctx_streams(cu_dev, resources)
-        return streams, resources
     except RuntimeError as e:
-        if (
-            "CUDA error code=914" in str(e)
-            or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e)
-            or "CUDA error code=915" in str(e)
-            or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e)
-        ):
+        if any(sig in str(e) for sig in (
+            "CUDA error code=914",
+            "CUDA_ERROR_INVALID_RESOURCE_TYPE",
+            "CUDA error code=915",
+            "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION",
+        )):
             raise RuntimeError(
                 f"{e}\n"
                 f"Failed to split device into {num_groups} groups with min_count={min_count}. "
                 f"This is likely due to insufficient number of SMs available on the device. "
                 f"Please reduce the number of groups or the minimum SM count per group."
             ) from e
         raise
+    else:
+        return streams, resources

259-295: Hoist constraints; empty-input check; style/lints parity with above.

  • Compute (min_sm_count, sm_alignment) once per device.
  • Validate sm_counts is non-empty (docstring promises ValueError).
  • Apply list-unpack concat and try/else.
     try:
         cu_dev = get_cudevice(dev)
         resource = get_device_resource(cu_dev)
 
-        # Round sm counts to meet the alignment and granularity requirements
-        rounded_sm_counts = []
-        for sm_count in sm_counts:
-            min_sm_count, sm_alignment = get_sm_count_constraint(
-                *get_compute_capability(dev)
-            )
-            if sm_count <= 0:
-                raise ValueError(f"SM count must be positive, got {sm_count}")
-            rounded_sm_counts.append(
-                round_up(max(sm_count, min_sm_count), sm_alignment)
-            )
+        # Round sm counts to meet the alignment and granularity requirements
+        if not sm_counts:
+            raise ValueError("sm_counts must be non-empty")
+        min_sm_count, sm_alignment = get_sm_count_constraint(
+            *get_compute_capability(dev)
+        )
+        rounded_sm_counts = []
+        for sm_count in sm_counts:
+            if sm_count <= 0:
+                raise ValueError(f"SM count must be positive, got {sm_count}")
+            rounded_sm_counts.append(round_up(max(sm_count, min_sm_count), sm_alignment))
 
         # Split the device into multiple green contexts
         results, remaining = split_resource_by_sm_count(
             cu_dev, resource, rounded_sm_counts
         )
-        resources = results + [remaining]
+        resources = [*results, remaining]
         streams = create_green_ctx_streams(cu_dev, resources)
-        return streams, resources
     except RuntimeError as e:
-        if (
-            "CUDA error code=914" in str(e)
-            or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e)
-            or "CUDA error code=915" in str(e)
-            or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e)
-        ):
+        if any(sig in str(e) for sig in (
+            "CUDA error code=914",
+            "CUDA_ERROR_INVALID_RESOURCE_TYPE",
+            "CUDA error code=915",
+            "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION",
+        )):
             raise RuntimeError(
                 f"{e}\n"
                 f"Failed to split device with SM counts {sm_counts} (rounded to {rounded_sm_counts}). "
                 f"This is likely due to insufficient number of SMs available on the device. "
                 f"Please reduce the requested SM counts or use fewer partitions."
             ) from e
         raise
+    else:
+        return streams, resources

187-193: Optional: avoid long message construction in except (TRY003).

Consider defining a small custom exception (e.g., SMAllocationError) or assembling the message inside the exception class/__str__ to satisfy linters and keep handlers catching by type, not by substrings.

Also applies to: 289-294

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9424eef and a6ec87f.

📒 Files selected for processing (2)
  • flashinfer/green_ctx.py (2 hunks)
  • tests/utils/test_green_ctx.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/utils/test_green_ctx.py (2)
flashinfer/green_ctx.py (5)
  • split_device_green_ctx (126-193)
  • get_cudevice (47-53)
  • get_device_resource (56-61)
  • split_device_green_ctx_by_sm_count (196-295)
  • get_sm_count_constraint (34-44)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
flashinfer/green_ctx.py (1)
flashinfer/utils.py (2)
  • get_compute_capability (251-254)
  • round_up (589-591)
🪛 Ruff (0.14.2)
tests/utils/test_green_ctx.py

16-16: Unpacked variable streams is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

flashinfer/green_ctx.py

177-177: Consider [*results, remaining] instead of concatenation

Replace with [*results, remaining]

(RUF005)


179-179: Consider moving this statement to an else block

(TRY300)


187-192: Avoid specifying long messages outside the exception class

(TRY003)


270-270: Avoid specifying long messages outside the exception class

(TRY003)


279-279: Consider [*results, remaining] instead of concatenation

Replace with [*results, remaining]

(RUF005)


281-281: Consider moving this statement to an else block

(TRY300)


289-294: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Comment on lines +15 to +18
try:
streams, resources = green_ctx.split_device_green_ctx(
torch.device(device), num_groups, min_count
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix unused variable per Ruff (RUF059).

streams is not used in this test. Use _ to silence the warning.

-        streams, resources = green_ctx.split_device_green_ctx(
+        _, resources = green_ctx.split_device_green_ctx(
             torch.device(device), num_groups, min_count
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
streams, resources = green_ctx.split_device_green_ctx(
torch.device(device), num_groups, min_count
)
try:
_, resources = green_ctx.split_device_green_ctx(
torch.device(device), num_groups, min_count
)
🧰 Tools
🪛 Ruff (0.14.2)

16-16: Unpacked variable streams is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In tests/utils/test_green_ctx.py around lines 15 to 18, the variable `streams`
from the tuple assignment is unused and triggers Ruff RUF059; change the
unpacking to use a throwaway name (e.g., `_, resources =
green_ctx.split_device_green_ctx(torch.device(device), num_groups, min_count)`)
so the test retains the same behavior while silencing the unused-variable
warning.

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: I can now repro passes is test_green_ctx.py and test_jit_example.py but the nan issues in test_sampling.py persists.

I have a suggestion @yzh119: how about we restore the change in test_sampling.py, check the rest in so that the other two tests can now pass unit tests on Spark. Then I can also help investigate the softmax issue

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for the test_jit_example.py and test_green_ctx.py fix. I will work on test_sampling.py on a separate PR.

@yzh119 yzh119 changed the title bugfix: fix failed unittest on spark (sm_121) bugfix: fix failed unittest test_green_ctx and test_jit_example on spark (sm_121) Nov 5, 2025
@yzh119 yzh119 enabled auto-merge (squash) November 5, 2025 01:14
@yzh119 yzh119 merged commit 9bc5bd5 into flashinfer-ai:main Nov 5, 2025
4 checks passed
wangbo981016 pushed a commit to meituan-longcat/flashinfer that referenced this pull request Feb 5, 2026
Update to v0.5.2 and opt cuda graph launch config for MTP situation
* fix q len for MTP;
* release: Bump version for v0.5.2 release (flashinfer-ai#2057)

<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
  * Version updated to 0.5.2

<!-- end of auto-generated comment: release notes by coderabbit.ai -->;
* [BUG] Fix trtllm-gen fp4 moe renormalize routing (flashinfer-ai#2049)

<!-- .github/pull_request_template.md -->

## 📌 Description

Temporarily disable `routingIndicesBlockKernel` as it's not compatible
with the current packing format (topk-id and expert weights are packed
into a 32 bit tensor). This solves the issue
flashinfer-ai#2032

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Forced multi-block MoE execution to avoid sporadic single-block
selection and improve stability with certain workloads.

* **New Features**
* Added an alternative packed top‑k routing input path that propagates
routing scores when present.

* **Tests**
* Added a comprehensive parametrized test validating routed fused MoE
across token counts, model sizes, expert counts and multiple
quantization modes.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
Co-authored-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>;
* test: Skip test_fp8_quantize.py on Hopper (flashinfer-ai#2052)

<!-- .github/pull_request_template.md -->

## 📌 Description

The unit test `test_fp8_quantize.py` currently fails on sm90. 

Root cause: The test file tests the accuracy of `mxfp8_quantize()`.
However, in
[fp8_quantization.py](https://github.com/flashinfer-ai/flashinfer/blob/adb0e89fdee0a3140a43982bc3bef4e79ce20046/flashinfer/fp8_quantization.py#L7),
the `mxfp8_quantize()`'s underlying module only exists for
`gen_mxfp8_quantization_sm100_module` with no sm90 support.

Current PR changes test file to skip for pre-SM100 SM archs as they are
not supported..

Results:
* Before current PR on SM90: `72 failed, 40 passed in 2.69s`
* After current PR on SM90: `40 passed, 72 skipped in 1.41s`
* Before current PR on SM120: `112 passed  in 1.59s`
* After current PR on SM120: `112 passed in 1.54s` (expected to be the
same as before)

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Tests**
* Added conditional checks to skip FP8 quantization tests on GPUs that
lack required computational capabilities.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->;
* Add support for topkPacked input in block-level renormalize (flashinfer-ai#2051)

<!-- .github/pull_request_template.md -->

## 📌 Description

Add support for topkPacked input in block-level renormalize

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Performance**
* Optimized routing layer efficiency through improved index handling in
specialized processing configurations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>;
* chore: Update CODEOWNERS (flashinfer-ai#1984)

## Summary

This PR updates the CODEOWNERS file based on git commit history analysis
from the last 180 days.

## Changes

- Updated `.github/CODEOWNERS` with current code ownership based on:
  - Commit frequency
  - File coverage
  - Commit recency

## How to Review

1. Review the changes to `.github/CODEOWNERS`
2. Verify that the assigned owners are appropriate for each module
3. Make manual adjustments if needed before merging

## Notes

- This is an automated PR generated weekly
- Minimum commits threshold: 1
- Analysis period: 180 days
- Directory depth: 3 levels
- Top N owners per module: 5

---

🤖 This PR was automatically generated by the [update-codeowners
workflow](.github/workflows/update-codeowners.yml)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
* Updated code ownership assignments and reorganized related section
mappings for internal development processes.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: flashinfer-bot <flashinfer-bot@users.noreply.github.com>
Co-authored-by: Claude <noreply@anthropic.com>;
* Update trtllm-gen fused moe routing kernel and add more kernels (flashinfer-ai#1955)

<!-- .github/pull_request_template.md -->

## 📌 Description
co-work with @IwakuraRein 
- update the trtllm-gen fused moe headers
- add new kernels for trtllm-gen fused moe
  - for NvFp4, add tile 256
  - for MxFp8 x MxFp4, add 128, 256
  - for FP8 per-tensor, add 192, 256
  - for FP8 block scale, add 128
 - update the logics of `computeSelectedTileN`
 - add `tune_max_num_tokens` to FP8 per-tensor and FP8 block scale
 - rename `TLLM_GEN_BMM_CUBIN_PATH` to `TLLM_GEN_GEMM_CUBIN_PATH`
 - add `TLLM_GEN_EXPORT_FLASHINFER`

**NOTE: split-k kernels are temporarily disabled as they cause failure
in renormalize + expert 256 tests.**

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Expanded MoE tiling (adds 128/192/256), FP8 per‑tensor MoE path,
FP8/FP4 autotuner benchmark, and new tune_max_num_tokens tuning
parameter.

* **Improvements**
* Router now supports tile‑based (non‑power‑of‑two) layouts and
propagates explicit valid M/N/K for safer sizing; autotuner logs include
exception details; added export/compile flags and clearer kernel error
messages.

* **Bug Fixes**
* Relaxed strict padding/power‑of‑two checks and made log2 handling
safer.

* **Tests**
* Extended MoE tests to cover new FP8 block‑scale and routing scenarios.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Co-authored-by: Siyuan Fu <siyuanf@nvidia.com>;
* Fix dtype of output scales from mnnvl_moe_alltoallv_prepare_without_allgather (flashinfer-ai#2048)

<!-- .github/pull_request_template.md -->

## 📌 Description

During flashinfer-ai#1641 the dtype
of output scales in
moePrepare(mnnvl_moe_alltoallv_prepare_without_allgather) was accidently
changed from float to int32. This PR fixes that.

## 🔍 Related Issues

Fix flashinfer-ai#2040

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Corrected tensor type validation for mixture-of-experts scale
preparation so scales are validated and handled as float32, preventing
type mismatches with downstream float operations.
* Ensured scale tensors are created on the same device as expert
identifiers, keeping tensor placement consistent across distributed
processing and avoiding cross-device issues.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>;
* test: Fix test_sampling.py on Spark (flashinfer-ai#2042)

<!-- .github/pull_request_template.md -->

## 📌 Description

Current PR fixes `test_sampling.py::test_softmax` on Spark by inserting
a `torch.cuda.synchronize()` before calling the softmax function.

tl; dr why it works: PDL is enabled in these tests. Investigation shows
that when PDL is enabled, `logits.view(-1).index_fill_(0, inf_idx,
float("-inf"))` that prepares the inputs overlaps with the `probs =
flashinfer.sampling.softmax(logits, temperature=temperature_arr)`
function itself. Hence, we need to ensure that the input preparation is
complete before running the softmax function to get the correct output.


#### Observations
`test_sampling.py::test_softmax` fails on select cases Spark. Example
output
```
# pytest tests/utils/test_sampling.py::test_softmax
=================================================================================================================================================== test session starts ===================================================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 324 items                                    
...
================================================================================================================================================= short test summary info =================================================================================================================================================
FAILED tests/utils/test_sampling.py::test_softmax[True-True-1.0-normal_distribution(std=1)-128256-989] - AssertionError: assert False
FAILED tests/utils/test_sampling.py::test_softmax[True-True-1.0-normal_distribution(std=5)-128256-989] - AssertionError: assert False
FAILED tests/utils/test_sampling.py::test_softmax[True-True-1.0-gumbel_distribution(beta=0.1)-128256-989] - AssertionError: assert False
======================================================================================================================================== 3 failed, 321 passed, 1 warning in 10.33s
```

Observations from debugging:
* When outputs are printed, rows containing all `nan`s are produced in
the output of `probs = flashinfer.sampling.softmax(logits)`
* Surprisingly, the test passes with `CUDA_LAUNCH_BLOCKING=1 pytest
tests/utils/test_sampling.py::test_softmax`
* `compute-sanitizer` does not detect any IMAs
* Running only a failed test results in a pass:
```
$ pytest tests/utils/test_sampling.py::test_softmax[True-True-1.0-normal_distribution\(std=1\)-128256-989]
...
1 passed, 1 warning in 0.80s
```

Towards a fix:
* I empirically find that the test passes:
* when the reference `torch.softmax()` is called before
`flashinfer.sampling.softmax()` (currently reference is called after)
* when pdl is disabled in [line
67](https://github.com/flashinfer-ai/flashinfer/blob/main/tests/utils/test_sampling.py#L67)
with `probs = flashinfer.sampling.softmax(logits,
temperature=temperature_arr,enable_pdf=False)`
* when `torch.cuda.synchronize()` is inserted in the line 64 as in this
PR.
```
    if neg_inf_input:
        # assign random logits to -inf
        num_inf = torch.randint(0, logits.numel() - 1, (), device=logits.device).item()
        inf_idx = torch.randperm(logits.numel(), device=logits.device)[:num_inf]
        logits.view(-1).index_fill_(0, inf_idx, float("-inf"))
        torch.cuda.synchronize() ## This fixes the issue for some reason!

    if temperature_arr:
        temperature_arr = torch.full((batch_size,), temperature, device="cuda:0")
        probs = flashinfer.sampling.softmax(logits, temperature=temperature_arr)
        logits_scaled = logits / temperature_arr.unsqueeze(-1)
```
but **does not fix the issue if I place the synchronization any
earlier**

An nsys profile shows that surprisingly the
`logits.view(-1).index_fill_(0, inf_idx, float("-inf"))` and
`flashinfer.sampling.softmax(logits, temperature=temperature_arr)` can
overlap execution when pdl is enabled.
<img width="1243" height="640" alt="Screenshot 2025-11-04 at 5 49 50 PM"
src="https://github.com/user-attachments/assets/950ab8ab-0843-49c8-8411-ff81c00c34a6"
/>

This means that the softmax kernel is launching before inputs are done
being prepared when `neg_inf_input=True`. Hence, placing a
`torch.cuda.synchronize()` after the fill or disabling pdl can solve the
issue. With the current PR, the nsys timeline changes to:
<img width="1240" height="643" alt="Screenshot 2025-11-04 at 5 51 32 PM"
src="https://github.com/user-attachments/assets/aae63a88-d7cd-4661-8476-6d8c581879b2"
/>
and the unit test passes.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

* **Bug Fixes**
* Improved synchronization of concurrent operations to ensure proper
execution order and prevent potential timing-related issues.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->;
* fix: support both pip and uv pip for finding flashinfer-python package (flashinfer-ai#2043)

Update getJitIncludeDirs() to try pip first, then fallback to uv pip if
pip is not available. This ensures compatibility with both standard pip
and uv pip package managers when locating the flashinfer-python
installation for JIT compilation include paths.

The command now uses shell OR operator (||) to attempt pip first, and
only falls back to uv pip if the first command fails.
```
pytest -xs tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_fp8_block_scaling
============================================================================================================================================================ test session starts =============================================================================================================================================================
platform linux -- Python 3.10.12, pytest-8.4.2, pluggy-1.6.0
rootdir: /home/scratch.dmoss_gpu_1/repos/flashinfer
configfile: pytest.ini
collected 1 item                                                                                                                                                                                                                                                                                                                             

tests/moe/test_trtllm_cutlass_fused_moe.py [TensorRT-LLM][INFO] Compiling JIT runtime gemm_swapAB_256_128_128_16_128_2_82_8_1_GroupedWithOffset with options: 
[TensorRT-LLM][INFO] -std=c++17 
[TensorRT-LLM][INFO] --gpu-architecture=sm_90a 
[TensorRT-LLM][INFO] --ptxas-options=-allow-expensive-optimizations=true 
[TensorRT-LLM][INFO] --ptxas-options=--register-usage-level=10 
[TensorRT-LLM][INFO] --diag-suppress=161,174,177,940 
[TensorRT-LLM][INFO] -D__FORCE_INCLUDE_CUDA_FP16_HPP_FROM_FP16_H__=1 
[TensorRT-LLM][INFO] -D__FORCE_INCLUDE_CUDA_BF16_HPP_FROM_BF16_H__=1 
[TensorRT-LLM][INFO] -O3 
[TensorRT-LLM][INFO] -cubin 
[TensorRT-LLM][INFO] --expt-relaxed-constexpr 
[TensorRT-LLM][INFO] --expt-extended-lambda 
[TensorRT-LLM][INFO] --compiler-options=-fPIC,-O3,-Wno-deprecated-declarations,-Wno-abi 
[TensorRT-LLM][INFO] -I/home/scratch.dmoss_gpu_1/repos/flashinfer/flashinfer/data/csrc/nv_internal/tensorrt_llm 
[TensorRT-LLM][INFO] 

[TensorRT-LLM][INFO] Generated kernel code:

#ifdef __CUDACC_RTC__
#ifndef NVRTC_JIT_COMPILATION
#define NVRTC_JIT_COMPILATION
#endif

#include <deep_gemm/nvrtc_std.cuh>

#else

#include <string>
#include <cuda.h>

#endif

#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/nvrtc_cutlass.cuh>
#include <deep_gemm/fp8_gemm_impl.cuh>

using namespace deep_gemm;

using SchedulerType =
typename SchedulerSelectorSwapAB<GemmType::GroupedWithOffset, 256, 128, 128, 16, 128, 2, 1>::type;

__global__ void dummy_kernel() {
  void *ptr = (void *)&fp8_gemm_kernel_swapAB<256, 128, 128, 16, 128, 2, 8, 128, 128, 1, SchedulerType, GroupedWithOffsetSchedulerInputSwapAB>;
}

[TensorRT-LLM][INFO] NVCC compilation took 3064 ms
[TensorRT-LLM][INFO] Compilation log:

[TensorRT-LLM][INFO] Successfully copied kernel files to cache directory: /home/dmoss/.tensorrt_llm/cache/gemm_swapAB_256_128_128_16_128_2_82_8_1_GroupedWithOffset
[TensorRT-LLM][INFO] Compiling JIT runtime gemm_swapAB_128_128_128_16_128_2_82_8_1_GroupedWithOffset with options: 
[TensorRT-LLM][INFO] -std=c++17 
[TensorRT-LLM][INFO] --gpu-architecture=sm_90a 
[TensorRT-LLM][INFO] --ptxas-options=-allow-expensive-optimizations=true 
[TensorRT-LLM][INFO] --ptxas-options=--register-usage-level=10 
[TensorRT-LLM][INFO] --diag-suppress=161,174,177,940 
[TensorRT-LLM][INFO] -D__FORCE_INCLUDE_CUDA_FP16_HPP_FROM_FP16_H__=1 
[TensorRT-LLM][INFO] -D__FORCE_INCLUDE_CUDA_BF16_HPP_FROM_BF16_H__=1 
[TensorRT-LLM][INFO] -O3 
[TensorRT-LLM][INFO] -cubin 
[TensorRT-LLM][INFO] --expt-relaxed-constexpr 
[TensorRT-LLM][INFO] --expt-extended-lambda 
[TensorRT-LLM][INFO] --compiler-options=-fPIC,-O3,-Wno-deprecated-declarations,-Wno-abi 
[TensorRT-LLM][INFO] -I/home/scratch.dmoss_gpu_1/repos/flashinfer/flashinfer/data/csrc/nv_internal/tensorrt_llm 
[TensorRT-LLM][INFO] 

[TensorRT-LLM][INFO] Generated kernel code:

#ifdef __CUDACC_RTC__
#ifndef NVRTC_JIT_COMPILATION
#define NVRTC_JIT_COMPILATION
#endif

#include <deep_gemm/nvrtc_std.cuh>

#else

#include <string>
#include <cuda.h>

#endif

#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/nvrtc_cutlass.cuh>
#include <deep_gemm/fp8_gemm_impl.cuh>

using namespace deep_gemm;

using SchedulerType =
typename SchedulerSelectorSwapAB<GemmType::GroupedWithOffset, 128, 128, 128, 16, 128, 2, 1>::type;

__global__ void dummy_kernel() {
  void *ptr = (void *)&fp8_gemm_kernel_swapAB<128, 128, 128, 16, 128, 2, 8, 128, 128, 1, SchedulerType, GroupedWithOffsetSchedulerInputSwapAB>;
}

[TensorRT-LLM][INFO] NVCC compilation took 1479 ms
[TensorRT-LLM][INFO] Compilation log:

[TensorRT-LLM][INFO] Successfully copied kernel files to cache directory: /home/dmoss/.tensorrt_llm/cache/gemm_swapAB_128_128_128_16_128_2_82_8_1_GroupedWithOffset
.

============================================================================================================================================================= 1 passed in 9.02s ==============================================================================================================================================================
```

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Improved package detection compatibility for alternative package
management tool installations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->;
* use scalar for kv_scale in xqa (flashinfer-ai#2033)

<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Breaking Changes**
* Public xqa/xqa_mla entry points now accept kv_scale as a plain float
(default 1.0) instead of a 1-element tensor. Update call sites
accordingly.

* **Documentation**
  * Docstrings updated to reflect kv_scale as float.

* **Tests**
* Tests updated to pass scalar kv_scale, with added parameterization and
conditional skip for FP8 kv-cache scenarios.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>;
* Support cc common check decorator for empty backends (flashinfer-ai#2015)

<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Improved backend/compute-capability validation with clearer errors and
correct fallback when backend-specific checks are absent.

* **New Features**
* Decorated functions expose runtime attributes to query backend
availability and choices.
  * Default-backend behavior: kernels use a default when none is passed.

* **Compatibility**
* Expanded supported compute-capability set and raised minimum cuDNN
package requirements.

* **Tests**
* Added tests for empty-backend common-checks and default-backend
behavior.

* **Chores**
  * Version bumped to 0.5.1.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->;
* perf: Speed up fp4 quantization for small batch with swizzling for cutlass MoE (flashinfer-ai#2025)

<!-- .github/pull_request_template.md -->

## 📌 Description

Performance optimization for `fp4_quantize()` function. The performance
issue was raised in issues flashinfer-ai#1734 and flashinfer-ai#2021

Observed behavior was slow performance when `is_sf_swizzled_layout=True`
(as opposed to False). Root cause of the issue was

* Excessive Padding Overhead: Swizzled layouts require row padding to
tile boundaries where `SWIZZLED_128x4` pads to multiples of 128 rows and
`SWIZZLED_8x4` pads to multiples of 8 rows
* This means `For batch_size=1` with SWIZZLED_128x4: 127 out of 128 rows
are padding (99.2% wasted work)
* Sequential Processing: The original grid launch used grid.x = min(m,
multiProcessorCount * numBlocksPerSM), so:
For batch_size=1: only 1 block launched
* This single block iterated sequentially over all 128 padded rows
* Each padding row still computed scale factors, checked bounds, and
performed conditional logic
* No Fast Path: Every row (real or padding) went through the same
expensive code path with multiple conditional branches

The fix:
1. Kernel-Level Early Exit Fast Path (`quantization.cuh`): Added branch
divergence optimization with separate handling for padding vs. data rows
- Padding rows now execute ~10× fewer instructions; Eliminates memory
loads/stores for input/output data on padding rows; Reduces register
pressure and divergence overhead

2. Host-Level Parallel Grid Launch (`quantization.cu`): Modified grid
calculation to launch blocks proportional to padded rows instead of
actual rows:
- For batch_size=1 with SWIZZLED_128x4: launches up to 128 blocks
instead of 1; Each block processes 1 row in parallel instead of
sequentially; overall tries to achieve full GPU occupancy even with
small batch sizes

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->


`fp4_quantize()` performance before fix:
```
$ python3 bench_fp4_quantize.py 
+------------+---------------------+-------------------------+
| batch size | swizzled_times (us) | non_swizzled_times (us) |
+------------+---------------------+-------------------------+
|    1.0     |        71.52        |          3.136          |
|    2.0     |       37.152        |          3.168          |
|    4.0     |       19.904        |          3.168          |
|    8.0     |       11.296        |           3.2           |
|    16.0    |        7.103        |          3.296          |
|    32.0    |        4.96         |          3.376          |
|    64.0    |        4.128        |          3.487          |
|   128.0    |        3.808        |          3.648          |
|   256.0    |        4.32         |          4.161          |
|   512.0    |        5.472        |          5.184          |
+------------+---------------------+-------------------------+
```
After fix in current PR:
```
$ python3 bench_fp4_quantize.py 
+------------+---------------------+-------------------------+
| batch size | swizzled_times (us) | non_swizzled_times (us) |
+------------+---------------------+-------------------------+
|    1.0     |        3.456        |          3.264          |
|    2.0     |        3.488        |          3.296          |
|    4.0     |        3.536        |          3.296          |
|    8.0     |        3.52         |          3.296          |
|    16.0    |        3.52         |          3.456          |
|    32.0    |        3.696        |          3.488          |
|    64.0    |        3.744        |          3.584          |
|   128.0    |        3.936        |          3.776          |
|   256.0    |        4.384        |          4.288          |
|   512.0    |        5.568        |          5.248          |
+------------+---------------------+-------------------------+
```

where the `bench_fp4_quantize.py` script used to benchmark (adopted from
flashinfer-ai#1734) :
```
from flashinfer.testing.utils import bench_gpu_time_with_cupti
from flashinfer import fp4_quantize
import torch
import numpy as np
import pandas as pd
from tabulate import tabulate

A_scale = torch.randn(16).cuda().float()
bsz = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
swizzled_times = []
for bs in bsz:
    A = torch.randn(bs, 5120).cuda().to(torch.bfloat16)
    t = np.median(bench_gpu_time_with_cupti(
            lambda: fp4_quantize(A, A_scale, is_sf_swizzled_layout=True),
            dry_run_iters = 10, 
            repeat_iters = 100,
            )
        ) * 1000
    swizzled_times.append(t)

non_swizzled_times = []
for bs in bsz:
    A = torch.randn(bs, 5120).cuda().to(torch.bfloat16)
    t = np.median(bench_gpu_time_with_cupti(
        lambda: fp4_quantize(A, A_scale, is_sf_swizzled_layout=False),
            dry_run_iters = 10, 
            repeat_iters = 100,
            )
        ) * 1000
    non_swizzled_times.append(t)


summary_df = pd.DataFrame({
    "batch size": bsz,
    "swizzled_times (us)": swizzled_times,
    "non_swizzled_times (us)": non_swizzled_times,
})

# Round numeric columns to three decimals before printing
summary_df_rounded = summary_df.copy()
summary_df_rounded["batch size"] = summary_df_rounded["batch size"].astype(int)
summary_df_rounded["swizzled_times (us)"] = summary_df_rounded["swizzled_times (us)"].round(3)
summary_df_rounded["non_swizzled_times (us)"] = summary_df_rounded["non_swizzled_times (us)"].round(3)
print(tabulate(summary_df_rounded, headers='keys', tablefmt='pretty', showindex=False))
```

## 🔍 Related Issues

flashinfer-ai#1734 
flashinfer-ai#2021 

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Improved quantization for swizzled memory layouts by adjusting how
effective processing rows are computed to better utilize GPU resources.
* Added early-exit handling for padding-only rows so padding outputs are
zeroed without processing data.
* Ensured consistent zeroing of scale/format outputs for padded columns
across all quantization paths.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->;
* bugfix: fix failed unittest `test_green_ctx` and `test_jit_example` on spark (sm_121) (flashinfer-ai#1951)

<!-- .github/pull_request_template.md -->

## 📌 Description

There are three failed unittests on spark (sm_121):
* tests/utils/test_green_ctx.py
* tests/utils/test_jit_example.py
* tests/utils/test_sampling.py

First one is because spark has small number of SMs (48) and we don't
have a guard on green context splitting.
Second one is an unknown issue (logits don't match with reference) and
probably related to barriers on sm_121, xfail now and will fix later.

The last one will be fixed by another PR from @bkryu , this PR fixes the
first two issues.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Tests**
* Tests now pre-check GPU resources and auto-skip with informative
messages including available and requested SM counts to avoid spurious
failures.
* Added a conditional xfail for GPUs with compute capability 12.1 to
avoid false negatives on that hardware.
* Tightened a sampling test by adding a relative tolerance for more
robust numerical validation.

* **Bug Fixes**
* Improved runtime error handling to surface clearer guidance when GPU
SM resources are insufficient.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>;
* Update Docker CI tags to 20251104-d528f0c (flashinfer-ai#2041)

This PR updates the Docker CI image tags to the latest version:
`20251104-d528f0c`

Updated images:
- flashinfer/flashinfer-ci-cu126:20251104-d528f0c
- flashinfer/flashinfer-ci-cu128:20251104-d528f0c
- flashinfer/flashinfer-ci-cu129:20251104-d528f0c
- flashinfer/flashinfer-ci-cu130:20251104-d528f0c

Auto-generated by [release-ci-docker
workflow](https://github.com/flashinfer-ai/flashinfer/actions/runs/19084098717)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
* Updated Docker image tags to latest versions for CUDA 12.6, 12.8,
12.9, and 13.0 distributions.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: yzh119 <11773619+yzh119@users.noreply.github.com>;
* test: Mark test_fp8_prefill.py as xfail on SM90 (flashinfer-ai#2038)

<!-- .github/pull_request_template.md -->

## 📌 Description

`test_fp8_prefill.py` is currently failing on SM90, but consumes too
much time to run/fail, causing unit-tests to time out.

--Current PR marks it as xfail so that unit tests can progress
forward.--

Update: Root cause of failure is because mixed precision attention is
not available on `fa3` backend, but the attention prefill wrapper
automatically selects `backend='fa3'` on SM90.

Fix is to explicitly specify the `backend='fa2'` so that fa2 is always
used.

Status after fix:
```
$ pytest tests/attention/test_fp8_prefill.py
=================================================================================================================================================== test session starts ===================================================================================================================================================
...
collected 768 items                                                                                                                                                                                                                                                                                                       

tests/attention/test_fp8_prefill.py ............................................................................................................................................................................................................................................................................... [ 35%]
................................................................................................................................................................................................................................................................................................................... [ 75%]
..............................................................................................................................................................................................                                                                                                                      [100%]
======================================================================================================================================= 768 passed, 1 warning in 131.42s (0:02:11) ========================================================================================================================================

```

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Tests**
* Adjusted FP8/FP16 attention test configuration to explicitly select a
backend during prefill/decoding, stabilizing test behavior across
environments.

* **Public API**
* Constructors now accept an explicit backend parameter to allow
selecting the backend used for KV cache operations.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->;
* ci: Update cudnn version requirements in CI container (flashinfer-ai#2039)

<!-- .github/pull_request_template.md -->

## 📌 Description

cuDNN versions specified in CI container setup
(`docker/install/install_python_packages.sh`) are currently 9.11 and
9.12.

In unit testing, this causes issues as `mm_fp4(backend='cudnn')` is not
supported on Spark (sm121) for older cuDNN versions in cu130.

Failure is due to cuDNN version shipped with container being too old. In
the [latest container build pipeline
output](https://github.com/flashinfer-ai/flashinfer/actions/runs/18778064727/job/53577233568#step:6:727),
cudnn 9.13.0.50 is installed
```
flashinfer-ai#16 207.0 Requirement already satisfied: nvidia-cudnn-cu13>=9.12.0.46 in /opt/conda/envs/py312/lib/python3.12/site-packages (9.13.0.50)
flashinfer-ai#16 207.0 Requirement already satisfied: nvidia-cublas in /opt/conda/envs/py312/lib/python3.12/site-packages (from nvidia-cudnn-cu13>=9.12.0.46) (13.0.0.19)
```

Current PR updates the minimum cudnn version for both
[cu12](https://pypi.org/project/nvidia-cudnn-cu12/#history) and
[cu13](https://pypi.org/project/nvidia-cudnn-cu13/#history) to
9.14.0.64.

cudnn 9.13 --> unit test fails with 180 failed, 270 passed, 2790
skipped, 1 warning in 8.97s
```
# pytest tests/gemm/test_mm_fp4.py 
=================================================================================================================================================== test session starts ===================================================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 3240 items    
...
FAILED tests/gemm/test_mm_fp4.py::test_mm_fp4[mxfp4_alpha-False-True-cudnn-res_dtype1-512-512-256] - cudnn._compiled_module.cudnnGraphNotSupportedError: No valid engine configs for Matmul_MUL_
FAILED tests/gemm/test_mm_fp4.py::test_mm_fp4[mxfp4_alpha-False-True-cudnn-res_dtype1-512-512-512] - cudnn._compiled_module.cudnnGraphNotSupportedError: No valid engine configs for Matmul_MUL_
================================================================================================================================ 180 failed, 270 passed, 2790 skipped, 1 warning in 8.97s =================================================================================================================================

```
cudnn 9.14 --> unit test passes with 450 passed, 2790 skipped, 1 warning
in 5.37s
```
# pytest tests/gemm/test_mm_fp4.py 
=================================================================================================================================================== test session starts ===================================================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 3240 items                                                                                                                                                                                                                                                                                                      

tests/gemm/test_mm_fp4.py 
...
====================================================================================================================================== 450 passed, 2790 skipped, 1 warning in 5.37s =======================================================================================================================================

```

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
* Updated internal dependencies for improved system stability and
compatibility.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->;
* release: Bump version for v0.5.1 release (flashinfer-ai#2031)

<!-- .github/pull_request_template.md -->

## 📌 Description

Update `version.txt`

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
  * Version updated to 0.5.1

<!-- end of auto-generated comment: release notes by coderabbit.ai -->;
* Updated decorator to support unspecified default (flashinfer-ai#2026)

<!-- .github/pull_request_template.md -->

## 📌 Description

Updated decorator to support unspecified default. This was causing
issues when calling mm_fp4 without backend specified.
Also added SM 110 as a supported backend on the cutlass backend (mm_fp4)

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
  * FP4 Cutlass GEMM now supports the SM110 GPU compute capability.

* **Bug Fixes**
* Kernels called without an explicit backend now consistently use the
default backend.

* **Tests**
* Added a unit test to verify default backend selection and correct
results when backend is omitted.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->;
* test: Enable xfailed trtllm decode long seqlen tests and update microbenchmark (flashinfer-ai#2018)

<!-- .github/pull_request_template.md -->

## 📌 Description


[tests/attention/test_trtllm_gen_attention.py](https://github.com/flashinfer-ai/flashinfer/blob/v0.5.0rc2/tests/attention/test_trtllm_gen_attention.py#L1021-L1076)
was failing and therefore marked xfail.

PR flashinfer-ai#2002 fixed the underlying root cause. Current PR thus removed the
`xfail` marker so that these long seqlen cases could be fixed moving
forward.

Additionally, PR flashinfer-ai#2002 revealed a bug in the microbenchmark script where
[trtllm_batch_decode_with_kv_cache](https://github.com/flashinfer-ai/flashinfer/blob/v0.5.0rc2/flashinfer/decode.py#L2082-L2083)
explicitly requires the workspace to
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
…n spark (sm_121) (flashinfer-ai#1951)

<!-- .github/pull_request_template.md -->

## 📌 Description

There are three failed unittests on spark (sm_121):
* tests/utils/test_green_ctx.py
* tests/utils/test_jit_example.py
* tests/utils/test_sampling.py

First one is because spark has small number of SMs (48) and we don't
have a guard on green context splitting.
Second one is an unknown issue (logits don't match with reference) and
probably related to barriers on sm_121, xfail now and will fix later.

The last one will be fixed by another PR from @bkryu , this PR fixes the
first two issues.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Tests**
* Tests now pre-check GPU resources and auto-skip with informative
messages including available and requested SM counts to avoid spurious
failures.
* Added a conditional xfail for GPUs with compute capability 12.1 to
avoid false negatives on that hardware.
* Tightened a sampling test by adding a relative tolerance for more
robust numerical validation.

* **Bug Fixes**
* Improved runtime error handling to surface clearer guidance when GPU
SM resources are insufficient.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants