Skip to content

Conversation

@wenscarl
Copy link
Collaborator

@wenscarl wenscarl commented Sep 25, 2025

📌 Description

This PR quantizes the tensor to NVFP4 along with the associated scales, making it directly usable by FlashInfer’s grouped_gemm_nt_masked.

Note: Leave the permutation to framework .

cc @kaixih

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @wenscarl, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly extends FlashInfer's NVFP4 quantization capabilities by introducing support for masked batch quantization and a combined SiLU activation with element-wise multiplication prior to quantization. These enhancements are designed to make quantized tensors directly compatible with FlashInfer's grouped_gemm_nt_masked function, thereby improving efficiency and adaptability for various model architectures. The changes span across CUDA kernels, C++ bindings, and Python interfaces to seamlessly integrate these new, more flexible quantization options.

Highlights

  • Masked Batch NVFP4 Quantization: Introduced the capability to perform NVFP4 quantization on batched tensors with an optional mask, allowing selective quantization based on a provided mask tensor. This enhances flexibility for operations like FlashInfer's grouped_gemm_nt_masked.
  • SiLU and Multiplication with NVFP4 Quantization: Added a new silu_and_mul_fp4_batched_quantize operation. This function performs a SiLU activation and element-wise multiplication on the input tensor before quantizing it to NVFP4, streamlining common computational patterns.
  • Kernel Refactoring and Policy-Based Execution: The core CUDA quantization kernel (quantize_with_block_size) was refactored into a generic quantize_with_block_size_impl that accepts a Policy template parameter. This allows for flexible execution paths (e.g., with or without SiLU/multiplication) and integrates mask handling efficiently.
  • Python API and Testing Updates: The Python bindings (flashinfer/__init__.py, flashinfer/fp4_quantization.py) were updated to expose the new masked and SiLU-enabled quantization functionalities. Comprehensive unit tests (tests/test_fp4_quantize.py) were added or modified to ensure the correctness and robustness of these new features.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces masked batch quantization for NVFP4 and adds a new fused silu_and_mul quantization kernel. The changes are well-structured, particularly the use of a policy-based design in the CUDA kernel to handle different quantization paths. My review focuses on improving correctness, maintainability, and testing. I've identified a missing validation check in the C++ code, incorrect dtypes in Python fake ops that could affect memory profiling, and some inconsistencies in the new Python API and its documentation. Additionally, I've found a couple of issues in the new tests that should be addressed to ensure correctness and full coverage of the new functionality.

int64_t b = inputShape[0];
int64_t m = inputShape[1];
int64_t k_by_2 = inputShape[2];
int64_t k = k_by_2 / 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The silu_and_mul operation requires splitting the last dimension of the input tensor in half. To ensure this is always possible and to prevent potential runtime errors, you should add a check to verify that the last dimension (k_by_2) is even.

  TORCH_CHECK(k_by_2 % 2 == 0, "The last dimension of the input tensor for silu_and_mul must be even.");
  int64_t k = k_by_2 / 2;

Comment on lines +414 to +418
single_out, single_scale = fp4_quantize(
x_silu_mul, global_scale, 16, False, True
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The global_scale tensor has a shape of (b,), containing per-batch item scales. However, you are passing the entire tensor to fp4_quantize inside a loop that iterates through batches. The fp4_quantize function expects a scalar-like tensor (shape [1]) for the scale and will only use the first element, leading to incorrect test results for batches other than the first one. You should pass the scale corresponding to the current batch item i.

Suggested change
single_out, single_scale = fp4_quantize(
x_silu_mul, global_scale, 16, False, True
)
single_out, single_scale = fp4_quantize(
x_silu_mul, global_scale[i:i+1], 16, False, True
)

Comment on lines +319 to +374
input.new_empty([b, m, k // 2], dtype=torch.int64), # float4_e2m1_x2
input.new_empty(
[b, m * k // sf_vec_size], dtype=torch.int32
), # Scale factors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fake op _fp4_batched_quantize_sm100 allocates tensors with dtype=torch.int64 and torch.int32, which use significantly more memory than the actual torch.uint8 output tensors. This can be misleading for memory profiling and might cause out-of-memory issues during model tracing or exporting. Please use the correct dtype torch.uint8 for both output tensors, as this is what the underlying C++ extension returns.

Suggested change
input.new_empty([b, m, k // 2], dtype=torch.int64), # float4_e2m1_x2
input.new_empty(
[b, m * k // sf_vec_size], dtype=torch.int32
), # Scale factors
input.new_empty([b, m, k // 2], dtype=torch.uint8), # float4_e2m1_x2
input.new_empty(
[b, m * k // sf_vec_size], dtype=torch.uint8
), # Scale factors

Comment on lines 385 to 454
input.new_empty([b, m, k // 4], dtype=torch.int64), # float4_e2m1_x2
input.new_empty(
[b, m * k // sf_vec_size], dtype=torch.int32
), # Scale factors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fake op _silu_and_mul_fp4_batched_quantize_sm100 allocates tensors with dtype=torch.int64 and torch.int32, which use significantly more memory than the actual torch.uint8 output tensors. This can be misleading for memory profiling and might cause out-of-memory issues during model tracing. Please use the correct dtype torch.uint8 for both output tensors to match the behavior of the C++ kernel.

Suggested change
input.new_empty([b, m, k // 4], dtype=torch.int64), # float4_e2m1_x2
input.new_empty(
[b, m * k // sf_vec_size], dtype=torch.int32
), # Scale factors
input.new_empty([b, m, k // 4], dtype=torch.uint8), # float4_e2m1_x2
input.new_empty(
[b, m * k // (sf_vec_size * 2)], dtype=torch.uint8
), # Scale factors

Comment on lines 798 to 811
"""
Quantize batched input tensor to NVFP4 format.
Parameters:
a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
mask (torch.Tensor): Mask tensor to apply before quantization.
sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
- Scale factors tensor with shape determined by layout and sf_vec_size
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for silu_and_mul_fp4_batched_quantize appears to be copied from nvfp4_batched_quantize and is inaccurate. It doesn't mention the fused silu_and_mul operation, which is a key part of this function's behavior. Please update the docstring to accurately describe what the function does. You can use the more detailed docstring from silu_and_mul_fp4_batched_quantize_sm100 as a reference.

Comment on lines 792 to 797
def silu_and_mul_fp4_batched_quantize(
a,
mask,
a_global_sf,
sf_vec_size=16,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function signature for silu_and_mul_fp4_batched_quantize is inconsistent with nvfp4_batched_quantize. Here, mask is a required positional argument, whereas in nvfp4_batched_quantize it's an optional keyword argument. For better API consistency and user experience, consider making mask an optional keyword argument. This would likely require changes in the C++ backend to accept an optional mask, similar to how fp4_batched_quantize is implemented.

Suggested change
def silu_and_mul_fp4_batched_quantize(
a,
mask,
a_global_sf,
sf_vec_size=16,
):
def silu_and_mul_fp4_batched_quantize(
a,
a_global_sf,
sf_vec_size=16,
mask=None,
):


b, m, n = batch_shape
x = torch.randn((b, m, n * 2), dtype=dtype)
mask = torch.randint(low=m, high=m + 1, size=(b,), dtype=torch.int32, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The mask for this test is generated with low=m and high=m+1, which means all values in the mask tensor will be m. This doesn't effectively test the masking functionality, as no rows are actually masked out. To properly test the masking logic, you should generate masks with varying lengths.

Suggested change
mask = torch.randint(low=m, high=m + 1, size=(b,), dtype=torch.int32, device=device)
mask = torch.randint(low=1, high=m + 1, size=(b,), dtype=torch.int32, device=device)

@wenscarl wenscarl requested review from kaixih and yzh119 September 25, 2025 17:56
Copy link
Collaborator

@kaixih kaixih left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for the PR.

Do you think we should put the silu_and_mul_fp4_batched_quantize into the activation module like here.

The `nvfp4_batched_quantize) is natually here.

The main reason to separate their location is that the former assumes the ..., 2*hidden->..., hidden mapping but the latter assumes ..., hidden->..., hidden.

quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_MXFP8, T, SF_VEC_SIZE, true>, b,
m, n, padded_n, input, nullptr, reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput), layout);
reinterpret_cast<uint32_t*>(SFOuput), layout, nullptr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ..., /*mask=*/nullptr); for clarity.

@kaixih
Copy link
Collaborator

kaixih commented Sep 25, 2025

So, IIUC, the main motivation of this PR is to support masked FP4 quantization and to better maintain the different variants related to quantization methods:

  1. Add masking support to nvfp4_batched_quantize.
  2. Add a new function (or extend the existing silu_and_mul) to support NVFP4 output with masking.

The intended use case is to prepare inputs for the CuteDSL masked GEMM.

fyi. @zihaoye

@yzh119
Copy link
Collaborator

yzh119 commented Sep 26, 2025

Hi @kaixih I think you at another @zihaoye who is also a flashinfer contributor :)

Returns:
Tuple[torch.Tensor, torch.Tensor]:
- self_fp4 (torch.Tensor): Packed FP4 tensor in E2M1x2 format of shape
[B, M, K // 2] with dtype torch.uint8 (two FP4 lanes per byte).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As frameworks starts using torch 2.8 where torch.float4_e2m1_x2 is available, we should considering moving to native fp4x2 data type at some point (not necessarily in this PR)

@wenscarl wenscarl force-pushed the masked_batch_fp4_quant branch from d6cb24b to bd01434 Compare September 29, 2025 19:19
@wenscarl wenscarl requested review from kaixih and yzh119 September 29, 2025 19:20
@wenscarl wenscarl requested a review from kaixih September 30, 2025 02:33
@kaixih
Copy link
Collaborator

kaixih commented Sep 30, 2025

@wenscarl seems there is some conflicts? can you rebase?

wip

Pass test

fix precommit
@wenscarl wenscarl force-pushed the masked_batch_fp4_quant branch from c7eaba6 to 37c8f88 Compare September 30, 2025 22:30
Copy link
Collaborator

@kaixih kaixih left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved this for now.
Note, there are more following cleanup works to do. @yzh119 let us know if you need us to do that.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119 yzh119 merged commit d50cfbc into flashinfer-ai:main Oct 1, 2025
3 checks passed
yzh119 pushed a commit that referenced this pull request Oct 23, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

This PR reverts #1774
and #1835 which have
some issues with some shapes under cuda graph. The kernels ported in
this PR comes from SGLANG. [[NVIDA] [1/N] Nvfp4 Masked Gemm: Add quant
op for the flashinfer grouped
gemm](https://github.com/sgl-project/sglang/pull/9200/files) and
[[NVIDIA] [2/N] Optimize silu_and_mul_scaled_fp4_grouped_quant
perf](https://github.com/sgl-project/sglang/pull/9556/files) by @kaixih
.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
- Added grouped FP4 quantization (scaled_fp4_grouped_quantize) and an
NV-focused Silu+Mul expert quantization entry
(silu_and_mul_scaled_nvfp4_experts_quantize).

* **API Changes**
- Replaced legacy batched APIs with new expert/grouped APIs; removed
legacy mask parameter from FP4/MXFP8 quantization signatures and
adjusted FP4 output layouts/types.

* **Documentation**
  - Updated docs to list new functions and remove deprecated symbols.

* **Tests**
- Updated tests to validate new quantization paths, shapes, dtypes, and
layouts.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Shu Wang. <[email protected]>
murphymatt pushed a commit to fw-ai/flashinfer that referenced this pull request Nov 6, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

This PR reverts flashinfer-ai/flashinfer#1774
and flashinfer-ai/flashinfer#1835 which have
some issues with some shapes under cuda graph. The kernels ported in
this PR comes from SGLANG. [[NVIDA] [1/N] Nvfp4 Masked Gemm: Add quant
op for the flashinfer grouped
gemm](https://github.com/sgl-project/sglang/pull/9200/files) and
[[NVIDIA] [2/N] Optimize silu_and_mul_scaled_fp4_grouped_quant
perf](https://github.com/sgl-project/sglang/pull/9556/files) by @kaixih
.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
- Added grouped FP4 quantization (scaled_fp4_grouped_quantize) and an
NV-focused Silu+Mul expert quantization entry
(silu_and_mul_scaled_nvfp4_experts_quantize).

* **API Changes**
- Replaced legacy batched APIs with new expert/grouped APIs; removed
legacy mask parameter from FP4/MXFP8 quantization signatures and
adjusted FP4 output layouts/types.

* **Documentation**
  - Updated docs to list new functions and remove deprecated symbols.

* **Tests**
- Updated tests to validate new quantization paths, shapes, dtypes, and
layouts.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Shu Wang. <[email protected]>
murphymatt added a commit to fw-ai/flashinfer that referenced this pull request Nov 6, 2025
* chore: rename FLASHINFER_JIT_VERBOSE to FLASHINFER_JIT_DEBUG for clarity (#1946)

<!-- .github/pull_request_template.md -->

## 📌 Description

Rename environment variable `FLASHINFER_JIT_VERBOSE` to
`FLASHINFER_JIT_DEBUG` to better reflect its actual behavior.

- `FLASHINFER_JIT_DEBUG`: Enable debug mode during compilation (disable
optimization, add debug symbols)
- The previous name `FLASHINFER_JIT_VERBOSE` implied "showing more
compilation info", which was confusing
- Maintained backward compatibility: falls back to
`FLASHINFER_JIT_VERBOSE` if `FLASHINFER_JIT_DEBUG` is not set

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Refactor**
* Introduced FLASHINFER_JIT_DEBUG environment variable for controlling
JIT debug builds with backward compatibility for legacy
FLASHINFER_JIT_VERBOSE.
* Enhanced debug build configuration with improved compiler and CUDA
debugging flags. Non-debug builds continue using -O3 optimizations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* fix: Fix trtllm-gen prefill IMA when batch_size==1 (#1912)

<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Current PR fixes the test and benchmark codes IMAs when running
trtllm-gen paged & ragged prefill with batch size 1 -- the issue was
described in https://github.com/flashinfer-ai/flashinfer/issues/1898

Root cause of the issue:
`flashinfer.prefill.trtllm_ragged_attention_deepseek` and
`flashinfer.prefill.trtllm_batch_context_with_kv_cache` both require
`max_q_len` to match the length of the query when batch size is 1.

**Updated PR:**
Issue has been addressed from the kernel-side so that the "*`max_q_len`
to match the length of the query when batch size is 1*" is no longer
required.

Current PR updates trtllm-gen FMHA cubins to latest and brings minor
updates to kernel metadata.

Unit test results after PR: 
```
$ pytest tests/attention/test_trtllm_gen_attention.py 
...
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 2320 items   
...
2055 passed, 264 skipped, 1 xfailed in 224.43s (0:03:44)
```

**Description of previous solution:**
~~Updating `max_q_len` to `cum_seq_lens_q[-1].item()` within the
`trtllm_ragged_attention_deepseek` or
`trtllm_batch_context_with_kv_cache` functions are not a viable option
because the CPU-side synchronization breaks the deterministic and fully
device-side execution required during CUDA graph capture. The workaround
was thus to update the test & benchmark codes that call the trtllm
prefill functions, and clearly state in the docstring that when
batch_size == 1, max_q_len must match the query size.~~

## 🔍 Related Issues

https://github.com/flashinfer-ai/flashinfer/issues/1898

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Removed the automatic batch_size=1 restriction for a native backend,
enabling its use in more scenarios while other constraints remain.

* **New Features**
* Added configurable block-sparse attention support to kernel
parameters.

* **Documentation**
* Clarified supported attention optimizations and backend capabilities
in the benchmarks docs.

* **Tests**
* Expanded tests with configurable sequence lengths and added dedicated
batch-size-1 test coverage.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Zihao Ye <[email protected]>

* Feature: Support Relu2 activation in fused MoE (#1954)

## 📌 Description
Added support for Relu2 activation in cutlass fp8 FusedMoE path.
`Relu2(x) = Relu(x)^2`.

Validated this works correctly on H100 and B200.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added Relu2 as a selectable activation across MOE operations and
exposed activation_type configuration to public MOE APIs and runner
interfaces (Swiglu remains the default).
* **Behavior**
* Certain GEMM execution paths now explicitly reject Relu2 and raise a
clear runtime error instead of silently failing.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amir Klein <[email protected]>

* fix: Add cutlass as an mm_fp4 backend in compute capability 12.0 in benchmark code (#1959)

<!-- .github/pull_request_template.md -->

## 📌 Description

Previously `backend='cutlass'` was not available to be benchmarked in
`flashinfer_benchmark.py` for compute capability 12.0 while the kernel
actually has been available. Current PR marks the backend as available.

Example output of being runnable after PR:
```
# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 512 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv                                                  
[INFO] args = Namespace(routine='mm_fp4', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=False, repro_command='', batch_size=1, m=1024, n=7168, k=512, tile_size=128, group_size=1, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cudnn', 'cutlass', 'trtllm'], use_128x4_sf_layout=True, use_nvfp4=True, autotune=False)
[INFO] Running testMmFp4
[INFO] FlashInfer version: 0.4.1
[VVERBOSE] gpu_name = 'NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition'
[WARNING] trtllm for routine mm_fp4 is not supported on compute capability 12.0. Skipping.
[VVERBOSE] input_fp4.shape = torch.Size([1024, 256])
[VVERBOSE] input_fp4.dtype = torch.uint8
[VVERBOSE] mat2_fp4.shape = torch.Size([7168, 256])
[VVERBOSE] mat2_fp4.dtype = torch.uint8
[PERF] cudnn          :: median time 0.014 ms; std 0.000 ms; achieved tflops 535.891 TFLOPs/sec; achieved tb_per_sec 1.196 TB/sec
[PERF] cutlass        :: median time 0.015 ms; std 0.000 ms; achieved tflops 515.203 TFLOPs/sec; achieved tb_per_sec 1.150 TB/sec
```

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
* Expanded backend support for benchmarking routines on compute
capability 12.0, adding compatibility with additional processing
backends.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* rebase on fw repo branch

* unittest: fix deepgemm sha256 (#1953)

<!-- .github/pull_request_template.md -->

## 📌 Description

Deepgemm unittest failed because of out-dated sha256, this PR fixes the
issue.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
* Updated internal artifact version information to support latest
optimizations and improvements.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* misc: Update artifacts docstring and MetaInfoHash (#1967)

<!-- .github/pull_request_template.md -->

## 📌 Description

Amendment to [PR
1761](https://github.com/flashinfer-ai/flashinfer/pull/1761), appending
docstring to two artifactory path classes and deprecating need to update
MetaInfoHash by directly accessing the checksum.txt file.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added runtime integrity checks for compiled artifacts that verify and
use checksum data during loading to prevent missing or mismatched
artifact headers.

* **Refactor**
* Switched artifact hash resolution to compute hashes dynamically from
provided checksums, improving validation, reliability, and resilience
when loading precompiled components.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* silu_and_mul nvfp4 quanization fusion rework (#1927)

<!-- .github/pull_request_template.md -->

## 📌 Description

This PR reverts https://github.com/flashinfer-ai/flashinfer/pull/1774
and https://github.com/flashinfer-ai/flashinfer/pull/1835 which have
some issues with some shapes under cuda graph. The kernels ported in
this PR comes from SGLANG. [[NVIDA] [1/N] Nvfp4 Masked Gemm: Add quant
op for the flashinfer grouped
gemm](https://github.com/sgl-project/sglang/pull/9200/files) and
[[NVIDIA] [2/N] Optimize silu_and_mul_scaled_fp4_grouped_quant
perf](https://github.com/sgl-project/sglang/pull/9556/files) by @kaixih
.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
- Added grouped FP4 quantization (scaled_fp4_grouped_quantize) and an
NV-focused Silu+Mul expert quantization entry
(silu_and_mul_scaled_nvfp4_experts_quantize).

* **API Changes**
- Replaced legacy batched APIs with new expert/grouped APIs; removed
legacy mask parameter from FP4/MXFP8 quantization signatures and
adjusted FP4 output layouts/types.

* **Documentation**
  - Updated docs to list new functions and remove deprecated symbols.

* **Tests**
- Updated tests to validate new quantization paths, shapes, dtypes, and
layouts.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Shu Wang. <[email protected]>

* unittest: fix test_artifacts.py (#1950)

* chore: update the list of authorized codeowners (#1970)

<!-- .github/pull_request_template.md -->

## 📌 Description

Add @djmmoss @jiahanc to the authorized codeowner list.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
  * Updated internal codeowner authorization configuration.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Added heuristic for trtllm_allreduce_fusion (#1972)

<!-- .github/pull_request_template.md -->

## 📌 Description
The original heuristic does not accurately reflect the performance of
oneshot/twoshot. Updated with heuristics based on this benchmark
[allreduce_test.py](https://github.com/user-attachments/files/23094671/allreduce_test.py).
The benchmark uses hidden_dim of Llama3, LLama4 and GPT-OSS and
combinations of token_num, fusion patterns and fp32_acc.

The results are at the bottom. TL;DR token_num is a bad predictor of
whether to use oneshot or twoshot. Using the communication size of
oneshot is a good predictor, but only if we treat each TP separately.
Fusion patterns and fp32_acc is irrelevant to the choice.

# Full size results
<img width="1800" height="3600" alt="comm_size_TP=2"
src="https://github.com/user-attachments/assets/2874157e-6268-421a-8f45-00491b652702"
/>
<img width="1800" height="3600" alt="comm_size_TP=4"
src="https://github.com/user-attachments/assets/2cdfdb9d-569e-401b-89ad-787f8d755ac1"
/>
<img width="1800" height="3600" alt="comm_size_TP=8"
src="https://github.com/user-attachments/assets/fbb147da-3479-4dbc-85b8-c27a735d0cd6"
/>

# Results zoomed in on small comm_size
<img width="1800" height="3600" alt="comm_size_Enlarge_TP=2"
src="https://github.com/user-attachments/assets/e070c81f-edf9-4d7f-ab95-fa6dea9f42f2"
/>
<img width="1800" height="3600" alt="comm_size_Enlarge_TP=4"
src="https://github.com/user-attachments/assets/3b1c51d2-56ca-4d34-9bfd-8082390cc95e"
/>
<img width="1800" height="3600" alt="comm_size_Enlarge_TP=8"
src="https://github.com/user-attachments/assets/9a8095b4-11bc-4021-80c6-f2be69b33021"
/>

# Mixing TP=2/4/8 makes the choice noisy
<img width="1800" height="3600" alt="comm_size_TP=248"
src="https://github.com/user-attachments/assets/66956ebe-6cf0-43e8-93ce-950b1079148a"
/>
<img width="1800" height="3600" alt="comm_size_Enlarge_TP=248"
src="https://github.com/user-attachments/assets/0cd6982c-da42-4f42-b0ad-5ef564b2e78e"
/>

# token_num is a bad predictor
<img width="1800" height="3600" alt="token_num_TP=248"
src="https://github.com/user-attachments/assets/2968ca7c-2059-4305-8e4d-5b70a32faaee"
/>
<img width="1800" height="3600" alt="token_num_Enlarge_TP=248"
src="https://github.com/user-attachments/assets/881ba86d-fc71-4cbc-b5a6-c050f255d618"
/>


<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

---------

Co-authored-by: yzh119 <[email protected]>

* Bump tvm ffi to stable version 0.1.0 (#1960)

<!-- .github/pull_request_template.md -->

## 📌 Description

This PR bumps the tvm-ffi to stable version 0.1.0 and update the
flashinfer code base.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

https://github.com/flashinfer-ai/flashinfer/pull/1939 

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Chores**
* Relaxed build dependency pins for apache-tvm-ffi and setuptools across
project configs; removed installation of multiple build packages from
the nightly CI step.
* **Refactor**
* Modernized internal CUDA/tensor access patterns to a consistent
accessor API across many modules.
* **Bug Fixes**
* GEMM runner now returns the output tensor in the correct
(non‑transposed) orientation.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Zihao Ye <[email protected]>
Co-authored-by: yzh119 <[email protected]>

* Update Docker CI tags to 20251024-0e48aaf (#1975)

This PR updates the Docker CI image tags to the latest version:
`20251024-0e48aaf`

Updated images:
- flashinfer/flashinfer-ci-cu126:20251024-0e48aaf
- flashinfer/flashinfer-ci-cu128:20251024-0e48aaf
- flashinfer/flashinfer-ci-cu129:20251024-0e48aaf
- flashinfer/flashinfer-ci-cu130:20251024-0e48aaf

Auto-generated by [release-ci-docker
workflow](https://github.com/flashinfer-ai/flashinfer/actions/runs/18778064727)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
* Updated CI/CD Docker image configurations to ensure consistency and
reliability across build environments.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: yzh119 <[email protected]>

* fix: Make attention microbenchmark correctly use page table (#1976)

<!-- .github/pull_request_template.md -->

## 📌 Description

Current microbenchmark code does not provides instantiated
`block_tables` to all backends. The omission had no impact to
correctness or perf because page tables are instantiated linearly when
not provided, but will manifest as mismatches if it is shuffled.

The current PR simply calls the FlashInfer APIs in their intended way.

**No changes to library code**

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Refactor**
* Enhanced consistency in attention computation by aligning page-table
parameter handling across different inference backend implementations
for improved paged key-value cache operations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* fix: Skipping attention sink Blackwell test outside of Blackwell (#1978)

<!-- .github/pull_request_template.md -->

## 📌 Description

`test_attention_sink_blackwell.py` checks
`flashinfer.prefill.trtllm_batch_context_with_kv_cache` and
`flashinfer.decode.trtllm_batch_decode_with_kv_cache` which are only
supported on Blackwell SM100 and SM103.

Existing check only skips testing of SM 11x or 12x, which causes
failures on Hopper SM90.

Test outputs:
* H200:
   * Before Fix: `144 failed, 1 warning in 9.20s`
   * After Fix: `144 skipped, 1 warning in 0.42s`
* B200: 
   * After Fix: `144 passed in 34.64s `

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Tests**
* Updated GPU compatibility checks for attention sink tests to target
specific GPU architectures (SM100/SM103). Tests now run exclusively on
supported GPU models with updated filtering criteria.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* feat: enable deepgemm jit for fp8 block-scale on SM90 (#1969)

<!-- .github/pull_request_template.md -->

## 📌 Description

Enable JIT compile for the FP8 DeepGEMM kernels, NVRTC is currently
disabled it uses NVCC by default.

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).




<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Refactor**
* JIT include directory discovery now uses the flashinfer-python package
instead of the previous package.
  * Updated resolved include path to the flashinfer data location.
* Runtime compilation now consistently uses NVCC; the prior
environment-variable toggle was removed.
* Updated warning text when the expected package installation cannot be
found.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Duncan Moss <[email protected]>

* chore: Update CODEOWNERS (#1949)

## Summary

This PR updates the CODEOWNERS file based on git commit history analysis
from the last 180 days.

## Changes

- Updated `.github/CODEOWNERS` with current code ownership based on:
  - Commit frequency
  - File coverage
  - Commit recency

## How to Review

1. Review the changes to `.github/CODEOWNERS`
2. Verify that the assigned owners are appropriate for each module
3. Make manual adjustments if needed before merging

## Notes

- This is an automated PR generated weekly
- Minimum commits threshold: 1
- Analysis period: 180 days
- Directory depth: 3 levels
- Top N owners per module: 5

---

🤖 This PR was automatically generated by the [update-codeowners
workflow](.github/workflows/update-codeowners.yml)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
  * Updated internal code ownership assignments.

---

**Note:** This update contains no user-facing changes or feature
updates. It is an internal administrative modification.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: flashinfer-bot <[email protected]>
Co-authored-by: Claude <[email protected]>

* fix: correct PDL parameter handling in RopeQuantize kernel (#1982)

<!-- .github/pull_request_template.md -->

## 📌 Description

### 1. Fixed Parameter Alignment
- **Issue**: The `stream` parameter was being passed to the wrong
position in the `RopeQuantize` function call due to missing `enable_pdl`
parameter. SGLang will hang before this pr.
- **Fix**: Added the `enable_pdl` parameter to the function signature
and properly aligned all parameters

### 2. Fixed PDL Launch Configuration
- **Issue**: When `enable_pdl=true`, the kernel would throw CUDA errors
due to incorrect PDL attribute handling
- **Fix**: Aligned the implementation with `csrc/fmhaReduction.cu`.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added PDL (Programmatic Dynamic Launch) benchmarking capability for
rope quantization operations.
* Extended configuration options to enable or disable PDL functionality.

* **Tests**
* Updated test suite to validate PDL enabled and disabled scenarios in
rope quantization workflows.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Fix: Verify scales are not None for Cutlass FP8 FusedMoE (#1961)

## 📌 Description
Verify quant scales for fp8 are non null in cutlass FusedMoE path.
Currently, if these tensors are passed as None from python it will
result in segmentation fault.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Enhanced validation for FP8 quantization parameters to improve system
robustness and prevent potential null reference issues during
quantization operations, reducing the risk of runtime errors when
processing quantized model data.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amir Klein <[email protected]>

* add xqa fp8 mha and fp8 kv cache (#1769)

<!-- .github/pull_request_template.md -->

## 📌 Description

Add xqa fp8 mha and fp8 kv cache. Add fp8 mla for sm120. Use vllm kv
layout.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
  * MLA-based attention path and dedicated MLA entrypoints (SM120/121)
* FP8 KV-cache support with optional paged KV layout and separate K/V
cache inputs
* Asynchronous tensor-map/TMA and matrix-descriptor primitives for
high-throughput GPU transfers
  * Dtype-driven config and expanded GPU SM gating for builds/runtimes

* **Bug Fixes**
  * Improved numerical stability for attention mask initialization

* **Tests**
  * Expanded coverage for MLA, FP8, FP16/BF16, and new cache layouts

* **Documentation**
  * Added XQA API docs and new public symbols
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Qidi Sang <[email protected]>
Co-authored-by: yzh119 <[email protected]>

* unittest: fix failed unittest on hopper (#1952)

<!-- .github/pull_request_template.md -->

## 📌 Description

Some invalid configuration are generated in JIT warmup (mixed precision)
function `gen_prefill_attention_modules`.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Tests**
* Updated test infrastructure to enhance compatibility handling for
specific hardware acceleration scenarios, improving test robustness for
mixed-precision configurations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* docs: Update documented versioning scheme to right-shifted semver (#1990)

<!-- .github/pull_request_template.md -->

## 📌 Description

Based on discussion with @yzh119 and others, we're planning to follow
the vLLM "right-shifted" versioning scheme. This PR updates the docs to
reflect that.

## 🔍 Related Issues

Previously we said we would follow Semantic Versioning (see #1553).
However, we recently re-considered this approach, to better match the
conventions followed by vLLM and PyTorch.

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

Docs only, so no new tests are needed. Did not verify passing unit
tests.

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Documentation**
* Updated release versioning scheme to a "right-shifted" format
(major.minor.patch[.post1]) with an optional post-release suffix for
expedited follow-up releases.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Bugfix: Change get() -> GetDLTensorPtr() in cutlass FusedMoE validations (#1995)

## 📌 Description
Using different API after `apach-tvm-ffi` version bump.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Improved null-pointer validation for FP8 quantization tensors used
during inference, increasing robustness and reducing risk of runtime
errors.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amir Klein <[email protected]>

* unittest: Add SM arch checks to skip unsupported tests on Hopper (#1998)

<!-- .github/pull_request_template.md -->

## 📌 Description

A number of unit tests fail on Hopper because they either do not have a
support-check or fail based on "what is not supported" while missing
SM90. Current PR adds checks based on "what is supported" and skips if
not in the supported list of SMs.

Special case of `mm_fp4` where `mm_fp4.is_backend_supported(backend,
compute_capability_number)` now exists and is used to skip tests if not
supported.

Impacted tests:
* tests/attention/test_trtllm_gen_attention.py
* tests/attention/test_trtllm_gen_mla.py
* tests/gemm/test_bmm_fp8.py
* tests/gemm/test_mm_fp4.py
* tests/gemm/test_groupwise_scaled_gemm_fp8.py
* tests/gemm/test_groupwise_scaled_gemm_mxfp4.py
* tests/moe/test_trtllm_gen_fused_moe.py


<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

* Added workspace check and reflected this in test (#1991)

<!-- .github/pull_request_template.md -->

## 📌 Description

This PR attempts to fix #1986 (to be confirmed by requester)

The issue is that num_tokens was larger than MAX_TOKEN_NUM, which
results in an IMA, or even in a hang. To address this, I added a
validation check. This required a non-breaking API change:
* create_ipc_workspace_for_all_reduce_fusion now has an optional
"create_metadata" bool, which results in an additional return value
  * it is made optional as additional return value could break the API
* trtllm_allreduce_fusion now takes an optional metadata dictionary
  * When provided, this will run the validation check
  * again, this is also optional, to avoid breaking the api   


In addition this PR deprecates the older AllReduce functionality so it can be removed in a major version bump.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **API Changes**
* Workspace creation can optionally return metadata describing the
workspace configuration (create_metadata flag).
* Allreduce fusion operations accept optional metadata to validate
runtime parameters against the workspace and raise clear errors on
mismatch.
  * A workspace destruction endpoint was renamed for naming consistency.
* Legacy wrappers were marked deprecated and now point users toward the
newer fusion variants.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* minor fix for xqa (#1994)

<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
1 change xqa_mla comments to be consistent with mla instead of mha.
2 put cudaMemcpyFromSymbol/cudaFuncSetAttribute outside of launch
function to avoid breaking cuda graph capture
3 use int32 as pagetable index 

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added MLA variant documentation clarifying SM120 GPU requirement and
fixed head group ratio configuration.

* **Documentation**
* Updated data type specifications for XQA operations; page table now
requires int32 instead of uint32.
* Added max sequence length derivation notes for page-table-based
configurations.
* Clarified MLA variant input/output data types (float8_e4m3fn and
bfloat16).

* **Bug Fixes**
* Corrected data type handling in page table processing to ensure
compatibility.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Qidi Sang <[email protected]>
Co-authored-by: yzh119 <[email protected]>

* Feature: Add support for L40 FusedMoE in cutlass path (#1973)

## 📌 Description
Fixed a few compilation issues for L40, and removed 1 gemm tactic for
`sm == 89` that crashes due to:
```
Assertion failed: GPU lacks the shared memory resources to run GroupedGEMM kernel
```

## 🧪 Tests

Ran `pytest tests/moe/test_trtllm_cutlass_fused_moe.py` manually on an
L40 GPU and verified all tests passed.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Official support for SM89 target: build/JIT flags and a public
generation path to target it.

* **Bug Fixes / Compatibility**
* Clarified FP8/FP4 dispatch: FP8 paths enabled for SM89; FP4 usage
remains gated and now requires explicit enablement.

* **Performance**
* Adjusted kernel/tile selection order for certain FP8 paths to prefer
SM89-optimized options.

* **Chores**
  * Reduced logging severity for failed tactic profiling to warn/debug.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amir Klein <[email protected]>

* unittest: Add head dim 256 test cases and mark as xfail (#1999)

* feat: autotune tile_tokens_dim in trtllm-gen MOE (#1980)

<!-- .github/pull_request_template.md -->

## 📌 Description
- Update the autotune logic in trtllm-gen moe. Instead of using a fixed
`tile_tokens_dim`, tune in a range of
`[max(8,tile_token_dim/2), tile_token_dim, min(128, tile_token_dim*2),
min(128, tile_token_dim*4)]`
- Add FP8 MOE autotune logic, initial PR
https://github.com/flashinfer-ai/flashinfer/pull/1494 from @aleozlx,
update logic to sync with new autotuner.
- Update logic in `test_trtllm_gen_fused_moe.py`.
- Update the `conftest.py` to speed up test, previously use `try_first`
which introduce duplicate run
- Add log_once in logger
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Runtime autotuning with per-tile dynamic routing and selectable MoE
runner options (gated activation, shuffled-weight, weight-layout).
  * One-time (deduplicated) logging helpers added to JIT logger.

* **Deprecations**
* tile_tokens_dim removed from new paths and marked deprecated in legacy
entry points; new tuning parameters introduced for autotuning.

* **Tests**
* Tests refactored for autotuning/routing with new helpers and improved
handling/reporting for missing JIT cache.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: jiahanc <[email protected]>
Co-authored-by: yzh119 <[email protected]>

* Fix trtllm-gen attention illegal memory access (#2002)

<!-- .github/pull_request_template.md -->

## 📌 Description

This PR fixes illegal memory access of trtllm-gen attention kernels. It
changes the workspace buffer from `int_workspace_buffer` to
`float_workspace_buffer`. `int_workspace_buffer` is a fixed sized buffer
and not initialized to zero, which should not be used.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

Issue #1928 

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Fixed memory allocation in the decode module to improve computation
accuracy and stability during text generation.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* release: Bump version for v0.5.0rc1 release; (#2008)

<!-- .github/pull_request_template.md -->

## 📌 Description

Update version in `version.txt` to v0.5.0 as we prepare for v0.5.0rc1
release.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
  * Version bump to 0.5.0 (no functional changes)

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* bugfix: fix regex in update wheel index script (#2009)

<!-- .github/pull_request_template.md -->

## 📌 Description

The regex cannot recognize release candidates (`v0.5.0rc1`) or post
releases (`v1.2.3.post1`):
https://github.com/flashinfer-ai/flashinfer/actions/runs/18929490991/job/54049304551

This PR fixes the issue.

## 🔍 Related Issues


https://github.com/flashinfer-ai/flashinfer/actions/runs/18929490991/job/54049304551

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
* Enhanced version string parsing in the wheel package indexing process
to support more complex version formats, including pre-release,
post-release, and development versions, ensuring compatibility with PEP
440 versioning standards.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* fix: Enable SM121 for mm_fp4 (#2012)

<!-- .github/pull_request_template.md -->

## 📌 Description

In #1809 we previously added a compute-capability-based support check
for `mm_fp4`.

However, we missed enabling SM121 for backend = `cudnn` and  `cutlass`. 
Additionally, we marked `trtllm` as supported on SM120 when it is not.

Current PR fixes it. Example benchmark and pytest command on SM121 after
the fix
```
(py312) root@f414f262f02a:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 7168 --k 512 --out_dtype bfloat16 --backends cudnn cutlass --use_128x4_sf_layout --use_nvfp4 --refcheck --use_cupti
/opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: 
    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    
  warnings.warn(
[PERF] cudnn          :: median time 0.656 ms; std 0.025 ms; achieved tflops 91.701 TFLOPs/sec; achieved tb_per_sec 0.185 TB/sec
[PERF] cutlass        :: median time 0.669 ms; std 0.022 ms; achieved tflops 89.859 TFLOPs/sec; achieved tb_per_sec 0.181 TB/sec

(py312) root@f414f262f02a:/flashinfer# pytest tests/gemm/test_mm_fp4.py 
====================================================================================================================== test session starts ======================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 3240 items     
...
======================================================================================================================= warnings summary ========================================================================================================================
../opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285
  /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: 
      Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
      Minimum and Maximum cuda capability supported by this version of PyTorch is
      (8.0) - (12.0)
      
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================================= 450 passed, 2790 skipped, 1 warning in 8.24s ==========================================================================================================


```

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Expanded hardware compatibility by adding support for newer NVIDIA GPU
architectures.
* FP4 quantized operations now available across multiple backends on
supported devices.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* fix: ensure SM120/121 SFA/SFB contiguity (#1963)

<!-- .github/pull_request_template.md -->

## 📌 Description

Fix the regression in vLLM and SGLang with FI 0.4.0 in bmm_fp8

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

cc: @yzh119


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Fixed memory layout handling for tensor operations in GPU computations
to ensure proper alignment, improving stability and performance.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* More realistic bench for POD Attn (#2013)

<!-- .github/pull_request_template.md -->

## 📌 Description

Use real head sizes, seq lens and add comparison with sequential prefill
+ decode.
Results on H100 (without overlap, which only adds ~150GB/s for
persistent):
<img width="433" height="571" alt="image"
src="https://github.com/user-attachments/assets/50de01cd-e5ca-450c-9cc0-521d83b7e487"
/>
cc @yzh119 
## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

* **New Features**
* Added comprehensive performance benchmarking for batch attention
operations with detailed timing measurements.
* Introduced sequential dual-kernel benchmark path with extended memory
bandwidth reporting.

* **Tests**
* Updated benchmark test configurations to use deterministic, fixed
values for improved reproducibility.
* Adjusted benchmark parameters for consistency across test iterations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Feature: Support non-gated activation in cutlass fused MoE nvfp4 (#2011)

## 📌 Description

This PR removes an assertion in the cutlass fused moe bindings to enable
non-gated activations in nvfp4.
It also adds a test for this path with relu2 activation.

## 🔍 Related Issues

N/A

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [v] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [v] I have installed the hooks with `pre-commit install`.
- [v] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [v] Tests have been added or updated as needed.
- [v] All tests are passing (`unittest`, etc.).

## Reviewer Notes

N/A

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Enhanced quantized Mixture of Experts models to support configurable
activation types (Swiglu and ReLU2) in the NVFP4 quantization path.
* Improved parameter handling to correctly adapt weight shapes and
quantization settings based on the selected activation type.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Omer Ullman Argov <[email protected]>

* feat: add xqa backend and completes NHD/HND coverage for trtllm-gen/xqa backend (#2001)

<!-- .github/pull_request_template.md -->

## 📌 Description
Expose xqa backend to trtllm attention interface, and improve layout
coverage of trtllm-gen and xqa backends.

Now both trtllm-gen/xqa supports NHD/HND kv-cache layout.
* support NHD layout for trtllm-gen
* refactor xqa
(https://github.com/flashinfer-ai/flashinfer/commit/869c0c1c6bc199f82f30c23ab78a1b4aa9a1bd3a)
    * allow user passed stride_page/head/token
    * support both HND and NHD
    * remove macros such as PAGED_KV_CACHE_LAYOUT and USE_PAGED_KV_CACHE
* adding unittests for both trtllm-gen/xqa on NHD/HND
* adding unified API for trtllm-gen/xqa, and unified unittest

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added xqa-based batch decode API and public kv_layout option
(NHD/HND); added enable_pdl toggle to inference wrappers.

* **Improvements**
* Automatic backend selection for decoding, consistent KV-layout
normalization across paths, and unified stride-aware paged-KV handling
with layout-aware shapes, scales, and workspace handling.

* **Tests**
* Expanded tests to cover both KV layouts, enable_pdl, new batch-decode
workflows, backend/layout permutations, and fp8/mixed-dtype scenarios.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Qidi Sang <[email protected]>
Co-authored-by: yzh119 <[email protected]>
Co-authored-by: Zihao Ye <[email protected]>

* test: Enable xfailed trtllm decode long seqlen tests and update microbenchmark (#2018)

<!-- .github/pull_request_template.md -->

## 📌 Description


[tests/attention/test_trtllm_gen_attention.py](https://github.com/flashinfer-ai/flashinfer/blob/v0.5.0rc2/tests/attention/test_trtllm_gen_attention.py#L1021-L1076)
was failing and therefore marked xfail.

PR #2002 fixed the underlying root cause. Current PR thus removed the
`xfail` marker so that these long seqlen cases could be fixed moving
forward.

Additionally, PR #2002 revealed a bug in the microbenchmark script where
[trtllm_batch_decode_with_kv_cache](https://github.com/flashinfer-ai/flashinfer/blob/v0.5.0rc2/flashinfer/decode.py#L2082-L2083)
explicitly requires the workspace to be zeroed before first use:
```
    workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use.
        workspace
```
while the microbenchmark code does not zero out, causing undefined
behavior such as IMAs that depend on the ordering of backends tested.
Current PR fixes the issue by explicitly calling
`workspace_buffer.zero_()` between testing different backends.


<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Improved stability of performance benchmarks by properly resetting
workspace buffer between backend invocations.

* **Tests**
  * Enabled previously skipped test for long sequence length handling.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Updated decorator to support unspecified default (#2026)

<!-- .github/pull_request_template.md -->

## 📌 Description

Updated decorator to support unspecified default. This was causing
issues when calling mm_fp4 without backend specified.
Also added SM 110 as a supported backend on the cutlass backend (mm_fp4)

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, s…
murphymatt added a commit to fw-ai/flashinfer that referenced this pull request Nov 13, 2025
* chore: rename FLASHINFER_JIT_VERBOSE to FLASHINFER_JIT_DEBUG for clarity (#1946)

<!-- .github/pull_request_template.md -->

Rename environment variable `FLASHINFER_JIT_VERBOSE` to
`FLASHINFER_JIT_DEBUG` to better reflect its actual behavior.

- `FLASHINFER_JIT_DEBUG`: Enable debug mode during compilation (disable
optimization, add debug symbols)
- The previous name `FLASHINFER_JIT_VERBOSE` implied "showing more
compilation info", which was confusing
- Maintained backward compatibility: falls back to
`FLASHINFER_JIT_VERBOSE` if `FLASHINFER_JIT_DEBUG` is not set

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Refactor**
* Introduced FLASHINFER_JIT_DEBUG environment variable for controlling
JIT debug builds with backward compatibility for legacy
FLASHINFER_JIT_VERBOSE.
* Enhanced debug build configuration with improved compiler and CUDA
debugging flags. Non-debug builds continue using -O3 optimizations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* fix: Fix trtllm-gen prefill IMA when batch_size==1 (#1912)

<!-- .github/pull_request_template.md -->

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Current PR fixes the test and benchmark codes IMAs when running
trtllm-gen paged & ragged prefill with batch size 1 -- the issue was
described in https://github.com/flashinfer-ai/flashinfer/issues/1898

Root cause of the issue:
`flashinfer.prefill.trtllm_ragged_attention_deepseek` and
`flashinfer.prefill.trtllm_batch_context_with_kv_cache` both require
`max_q_len` to match the length of the query when batch size is 1.

**Updated PR:**
Issue has been addressed from the kernel-side so that the "*`max_q_len`
to match the length of the query when batch size is 1*" is no longer
required.

Current PR updates trtllm-gen FMHA cubins to latest and brings minor
updates to kernel metadata.

Unit test results after PR:
```
$ pytest tests/attention/test_trtllm_gen_attention.py
...
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 2320 items
...
2055 passed, 264 skipped, 1 xfailed in 224.43s (0:03:44)
```

**Description of previous solution:**
~~Updating `max_q_len` to `cum_seq_lens_q[-1].item()` within the
`trtllm_ragged_attention_deepseek` or
`trtllm_batch_context_with_kv_cache` functions are not a viable option
because the CPU-side synchronization breaks the deterministic and fully
device-side execution required during CUDA graph capture. The workaround
was thus to update the test & benchmark codes that call the trtllm
prefill functions, and clearly state in the docstring that when
batch_size == 1, max_q_len must match the query size.~~

https://github.com/flashinfer-ai/flashinfer/issues/1898

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Bug Fixes**
* Removed the automatic batch_size=1 restriction for a native backend,
enabling its use in more scenarios while other constraints remain.

* **New Features**
* Added configurable block-sparse attention support to kernel
parameters.

* **Documentation**
* Clarified supported attention optimizations and backend capabilities
in the benchmarks docs.

* **Tests**
* Expanded tests with configurable sequence lengths and added dedicated
batch-size-1 test coverage.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Zihao Ye <[email protected]>

* Feature: Support Relu2 activation in fused MoE (#1954)

Added support for Relu2 activation in cutlass fp8 FusedMoE path.
`Relu2(x) = Relu(x)^2`.

Validated this works correctly on H100 and B200.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Added Relu2 as a selectable activation across MOE operations and
exposed activation_type configuration to public MOE APIs and runner
interfaces (Swiglu remains the default).
* **Behavior**
* Certain GEMM execution paths now explicitly reject Relu2 and raise a
clear runtime error instead of silently failing.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amir Klein <[email protected]>

* fix: Add cutlass as an mm_fp4 backend in compute capability 12.0 in benchmark code (#1959)

<!-- .github/pull_request_template.md -->

Previously `backend='cutlass'` was not available to be benchmarked in
`flashinfer_benchmark.py` for compute capability 12.0 while the kernel
actually has been available. Current PR marks the backend as available.

Example output of being runnable after PR:
```
[INFO] args = Namespace(routine='mm_fp4', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=False, repro_command='', batch_size=1, m=1024, n=7168, k=512, tile_size=128, group_size=1, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cudnn', 'cutlass', 'trtllm'], use_128x4_sf_layout=True, use_nvfp4=True, autotune=False)
[INFO] Running testMmFp4
[INFO] FlashInfer version: 0.4.1
[VVERBOSE] gpu_name = 'NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition'
[WARNING] trtllm for routine mm_fp4 is not supported on compute capability 12.0. Skipping.
[VVERBOSE] input_fp4.shape = torch.Size([1024, 256])
[VVERBOSE] input_fp4.dtype = torch.uint8
[VVERBOSE] mat2_fp4.shape = torch.Size([7168, 256])
[VVERBOSE] mat2_fp4.dtype = torch.uint8
[PERF] cudnn          :: median time 0.014 ms; std 0.000 ms; achieved tflops 535.891 TFLOPs/sec; achieved tb_per_sec 1.196 TB/sec
[PERF] cutlass        :: median time 0.015 ms; std 0.000 ms; achieved tflops 515.203 TFLOPs/sec; achieved tb_per_sec 1.150 TB/sec
```

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
* Expanded backend support for benchmarking routines on compute
capability 12.0, adding compatibility with additional processing
backends.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* rebase on fw repo branch

* unittest: fix deepgemm sha256 (#1953)

<!-- .github/pull_request_template.md -->

Deepgemm unittest failed because of out-dated sha256, this PR fixes the
issue.

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
* Updated internal artifact version information to support latest
optimizations and improvements.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* misc: Update artifacts docstring and MetaInfoHash (#1967)

<!-- .github/pull_request_template.md -->

Amendment to [PR
1761](https://github.com/flashinfer-ai/flashinfer/pull/1761), appending
docstring to two artifactory path classes and deprecating need to update
MetaInfoHash by directly accessing the checksum.txt file.

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Added runtime integrity checks for compiled artifacts that verify and
use checksum data during loading to prevent missing or mismatched
artifact headers.

* **Refactor**
* Switched artifact hash resolution to compute hashes dynamically from
provided checksums, improving validation, reliability, and resilience
when loading precompiled components.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* silu_and_mul nvfp4 quanization fusion rework (#1927)

<!-- .github/pull_request_template.md -->

This PR reverts https://github.com/flashinfer-ai/flashinfer/pull/1774
and https://github.com/flashinfer-ai/flashinfer/pull/1835 which have
some issues with some shapes under cuda graph. The kernels ported in
this PR comes from SGLANG. [[NVIDA] [1/N] Nvfp4 Masked Gemm: Add quant
op for the flashinfer grouped
gemm](https://github.com/sgl-project/sglang/pull/9200/files) and
[[NVIDIA] [2/N] Optimize silu_and_mul_scaled_fp4_grouped_quant
perf](https://github.com/sgl-project/sglang/pull/9556/files) by @kaixih
.

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
- Added grouped FP4 quantization (scaled_fp4_grouped_quantize) and an
NV-focused Silu+Mul expert quantization entry
(silu_and_mul_scaled_nvfp4_experts_quantize).

* **API Changes**
- Replaced legacy batched APIs with new expert/grouped APIs; removed
legacy mask parameter from FP4/MXFP8 quantization signatures and
adjusted FP4 output layouts/types.

* **Documentation**
  - Updated docs to list new functions and remove deprecated symbols.

* **Tests**
- Updated tests to validate new quantization paths, shapes, dtypes, and
layouts.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Shu Wang. <[email protected]>

* unittest: fix test_artifacts.py (#1950)

* chore: update the list of authorized codeowners (#1970)

<!-- .github/pull_request_template.md -->

Add @djmmoss @jiahanc to the authorized codeowner list.

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
  * Updated internal codeowner authorization configuration.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Added heuristic for trtllm_allreduce_fusion (#1972)

<!-- .github/pull_request_template.md -->

The original heuristic does not accurately reflect the performance of
oneshot/twoshot. Updated with heuristics based on this benchmark
[allreduce_test.py](https://github.com/user-attachments/files/23094671/allreduce_test.py).
The benchmark uses hidden_dim of Llama3, LLama4 and GPT-OSS and
combinations of token_num, fusion patterns and fp32_acc.

The results are at the bottom. TL;DR token_num is a bad predictor of
whether to use oneshot or twoshot. Using the communication size of
oneshot is a good predictor, but only if we treat each TP separately.
Fusion patterns and fp32_acc is irrelevant to the choice.

<img width="1800" height="3600" alt="comm_size_TP=2"
src="https://github.com/user-attachments/assets/2874157e-6268-421a-8f45-00491b652702"
/>
<img width="1800" height="3600" alt="comm_size_TP=4"
src="https://github.com/user-attachments/assets/2cdfdb9d-569e-401b-89ad-787f8d755ac1"
/>
<img width="1800" height="3600" alt="comm_size_TP=8"
src="https://github.com/user-attachments/assets/fbb147da-3479-4dbc-85b8-c27a735d0cd6"
/>

<img width="1800" height="3600" alt="comm_size_Enlarge_TP=2"
src="https://github.com/user-attachments/assets/e070c81f-edf9-4d7f-ab95-fa6dea9f42f2"
/>
<img width="1800" height="3600" alt="comm_size_Enlarge_TP=4"
src="https://github.com/user-attachments/assets/3b1c51d2-56ca-4d34-9bfd-8082390cc95e"
/>
<img width="1800" height="3600" alt="comm_size_Enlarge_TP=8"
src="https://github.com/user-attachments/assets/9a8095b4-11bc-4021-80c6-f2be69b33021"
/>

<img width="1800" height="3600" alt="comm_size_TP=248"
src="https://github.com/user-attachments/assets/66956ebe-6cf0-43e8-93ce-950b1079148a"
/>
<img width="1800" height="3600" alt="comm_size_Enlarge_TP=248"
src="https://github.com/user-attachments/assets/0cd6982c-da42-4f42-b0ad-5ef564b2e78e"
/>

<img width="1800" height="3600" alt="token_num_TP=248"
src="https://github.com/user-attachments/assets/2968ca7c-2059-4305-8e4d-5b70a32faaee"
/>
<img width="1800" height="3600" alt="token_num_Enlarge_TP=248"
src="https://github.com/user-attachments/assets/881ba86d-fc71-4cbc-b5a6-c050f255d618"
/>

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

---------

Co-authored-by: yzh119 <[email protected]>

* Bump tvm ffi to stable version 0.1.0 (#1960)

<!-- .github/pull_request_template.md -->

This PR bumps the tvm-ffi to stable version 0.1.0 and update the
flashinfer code base.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

https://github.com/flashinfer-ai/flashinfer/pull/1939

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
* Relaxed build dependency pins for apache-tvm-ffi and setuptools across
project configs; removed installation of multiple build packages from
the nightly CI step.
* **Refactor**
* Modernized internal CUDA/tensor access patterns to a consistent
accessor API across many modules.
* **Bug Fixes**
* GEMM runner now returns the output tensor in the correct
(non‑transposed) orientation.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Zihao Ye <[email protected]>
Co-authored-by: yzh119 <[email protected]>

* Update Docker CI tags to 20251024-0e48aaf (#1975)

This PR updates the Docker CI image tags to the latest version:
`20251024-0e48aaf`

Updated images:
- flashinfer/flashinfer-ci-cu126:20251024-0e48aaf
- flashinfer/flashinfer-ci-cu128:20251024-0e48aaf
- flashinfer/flashinfer-ci-cu129:20251024-0e48aaf
- flashinfer/flashinfer-ci-cu130:20251024-0e48aaf

Auto-generated by [release-ci-docker
workflow](https://github.com/flashinfer-ai/flashinfer/actions/runs/18778064727)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
* Updated CI/CD Docker image configurations to ensure consistency and
reliability across build environments.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: yzh119 <[email protected]>

* fix: Make attention microbenchmark correctly use page table (#1976)

<!-- .github/pull_request_template.md -->

Current microbenchmark code does not provides instantiated
`block_tables` to all backends. The omission had no impact to
correctness or perf because page tables are instantiated linearly when
not provided, but will manifest as mismatches if it is shuffled.

The current PR simply calls the FlashInfer APIs in their intended way.

**No changes to library code**

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Refactor**
* Enhanced consistency in attention computation by aligning page-table
parameter handling across different inference backend implementations
for improved paged key-value cache operations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* fix: Skipping attention sink Blackwell test outside of Blackwell (#1978)

<!-- .github/pull_request_template.md -->

`test_attention_sink_blackwell.py` checks
`flashinfer.prefill.trtllm_batch_context_with_kv_cache` and
`flashinfer.decode.trtllm_batch_decode_with_kv_cache` which are only
supported on Blackwell SM100 and SM103.

Existing check only skips testing of SM 11x or 12x, which causes
failures on Hopper SM90.

Test outputs:
* H200:
   * Before Fix: `144 failed, 1 warning in 9.20s`
   * After Fix: `144 skipped, 1 warning in 0.42s`
* B200:
   * After Fix: `144 passed in 34.64s `

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Tests**
* Updated GPU compatibility checks for attention sink tests to target
specific GPU architectures (SM100/SM103). Tests now run exclusively on
supported GPU models with updated filtering criteria.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* feat: enable deepgemm jit for fp8 block-scale on SM90 (#1969)

<!-- .github/pull_request_template.md -->

Enable JIT compile for the FP8 DeepGEMM kernels, NVRTC is currently
disabled it uses NVCC by default.

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Refactor**
* JIT include directory discovery now uses the flashinfer-python package
instead of the previous package.
  * Updated resolved include path to the flashinfer data location.
* Runtime compilation now consistently uses NVCC; the prior
environment-variable toggle was removed.
* Updated warning text when the expected package installation cannot be
found.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Duncan Moss <[email protected]>

* chore: Update CODEOWNERS (#1949)

This PR updates the CODEOWNERS file based on git commit history analysis
from the last 180 days.

- Updated `.github/CODEOWNERS` with current code ownership based on:
  - Commit frequency
  - File coverage
  - Commit recency

1. Review the changes to `.github/CODEOWNERS`
2. Verify that the assigned owners are appropriate for each module
3. Make manual adjustments if needed before merging

- This is an automated PR generated weekly
- Minimum commits threshold: 1
- Analysis period: 180 days
- Directory depth: 3 levels
- Top N owners per module: 5

---

🤖 This PR was automatically generated by the [update-codeowners
workflow](.github/workflows/update-codeowners.yml)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
  * Updated internal code ownership assignments.

---

**Note:** This update contains no user-facing changes or feature
updates. It is an internal administrative modification.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: flashinfer-bot <[email protected]>
Co-authored-by: Claude <[email protected]>

* fix: correct PDL parameter handling in RopeQuantize kernel (#1982)

<!-- .github/pull_request_template.md -->

- **Issue**: The `stream` parameter was being passed to the wrong
position in the `RopeQuantize` function call due to missing `enable_pdl`
parameter. SGLang will hang before this pr.
- **Fix**: Added the `enable_pdl` parameter to the function signature
and properly aligned all parameters

- **Issue**: When `enable_pdl=true`, the kernel would throw CUDA errors
due to incorrect PDL attribute handling
- **Fix**: Aligned the implementation with `csrc/fmhaReduction.cu`.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Added PDL (Programmatic Dynamic Launch) benchmarking capability for
rope quantization operations.
* Extended configuration options to enable or disable PDL functionality.

* **Tests**
* Updated test suite to validate PDL enabled and disabled scenarios in
rope quantization workflows.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Fix: Verify scales are not None for Cutlass FP8 FusedMoE (#1961)

Verify quant scales for fp8 are non null in cutlass FusedMoE path.
Currently, if these tensors are passed as None from python it will
result in segmentation fault.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Bug Fixes**
* Enhanced validation for FP8 quantization parameters to improve system
robustness and prevent potential null reference issues during
quantization operations, reducing the risk of runtime errors when
processing quantized model data.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amir Klein <[email protected]>

* add xqa fp8 mha and fp8 kv cache (#1769)

<!-- .github/pull_request_template.md -->

Add xqa fp8 mha and fp8 kv cache. Add fp8 mla for sm120. Use vllm kv
layout.

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
  * MLA-based attention path and dedicated MLA entrypoints (SM120/121)
* FP8 KV-cache support with optional paged KV layout and separate K/V
cache inputs
* Asynchronous tensor-map/TMA and matrix-descriptor primitives for
high-throughput GPU transfers
  * Dtype-driven config and expanded GPU SM gating for builds/runtimes

* **Bug Fixes**
  * Improved numerical stability for attention mask initialization

* **Tests**
  * Expanded coverage for MLA, FP8, FP16/BF16, and new cache layouts

* **Documentation**
  * Added XQA API docs and new public symbols
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Qidi Sang <[email protected]>
Co-authored-by: yzh119 <[email protected]>

* unittest: fix failed unittest on hopper (#1952)

<!-- .github/pull_request_template.md -->

Some invalid configuration are generated in JIT warmup (mixed precision)
function `gen_prefill_attention_modules`.

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Tests**
* Updated test infrastructure to enhance compatibility handling for
specific hardware acceleration scenarios, improving test robustness for
mixed-precision configurations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* docs: Update documented versioning scheme to right-shifted semver (#1990)

<!-- .github/pull_request_template.md -->

Based on discussion with @yzh119 and others, we're planning to follow
the vLLM "right-shifted" versioning scheme. This PR updates the docs to
reflect that.

Previously we said we would follow Semantic Versioning (see #1553).
However, we recently re-considered this approach, to better match the
conventions followed by vLLM and PyTorch.

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

Docs only, so no new tests are needed. Did not verify passing unit
tests.

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Documentation**
* Updated release versioning scheme to a "right-shifted" format
(major.minor.patch[.post1]) with an optional post-release suffix for
expedited follow-up releases.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Bugfix: Change get() -> GetDLTensorPtr() in cutlass FusedMoE validations (#1995)

Using different API after `apach-tvm-ffi` version bump.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Bug Fixes**
* Improved null-pointer validation for FP8 quantization tensors used
during inference, increasing robustness and reducing risk of runtime
errors.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amir Klein <[email protected]>

* unittest: Add SM arch checks to skip unsupported tests on Hopper (#1998)

<!-- .github/pull_request_template.md -->

A number of unit tests fail on Hopper because they either do not have a
support-check or fail based on "what is not supported" while missing
SM90. Current PR adds checks based on "what is supported" and skips if
not in the supported list of SMs.

Special case of `mm_fp4` where `mm_fp4.is_backend_supported(backend,
compute_capability_number)` now exists and is used to skip tests if not
supported.

Impacted tests:
* tests/attention/test_trtllm_gen_attention.py
* tests/attention/test_trtllm_gen_mla.py
* tests/gemm/test_bmm_fp8.py
* tests/gemm/test_mm_fp4.py
* tests/gemm/test_groupwise_scaled_gemm_fp8.py
* tests/gemm/test_groupwise_scaled_gemm_mxfp4.py
* tests/moe/test_trtllm_gen_fused_moe.py

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

* Added workspace check and reflected this in test (#1991)

<!-- .github/pull_request_template.md -->

This PR attempts to fix #1986 (to be confirmed by requester)

The issue is that num_tokens was larger than MAX_TOKEN_NUM, which
results in an IMA, or even in a hang. To address this, I added a
validation check. This required a non-breaking API change:
* create_ipc_workspace_for_all_reduce_fusion now has an optional
"create_metadata" bool, which results in an additional return value
  * it is made optional as additional return value could break the API
* trtllm_allreduce_fusion now takes an optional metadata dictionary
  * When provided, this will run the validation check
  * again, this is also optional, to avoid breaking the api

In addition this PR deprecates the older AllReduce functionality so it can be removed in a major version bump.

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **API Changes**
* Workspace creation can optionally return metadata describing the
workspace configuration (create_metadata flag).
* Allreduce fusion operations accept optional metadata to validate
runtime parameters against the workspace and raise clear errors on
mismatch.
  * A workspace destruction endpoint was renamed for naming consistency.
* Legacy wrappers were marked deprecated and now point users toward the
newer fusion variants.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* minor fix for xqa (#1994)

<!-- .github/pull_request_template.md -->

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
1 change xqa_mla comments to be consistent with mla instead of mha.
2 put cudaMemcpyFromSymbol/cudaFuncSetAttribute outside of launch
function to avoid breaking cuda graph capture
3 use int32 as pagetable index

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Added MLA variant documentation clarifying SM120 GPU requirement and
fixed head group ratio configuration.

* **Documentation**
* Updated data type specifications for XQA operations; page table now
requires int32 instead of uint32.
* Added max sequence length derivation notes for page-table-based
configurations.
* Clarified MLA variant input/output data types (float8_e4m3fn and
bfloat16).

* **Bug Fixes**
* Corrected data type handling in page table processing to ensure
compatibility.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Qidi Sang <[email protected]>
Co-authored-by: yzh119 <[email protected]>

* Feature: Add support for L40 FusedMoE in cutlass path (#1973)

Fixed a few compilation issues for L40, and removed 1 gemm tactic for
`sm == 89` that crashes due to:
```
Assertion failed: GPU lacks the shared memory resources to run GroupedGEMM kernel
```

Ran `pytest tests/moe/test_trtllm_cutlass_fused_moe.py` manually on an
L40 GPU and verified all tests passed.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Official support for SM89 target: build/JIT flags and a public
generation path to target it.

* **Bug Fixes / Compatibility**
* Clarified FP8/FP4 dispatch: FP8 paths enabled for SM89; FP4 usage
remains gated and now requires explicit enablement.

* **Performance**
* Adjusted kernel/tile selection order for certain FP8 paths to prefer
SM89-optimized options.

* **Chores**
  * Reduced logging severity for failed tactic profiling to warn/debug.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amir Klein <[email protected]>

* unittest: Add head dim 256 test cases and mark as xfail (#1999)

* feat: autotune tile_tokens_dim in trtllm-gen MOE (#1980)

<!-- .github/pull_request_template.md -->

- Update the autotune logic in trtllm-gen moe. Instead of using a fixed
`tile_tokens_dim`, tune in a range of
`[max(8,tile_token_dim/2), tile_token_dim, min(128, tile_token_dim*2),
min(128, tile_token_dim*4)]`
- Add FP8 MOE autotune logic, initial PR
https://github.com/flashinfer-ai/flashinfer/pull/1494 from @aleozlx,
update logic to sync with new autotuner.
- Update logic in `test_trtllm_gen_fused_moe.py`.
- Update the `conftest.py` to speed up test, previously use `try_first`
which introduce duplicate run
- Add log_once in logger
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Runtime autotuning with per-tile dynamic routing and selectable MoE
runner options (gated activation, shuffled-weight, weight-layout).
  * One-time (deduplicated) logging helpers added to JIT logger.

* **Deprecations**
* tile_tokens_dim removed from new paths and marked deprecated in legacy
entry points; new tuning parameters introduced for autotuning.

* **Tests**
* Tests refactored for autotuning/routing with new helpers and improved
handling/reporting for missing JIT cache.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: jiahanc <[email protected]>
Co-authored-by: yzh119 <[email protected]>

* Fix trtllm-gen attention illegal memory access (#2002)

<!-- .github/pull_request_template.md -->

This PR fixes illegal memory access of trtllm-gen attention kernels. It
changes the workspace buffer from `int_workspace_buffer` to
`float_workspace_buffer`. `int_workspace_buffer` is a fixed sized buffer
and not initialized to zero, which should not be used.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Issue #1928

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Bug Fixes**
* Fixed memory allocation in the decode module to improve computation
accuracy and stability during text generation.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* release: Bump version for v0.5.0rc1 release; (#2008)

<!-- .github/pull_request_template.md -->

Update version in `version.txt` to v0.5.0 as we prepare for v0.5.0rc1
release.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
  * Version bump to 0.5.0 (no functional changes)

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* bugfix: fix regex in update wheel index script (#2009)

<!-- .github/pull_request_template.md -->

The regex cannot recognize release candidates (`v0.5.0rc1`) or post
releases (`v1.2.3.post1`):
https://github.com/flashinfer-ai/flashinfer/actions/runs/18929490991/job/54049304551

This PR fixes the issue.

https://github.com/flashinfer-ai/flashinfer/actions/runs/18929490991/job/54049304551

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
* Enhanced version string parsing in the wheel package indexing process
to support more complex version formats, including pre-release,
post-release, and development versions, ensuring compatibility with PEP
440 versioning standards.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* fix: Enable SM121 for mm_fp4 (#2012)

<!-- .github/pull_request_template.md -->

In #1809 we previously added a compute-capability-based support check
for `mm_fp4`.

However, we missed enabling SM121 for backend = `cudnn` and  `cutlass`.
Additionally, we marked `trtllm` as supported on SM120 when it is not.

Current PR fixes it. Example benchmark and pytest command on SM121 after
the fix
```
(py312) root@f414f262f02a:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 7168 --k 512 --out_dtype bfloat16 --backends cudnn cutlass --use_128x4_sf_layout --use_nvfp4 --refcheck --use_cupti
/opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning:
    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)

  warnings.warn(
[PERF] cudnn          :: median time 0.656 ms; std 0.025 ms; achieved tflops 91.701 TFLOPs/sec; achieved tb_per_sec 0.185 TB/sec
[PERF] cutlass        :: median time 0.669 ms; std 0.022 ms; achieved tflops 89.859 TFLOPs/sec; achieved tb_per_sec 0.181 TB/sec

(py312) root@f414f262f02a:/flashinfer# pytest tests/gemm/test_mm_fp4.py
====================================================================================================================== test session starts ======================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 3240 items
...
======================================================================================================================= warnings summary ========================================================================================================================
../opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285
  /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning:
      Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
      Minimum and Maximum cuda capability supported by this version of PyTorch is
      (8.0) - (12.0)

    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================================= 450 passed, 2790 skipped, 1 warning in 8.24s ==========================================================================================================

```

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Expanded hardware compatibility by adding support for newer NVIDIA GPU
architectures.
* FP4 quantized operations now available across multiple backends on
supported devices.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* fix: ensure SM120/121 SFA/SFB contiguity (#1963)

<!-- .github/pull_request_template.md -->

Fix the regression in vLLM and SGLang with FI 0.4.0 in bmm_fp8

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

cc: @yzh119

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Bug Fixes**
* Fixed memory layout handling for tensor operations in GPU computations
to ensure proper alignment, improving stability and performance.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* More realistic bench for POD Attn (#2013)

<!-- .github/pull_request_template.md -->

Use real head sizes, seq lens and add comparison with sequential prefill
+ decode.
Results on H100 (without overlap, which only adds ~150GB/s for
persistent):
<img width="433" height="571" alt="image"
src="https://github.com/user-attachments/assets/50de01cd-e5ca-450c-9cc0-521d83b7e487"
/>
cc @yzh119

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Added comprehensive performance benchmarking for batch attention
operations with detailed timing measurements.
* Introduced sequential dual-kernel benchmark path with extended memory
bandwidth reporting.

* **Tests**
* Updated benchmark test configurations to use deterministic, fixed
values for improved reproducibility.
* Adjusted benchmark parameters for consistency across test iterations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Feature: Support non-gated activation in cutlass fused MoE nvfp4 (#2011)

This PR removes an assertion in the cutlass fused moe bindings to enable
non-gated activations in nvfp4.
It also adds a test for this path with relu2 activation.

N/A

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [v] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [v] I have installed the hooks with `pre-commit install`.
- [v] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [v] Tests have been added or updated as needed.
- [v] All tests are passing (`unittest`, etc.).

N/A

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Enhanced quantized Mixture of Experts models to support configurable
activation types (Swiglu and ReLU2) in the NVFP4 quantization path.
* Improved parameter handling to correctly adapt weight shapes and
quantization settings based on the selected activation type.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Omer Ullman Argov <[email protected]>

* feat: add xqa backend and completes NHD/HND coverage for trtllm-gen/xqa backend (#2001)

<!-- .github/pull_request_template.md -->

Expose xqa backend to trtllm attention interface, and improve layout
coverage of trtllm-gen and xqa backends.

Now both trtllm-gen/xqa supports NHD/HND kv-cache layout.
* support NHD layout for trtllm-gen
* refactor xqa
(https://github.com/flashinfer-ai/flashinfer/commit/869c0c1c6bc199f82f30c23ab78a1b4aa9a1bd3a)
    * allow user passed stride_page/head/token
    * support both HND and NHD
    * remove macros such as PAGED_KV_CACHE_LAYOUT and USE_PAGED_KV_CACHE
* adding unittests for both trtllm-gen/xqa on NHD/HND
* adding unified API for trtllm-gen/xqa, and unified unittest

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Added xqa-based batch decode API and public kv_layout option
(NHD/HND); added enable_pdl toggle to inference wrappers.

* **Improvements**
* Automatic backend selection for decoding, consistent KV-layout
normalization across paths, and unified stride-aware paged-KV handling
with layout-aware shapes, scales, and workspace handling.

* **Tests**
* Expanded tests to cover both KV layouts, enable_pdl, new batch-decode
workflows, backend/layout permutations, and fp8/mixed-dtype scenarios.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Qidi Sang <[email protected]>
Co-authored-by: yzh119 <[email protected]>
Co-authored-by: Zihao Ye <[email protected]>

* test: Enable xfailed trtllm decode long seqlen tests and update microbenchmark (#2018)

<!-- .github/pull_request_template.md -->

[tests/attention/test_trtllm_gen_attention.py](https://github.com/flashinfer-ai/flashinfer/blob/v0.5.0rc2/tests/attention/test_trtllm_gen_attention.py#L1021-L1076)
was failing and therefore marked xfail.

PR #2002 fixed the underlying root cause. Current PR thus removed the
`xfail` marker so that these long seqlen cases could be fixed moving
forward.

Additionally, PR #2002 revealed a bug in the microbenchmark script where
[trtllm_batch_decode_with_kv_cache](https://github.com/flashinfer-ai/flashinfer/blob/v0.5.0rc2/flashinfer/decode.py#L2082-L2083)
explicitly requires the workspace to be zeroed before first use:
```
    workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use.
        workspace
```
while the microbenchmark code does not zero out, causing undefined
behavior such as IMAs that depend on the ordering of backends tested.
Current PR fixes the issue by explicitly calling
`workspace_buffer.zero_()` between testing different backends.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Bug Fixes**
* Improved stability of performance benchmarks by properly resetting
workspace buffer between backend invocations.

* **Tests**
  * Enabled previously skipped test for long sequence length handling.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Updated decorator to support unspecified default (#2026)

<!-- .github/pull_request_template.md -->

Updated decorator to support unspecified default. This was causing
issues when calling mm_fp4 without backend specified.
Also added SM 110 as a supported backend on the cutlass backend (mm_fp4)

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
  * FP4 Cutlass GEMM now supports the SM110 GPU compute capability.

* **Bug Fixes**
* Kernels called without an explicit backend now consistently use the
default backend.

* **Tests**
* Added a unit test to verify default backend selection and correct
results when backend is omitted.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* release: Bump version for v0.5.1 release (#2031)

<!-- .github/pull_request_template.md -->

Update `version.txt`

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
  * Version updated to 0.5.1

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* ci: Update cudnn version requirements in CI container (#2039)

<!-- .github/pull_request_template.md -->

cuDNN versions specified in CI container setup
(`docker/install/install_python_packages.sh`) are currently 9.11 and
9.12.

In unit testing, this causes issues as `mm_fp4(backend='cudnn')` is not
supported on Spark (sm121) for older cuDNN versions in cu130.

Failure is due to cuDNN version shipped with container being too old. In
the [latest container build pipeline
output](https://github.com/flashinfer-ai/flashinfer/actions/runs/18778064727/job/53577233568#step:6:727),
cudnn 9.13.0.50 is installed
```
```

Current PR updates the minimum cudnn version for both
[cu12](https://pypi.org/project/nvidia-cudnn-cu12/#history) and
[cu13](https://pypi.org/project/nvidia-cudnn-cu13/#history) to
9.14.0.64.

cudnn 9.13 --> unit test fails with 180 failed, 270 passed, 2790
skipped, 1 warning in 8.97s
```
=================================================================================================================================================== test session starts ===================================================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 3240 items
...
FAILED tests/gemm/test_mm_fp4.py::test_mm_fp4[mxfp4_alpha-False-True-cudnn-res_dtype1-512-512-256] - cudnn._compiled_module.cudnnGraphNotSupportedError: No valid engine configs for Matmul_MUL_
FAILED tests/gemm/test_mm_fp4.py::test_mm_fp4[mxfp4_alpha-False-True-cudnn-res_dtype1-512-512-512] - cudnn._compiled_module.cudnnGraphNotSupportedError: No valid engine configs for Matmul_MUL_
================================================================================================================================ 180 failed, 270 passed, 2790 skipped, 1 warning in 8.97s =================================================================================================================================

```
cudnn 9.14 --> unit test passes with 450 passed, 2790 skipped, 1 warning
in 5.37s
```
=================================================================================================================================================== test session starts ===================================================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 3240 items

tests/gemm/test_mm_fp4.py
...
====================================================================================================================================== 450 passed, 2790 skipped, 1 warning in 5.37s =======================================================================================================================================

```

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, e…
murphymatt added a commit to fw-ai/flashinfer that referenced this pull request Nov 16, 2025
* chore: rename FLASHINFER_JIT_VERBOSE to FLASHINFER_JIT_DEBUG for clarity (#1946)

<!-- .github/pull_request_template.md -->

Rename environment variable `FLASHINFER_JIT_VERBOSE` to
`FLASHINFER_JIT_DEBUG` to better reflect its actual behavior.

- `FLASHINFER_JIT_DEBUG`: Enable debug mode during compilation (disable
optimization, add debug symbols)
- The previous name `FLASHINFER_JIT_VERBOSE` implied "showing more
compilation info", which was confusing
- Maintained backward compatibility: falls back to
`FLASHINFER_JIT_VERBOSE` if `FLASHINFER_JIT_DEBUG` is not set

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Refactor**
* Introduced FLASHINFER_JIT_DEBUG environment variable for controlling
JIT debug builds with backward compatibility for legacy
FLASHINFER_JIT_VERBOSE.
* Enhanced debug build configuration with improved compiler and CUDA
debugging flags. Non-debug builds continue using -O3 optimizations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* fix: Fix trtllm-gen prefill IMA when batch_size==1 (#1912)

<!-- .github/pull_request_template.md -->

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Current PR fixes the test and benchmark codes IMAs when running
trtllm-gen paged & ragged prefill with batch size 1 -- the issue was
described in https://github.com/flashinfer-ai/flashinfer/issues/1898

Root cause of the issue:
`flashinfer.prefill.trtllm_ragged_attention_deepseek` and
`flashinfer.prefill.trtllm_batch_context_with_kv_cache` both require
`max_q_len` to match the length of the query when batch size is 1.

**Updated PR:**
Issue has been addressed from the kernel-side so that the "*`max_q_len`
to match the length of the query when batch size is 1*" is no longer
required.

Current PR updates trtllm-gen FMHA cubins to latest and brings minor
updates to kernel metadata.

Unit test results after PR:
```
$ pytest tests/attention/test_trtllm_gen_attention.py
...
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 2320 items
...
2055 passed, 264 skipped, 1 xfailed in 224.43s (0:03:44)
```

**Description of previous solution:**
~~Updating `max_q_len` to `cum_seq_lens_q[-1].item()` within the
`trtllm_ragged_attention_deepseek` or
`trtllm_batch_context_with_kv_cache` functions are not a viable option
because the CPU-side synchronization breaks the deterministic and fully
device-side execution required during CUDA graph capture. The workaround
was thus to update the test & benchmark codes that call the trtllm
prefill functions, and clearly state in the docstring that when
batch_size == 1, max_q_len must match the query size.~~

https://github.com/flashinfer-ai/flashinfer/issues/1898

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Bug Fixes**
* Removed the automatic batch_size=1 restriction for a native backend,
enabling its use in more scenarios while other constraints remain.

* **New Features**
* Added configurable block-sparse attention support to kernel
parameters.

* **Documentation**
* Clarified supported attention optimizations and backend capabilities
in the benchmarks docs.

* **Tests**
* Expanded tests with configurable sequence lengths and added dedicated
batch-size-1 test coverage.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Zihao Ye <[email protected]>

* Feature: Support Relu2 activation in fused MoE (#1954)

Added support for Relu2 activation in cutlass fp8 FusedMoE path.
`Relu2(x) = Relu(x)^2`.

Validated this works correctly on H100 and B200.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Added Relu2 as a selectable activation across MOE operations and
exposed activation_type configuration to public MOE APIs and runner
interfaces (Swiglu remains the default).
* **Behavior**
* Certain GEMM execution paths now explicitly reject Relu2 and raise a
clear runtime error instead of silently failing.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amir Klein <[email protected]>

* fix: Add cutlass as an mm_fp4 backend in compute capability 12.0 in benchmark code (#1959)

<!-- .github/pull_request_template.md -->

Previously `backend='cutlass'` was not available to be benchmarked in
`flashinfer_benchmark.py` for compute capability 12.0 while the kernel
actually has been available. Current PR marks the backend as available.

Example output of being runnable after PR:
```
[INFO] args = Namespace(routine='mm_fp4', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=False, repro_command='', batch_size=1, m=1024, n=7168, k=512, tile_size=128, group_size=1, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cudnn', 'cutlass', 'trtllm'], use_128x4_sf_layout=True, use_nvfp4=True, autotune=False)
[INFO] Running testMmFp4
[INFO] FlashInfer version: 0.4.1
[VVERBOSE] gpu_name = 'NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition'
[WARNING] trtllm for routine mm_fp4 is not supported on compute capability 12.0. Skipping.
[VVERBOSE] input_fp4.shape = torch.Size([1024, 256])
[VVERBOSE] input_fp4.dtype = torch.uint8
[VVERBOSE] mat2_fp4.shape = torch.Size([7168, 256])
[VVERBOSE] mat2_fp4.dtype = torch.uint8
[PERF] cudnn          :: median time 0.014 ms; std 0.000 ms; achieved tflops 535.891 TFLOPs/sec; achieved tb_per_sec 1.196 TB/sec
[PERF] cutlass        :: median time 0.015 ms; std 0.000 ms; achieved tflops 515.203 TFLOPs/sec; achieved tb_per_sec 1.150 TB/sec
```

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
* Expanded backend support for benchmarking routines on compute
capability 12.0, adding compatibility with additional processing
backends.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* rebase on fw repo branch

* unittest: fix deepgemm sha256 (#1953)

<!-- .github/pull_request_template.md -->

Deepgemm unittest failed because of out-dated sha256, this PR fixes the
issue.

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
* Updated internal artifact version information to support latest
optimizations and improvements.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* misc: Update artifacts docstring and MetaInfoHash (#1967)

<!-- .github/pull_request_template.md -->

Amendment to [PR
1761](https://github.com/flashinfer-ai/flashinfer/pull/1761), appending
docstring to two artifactory path classes and deprecating need to update
MetaInfoHash by directly accessing the checksum.txt file.

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Added runtime integrity checks for compiled artifacts that verify and
use checksum data during loading to prevent missing or mismatched
artifact headers.

* **Refactor**
* Switched artifact hash resolution to compute hashes dynamically from
provided checksums, improving validation, reliability, and resilience
when loading precompiled components.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* silu_and_mul nvfp4 quanization fusion rework (#1927)

<!-- .github/pull_request_template.md -->

This PR reverts https://github.com/flashinfer-ai/flashinfer/pull/1774
and https://github.com/flashinfer-ai/flashinfer/pull/1835 which have
some issues with some shapes under cuda graph. The kernels ported in
this PR comes from SGLANG. [[NVIDA] [1/N] Nvfp4 Masked Gemm: Add quant
op for the flashinfer grouped
gemm](https://github.com/sgl-project/sglang/pull/9200/files) and
[[NVIDIA] [2/N] Optimize silu_and_mul_scaled_fp4_grouped_quant
perf](https://github.com/sgl-project/sglang/pull/9556/files) by @kaixih
.

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
- Added grouped FP4 quantization (scaled_fp4_grouped_quantize) and an
NV-focused Silu+Mul expert quantization entry
(silu_and_mul_scaled_nvfp4_experts_quantize).

* **API Changes**
- Replaced legacy batched APIs with new expert/grouped APIs; removed
legacy mask parameter from FP4/MXFP8 quantization signatures and
adjusted FP4 output layouts/types.

* **Documentation**
  - Updated docs to list new functions and remove deprecated symbols.

* **Tests**
- Updated tests to validate new quantization paths, shapes, dtypes, and
layouts.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Shu Wang. <[email protected]>

* unittest: fix test_artifacts.py (#1950)

* chore: update the list of authorized codeowners (#1970)

<!-- .github/pull_request_template.md -->

Add @djmmoss @jiahanc to the authorized codeowner list.

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
  * Updated internal codeowner authorization configuration.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Added heuristic for trtllm_allreduce_fusion (#1972)

<!-- .github/pull_request_template.md -->

The original heuristic does not accurately reflect the performance of
oneshot/twoshot. Updated with heuristics based on this benchmark
[allreduce_test.py](https://github.com/user-attachments/files/23094671/allreduce_test.py).
The benchmark uses hidden_dim of Llama3, LLama4 and GPT-OSS and
combinations of token_num, fusion patterns and fp32_acc.

The results are at the bottom. TL;DR token_num is a bad predictor of
whether to use oneshot or twoshot. Using the communication size of
oneshot is a good predictor, but only if we treat each TP separately.
Fusion patterns and fp32_acc is irrelevant to the choice.

<img width="1800" height="3600" alt="comm_size_TP=2"
src="https://github.com/user-attachments/assets/2874157e-6268-421a-8f45-00491b652702"
/>
<img width="1800" height="3600" alt="comm_size_TP=4"
src="https://github.com/user-attachments/assets/2cdfdb9d-569e-401b-89ad-787f8d755ac1"
/>
<img width="1800" height="3600" alt="comm_size_TP=8"
src="https://github.com/user-attachments/assets/fbb147da-3479-4dbc-85b8-c27a735d0cd6"
/>

<img width="1800" height="3600" alt="comm_size_Enlarge_TP=2"
src="https://github.com/user-attachments/assets/e070c81f-edf9-4d7f-ab95-fa6dea9f42f2"
/>
<img width="1800" height="3600" alt="comm_size_Enlarge_TP=4"
src="https://github.com/user-attachments/assets/3b1c51d2-56ca-4d34-9bfd-8082390cc95e"
/>
<img width="1800" height="3600" alt="comm_size_Enlarge_TP=8"
src="https://github.com/user-attachments/assets/9a8095b4-11bc-4021-80c6-f2be69b33021"
/>

<img width="1800" height="3600" alt="comm_size_TP=248"
src="https://github.com/user-attachments/assets/66956ebe-6cf0-43e8-93ce-950b1079148a"
/>
<img width="1800" height="3600" alt="comm_size_Enlarge_TP=248"
src="https://github.com/user-attachments/assets/0cd6982c-da42-4f42-b0ad-5ef564b2e78e"
/>

<img width="1800" height="3600" alt="token_num_TP=248"
src="https://github.com/user-attachments/assets/2968ca7c-2059-4305-8e4d-5b70a32faaee"
/>
<img width="1800" height="3600" alt="token_num_Enlarge_TP=248"
src="https://github.com/user-attachments/assets/881ba86d-fc71-4cbc-b5a6-c050f255d618"
/>

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

---------

Co-authored-by: yzh119 <[email protected]>

* Bump tvm ffi to stable version 0.1.0 (#1960)

<!-- .github/pull_request_template.md -->

This PR bumps the tvm-ffi to stable version 0.1.0 and update the
flashinfer code base.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

https://github.com/flashinfer-ai/flashinfer/pull/1939

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
* Relaxed build dependency pins for apache-tvm-ffi and setuptools across
project configs; removed installation of multiple build packages from
the nightly CI step.
* **Refactor**
* Modernized internal CUDA/tensor access patterns to a consistent
accessor API across many modules.
* **Bug Fixes**
* GEMM runner now returns the output tensor in the correct
(non‑transposed) orientation.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Zihao Ye <[email protected]>
Co-authored-by: yzh119 <[email protected]>

* Update Docker CI tags to 20251024-0e48aaf (#1975)

This PR updates the Docker CI image tags to the latest version:
`20251024-0e48aaf`

Updated images:
- flashinfer/flashinfer-ci-cu126:20251024-0e48aaf
- flashinfer/flashinfer-ci-cu128:20251024-0e48aaf
- flashinfer/flashinfer-ci-cu129:20251024-0e48aaf
- flashinfer/flashinfer-ci-cu130:20251024-0e48aaf

Auto-generated by [release-ci-docker
workflow](https://github.com/flashinfer-ai/flashinfer/actions/runs/18778064727)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
* Updated CI/CD Docker image configurations to ensure consistency and
reliability across build environments.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: yzh119 <[email protected]>

* fix: Make attention microbenchmark correctly use page table (#1976)

<!-- .github/pull_request_template.md -->

Current microbenchmark code does not provides instantiated
`block_tables` to all backends. The omission had no impact to
correctness or perf because page tables are instantiated linearly when
not provided, but will manifest as mismatches if it is shuffled.

The current PR simply calls the FlashInfer APIs in their intended way.

**No changes to library code**

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Refactor**
* Enhanced consistency in attention computation by aligning page-table
parameter handling across different inference backend implementations
for improved paged key-value cache operations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* fix: Skipping attention sink Blackwell test outside of Blackwell (#1978)

<!-- .github/pull_request_template.md -->

`test_attention_sink_blackwell.py` checks
`flashinfer.prefill.trtllm_batch_context_with_kv_cache` and
`flashinfer.decode.trtllm_batch_decode_with_kv_cache` which are only
supported on Blackwell SM100 and SM103.

Existing check only skips testing of SM 11x or 12x, which causes
failures on Hopper SM90.

Test outputs:
* H200:
   * Before Fix: `144 failed, 1 warning in 9.20s`
   * After Fix: `144 skipped, 1 warning in 0.42s`
* B200:
   * After Fix: `144 passed in 34.64s `

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Tests**
* Updated GPU compatibility checks for attention sink tests to target
specific GPU architectures (SM100/SM103). Tests now run exclusively on
supported GPU models with updated filtering criteria.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* feat: enable deepgemm jit for fp8 block-scale on SM90 (#1969)

<!-- .github/pull_request_template.md -->

Enable JIT compile for the FP8 DeepGEMM kernels, NVRTC is currently
disabled it uses NVCC by default.

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Refactor**
* JIT include directory discovery now uses the flashinfer-python package
instead of the previous package.
  * Updated resolved include path to the flashinfer data location.
* Runtime compilation now consistently uses NVCC; the prior
environment-variable toggle was removed.
* Updated warning text when the expected package installation cannot be
found.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Duncan Moss <[email protected]>

* chore: Update CODEOWNERS (#1949)

This PR updates the CODEOWNERS file based on git commit history analysis
from the last 180 days.

- Updated `.github/CODEOWNERS` with current code ownership based on:
  - Commit frequency
  - File coverage
  - Commit recency

1. Review the changes to `.github/CODEOWNERS`
2. Verify that the assigned owners are appropriate for each module
3. Make manual adjustments if needed before merging

- This is an automated PR generated weekly
- Minimum commits threshold: 1
- Analysis period: 180 days
- Directory depth: 3 levels
- Top N owners per module: 5

---

🤖 This PR was automatically generated by the [update-codeowners
workflow](.github/workflows/update-codeowners.yml)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
  * Updated internal code ownership assignments.

---

**Note:** This update contains no user-facing changes or feature
updates. It is an internal administrative modification.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: flashinfer-bot <[email protected]>
Co-authored-by: Claude <[email protected]>

* fix: correct PDL parameter handling in RopeQuantize kernel (#1982)

<!-- .github/pull_request_template.md -->

- **Issue**: The `stream` parameter was being passed to the wrong
position in the `RopeQuantize` function call due to missing `enable_pdl`
parameter. SGLang will hang before this pr.
- **Fix**: Added the `enable_pdl` parameter to the function signature
and properly aligned all parameters

- **Issue**: When `enable_pdl=true`, the kernel would throw CUDA errors
due to incorrect PDL attribute handling
- **Fix**: Aligned the implementation with `csrc/fmhaReduction.cu`.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Added PDL (Programmatic Dynamic Launch) benchmarking capability for
rope quantization operations.
* Extended configuration options to enable or disable PDL functionality.

* **Tests**
* Updated test suite to validate PDL enabled and disabled scenarios in
rope quantization workflows.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Fix: Verify scales are not None for Cutlass FP8 FusedMoE (#1961)

Verify quant scales for fp8 are non null in cutlass FusedMoE path.
Currently, if these tensors are passed as None from python it will
result in segmentation fault.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Bug Fixes**
* Enhanced validation for FP8 quantization parameters to improve system
robustness and prevent potential null reference issues during
quantization operations, reducing the risk of runtime errors when
processing quantized model data.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amir Klein <[email protected]>

* add xqa fp8 mha and fp8 kv cache (#1769)

<!-- .github/pull_request_template.md -->

Add xqa fp8 mha and fp8 kv cache. Add fp8 mla for sm120. Use vllm kv
layout.

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
  * MLA-based attention path and dedicated MLA entrypoints (SM120/121)
* FP8 KV-cache support with optional paged KV layout and separate K/V
cache inputs
* Asynchronous tensor-map/TMA and matrix-descriptor primitives for
high-throughput GPU transfers
  * Dtype-driven config and expanded GPU SM gating for builds/runtimes

* **Bug Fixes**
  * Improved numerical stability for attention mask initialization

* **Tests**
  * Expanded coverage for MLA, FP8, FP16/BF16, and new cache layouts

* **Documentation**
  * Added XQA API docs and new public symbols
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Qidi Sang <[email protected]>
Co-authored-by: yzh119 <[email protected]>

* unittest: fix failed unittest on hopper (#1952)

<!-- .github/pull_request_template.md -->

Some invalid configuration are generated in JIT warmup (mixed precision)
function `gen_prefill_attention_modules`.

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Tests**
* Updated test infrastructure to enhance compatibility handling for
specific hardware acceleration scenarios, improving test robustness for
mixed-precision configurations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* docs: Update documented versioning scheme to right-shifted semver (#1990)

<!-- .github/pull_request_template.md -->

Based on discussion with @yzh119 and others, we're planning to follow
the vLLM "right-shifted" versioning scheme. This PR updates the docs to
reflect that.

Previously we said we would follow Semantic Versioning (see #1553).
However, we recently re-considered this approach, to better match the
conventions followed by vLLM and PyTorch.

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

Docs only, so no new tests are needed. Did not verify passing unit
tests.

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Documentation**
* Updated release versioning scheme to a "right-shifted" format
(major.minor.patch[.post1]) with an optional post-release suffix for
expedited follow-up releases.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Bugfix: Change get() -> GetDLTensorPtr() in cutlass FusedMoE validations (#1995)

Using different API after `apach-tvm-ffi` version bump.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Bug Fixes**
* Improved null-pointer validation for FP8 quantization tensors used
during inference, increasing robustness and reducing risk of runtime
errors.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amir Klein <[email protected]>

* unittest: Add SM arch checks to skip unsupported tests on Hopper (#1998)

<!-- .github/pull_request_template.md -->

A number of unit tests fail on Hopper because they either do not have a
support-check or fail based on "what is not supported" while missing
SM90. Current PR adds checks based on "what is supported" and skips if
not in the supported list of SMs.

Special case of `mm_fp4` where `mm_fp4.is_backend_supported(backend,
compute_capability_number)` now exists and is used to skip tests if not
supported.

Impacted tests:
* tests/attention/test_trtllm_gen_attention.py
* tests/attention/test_trtllm_gen_mla.py
* tests/gemm/test_bmm_fp8.py
* tests/gemm/test_mm_fp4.py
* tests/gemm/test_groupwise_scaled_gemm_fp8.py
* tests/gemm/test_groupwise_scaled_gemm_mxfp4.py
* tests/moe/test_trtllm_gen_fused_moe.py

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

* Added workspace check and reflected this in test (#1991)

<!-- .github/pull_request_template.md -->

This PR attempts to fix #1986 (to be confirmed by requester)

The issue is that num_tokens was larger than MAX_TOKEN_NUM, which
results in an IMA, or even in a hang. To address this, I added a
validation check. This required a non-breaking API change:
* create_ipc_workspace_for_all_reduce_fusion now has an optional
"create_metadata" bool, which results in an additional return value
  * it is made optional as additional return value could break the API
* trtllm_allreduce_fusion now takes an optional metadata dictionary
  * When provided, this will run the validation check
  * again, this is also optional, to avoid breaking the api

In addition this PR deprecates the older AllReduce functionality so it can be removed in a major version bump.

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **API Changes**
* Workspace creation can optionally return metadata describing the
workspace configuration (create_metadata flag).
* Allreduce fusion operations accept optional metadata to validate
runtime parameters against the workspace and raise clear errors on
mismatch.
  * A workspace destruction endpoint was renamed for naming consistency.
* Legacy wrappers were marked deprecated and now point users toward the
newer fusion variants.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* minor fix for xqa (#1994)

<!-- .github/pull_request_template.md -->

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
1 change xqa_mla comments to be consistent with mla instead of mha.
2 put cudaMemcpyFromSymbol/cudaFuncSetAttribute outside of launch
function to avoid breaking cuda graph capture
3 use int32 as pagetable index

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Added MLA variant documentation clarifying SM120 GPU requirement and
fixed head group ratio configuration.

* **Documentation**
* Updated data type specifications for XQA operations; page table now
requires int32 instead of uint32.
* Added max sequence length derivation notes for page-table-based
configurations.
* Clarified MLA variant input/output data types (float8_e4m3fn and
bfloat16).

* **Bug Fixes**
* Corrected data type handling in page table processing to ensure
compatibility.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Qidi Sang <[email protected]>
Co-authored-by: yzh119 <[email protected]>

* Feature: Add support for L40 FusedMoE in cutlass path (#1973)

Fixed a few compilation issues for L40, and removed 1 gemm tactic for
`sm == 89` that crashes due to:
```
Assertion failed: GPU lacks the shared memory resources to run GroupedGEMM kernel
```

Ran `pytest tests/moe/test_trtllm_cutlass_fused_moe.py` manually on an
L40 GPU and verified all tests passed.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Official support for SM89 target: build/JIT flags and a public
generation path to target it.

* **Bug Fixes / Compatibility**
* Clarified FP8/FP4 dispatch: FP8 paths enabled for SM89; FP4 usage
remains gated and now requires explicit enablement.

* **Performance**
* Adjusted kernel/tile selection order for certain FP8 paths to prefer
SM89-optimized options.

* **Chores**
  * Reduced logging severity for failed tactic profiling to warn/debug.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amir Klein <[email protected]>

* unittest: Add head dim 256 test cases and mark as xfail (#1999)

* feat: autotune tile_tokens_dim in trtllm-gen MOE (#1980)

<!-- .github/pull_request_template.md -->

- Update the autotune logic in trtllm-gen moe. Instead of using a fixed
`tile_tokens_dim`, tune in a range of
`[max(8,tile_token_dim/2), tile_token_dim, min(128, tile_token_dim*2),
min(128, tile_token_dim*4)]`
- Add FP8 MOE autotune logic, initial PR
https://github.com/flashinfer-ai/flashinfer/pull/1494 from @aleozlx,
update logic to sync with new autotuner.
- Update logic in `test_trtllm_gen_fused_moe.py`.
- Update the `conftest.py` to speed up test, previously use `try_first`
which introduce duplicate run
- Add log_once in logger
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Runtime autotuning with per-tile dynamic routing and selectable MoE
runner options (gated activation, shuffled-weight, weight-layout).
  * One-time (deduplicated) logging helpers added to JIT logger.

* **Deprecations**
* tile_tokens_dim removed from new paths and marked deprecated in legacy
entry points; new tuning parameters introduced for autotuning.

* **Tests**
* Tests refactored for autotuning/routing with new helpers and improved
handling/reporting for missing JIT cache.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: jiahanc <[email protected]>
Co-authored-by: yzh119 <[email protected]>

* Fix trtllm-gen attention illegal memory access (#2002)

<!-- .github/pull_request_template.md -->

This PR fixes illegal memory access of trtllm-gen attention kernels. It
changes the workspace buffer from `int_workspace_buffer` to
`float_workspace_buffer`. `int_workspace_buffer` is a fixed sized buffer
and not initialized to zero, which should not be used.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Issue #1928

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Bug Fixes**
* Fixed memory allocation in the decode module to improve computation
accuracy and stability during text generation.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* release: Bump version for v0.5.0rc1 release; (#2008)

<!-- .github/pull_request_template.md -->

Update version in `version.txt` to v0.5.0 as we prepare for v0.5.0rc1
release.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
  * Version bump to 0.5.0 (no functional changes)

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* bugfix: fix regex in update wheel index script (#2009)

<!-- .github/pull_request_template.md -->

The regex cannot recognize release candidates (`v0.5.0rc1`) or post
releases (`v1.2.3.post1`):
https://github.com/flashinfer-ai/flashinfer/actions/runs/18929490991/job/54049304551

This PR fixes the issue.

https://github.com/flashinfer-ai/flashinfer/actions/runs/18929490991/job/54049304551

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
* Enhanced version string parsing in the wheel package indexing process
to support more complex version formats, including pre-release,
post-release, and development versions, ensuring compatibility with PEP
440 versioning standards.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* fix: Enable SM121 for mm_fp4 (#2012)

<!-- .github/pull_request_template.md -->

In #1809 we previously added a compute-capability-based support check
for `mm_fp4`.

However, we missed enabling SM121 for backend = `cudnn` and  `cutlass`.
Additionally, we marked `trtllm` as supported on SM120 when it is not.

Current PR fixes it. Example benchmark and pytest command on SM121 after
the fix
```
(py312) root@f414f262f02a:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 7168 --k 512 --out_dtype bfloat16 --backends cudnn cutlass --use_128x4_sf_layout --use_nvfp4 --refcheck --use_cupti
/opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning:
    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)

  warnings.warn(
[PERF] cudnn          :: median time 0.656 ms; std 0.025 ms; achieved tflops 91.701 TFLOPs/sec; achieved tb_per_sec 0.185 TB/sec
[PERF] cutlass        :: median time 0.669 ms; std 0.022 ms; achieved tflops 89.859 TFLOPs/sec; achieved tb_per_sec 0.181 TB/sec

(py312) root@f414f262f02a:/flashinfer# pytest tests/gemm/test_mm_fp4.py
====================================================================================================================== test session starts ======================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 3240 items
...
======================================================================================================================= warnings summary ========================================================================================================================
../opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285
  /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning:
      Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
      Minimum and Maximum cuda capability supported by this version of PyTorch is
      (8.0) - (12.0)

    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================================= 450 passed, 2790 skipped, 1 warning in 8.24s ==========================================================================================================

```

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Expanded hardware compatibility by adding support for newer NVIDIA GPU
architectures.
* FP4 quantized operations now available across multiple backends on
supported devices.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* fix: ensure SM120/121 SFA/SFB contiguity (#1963)

<!-- .github/pull_request_template.md -->

Fix the regression in vLLM and SGLang with FI 0.4.0 in bmm_fp8

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

cc: @yzh119

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Bug Fixes**
* Fixed memory layout handling for tensor operations in GPU computations
to ensure proper alignment, improving stability and performance.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* More realistic bench for POD Attn (#2013)

<!-- .github/pull_request_template.md -->

Use real head sizes, seq lens and add comparison with sequential prefill
+ decode.
Results on H100 (without overlap, which only adds ~150GB/s for
persistent):
<img width="433" height="571" alt="image"
src="https://github.com/user-attachments/assets/50de01cd-e5ca-450c-9cc0-521d83b7e487"
/>
cc @yzh119

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Added comprehensive performance benchmarking for batch attention
operations with detailed timing measurements.
* Introduced sequential dual-kernel benchmark path with extended memory
bandwidth reporting.

* **Tests**
* Updated benchmark test configurations to use deterministic, fixed
values for improved reproducibility.
* Adjusted benchmark parameters for consistency across test iterations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Feature: Support non-gated activation in cutlass fused MoE nvfp4 (#2011)

This PR removes an assertion in the cutlass fused moe bindings to enable
non-gated activations in nvfp4.
It also adds a test for this path with relu2 activation.

N/A

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [v] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [v] I have installed the hooks with `pre-commit install`.
- [v] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [v] Tests have been added or updated as needed.
- [v] All tests are passing (`unittest`, etc.).

N/A

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Enhanced quantized Mixture of Experts models to support configurable
activation types (Swiglu and ReLU2) in the NVFP4 quantization path.
* Improved parameter handling to correctly adapt weight shapes and
quantization settings based on the selected activation type.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Omer Ullman Argov <[email protected]>

* feat: add xqa backend and completes NHD/HND coverage for trtllm-gen/xqa backend (#2001)

<!-- .github/pull_request_template.md -->

Expose xqa backend to trtllm attention interface, and improve layout
coverage of trtllm-gen and xqa backends.

Now both trtllm-gen/xqa supports NHD/HND kv-cache layout.
* support NHD layout for trtllm-gen
* refactor xqa
(https://github.com/flashinfer-ai/flashinfer/commit/869c0c1c6bc199f82f30c23ab78a1b4aa9a1bd3a)
    * allow user passed stride_page/head/token
    * support both HND and NHD
    * remove macros such as PAGED_KV_CACHE_LAYOUT and USE_PAGED_KV_CACHE
* adding unittests for both trtllm-gen/xqa on NHD/HND
* adding unified API for trtllm-gen/xqa, and unified unittest

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Added xqa-based batch decode API and public kv_layout option
(NHD/HND); added enable_pdl toggle to inference wrappers.

* **Improvements**
* Automatic backend selection for decoding, consistent KV-layout
normalization across paths, and unified stride-aware paged-KV handling
with layout-aware shapes, scales, and workspace handling.

* **Tests**
* Expanded tests to cover both KV layouts, enable_pdl, new batch-decode
workflows, backend/layout permutations, and fp8/mixed-dtype scenarios.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Qidi Sang <[email protected]>
Co-authored-by: yzh119 <[email protected]>
Co-authored-by: Zihao Ye <[email protected]>

* test: Enable xfailed trtllm decode long seqlen tests and update microbenchmark (#2018)

<!-- .github/pull_request_template.md -->

[tests/attention/test_trtllm_gen_attention.py](https://github.com/flashinfer-ai/flashinfer/blob/v0.5.0rc2/tests/attention/test_trtllm_gen_attention.py#L1021-L1076)
was failing and therefore marked xfail.

PR #2002 fixed the underlying root cause. Current PR thus removed the
`xfail` marker so that these long seqlen cases could be fixed moving
forward.

Additionally, PR #2002 revealed a bug in the microbenchmark script where
[trtllm_batch_decode_with_kv_cache](https://github.com/flashinfer-ai/flashinfer/blob/v0.5.0rc2/flashinfer/decode.py#L2082-L2083)
explicitly requires the workspace to be zeroed before first use:
```
    workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use.
        workspace
```
while the microbenchmark code does not zero out, causing undefined
behavior such as IMAs that depend on the ordering of backends tested.
Current PR fixes the issue by explicitly calling
`workspace_buffer.zero_()` between testing different backends.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Bug Fixes**
* Improved stability of performance benchmarks by properly resetting
workspace buffer between backend invocations.

* **Tests**
  * Enabled previously skipped test for long sequence length handling.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* Updated decorator to support unspecified default (#2026)

<!-- .github/pull_request_template.md -->

Updated decorator to support unspecified default. This was causing
issues when calling mm_fp4 without backend specified.
Also added SM 110 as a supported backend on the cutlass backend (mm_fp4)

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [ ] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
  * FP4 Cutlass GEMM now supports the SM110 GPU compute capability.

* **Bug Fixes**
* Kernels called without an explicit backend now consistently use the
default backend.

* **Tests**
* Added a unit test to verify default backend selection and correct
results when backend is omitted.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* release: Bump version for v0.5.1 release (#2031)

<!-- .github/pull_request_template.md -->

Update `version.txt`

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **Chores**
  * Version updated to 0.5.1

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

* ci: Update cudnn version requirements in CI container (#2039)

<!-- .github/pull_request_template.md -->

cuDNN versions specified in CI container setup
(`docker/install/install_python_packages.sh`) are currently 9.11 and
9.12.

In unit testing, this causes issues as `mm_fp4(backend='cudnn')` is not
supported on Spark (sm121) for older cuDNN versions in cu130.

Failure is due to cuDNN version shipped with container being too old. In
the [latest container build pipeline
output](https://github.com/flashinfer-ai/flashinfer/actions/runs/18778064727/job/53577233568#step:6:727),
cudnn 9.13.0.50 is installed
```
```

Current PR updates the minimum cudnn version for both
[cu12](https://pypi.org/project/nvidia-cudnn-cu12/#history) and
[cu13](https://pypi.org/project/nvidia-cudnn-cu13/#history) to
9.14.0.64.

cudnn 9.13 --> unit test fails with 180 failed, 270 passed, 2790
skipped, 1 warning in 8.97s
```
=================================================================================================================================================== test session starts ===================================================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 3240 items
...
FAILED tests/gemm/test_mm_fp4.py::test_mm_fp4[mxfp4_alpha-False-True-cudnn-res_dtype1-512-512-256] - cudnn._compiled_module.cudnnGraphNotSupportedError: No valid engine configs for Matmul_MUL_
FAILED tests/gemm/test_mm_fp4.py::test_mm_fp4[mxfp4_alpha-False-True-cudnn-res_dtype1-512-512-512] - cudnn._compiled_module.cudnnGraphNotSupportedError: No valid engine configs for Matmul_MUL_
================================================================================================================================ 180 failed, 270 passed, 2790 skipped, 1 warning in 8.97s =================================================================================================================================

```
cudnn 9.14 --> unit test passes with 450 passed, 2790 skipped, 1 warning
in 5.37s
```
=================================================================================================================================================== test session starts ===================================================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 3240 items

tests/gemm/test_mm_fp4.py
...
====================================================================================================================================== 450 passed, 2790 skipped, 1 warning in 5.37s =======================================================================================================================================

```

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, e…
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants