Skip to content

Added missing padding#2726

Merged
bkryu merged 4 commits intoflashinfer-ai:mainfrom
nvjullin:fix-fp4-quant-padding
Mar 12, 2026
Merged

Added missing padding#2726
bkryu merged 4 commits intoflashinfer-ai:mainfrom
nvjullin:fix-fp4-quant-padding

Conversation

@nvjullin
Copy link
Copy Markdown
Contributor

@nvjullin nvjullin commented Mar 9, 2026

📌 Description

Linear sf is missing padding (or the kernel shouldn't try to write to non-existent padding), which is required when quantize_with_block_size_tma is called. There's a few requirements to hit this code path, some notable ones being m>=1024, n%512==0 and sm100.

The offending code is

if (threadRowIdxGlobal >= numRows || tidx.colIdx >= numCols) {
if (sf_out != nullptr) {
sf_out[0] = 0x00;
}

Solves #2704.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Unified padding behavior for FP4 quantization by switching to a consistent round-up alignment across swizzled and non-swizzled paths, preventing misaligned allocations.
    • Internal buffers now use padded allocations and are trimmed or reshaped before return, preserving public interfaces and avoiding layout/memory issues.
  • Tests

    • Added tests that validate quantization padding, output shapes, and numerical correctness for unaligned inputs on CUDA across dtypes, with caching bypass and architecture/version guards.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 9, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Replaced manual padding with consistent round_up-based alignment across FP4 quantization paths; allocations now use padded-size variables (e.g., out_sf_size_padded) and returned buffers are sliced back to logical shapes. Added imports for get_compute_capability and round_up. Public function signatures remain unchanged.

Changes

Cohort / File(s) Summary
FP4 quantization core
flashinfer/fp4_quantization.py
Replaced explicit padding math with round_up for rows/cols and scale-factor buffers across SM100, NVFP4, and fake variants; added get_compute_capability and round_up imports; allocate using padded sizes and slice returns to original logical shapes (introduces vars like out_sf_size_padded).
CUDA quantization tests
tests/utils/test_fp4_quantize_padding.py
New test validating fp4_quantize for unaligned M on CUDA (requires compute capability ≥10, CUDA ≥12.8). Parameterized by dtype and shapes; bypasses CUDA caching, checks output shapes, scale metadata length, and numeric agreement versus ref_fp4_quant/cast_from_fp4.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested labels

op: comm

Suggested reviewers

  • yzh119
  • cyx-6
  • nvmbreughe

Poem

🐰 I rounded rows with careful paws,
Padded snug without a pause,
Sliced back tidy, snug and neat,
Scales aligned — a hopping feat,
Carrot-coded joy in tiny claws 🥕

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title is vague and generic, using non-descriptive phrasing that doesn't clearly convey the specific nature of the change. Replace with a more specific title that describes the padding fix, such as 'Fix missing padding in FP4 quantization for SM100 swizzled layout' or similar.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed The description includes the required 📌 Description section explaining the issue, links to related issue #2704, but pre-commit and test checkboxes are incomplete.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@nvjullin nvjullin force-pushed the fix-fp4-quant-padding branch from c23f79d to ce48d4f Compare March 9, 2026 07:00
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical padding issue within the fp4_quantize_sm100 function, specifically for the Linear sf (scale factor) tensor. The changes ensure that the out_sf tensor is allocated with sufficient padding, which is a requirement for the quantize_with_block_size_tma function, particularly on sm100 architectures and for larger input sizes. By correctly padding the internal memory allocation and then slicing the result to the expected size, the PR resolves a reported bug and improves the robustness of the FP4 quantization process.

Highlights

  • Padding for Scale Factors: Implemented necessary padding for the out_sf (scale factor) tensor in fp4_quantize_sm100 to meet requirements for quantize_with_block_size_tma on sm100 architectures.
  • Memory Allocation: Modified the allocation of out_sf to use a 16-byte padded size, ensuring proper memory alignment for specific GPU operations.
  • Output Correction: Adjusted the return value to slice the padded out_sf tensor back to its original, unpadded out_sf_size before returning, maintaining the expected output shape.
  • Bug Fix: Resolved issue FP4 quantization kernel with tma cause the warp illegal address exception #2704, which reported missing padding for linear scale factors, preventing potential errors in specific quantization paths.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/fp4_quantization.py
    • Added calculation for out_sf_size_padded to ensure 16-byte alignment for the scale factor tensor.
    • Modified the out_sf tensor allocation to use the newly calculated padded size.
    • Updated the return statement to slice out_sf to its original out_sf_size before returning, effectively hiding the internal padding.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to add necessary padding for linear scale factors in FP4 quantization, a change required for certain hardware paths and crucial for preventing potential buffer overflows in the underlying CUDA kernels. However, the current implementation introduces a critical NameError where the padding variable out_sf_size_padded is not defined in all code paths, specifically when using the default swizzled layout. This flaw will cause the function to crash in most scenarios, leading to a Denial of Service. Please address this NameError to ensure the stability and security of the FP4 quantization process.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
flashinfer/fp4_quantization.py (1)

215-230: Add a regression test for the linear-layout padding path.

This fix changes only the backing allocation while still returning the logical out_sf view, so it can regress silently. The existing FP4 tests shown here cover the swizzled reshape/unswizzle flow, not fp4_quantize(..., is_sf_swizzled_layout=False). Please add a case with a non-16-aligned m that is large enough to take the TMA path (per the PR notes, m >= 1024 on SM100) so this padding requirement stays locked in.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fp4_quantization.py` around lines 215 - 230, The tests miss
exercising the linear-layout padding path in fp4_quantize, so add a regression
test that calls module.fp4_quantize (or the public wrapper that returns out_val,
out_sf) with is_sf_swizzled_layout=False and a non-16-aligned m large enough to
trigger the TMA path (e.g., m >= 1024 on SM100) to ensure out_sf is a view of a
larger out_sf backing allocation (created via out_sf_size_padded) but returned
truncated to out_sf_size; the test should verify correct values and that
out_sf.shape equals the logical size (out_sf_size) and that no indexing/overflow
occurs when m is not 16-aligned.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/fp4_quantization.py`:
- Around line 215-230: The tests miss exercising the linear-layout padding path
in fp4_quantize, so add a regression test that calls module.fp4_quantize (or the
public wrapper that returns out_val, out_sf) with is_sf_swizzled_layout=False
and a non-16-aligned m large enough to trigger the TMA path (e.g., m >= 1024 on
SM100) to ensure out_sf is a view of a larger out_sf backing allocation (created
via out_sf_size_padded) but returned truncated to out_sf_size; the test should
verify correct values and that out_sf.shape equals the logical size
(out_sf_size) and that no indexing/overflow occurs when m is not 16-aligned.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1568cdf0-020d-4ee9-8873-becf76a93b93

📥 Commits

Reviewing files that changed from the base of the PR and between bcdf8d8 and ce48d4f.

📒 Files selected for processing (1)
  • flashinfer/fp4_quantization.py

Copy link
Copy Markdown
Contributor

@rainj-me rainj-me left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nvjullin
Copy link
Copy Markdown
Contributor Author

nvjullin commented Mar 9, 2026

@rainj-me The unit test will be much more reliable when ran with PYTORCH_NO_CUDA_MEMORY_CACHING=1. You'll only need m=1025 instead of range(1000, 8192) to reproduce.

@rainj-me
Copy link
Copy Markdown
Contributor

rainj-me commented Mar 9, 2026

@rainj-me The unit test will be much more reliable when ran with PYTORCH_NO_CUDA_MEMORY_CACHING=1. You'll only need m=1025 instead of range(1000, 8192) to reproduce.

Probably reduce from 8192 to 4096 or sth. But please add to ut, in case there is further break.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/fp4_quantization.py (1)

233-245: ⚠️ Potential issue | 🔴 Critical

Fix fake op signature and output dtypes to match real op.

The fake op _fake_fp4_quantize_sm100 has multiple critical issues that break torch.compile shape/dtype inference:

  1. Missing parameters: Real op has is_sf_8x4_layout and enable_pdl parameters that fake op lacks. Signatures must exactly mirror.
  2. Wrong return dtypes: Real op returns (torch.uint8, torch.uint8), but fake op returns (torch.int64, torch.int32). torch.compile depends on fake ops for dtype inference.
  3. Incorrect scale factor shape: When is_sf_swizzled_layout=True (the default), real op uses _compute_swizzled_layout_sf_size() which pads the output significantly. Fake op ignores this and always returns [m * k // sf_vec_size], causing shape mismatches during tracing.

Add missing parameters to the fake op signature, correct the dtypes, and compute the proper swizzled layout size when needed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fp4_quantization.py` around lines 233 - 245, The fake op
_fake_fp4_quantize_sm100 must match the real op signature and outputs: add the
missing parameters is_sf_8x4_layout and enable_pdl to the function signature,
change the returned dtypes to torch.uint8 for both tensors, and compute the
scale-factor length using _compute_swizzled_layout_sf_size(m, k, sf_vec_size)
when is_sf_swizzled_layout is True (otherwise use m * k // sf_vec_size). Ensure
the first return has shape [m, k // 2] dtype=torch.uint8 and the second return
uses the computed swizzled/un-swizzled length dtype=torch.uint8.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/fp4_quantization.py`:
- Around line 233-245: The fake op _fake_fp4_quantize_sm100 must match the real
op signature and outputs: add the missing parameters is_sf_8x4_layout and
enable_pdl to the function signature, change the returned dtypes to torch.uint8
for both tensors, and compute the scale-factor length using
_compute_swizzled_layout_sf_size(m, k, sf_vec_size) when is_sf_swizzled_layout
is True (otherwise use m * k // sf_vec_size). Ensure the first return has shape
[m, k // 2] dtype=torch.uint8 and the second return uses the computed
swizzled/un-swizzled length dtype=torch.uint8.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 20d8cdf7-6ae7-427e-bbcd-2e54ad16bdf5

📥 Commits

Reviewing files that changed from the base of the PR and between ce48d4f and 1030734.

📒 Files selected for processing (1)
  • flashinfer/fp4_quantization.py

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Mar 10, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !398 has been created, and the CI pipeline #45817934 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #45817934: 7/20 passed

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/utils/test_fp4_quantize_padding.py`:
- Around line 52-59: Remove the global torch.set_default_device(device) call to
avoid leaking global state; instead, create tensors with explicit device
arguments (e.g., change x = torch.randn((m, n), dtype=dtype) to pass
device=device). Ensure any other tensor factories in this test use device=device
as well; keep torch.manual_seed(seed) as is. This targets the
torch.set_default_device symbol and the torch.randn creation in the test.
- Around line 1-5: The module-level os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"]
assignment should be removed (set this in the test command/CI instead), and the
test's mutation of global device state via torch.set_default_device(device) must
be wrapped so the original default device is restored; capture the current
default device before calling torch.set_default_device(device) and use a
try/finally to call torch.set_default_device(original_device) in the finally
block (follow the pattern used in tests/utils/test_sampling.py), ensuring no
global torch state is left changed after the test.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 0bc45642-c173-4005-a510-34954c3fa288

📥 Commits

Reviewing files that changed from the base of the PR and between 1030734 and 63c70b3.

📒 Files selected for processing (1)
  • tests/utils/test_fp4_quantize_padding.py

@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 11, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !398 has been updated with latest changes, and the CI pipeline #45892595 is currently running. I'll report back once the pipeline job completes.

@bkryu bkryu added the run-ci label Mar 11, 2026
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[CANCELING] Pipeline #45892595: canceled

@nvpohanh
Copy link
Copy Markdown
Contributor

/workspace/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp(794): error: class "cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100<flashinfer::gemm::DeviceGemmFp4GemmSm120_half_128_128_256_1_1_1_1SM::ClusterShape, 3U>" has no member "is_last_tile"
          if (scheduler.is_last_tile(work_tile_info)) {
                        ^
          detected during:
            instantiation of "void cutlass::gemm::kernel::GemmUniversal<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileSchedulerTag_, std::enable_if_t<std::is_base_of_v, void>>::operator()(const cutlass::gemm::kernel::GemmUniversal<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileSchedulerTag_, std::enable_if_t<std::is_base_of_v, void>>::Params &, char *) [with ProblemShape_=cute::tuple<int, int, int, int>, CollectiveMainloop_=flashinfer::gemm::DeviceGemmFp4GemmSm120_half_128_128_256_1_1_1_1SM::CollectiveMainloop, CollectiveEpilogue_=flashinfer::gemm::DeviceGemmFp4GemmSm120_half_128_128_256_1_1_1_1SM::CollectiveEpilogue, TileSchedulerTag_=void]" at line 123 of /workspace/3rdparty/cutlass/include/cutlass/device_kernel.h

Known issue. @nvjullin Please rebase this PR. thanks

@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 12, 2026

/workspace/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp(794): error: class "cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100<flashinfer::gemm::DeviceGemmFp4GemmSm120_half_128_128_256_1_1_1_1SM::ClusterShape, 3U>" has no member "is_last_tile"
          if (scheduler.is_last_tile(work_tile_info)) {
                        ^
          detected during:
            instantiation of "void cutlass::gemm::kernel::GemmUniversal<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileSchedulerTag_, std::enable_if_t<std::is_base_of_v, void>>::operator()(const cutlass::gemm::kernel::GemmUniversal<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileSchedulerTag_, std::enable_if_t<std::is_base_of_v, void>>::Params &, char *) [with ProblemShape_=cute::tuple<int, int, int, int>, CollectiveMainloop_=flashinfer::gemm::DeviceGemmFp4GemmSm120_half_128_128_256_1_1_1_1SM::CollectiveMainloop, CollectiveEpilogue_=flashinfer::gemm::DeviceGemmFp4GemmSm120_half_128_128_256_1_1_1_1SM::CollectiveEpilogue, TileSchedulerTag_=void]" at line 123 of /workspace/3rdparty/cutlass/include/cutlass/device_kernel.h

Known issue. @nvjullin Please rebase this PR. thanks

FYI, must be rebased on top of #2737.

@nvjullin nvjullin force-pushed the fix-fp4-quant-padding branch from 2b58537 to d7033a2 Compare March 12, 2026 02:51
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
flashinfer/fp4_quantization.py (1)

440-447: Consider extracting the repeated NVFP4 scale-buffer sizing.

The scale_kpadded_kpadded_m sequence now appears in four real/fake paths. Pulling it into a helper would make future padding changes much harder to miss in one branch.

Also applies to: 476-483, 524-531, 564-571

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fp4_quantization.py` around lines 440 - 447, Extract the repeated
NVFP4 scale-buffer sizing logic (scale_k = k // sf_vec_size; padded_k =
round_up(scale_k, 4); padded_k_int32 = padded_k // 4; padded_m = round_up(m,
128)) into a small helper function (e.g., compute_nvfp4_scale_dims or
nvfp4_scale_sizes) and return the computed values (scale_k, padded_k,
padded_k_int32, padded_m); then replace the duplicated blocks in the branches
that build output/output_scales (the sites using scale_k, padded_k,
padded_k_int32, padded_m such as where output = torch.empty(...) and
output_scales = torch.empty(...)) to call the helper and use its results,
keeping existing variable names (scale_k, padded_k_int32, padded_m) to minimize
downstream changes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/fp4_quantization.py`:
- Around line 440-447: Extract the repeated NVFP4 scale-buffer sizing logic
(scale_k = k // sf_vec_size; padded_k = round_up(scale_k, 4); padded_k_int32 =
padded_k // 4; padded_m = round_up(m, 128)) into a small helper function (e.g.,
compute_nvfp4_scale_dims or nvfp4_scale_sizes) and return the computed values
(scale_k, padded_k, padded_k_int32, padded_m); then replace the duplicated
blocks in the branches that build output/output_scales (the sites using scale_k,
padded_k, padded_k_int32, padded_m such as where output = torch.empty(...) and
output_scales = torch.empty(...)) to call the helper and use its results,
keeping existing variable names (scale_k, padded_k_int32, padded_m) to minimize
downstream changes.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c0c2bf85-ad67-4de2-ad0c-4ec241ae6056

📥 Commits

Reviewing files that changed from the base of the PR and between 2b58537 and d7033a2.

📒 Files selected for processing (2)
  • flashinfer/fp4_quantization.py
  • tests/utils/test_fp4_quantize_padding.py

@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 12, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !398 has been updated with latest changes, and the CI pipeline #45942317 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #45942317: 1/20 passed

@bkryu bkryu merged commit f4d10d9 into flashinfer-ai:main Mar 12, 2026
71 of 102 checks passed
frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

Linear sf is missing padding (or the kernel shouldn't try to write to
non-existent padding), which is required when
`quantize_with_block_size_tma` is called. There's a few requirements to
hit this code path, some notable ones being m>=1024, n%512==0 and sm100.

The offending code is
https://github.com/flashinfer-ai/flashinfer/blob/bcdf8d8ac725498416d2995de54323e3c9996f5a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh#L454-L457

Solves flashinfer-ai#2704.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Unified padding behavior for FP4 quantization by switching to a
consistent round-up alignment across swizzled and non-swizzled paths,
preventing misaligned allocations.
* Internal buffers now use padded allocations and are trimmed or
reshaped before return, preserving public interfaces and avoiding
layout/memory issues.

* **Tests**
* Added tests that validate quantization padding, output shapes, and
numerical correctness for unaligned inputs on CUDA across dtypes, with
caching bypass and architecture/version guards.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

Linear sf is missing padding (or the kernel shouldn't try to write to
non-existent padding), which is required when
`quantize_with_block_size_tma` is called. There's a few requirements to
hit this code path, some notable ones being m>=1024, n%512==0 and sm100.

The offending code is
https://github.com/flashinfer-ai/flashinfer/blob/bcdf8d8ac725498416d2995de54323e3c9996f5a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh#L454-L457

Solves flashinfer-ai#2704.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Unified padding behavior for FP4 quantization by switching to a
consistent round-up alignment across swizzled and non-swizzled paths,
preventing misaligned allocations.
* Internal buffers now use padded allocations and are trimmed or
reshaped before return, preserving public interfaces and avoiding
layout/memory issues.

* **Tests**
* Added tests that validate quantization padding, output shapes, and
numerical correctness for unaligned inputs on CUDA across dtypes, with
caching bypass and architecture/version guards.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants