Skip to content

feat: Add output_both_sf_layouts option to add_rmsnorm_fp4quant API#2395

Merged
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
bkryu:add_rmsn_f4q_dual_sf
Jan 23, 2026
Merged

feat: Add output_both_sf_layouts option to add_rmsnorm_fp4quant API#2395
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
bkryu:add_rmsn_f4q_dual_sf

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Jan 21, 2026

📌 Description

Must be merged after #2385

This PR extends the add_rmsnorm_fp4quant API to support outputting both swizzled and unswizzled scale factors simultaneously. This is useful for scenarios where the quantized output needs to be consumed by both GEMMs (experts) and All-to-All without requiring a separate layout conversion pass.

When output_both_sf_layouts=True, the function returns a 3-tuple (y_fp4, block_scale_swizzled, block_scale_unswizzled) instead of the standard 2-tuple. This flag overrides is_sf_swizzled_layout when set.

Changes Summary

File Change
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py Added output_both_sf_layouts and block_scale_unswizzled parameters; updated kernel to write both SF layouts
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py Added TestOutputBothSFLayouts test class with 10 test methods covering NVFP4/MXFP4, 2D/3D inputs, auto/pre-allocation, and large hidden sizes
benchmarks/routines/norm.py Added --output_both_sf_layouts flag; adjusted bandwidth calculation to account for 2× SF writes
benchmarks/samples/sample_testlist.txt Added example benchmark commands for dual SF output

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added --output_both_sf_layouts to emit both swizzled and unswizzled scale-factor layouts and to enable in-place residual updates for the FP4 quantization add+RMSNorm flow.
    • API now returns an extra unswizzled scale tensor when requested.
  • Bug Fixes / UX

    • rmsnorm_fp4quant surfaces a user warning when the new flag is unsupported and ignores it.
    • Verbose output and result reporting now include the new flag state.
  • Tests

    • Added extensive tests validating dual-layout outputs, dtypes, shapes, in-place residual semantics, and cross-path consistency.

✏️ Tip: You can customize this high-level summary in your review settings.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the add_rmsnorm_fp4quant API by introducing the capability to output both swizzled and unswizzled scale factors in a single operation. This feature streamlines workflows for different downstream consumers of quantized data, improving efficiency by eliminating redundant layout conversions. Additionally, the underlying kernel has been optimized to perform an in-place update of the residual tensor, leading to more efficient memory usage. These changes are supported by extensive new tests and updated benchmarks to ensure correctness and performance.

Highlights

  • Dual Scale Factor Output: The add_rmsnorm_fp4quant API now supports an output_both_sf_layouts option, allowing it to return both swizzled and unswizzled scale factors simultaneously. This is beneficial for scenarios where the quantized output needs to be consumed by different components (e.g., GEMMs and All-to-All) without requiring separate layout conversions.
  • In-place Residual Update: The residual tensor passed to add_rmsnorm_fp4quant is now updated in-place with the sum of input + residual. This change affects the kernel's internal operation and how reference computations are handled in tests and benchmarks.
  • Enhanced Testing: Comprehensive new test cases have been added to verify the correctness of the dual scale factor output across various configurations (NVFP4/MXFP4, 2D/3D inputs, auto/pre-allocation, large hidden sizes). Existing tests were also updated to correctly account for the in-place residual modification.
  • Benchmark Adjustments: The benchmarking routine for add_rmsnorm_fp4quant has been updated to include a new --output_both_sf_layouts flag and to accurately calculate memory bandwidth, considering the potential for writing two scale factors.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 21, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

This PR adds an optional --output_both_sf_layouts flag to request both swizzled and unswizzled scale-factor layouts, threads the flag through benchmark CLI and the add_rmsnorm_fp4quant backend, implements dual-scale outputs and in-place residual updates in the fused kernel/host, and adds comprehensive tests and benchmark cases for the new behavior.

Changes

Cohort / File(s) Summary
Benchmark & CLI
benchmarks/routines/norm.py
Added --output_both_sf_layouts CLI flag; propagate flag through testAddRmsnormFp4quant and verbose logging; warn/ignore in rmsnorm_fp4quant paths; record flag in cur_res; update memory-bandwidth notes for in-place residual and potential double SF writes.
Test Matrix Samples
benchmarks/samples/sample_testlist.txt
Added multiple test entries exercising output_both_sf_layouts, including combinations with use_global_scale and mxfp4 across new batch/hidden configs (entries duplicated in two sections).
Core Kernel & Host Implementation
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
Added output_both_sf_layouts parameter and optional block_scale_unswizzled path; kernel/host updated for in-place residual = residual + input semantics; kernel produces conditional dual SF outputs (swizzled + unswizzled) and updated signatures/returns.
Unit Tests
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py
Added TestOutputBothSFLayouts with multiple tests validating dual SF outputs, dtype/shape checks, parity with single-layout references, 3D input handling, global-scale and preallocated output cases, and in-place residual semantics.

Sequence Diagram(s)

sequenceDiagram
    participant Host
    participant Kernel as Device Kernel
    participant Memory as Global Memory

    Host->>Kernel: launch add_rmsnorm_fp4quant(..., output_both_sf_layouts=true, inputs)
    
    rect rgba(100, 200, 150, 0.5)
    Note over Kernel: Fused operations in kernel
    Kernel->>Memory: read input, residual, weights
    Kernel->>Kernel: h = input + residual (compute)
    Kernel->>Memory: write h back to residual (in-place)
    Kernel->>Kernel: apply RMSNorm to h
    Kernel->>Kernel: quantize h -> y_fp4
    Kernel->>Kernel: compute block_scale (swizzled)
    Kernel->>Kernel: compute block_scale_unswizzled (unswizzled)
    end

    Kernel->>Memory: write y_fp4, block_scale, block_scale_unswizzled
    Kernel->>Host: return (y_fp4, block_scale, block_scale_unswizzled)
    Host->>Host: verify/consume outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • kaixih
  • aleozlx
  • jimmyzho
  • jiahanc
  • cyx-6
  • kahyunnam
  • yzh119

Poem

🐰 In memory's burrow two scales now play,

swizzled and straight at the break of day.
Residuals updated with a hop and a wink—
FP4 twinkles while tensors link.
Hooray for dual layouts! 🥕✨

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The PR title accurately summarizes the main change: adding an output_both_sf_layouts option to the add_rmsnorm_fp4quant API, which is the core feature introduced.
Description check ✅ Passed The PR description covers the key aspects including the feature rationale, return type changes, flag behavior, file-by-file changes, and checklist completion. However, the 'Tests have been added' and 'All tests are passing' checkboxes are unchecked.
Docstring Coverage ✅ Passed Docstring coverage is 86.96% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully introduces the output_both_sf_layouts option to the add_rmsnorm_fp4quant API, enabling the simultaneous output of swizzled and unswizzled scale factors. A key related change is the modification of the kernel to perform an in-place update on the residual tensor (residual = input + residual), which is now correctly reflected in the documentation, benchmarks, and tests. The test suite has been significantly expanded with new classes to thoroughly validate both the in-place update behavior and the new dual-layout output feature. The changes are well-implemented, but there is an opportunity to improve maintainability by refactoring duplicated code within the kernel.

Comment on lines +1381 to +1398
if cutlass.const_expr(self.output_both_sf_layouts):
# Output both swizzled and unswizzled scale factors
inner_k_idx = sf_idx % Int32(4)
inner_m_idx = (actual_row_idx % Int32(128)) // Int32(32)
outer_m_idx = actual_row_idx % Int32(32)
k_tile_idx = sf_idx // Int32(4)
m_tile_idx = actual_row_idx // Int32(128)
m_tile_stride = self.num_k_tiles * self.k_tile_stride
swizzled_offset = (
m_tile_idx * m_tile_stride
+ k_tile_idx * self.k_tile_stride
+ outer_m_idx * Int32(16)
+ inner_m_idx * Int32(4)
+ inner_k_idx
)
mS[swizzled_offset] = scale_fp8
mS_unswizzled[actual_row_idx, sf_idx] = scale_fp8
elif cutlass.const_expr(self.output_swizzled):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for calculating swizzled_offset is duplicated in the if block here and the elif block below, and this pattern is repeated in three other places in the kernel. This harms maintainability.

To improve this, you could extract the swizzled_offset calculation into a @cute.jit helper function and call it from both branches. This would centralize the swizzle logic and make the kernel easier to read and modify.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)

2424-2545: Residual in-place update is silently broken for non-contiguous inputs in the 2D path.

The function promises in-place modification of residual (documented as "Modified in-place to contain residual + input"), but when residual is non-contiguous in the 2D case, calling .contiguous() before the kernel invocation creates a copy. The kernel modifies this copy, leaving the original residual unchanged and violating the API contract.

The 3D path explicitly calls .contiguous() on both tensors early (line 2427-2428), but the 2D path does not, creating inconsistent behavior. Add an explicit contiguity check to enforce the requirement:

Proposed fix
     is_3d = input.dim() == 3
+    if not residual.is_contiguous():
+        raise ValueError("residual must be contiguous for in-place update")
     if is_3d:
🤖 Fix all issues with AI agents
In `@tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py`:
- Around line 1621-2207: Several tests in the TestOutputBothSFLayouts class
unpack the 3-tuple returned by add_rmsnorm_fp4quant but don't use all elements,
which triggers Ruff RUF059; update those unpackings to mark unused values with _
(or prefix variable names with an underscore) so the linter ignores them. Locate
all uses of add_rmsnorm_fp4quant inside TestOutputBothSFLayouts (e.g.,
assignments like y_fp4_both, block_scale_swizzled, block_scale_unswizzled =
result or result1/result2 unpackings) and replace unused targets with _ (for
example y_fp4, _, _ = result or y_fp4, swizzled, _ = result when only the first
two are used). Ensure preallocated-return checks also use _ for any output
tensors not inspected (e.g., when only y_fp4 is asserted), and keep references
to add_rmsnorm_fp4quant, TestOutputBothSFLayouts, and the specific variable
names to guide the edits.
- Around line 1315-1620: Several test cases unpack the tuple returned by
add_rmsnorm_fp4quant but never use one or both outputs (e.g., in
TestResidualInPlaceUpdate:test_residual_inplace_update_2d,
test_residual_inplace_update_3d, test_residual_inplace_update_mxfp4,
test_residual_inplace_update_large_hidden,
test_residual_inplace_with_preallocated_outputs,
test_residual_inplace_swizzled_layout, test_residual_not_aliased_with_input).
Update those unpackings to mark unused values with a leading underscore (e.g.,
replace "y_fp4, block_scale =" with "_y_fp4, _block_scale =" or "_, block_scale
=" where only one is used) so the returned outputs from add_rmsnorm_fp4quant are
explicitly acknowledged as unused and silence Ruff RUF059.
🧹 Nitpick comments (2)
benchmarks/samples/sample_testlist.txt (1)

114-123: Consider adding a 3D input test case for comprehensive coverage.

The existing test suite includes a 3D input shape test (line 112). For completeness, consider adding a test that combines --output_both_sf_layouts with 3D inputs:

--routine add_rmsnorm_fp4quant --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 --output_both_sf_layouts -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_3d_both_sf"

This would ensure the dual SF layout feature works correctly with both 2D and 3D input shapes.

tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (1)

1315-1317: Prefer flashinfer.utils capability checks for skips.
Using the shared helpers keeps GPU-arch gating consistent across tests. As per coding guidelines, please switch to flashinfer.utils skip helpers.

Also applies to: 1621-1623

Comment on lines +1315 to +1620
@cute_dsl_available
@blackwell_required
class TestResidualInPlaceUpdate:
"""Tests to verify that the residual tensor is updated in-place with input + residual."""

@pytest.mark.parametrize("batch_size", [1, 4, 16, 128, 512])
@pytest.mark.parametrize("hidden_size", [256, 512, 1024, 4096])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_residual_inplace_update_2d(self, batch_size, hidden_size, dtype):
"""Test that residual is updated in-place for 2D input."""
from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant

torch.manual_seed(42)
block_size = 16
eps = 1e-6

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Store original values before kernel call
r_original = r.clone()
expected_residual = x + r_original

# Call kernel - residual should be modified in-place
y_fp4, block_scale = add_rmsnorm_fp4quant(
x, r, weight, eps=eps, block_size=block_size
)

# Verify residual is updated in-place to input + original_residual
torch.testing.assert_close(
r,
expected_residual,
rtol=0,
atol=0,
msg="Residual should be exactly input + original_residual",
)

@pytest.mark.parametrize("batch_size", [1, 4, 16])
@pytest.mark.parametrize("seq_len", [16, 64, 128])
@pytest.mark.parametrize("hidden_size", [256, 1024, 4096])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_residual_inplace_update_3d(self, batch_size, seq_len, hidden_size, dtype):
"""Test that residual is updated in-place for 3D input."""
from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant

torch.manual_seed(42)
block_size = 16
eps = 1e-6

x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Store original values before kernel call
r_original = r.clone()
expected_residual = x + r_original

# Call kernel - residual should be modified in-place
y_fp4, block_scale = add_rmsnorm_fp4quant(
x, r, weight, eps=eps, block_size=block_size
)

# Verify residual is updated in-place to input + original_residual
torch.testing.assert_close(
r,
expected_residual,
rtol=0,
atol=0,
msg="Residual should be exactly input + original_residual",
)

@pytest.mark.parametrize("batch_size", [1, 16, 128])
@pytest.mark.parametrize("hidden_size", [256, 1024, 2048])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_residual_inplace_update_mxfp4(self, batch_size, hidden_size, dtype):
"""Test that residual is updated in-place for MXFP4 format (block_size=32)."""
from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant

torch.manual_seed(42)
block_size = 32 # MXFP4
eps = 1e-6

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Store original values before kernel call
r_original = r.clone()
expected_residual = x + r_original

# Call kernel - residual should be modified in-place
y_fp4, block_scale = add_rmsnorm_fp4quant(
x, r, weight, eps=eps, block_size=block_size, scale_format="ue8m0"
)

# Verify residual is updated in-place to input + original_residual
torch.testing.assert_close(
r,
expected_residual,
rtol=0,
atol=0,
msg="Residual should be exactly input + original_residual",
)

@pytest.mark.parametrize("batch_size", [1, 16, 128])
@pytest.mark.parametrize("hidden_size", [16384, 32768])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_residual_inplace_update_large_hidden(self, batch_size, hidden_size, dtype):
"""Test residual in-place update with large hidden sizes (cluster sync path)."""
from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant

torch.manual_seed(42)
block_size = 16
eps = 1e-6

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Store original values before kernel call
r_original = r.clone()
expected_residual = x + r_original

# Call kernel - residual should be modified in-place
y_fp4, block_scale = add_rmsnorm_fp4quant(
x, r, weight, eps=eps, block_size=block_size
)

# Verify residual is updated in-place to input + original_residual
torch.testing.assert_close(
r,
expected_residual,
rtol=0,
atol=0,
msg="Residual should be exactly input + original_residual",
)

@pytest.mark.parametrize("batch_size", [16, 128])
@pytest.mark.parametrize("hidden_size", [512, 1024])
def test_residual_used_for_rmsnorm(self, batch_size, hidden_size):
"""
Test that the updated residual (input + original_residual) is used for RMSNorm.

This verifies the correct sequence of operations:
1. residual = input + residual (in-place)
2. output = RMSNorm(residual) * weight
3. quantize output to FP4
"""
from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant

torch.manual_seed(42)
block_size = 16
eps = 1e-6
dtype = torch.float16

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Store original residual and compute expected values
r_original = r.clone()
expected_h = x + r_original
expected_rmsnorm = llama_rms_norm(expected_h, weight, eps=eps)

# Call kernel
y_fp4, block_scale = add_rmsnorm_fp4quant(
x, r, weight, eps=eps, block_size=block_size
)

# Verify residual is updated
torch.testing.assert_close(r, expected_h, rtol=0, atol=0)

# Verify FP4 output matches RMSNorm of the updated residual
y_dequant = dequantize_fp4_output(y_fp4, block_scale, block_size)
assert_close_with_tiered_tolerance(
y_dequant,
expected_rmsnorm.float(),
tight_rtol=0.3,
tight_atol=0.5,
loose_rtol=0.5,
loose_atol=2.0,
tight_pct=0.99,
msg="FP4 output should match RMSNorm of updated residual",
)

@pytest.mark.parametrize("batch_size", [16, 128])
@pytest.mark.parametrize("hidden_size", [512, 1024])
def test_residual_inplace_with_preallocated_outputs(self, batch_size, hidden_size):
"""Test residual in-place update when using pre-allocated output tensors."""
from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant

torch.manual_seed(42)
block_size = 16
eps = 1e-6
dtype = torch.float16

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Pre-allocate output tensors
y_fp4 = torch.empty(
batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2
)
block_scale = torch.empty(
batch_size,
hidden_size // block_size,
device="cuda",
dtype=torch.float8_e4m3fn,
)

# Store original residual
r_original = r.clone()
expected_residual = x + r_original

# Call kernel with pre-allocated outputs
add_rmsnorm_fp4quant(
x, r, weight, y_fp4, block_scale, eps=eps, block_size=block_size
)

# Verify residual is updated in-place
torch.testing.assert_close(
r,
expected_residual,
rtol=0,
atol=0,
msg="Residual should be updated in-place even with pre-allocated outputs",
)

@pytest.mark.parametrize("batch_size", [16, 128])
@pytest.mark.parametrize("hidden_size", [512, 1024])
def test_residual_inplace_swizzled_layout(self, batch_size, hidden_size):
"""Test residual in-place update with swizzled scale factor layout."""
from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant

torch.manual_seed(42)
block_size = 16
eps = 1e-6
dtype = torch.float16

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Store original residual
r_original = r.clone()
expected_residual = x + r_original

# Call kernel with swizzled layout
y_fp4, block_scale = add_rmsnorm_fp4quant(
x, r, weight, eps=eps, block_size=block_size, is_sf_swizzled_layout=True
)

# Verify residual is updated in-place
torch.testing.assert_close(
r,
expected_residual,
rtol=0,
atol=0,
msg="Residual should be updated in-place with swizzled layout",
)

def test_residual_not_aliased_with_input(self):
"""Test that the kernel handles non-aliased input and residual correctly."""
from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant

torch.manual_seed(42)
batch_size = 64
hidden_size = 1024
block_size = 16
eps = 1e-6
dtype = torch.float16

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Ensure x and r are separate tensors (not views of each other)
assert x.data_ptr() != r.data_ptr()

# Store originals
x_original = x.clone()
r_original = r.clone()
expected_residual = x_original + r_original

# Call kernel
y_fp4, block_scale = add_rmsnorm_fp4quant(
x, r, weight, eps=eps, block_size=block_size
)

# Verify x is unchanged
torch.testing.assert_close(
x, x_original, rtol=0, atol=0, msg="Input tensor should not be modified"
)

# Verify r is updated
torch.testing.assert_close(
r,
expected_residual,
rtol=0,
atol=0,
msg="Residual should be updated in-place",
)


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Silence Ruff RUF059 by marking unused outputs.
Several tests unpack outputs they don’t use; prefix them with _ to avoid lint failures.

🧹 Example fix (apply similarly across this class)
-        y_fp4, block_scale = add_rmsnorm_fp4quant(
+        _y_fp4, _block_scale = add_rmsnorm_fp4quant(
             x, r, weight, eps=eps, block_size=block_size
         )
🧰 Tools
🪛 Ruff (0.14.13)

1340-1340: Unpacked variable y_fp4 is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1340-1340: Unpacked variable block_scale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1374-1374: Unpacked variable y_fp4 is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1374-1374: Unpacked variable block_scale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1407-1407: Unpacked variable y_fp4 is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1407-1407: Unpacked variable block_scale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1440-1440: Unpacked variable y_fp4 is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1440-1440: Unpacked variable block_scale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1565-1565: Unpacked variable y_fp4 is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1565-1565: Unpacked variable block_scale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1602-1602: Unpacked variable y_fp4 is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1602-1602: Unpacked variable block_scale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In `@tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py` around lines 1315 - 1620,
Several test cases unpack the tuple returned by add_rmsnorm_fp4quant but never
use one or both outputs (e.g., in
TestResidualInPlaceUpdate:test_residual_inplace_update_2d,
test_residual_inplace_update_3d, test_residual_inplace_update_mxfp4,
test_residual_inplace_update_large_hidden,
test_residual_inplace_with_preallocated_outputs,
test_residual_inplace_swizzled_layout, test_residual_not_aliased_with_input).
Update those unpackings to mark unused values with a leading underscore (e.g.,
replace "y_fp4, block_scale =" with "_y_fp4, _block_scale =" or "_, block_scale
=" where only one is used) so the returned outputs from add_rmsnorm_fp4quant are
explicitly acknowledged as unused and silence Ruff RUF059.

Comment on lines +1621 to +2207
@cute_dsl_available
@blackwell_required
class TestOutputBothSFLayouts:
"""Tests for output_both_sf_layouts=True which returns both swizzled and unswizzled SFs."""

@pytest.mark.parametrize("batch_size", [1, 16, 128, 256])
@pytest.mark.parametrize("hidden_size", [256, 512, 1024, 4096])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_nvfp4_both_sf_layouts_basic(self, batch_size, hidden_size, dtype):
"""
Test that output_both_sf_layouts=True returns 3 tensors and both SFs are correct.
Uses NVFP4 format (block_size=16, E4M3 scales).
"""
from flashinfer.cute_dsl import add_rmsnorm_fp4quant

block_size = 16
eps = 1e-6
torch.manual_seed(42)

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Clone r since kernel modifies it in-place
r_clone1 = r.clone()

# Call with output_both_sf_layouts=True
result = add_rmsnorm_fp4quant(
x,
r,
weight,
eps=eps,
block_size=block_size,
output_both_sf_layouts=True,
)

# Should return 3 tensors
assert len(result) == 3, f"Expected 3 tensors, got {len(result)}"
y_fp4_both, block_scale_swizzled, block_scale_unswizzled = result

# Verify shapes
assert y_fp4_both.shape == (batch_size, hidden_size // 2)
assert block_scale_unswizzled.shape == (batch_size, hidden_size // block_size)

# Swizzled layout should be 1D
factor = block_size * 4
num_m_tiles = (batch_size + 127) // 128
num_k_tiles = (hidden_size + factor - 1) // factor
expected_swizzled_size = num_m_tiles * num_k_tiles * 32 * 4 * 4
assert block_scale_swizzled.shape == (expected_swizzled_size,)

# Verify dtypes
assert y_fp4_both.dtype == torch.float4_e2m1fn_x2
assert block_scale_swizzled.dtype == torch.float8_e4m3fn
assert block_scale_unswizzled.dtype == torch.float8_e4m3fn

# Compare against separate calls with is_sf_swizzled_layout=False
y_fp4_unswizzled_only, block_scale_ref_unswizzled = add_rmsnorm_fp4quant(
x,
r_clone1,
weight,
eps=eps,
block_size=block_size,
is_sf_swizzled_layout=False,
)

# FP4 values should be identical
torch.testing.assert_close(
y_fp4_both.view(torch.uint8), y_fp4_unswizzled_only.view(torch.uint8)
)

# Unswizzled SF should match the reference unswizzled
torch.testing.assert_close(
block_scale_unswizzled.view(torch.uint8),
block_scale_ref_unswizzled.view(torch.uint8),
)

# Verify swizzled SF is correct by unswizzling and comparing
block_scale_from_swizzled = unswizzle_sf(
block_scale_swizzled.view(torch.uint8), batch_size, hidden_size, block_size
).view(torch.float8_e4m3fn)
torch.testing.assert_close(
block_scale_from_swizzled.view(torch.uint8),
block_scale_ref_unswizzled.view(torch.uint8),
)

@pytest.mark.parametrize("batch_size", [1, 16, 128, 256])
@pytest.mark.parametrize("hidden_size", [256, 512, 1024, 4096])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_mxfp4_both_sf_layouts_basic(self, batch_size, hidden_size, dtype):
"""
Test that output_both_sf_layouts=True returns 3 tensors for MXFP4 format.
Uses MXFP4 format (block_size=32, UE8M0 scales).
"""
from flashinfer.cute_dsl import add_rmsnorm_fp4quant

block_size = 32
eps = 1e-6
torch.manual_seed(42)

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Clone r since kernel modifies it in-place
r_clone1 = r.clone()

# Call with output_both_sf_layouts=True
result = add_rmsnorm_fp4quant(
x,
r,
weight,
eps=eps,
block_size=block_size,
scale_format="ue8m0",
output_both_sf_layouts=True,
)

# Should return 3 tensors
assert len(result) == 3, f"Expected 3 tensors, got {len(result)}"
y_fp4_both, block_scale_swizzled, block_scale_unswizzled = result

# Verify shapes
assert y_fp4_both.shape == (batch_size, hidden_size // 2)
assert block_scale_unswizzled.shape == (batch_size, hidden_size // block_size)

# Verify dtypes (UE8M0 uses uint8)
assert y_fp4_both.dtype == torch.float4_e2m1fn_x2
assert block_scale_swizzled.dtype == torch.uint8
assert block_scale_unswizzled.dtype == torch.uint8

# Compare against separate calls with is_sf_swizzled_layout=False
y_fp4_unswizzled_only, block_scale_ref_unswizzled = add_rmsnorm_fp4quant(
x,
r_clone1,
weight,
eps=eps,
block_size=block_size,
scale_format="ue8m0",
is_sf_swizzled_layout=False,
)

# FP4 values should be identical
torch.testing.assert_close(
y_fp4_both.view(torch.uint8), y_fp4_unswizzled_only.view(torch.uint8)
)

# Unswizzled SF should match the reference unswizzled
torch.testing.assert_close(block_scale_unswizzled, block_scale_ref_unswizzled)

# Verify swizzled SF is correct by unswizzling and comparing
block_scale_from_swizzled = unswizzle_sf(
block_scale_swizzled, batch_size, hidden_size, block_size
)
torch.testing.assert_close(
block_scale_from_swizzled, block_scale_ref_unswizzled
)

@pytest.mark.parametrize("batch_size", [1, 16, 128])
@pytest.mark.parametrize("hidden_size", [512, 1024, 4096])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_both_sf_layouts_consistency(self, batch_size, hidden_size, dtype):
"""
Test that unswizzling the swizzled SF matches the unswizzled SF.
This verifies internal consistency of the dual output.
"""
from flashinfer.cute_dsl import add_rmsnorm_fp4quant

block_size = 16
eps = 1e-6
torch.manual_seed(42)

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Call with output_both_sf_layouts=True
y_fp4, block_scale_swizzled, block_scale_unswizzled = add_rmsnorm_fp4quant(
x,
r,
weight,
eps=eps,
block_size=block_size,
output_both_sf_layouts=True,
)

# Unswizzle the swizzled SF
block_scale_unswizzled_from_swizzled = unswizzle_sf(
block_scale_swizzled.view(torch.uint8), batch_size, hidden_size, block_size
).view(torch.float8_e4m3fn)

# Should match the directly returned unswizzled SF
torch.testing.assert_close(
block_scale_unswizzled_from_swizzled.view(torch.uint8),
block_scale_unswizzled.view(torch.uint8),
)

@pytest.mark.parametrize("batch_size", [1, 4, 8])
@pytest.mark.parametrize("seq_len", [16, 64, 128])
@pytest.mark.parametrize("hidden_size", [256, 1024])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_both_sf_layouts_3d_input(self, batch_size, seq_len, hidden_size, dtype):
"""Test output_both_sf_layouts=True with 3D input tensors."""
from flashinfer.cute_dsl import add_rmsnorm_fp4quant

block_size = 16
eps = 1e-6
torch.manual_seed(42)

x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Reference computation before kernel call
h_ref = x + r
ref_rmsnorm = llama_rms_norm(h_ref, weight, eps=eps)

# Call with output_both_sf_layouts=True
result = add_rmsnorm_fp4quant(
x,
r,
weight,
eps=eps,
block_size=block_size,
output_both_sf_layouts=True,
)

assert len(result) == 3
y_fp4, block_scale_swizzled, block_scale_unswizzled = result

# Verify shapes
assert y_fp4.shape == (batch_size, seq_len, hidden_size // 2)
assert block_scale_unswizzled.shape == (
batch_size,
seq_len,
hidden_size // block_size,
)

# Verify dtypes
assert y_fp4.dtype == torch.float4_e2m1fn_x2
assert block_scale_swizzled.dtype == torch.float8_e4m3fn
assert block_scale_unswizzled.dtype == torch.float8_e4m3fn

# Dequantize using unswizzled SF and verify values
y_dequant = dequantize_fp4_output(y_fp4, block_scale_unswizzled, block_size)
assert_close_with_tiered_tolerance(
y_dequant,
ref_rmsnorm.float(),
tight_rtol=0.3,
tight_atol=0.5,
loose_rtol=0.5,
loose_atol=2.0,
tight_pct=0.99,
)

@pytest.mark.parametrize("batch_size", [16, 128])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_both_sf_layouts_with_global_scale(self, batch_size, hidden_size, dtype):
"""Test output_both_sf_layouts=True with global_scale applied."""
from flashinfer.cute_dsl import add_rmsnorm_fp4quant

block_size = 16
eps = 1e-6
torch.manual_seed(42)

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Compute global_scale
global_scale = compute_global_scale(x, r, weight, eps=eps)

# Clone r since kernel modifies it in-place
r_clone = r.clone()

# Call with output_both_sf_layouts=True and global_scale
result = add_rmsnorm_fp4quant(
x,
r,
weight,
global_scale=global_scale,
eps=eps,
block_size=block_size,
output_both_sf_layouts=True,
)

assert len(result) == 3
y_fp4_both, block_scale_swizzled, block_scale_unswizzled = result

# Compare with is_sf_swizzled_layout=False (unswizzled only)
y_fp4_ref, block_scale_ref = add_rmsnorm_fp4quant(
x,
r_clone,
weight,
global_scale=global_scale,
eps=eps,
block_size=block_size,
is_sf_swizzled_layout=False,
)

# FP4 values should be identical
torch.testing.assert_close(
y_fp4_both.view(torch.uint8), y_fp4_ref.view(torch.uint8)
)

# Unswizzled SF should match
torch.testing.assert_close(
block_scale_unswizzled.view(torch.uint8),
block_scale_ref.view(torch.uint8),
)

@pytest.mark.parametrize("batch_size", [16, 128])
@pytest.mark.parametrize("hidden_size", [512, 1024])
def test_both_sf_layouts_with_preallocated_tensors(self, batch_size, hidden_size):
"""Test output_both_sf_layouts=True with pre-allocated output tensors."""
from flashinfer.cute_dsl import add_rmsnorm_fp4quant

block_size = 16
eps = 1e-6
dtype = torch.float16
torch.manual_seed(42)

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Pre-allocate output tensors
y_fp4 = torch.empty(
batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2
)

# Swizzled scale factors
factor = block_size * 4
num_m_tiles = (batch_size + 127) // 128
num_k_tiles = (hidden_size + factor - 1) // factor
swizzled_size = num_m_tiles * num_k_tiles * 32 * 4 * 4
block_scale_swizzled = torch.empty(
swizzled_size, device="cuda", dtype=torch.float8_e4m3fn
)

# Unswizzled scale factors
block_scale_unswizzled = torch.empty(
batch_size,
hidden_size // block_size,
device="cuda",
dtype=torch.float8_e4m3fn,
)

# Clone r for comparison
r_clone = r.clone()

# Call with pre-allocated tensors
result = add_rmsnorm_fp4quant(
x,
r,
weight,
y_fp4=y_fp4,
block_scale=block_scale_swizzled,
block_scale_unswizzled=block_scale_unswizzled,
eps=eps,
block_size=block_size,
output_both_sf_layouts=True,
)

assert len(result) == 3
y_fp4_out, block_scale_swizzled_out, block_scale_unswizzled_out = result

# Verify the returned tensors are the same as pre-allocated
assert y_fp4_out.data_ptr() == y_fp4.data_ptr()
assert block_scale_swizzled_out.data_ptr() == block_scale_swizzled.data_ptr()
assert (
block_scale_unswizzled_out.data_ptr() == block_scale_unswizzled.data_ptr()
)

# Compare with auto-allocated version
y_fp4_auto, block_scale_swizzled_auto, block_scale_unswizzled_auto = (
add_rmsnorm_fp4quant(
x,
r_clone,
weight,
eps=eps,
block_size=block_size,
output_both_sf_layouts=True,
)
)

# FP4 and unswizzled SF should be identical
torch.testing.assert_close(
y_fp4.view(torch.uint8), y_fp4_auto.view(torch.uint8)
)
torch.testing.assert_close(
block_scale_unswizzled.view(torch.uint8),
block_scale_unswizzled_auto.view(torch.uint8),
)

# For swizzled SF, compare by unswizzling (to avoid padding differences)
block_scale_from_swizzled = unswizzle_sf(
block_scale_swizzled.view(torch.uint8), batch_size, hidden_size, block_size
).view(torch.float8_e4m3fn)
block_scale_from_swizzled_auto = unswizzle_sf(
block_scale_swizzled_auto.view(torch.uint8),
batch_size,
hidden_size,
block_size,
).view(torch.float8_e4m3fn)
torch.testing.assert_close(
block_scale_from_swizzled.view(torch.uint8),
block_scale_from_swizzled_auto.view(torch.uint8),
)

@pytest.mark.parametrize("batch_size", [1, 16, 128])
@pytest.mark.parametrize("hidden_size", [16384, 32768])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_both_sf_layouts_large_hidden(self, batch_size, hidden_size, dtype):
"""Test output_both_sf_layouts=True with large hidden sizes (cluster sync path)."""
from flashinfer.cute_dsl import add_rmsnorm_fp4quant

block_size = 16
eps = 1e-6
torch.manual_seed(42)

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Sample first few rows for value comparison (full dequant is slow)
num_check = min(10, batch_size)
h_ref = x[:num_check] + r[:num_check]
ref_rmsnorm = llama_rms_norm(h_ref, weight, eps=eps)

# Clone r for comparison
r_clone1 = r.clone()

# Call with output_both_sf_layouts=True
result = add_rmsnorm_fp4quant(
x,
r,
weight,
eps=eps,
block_size=block_size,
output_both_sf_layouts=True,
)

assert len(result) == 3
y_fp4_both, block_scale_swizzled, block_scale_unswizzled = result

# Compare with separate call using is_sf_swizzled_layout=False
y_fp4_unswizzled, block_scale_ref_unswizzled = add_rmsnorm_fp4quant(
x,
r_clone1,
weight,
eps=eps,
block_size=block_size,
is_sf_swizzled_layout=False,
)

# FP4 values should be identical
torch.testing.assert_close(
y_fp4_both.view(torch.uint8), y_fp4_unswizzled.view(torch.uint8)
)

# Unswizzled scale factors should match
torch.testing.assert_close(
block_scale_unswizzled.view(torch.uint8),
block_scale_ref_unswizzled.view(torch.uint8),
)

# Verify swizzled SF by unswizzling and comparing
block_scale_from_swizzled = unswizzle_sf(
block_scale_swizzled.view(torch.uint8), batch_size, hidden_size, block_size
).view(torch.float8_e4m3fn)
torch.testing.assert_close(
block_scale_from_swizzled.view(torch.uint8),
block_scale_ref_unswizzled.view(torch.uint8),
)

# Verify dequantized values
y_dequant = dequantize_fp4_output(
y_fp4_both[:num_check], block_scale_unswizzled[:num_check], block_size
)
torch.testing.assert_close(
y_dequant,
ref_rmsnorm.float(),
rtol=0.3,
atol=0.5,
)

@pytest.mark.parametrize("batch_size", [16, 128])
@pytest.mark.parametrize("hidden_size", [512, 1024])
def test_both_sf_layouts_residual_inplace(self, batch_size, hidden_size):
"""Test that residual is updated in-place when output_both_sf_layouts=True."""
from flashinfer.cute_dsl import add_rmsnorm_fp4quant

block_size = 16
eps = 1e-6
dtype = torch.float16
torch.manual_seed(42)

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Store original residual
r_original = r.clone()
expected_residual = x + r_original

# Call kernel
y_fp4, block_scale_swizzled, block_scale_unswizzled = add_rmsnorm_fp4quant(
x, r, weight, eps=eps, block_size=block_size, output_both_sf_layouts=True
)

# Verify residual is updated in-place
torch.testing.assert_close(
r,
expected_residual,
rtol=0,
atol=0,
msg="Residual should be updated in-place with output_both_sf_layouts=True",
)

def test_is_sf_swizzled_layout_ignored_when_output_both(self):
"""
Test that is_sf_swizzled_layout is effectively ignored when output_both_sf_layouts=True.
Both True and False values should produce identical results.
"""
from flashinfer.cute_dsl import add_rmsnorm_fp4quant

batch_size = 64
hidden_size = 1024
block_size = 16
eps = 1e-6
dtype = torch.float16
torch.manual_seed(42)

x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Clone r for both calls
r_clone = r.clone()

# Call with is_sf_swizzled_layout=False and output_both_sf_layouts=True
result1 = add_rmsnorm_fp4quant(
x,
r,
weight,
eps=eps,
block_size=block_size,
is_sf_swizzled_layout=False,
output_both_sf_layouts=True,
)

# Call with is_sf_swizzled_layout=True and output_both_sf_layouts=True
result2 = add_rmsnorm_fp4quant(
x,
r_clone,
weight,
eps=eps,
block_size=block_size,
is_sf_swizzled_layout=True,
output_both_sf_layouts=True,
)

# Both should return 3 tensors
assert len(result1) == 3
assert len(result2) == 3

y_fp4_1, swizzled_1, unswizzled_1 = result1
y_fp4_2, swizzled_2, unswizzled_2 = result2

# FP4 outputs should be identical
torch.testing.assert_close(y_fp4_1.view(torch.uint8), y_fp4_2.view(torch.uint8))

# Unswizzled outputs should be identical
torch.testing.assert_close(
unswizzled_1.view(torch.uint8), unswizzled_2.view(torch.uint8)
)

# For swizzled outputs, compare by unswizzling (to avoid padding differences)
swizzled_1_unswizzled = unswizzle_sf(
swizzled_1.view(torch.uint8), batch_size, hidden_size, block_size
)
swizzled_2_unswizzled = unswizzle_sf(
swizzled_2.view(torch.uint8), batch_size, hidden_size, block_size
)
torch.testing.assert_close(swizzled_1_unswizzled, swizzled_2_unswizzled)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unused outputs in dual-SF tests trigger Ruff RUF059.
Mark the unused values with _ to keep lint clean.

🧹 Example fix (apply similarly in this class)
-        y_fp4, block_scale_swizzled, block_scale_unswizzled = add_rmsnorm_fp4quant(
+        _y_fp4, block_scale_swizzled, block_scale_unswizzled = add_rmsnorm_fp4quant(
             x,
             r,
             weight,
             eps=eps,
             block_size=block_size,
             output_both_sf_layouts=True,
         )
🧰 Tools
🪛 Ruff (0.14.13)

1798-1798: Unpacked variable y_fp4 is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1909-1909: Unpacked variable block_scale_swizzled is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2129-2129: Unpacked variable y_fp4 is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2129-2129: Unpacked variable block_scale_swizzled is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2129-2129: Unpacked variable block_scale_unswizzled is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In `@tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py` around lines 1621 - 2207,
Several tests in the TestOutputBothSFLayouts class unpack the 3-tuple returned
by add_rmsnorm_fp4quant but don't use all elements, which triggers Ruff RUF059;
update those unpackings to mark unused values with _ (or prefix variable names
with an underscore) so the linter ignores them. Locate all uses of
add_rmsnorm_fp4quant inside TestOutputBothSFLayouts (e.g., assignments like
y_fp4_both, block_scale_swizzled, block_scale_unswizzled = result or
result1/result2 unpackings) and replace unused targets with _ (for example
y_fp4, _, _ = result or y_fp4, swizzled, _ = result when only the first two are
used). Ensure preallocated-return checks also use _ for any output tensors not
inspected (e.g., when only y_fp4 is asserted), and keep references to
add_rmsnorm_fp4quant, TestOutputBothSFLayouts, and the specific variable names
to guide the edits.

@bkryu bkryu self-assigned this Jan 21, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)

2424-2431: Potential issue: In-place residual update may not work for non-contiguous tensors.

If residual is not contiguous, residual.view(B * S, H).contiguous() creates a copy, and the kernel's in-place modification won't propagate back to the original tensor. The docstring promises in-place modification.

🐛 Proposed fix to ensure in-place semantics
     if is_3d:
         B, S, H = input.shape
         input_2d = input.view(B * S, H).contiguous()
-        residual_2d = residual.view(B * S, H).contiguous()
+        # Ensure residual is contiguous before view to preserve in-place semantics
+        if not residual.is_contiguous():
+            raise ValueError(
+                "residual tensor must be contiguous for in-place update. "
+                "Call residual.contiguous() before passing if needed."
+            )
+        residual_2d = residual.view(B * S, H)
     else:
-        input_2d = input
-        residual_2d = residual
+        input_2d = input.contiguous()
+        if not residual.is_contiguous():
+            raise ValueError(
+                "residual tensor must be contiguous for in-place update. "
+                "Call residual.contiguous() before passing if needed."
+            )
+        residual_2d = residual

Also update line 2537:

     tensor_api(
         input_2d.contiguous(),
-        residual_2d.contiguous(),
+        residual_2d,  # Already verified contiguous
♻️ Duplicate comments (1)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)

1381-1414: Swizzled offset calculation is duplicated across multiple code paths.

The swizzled offset calculation logic (lines 1383-1395) is repeated in 8 places throughout the kernel. This was flagged in a previous review. Consider extracting this into a @cute.jit helper function for maintainability.

♻️ Suggested helper function
`@cute.jit`
def compute_swizzled_offset(
    actual_row_idx: Int32,
    sf_idx: Int32,
    num_k_tiles: int,
    k_tile_stride: int,
) -> Int32:
    """Compute swizzled scale factor offset using 128x4 tile pattern."""
    inner_k_idx = sf_idx % Int32(4)
    inner_m_idx = (actual_row_idx % Int32(128)) // Int32(32)
    outer_m_idx = actual_row_idx % Int32(32)
    k_tile_idx = sf_idx // Int32(4)
    m_tile_idx = actual_row_idx // Int32(128)
    m_tile_stride = num_k_tiles * k_tile_stride
    return (
        m_tile_idx * m_tile_stride
        + k_tile_idx * k_tile_stride
        + outer_m_idx * Int32(16)
        + inner_m_idx * Int32(4)
        + inner_k_idx
    )

Then replace all occurrences with:

swizzled_offset = compute_swizzled_offset(
    actual_row_idx, sf_idx, self.num_k_tiles, self.k_tile_stride
)
🧹 Nitpick comments (1)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)

2490-2505: Consider lazy allocation for block_scale_unswizzled to avoid unnecessary memory usage.

When output_both_sf_layouts=False, a full (batch_size, num_sf_blocks_per_row) tensor is allocated but never written to. For large batch sizes, this could be significant.

You could allocate a minimal 1-element tensor when output_both_sf_layouts=False to satisfy TVM-FFI validation while avoiding the memory cost:

♻️ Proposed optimization
     if block_scale_unswizzled is None:
-        if is_3d:
-            block_scale_unswizzled = torch.empty(
-                (B, S, num_sf_blocks_per_row),
-                dtype=scale_dtype,
-                device=input.device,
-            )
-        else:
-            block_scale_unswizzled = torch.empty(
-                (batch_size, num_sf_blocks_per_row),
-                dtype=scale_dtype,
-                device=input.device,
-            )
+        if output_both_sf_layouts:
+            if is_3d:
+                block_scale_unswizzled = torch.empty(
+                    (B, S, num_sf_blocks_per_row),
+                    dtype=scale_dtype,
+                    device=input.device,
+                )
+            else:
+                block_scale_unswizzled = torch.empty(
+                    (batch_size, num_sf_blocks_per_row),
+                    dtype=scale_dtype,
+                    device=input.device,
+                )
+        else:
+            # Minimal allocation for TVM-FFI validation when not used
+            block_scale_unswizzled = torch.empty(
+                (1, 1), dtype=scale_dtype, device=input.device
+            )

Note: This requires verifying that TVM-FFI accepts differently-shaped tensors at runtime vs. compilation time.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 21, 2026

@yongwww would you mind checking CI errors such as:

error during connect: Get "http://%2Fvar%2Frun%2Fdocker.sock/_ping": read unix @->/run/docker.sock: read: connection reset by peer

@yongwww
Copy link
Copy Markdown
Member

yongwww commented Jan 21, 2026

@yongwww would you mind checking CI errors such as:

error during connect: Get "http://%2Fvar%2Frun%2Fdocker.sock/_ping": read unix @->/run/docker.sock: read: connection reset by peer

I look at the log, the error is due to AWS reclaimed the G5 spot instance due to capacity needs. The Docker errors were symptoms of the instance dying mid-job. We can rerun the failed job as workaround atm. For a long-term solution, I am thinking about adding a retry logic (if the tests failed on a spot instance, then rerun it via on-demand instance)

@yongwww
Copy link
Copy Markdown
Member

yongwww commented Jan 22, 2026

ci is green now

@yzh119 yzh119 requested a review from nv-yunzheq as a code owner January 22, 2026 07:21
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 22, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !257 has been created, and the CI pipeline #42259044 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, it would be better if we could get people from SGLang/vLLM/TensorRT-LLM to review

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bkryu LGTM overall, a minor suggestion is to replace torch.testing.assert_close with torch.equal in case we expect two tensors to be identical.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Jan 22, 2026

Hi @bkryu LGTM overall, a minor suggestion is to replace torch.testing.assert_close with torch.equal in case we expect two tensors to be identical.

Thank @yzh119, this is a good suggestion. I converted some of the assert_close to equal for a tighter check. will re-launch unit tests to make sure everything looks right

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Jan 22, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !257 has been updated with latest changes, and the CI pipeline #42301990 is currently running. I'll report back once the pipeline job completes.

@bkryu bkryu force-pushed the add_rmsn_f4q_dual_sf branch from 3cca6af to 1a4cdbd Compare January 23, 2026 03:27
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Jan 23, 2026

/bot stop

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #42301990 has been cancelled.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Jan 23, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !257 has been updated with latest changes, and the CI pipeline #42322559 is currently running. I'll report back once the pipeline job completes.

@yzh119 yzh119 merged commit 446b2c5 into flashinfer-ai:main Jan 23, 2026
7 of 20 checks passed
@claude claude bot mentioned this pull request Jan 23, 2026
5 tasks
@bkryu bkryu deleted the add_rmsn_f4q_dual_sf branch January 29, 2026 22:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants