feat: Add output_both_sf_layouts option to add_rmsnorm_fp4quant API#2395
feat: Add output_both_sf_layouts option to add_rmsnorm_fp4quant API#2395yzh119 merged 3 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughThis PR adds an optional Changes
Sequence Diagram(s)sequenceDiagram
participant Host
participant Kernel as Device Kernel
participant Memory as Global Memory
Host->>Kernel: launch add_rmsnorm_fp4quant(..., output_both_sf_layouts=true, inputs)
rect rgba(100, 200, 150, 0.5)
Note over Kernel: Fused operations in kernel
Kernel->>Memory: read input, residual, weights
Kernel->>Kernel: h = input + residual (compute)
Kernel->>Memory: write h back to residual (in-place)
Kernel->>Kernel: apply RMSNorm to h
Kernel->>Kernel: quantize h -> y_fp4
Kernel->>Kernel: compute block_scale (swizzled)
Kernel->>Kernel: compute block_scale_unswizzled (unswizzled)
end
Kernel->>Memory: write y_fp4, block_scale, block_scale_unswizzled
Kernel->>Host: return (y_fp4, block_scale, block_scale_unswizzled)
Host->>Host: verify/consume outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request successfully introduces the output_both_sf_layouts option to the add_rmsnorm_fp4quant API, enabling the simultaneous output of swizzled and unswizzled scale factors. A key related change is the modification of the kernel to perform an in-place update on the residual tensor (residual = input + residual), which is now correctly reflected in the documentation, benchmarks, and tests. The test suite has been significantly expanded with new classes to thoroughly validate both the in-place update behavior and the new dual-layout output feature. The changes are well-implemented, but there is an opportunity to improve maintainability by refactoring duplicated code within the kernel.
| if cutlass.const_expr(self.output_both_sf_layouts): | ||
| # Output both swizzled and unswizzled scale factors | ||
| inner_k_idx = sf_idx % Int32(4) | ||
| inner_m_idx = (actual_row_idx % Int32(128)) // Int32(32) | ||
| outer_m_idx = actual_row_idx % Int32(32) | ||
| k_tile_idx = sf_idx // Int32(4) | ||
| m_tile_idx = actual_row_idx // Int32(128) | ||
| m_tile_stride = self.num_k_tiles * self.k_tile_stride | ||
| swizzled_offset = ( | ||
| m_tile_idx * m_tile_stride | ||
| + k_tile_idx * self.k_tile_stride | ||
| + outer_m_idx * Int32(16) | ||
| + inner_m_idx * Int32(4) | ||
| + inner_k_idx | ||
| ) | ||
| mS[swizzled_offset] = scale_fp8 | ||
| mS_unswizzled[actual_row_idx, sf_idx] = scale_fp8 | ||
| elif cutlass.const_expr(self.output_swizzled): |
There was a problem hiding this comment.
The logic for calculating swizzled_offset is duplicated in the if block here and the elif block below, and this pattern is repeated in three other places in the kernel. This harms maintainability.
To improve this, you could extract the swizzled_offset calculation into a @cute.jit helper function and call it from both branches. This would centralize the swizzle logic and make the kernel easier to read and modify.
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
2424-2545: Residual in-place update is silently broken for non-contiguous inputs in the 2D path.The function promises in-place modification of
residual(documented as "Modified in-place to containresidual + input"), but whenresidualis non-contiguous in the 2D case, calling.contiguous()before the kernel invocation creates a copy. The kernel modifies this copy, leaving the originalresidualunchanged and violating the API contract.The 3D path explicitly calls
.contiguous()on both tensors early (line 2427-2428), but the 2D path does not, creating inconsistent behavior. Add an explicit contiguity check to enforce the requirement:Proposed fix
is_3d = input.dim() == 3 + if not residual.is_contiguous(): + raise ValueError("residual must be contiguous for in-place update") if is_3d:
🤖 Fix all issues with AI agents
In `@tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py`:
- Around line 1621-2207: Several tests in the TestOutputBothSFLayouts class
unpack the 3-tuple returned by add_rmsnorm_fp4quant but don't use all elements,
which triggers Ruff RUF059; update those unpackings to mark unused values with _
(or prefix variable names with an underscore) so the linter ignores them. Locate
all uses of add_rmsnorm_fp4quant inside TestOutputBothSFLayouts (e.g.,
assignments like y_fp4_both, block_scale_swizzled, block_scale_unswizzled =
result or result1/result2 unpackings) and replace unused targets with _ (for
example y_fp4, _, _ = result or y_fp4, swizzled, _ = result when only the first
two are used). Ensure preallocated-return checks also use _ for any output
tensors not inspected (e.g., when only y_fp4 is asserted), and keep references
to add_rmsnorm_fp4quant, TestOutputBothSFLayouts, and the specific variable
names to guide the edits.
- Around line 1315-1620: Several test cases unpack the tuple returned by
add_rmsnorm_fp4quant but never use one or both outputs (e.g., in
TestResidualInPlaceUpdate:test_residual_inplace_update_2d,
test_residual_inplace_update_3d, test_residual_inplace_update_mxfp4,
test_residual_inplace_update_large_hidden,
test_residual_inplace_with_preallocated_outputs,
test_residual_inplace_swizzled_layout, test_residual_not_aliased_with_input).
Update those unpackings to mark unused values with a leading underscore (e.g.,
replace "y_fp4, block_scale =" with "_y_fp4, _block_scale =" or "_, block_scale
=" where only one is used) so the returned outputs from add_rmsnorm_fp4quant are
explicitly acknowledged as unused and silence Ruff RUF059.
🧹 Nitpick comments (2)
benchmarks/samples/sample_testlist.txt (1)
114-123: Consider adding a 3D input test case for comprehensive coverage.The existing test suite includes a 3D input shape test (line 112). For completeness, consider adding a test that combines
--output_both_sf_layoutswith 3D inputs:--routine add_rmsnorm_fp4quant --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 --output_both_sf_layouts -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_3d_both_sf"This would ensure the dual SF layout feature works correctly with both 2D and 3D input shapes.
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (1)
1315-1317: Preferflashinfer.utilscapability checks for skips.
Using the shared helpers keeps GPU-arch gating consistent across tests. As per coding guidelines, please switch toflashinfer.utilsskip helpers.Also applies to: 1621-1623
| @cute_dsl_available | ||
| @blackwell_required | ||
| class TestResidualInPlaceUpdate: | ||
| """Tests to verify that the residual tensor is updated in-place with input + residual.""" | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [1, 4, 16, 128, 512]) | ||
| @pytest.mark.parametrize("hidden_size", [256, 512, 1024, 4096]) | ||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| def test_residual_inplace_update_2d(self, batch_size, hidden_size, dtype): | ||
| """Test that residual is updated in-place for 2D input.""" | ||
| from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant | ||
|
|
||
| torch.manual_seed(42) | ||
| block_size = 16 | ||
| eps = 1e-6 | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Store original values before kernel call | ||
| r_original = r.clone() | ||
| expected_residual = x + r_original | ||
|
|
||
| # Call kernel - residual should be modified in-place | ||
| y_fp4, block_scale = add_rmsnorm_fp4quant( | ||
| x, r, weight, eps=eps, block_size=block_size | ||
| ) | ||
|
|
||
| # Verify residual is updated in-place to input + original_residual | ||
| torch.testing.assert_close( | ||
| r, | ||
| expected_residual, | ||
| rtol=0, | ||
| atol=0, | ||
| msg="Residual should be exactly input + original_residual", | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [1, 4, 16]) | ||
| @pytest.mark.parametrize("seq_len", [16, 64, 128]) | ||
| @pytest.mark.parametrize("hidden_size", [256, 1024, 4096]) | ||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| def test_residual_inplace_update_3d(self, batch_size, seq_len, hidden_size, dtype): | ||
| """Test that residual is updated in-place for 3D input.""" | ||
| from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant | ||
|
|
||
| torch.manual_seed(42) | ||
| block_size = 16 | ||
| eps = 1e-6 | ||
|
|
||
| x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Store original values before kernel call | ||
| r_original = r.clone() | ||
| expected_residual = x + r_original | ||
|
|
||
| # Call kernel - residual should be modified in-place | ||
| y_fp4, block_scale = add_rmsnorm_fp4quant( | ||
| x, r, weight, eps=eps, block_size=block_size | ||
| ) | ||
|
|
||
| # Verify residual is updated in-place to input + original_residual | ||
| torch.testing.assert_close( | ||
| r, | ||
| expected_residual, | ||
| rtol=0, | ||
| atol=0, | ||
| msg="Residual should be exactly input + original_residual", | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [1, 16, 128]) | ||
| @pytest.mark.parametrize("hidden_size", [256, 1024, 2048]) | ||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| def test_residual_inplace_update_mxfp4(self, batch_size, hidden_size, dtype): | ||
| """Test that residual is updated in-place for MXFP4 format (block_size=32).""" | ||
| from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant | ||
|
|
||
| torch.manual_seed(42) | ||
| block_size = 32 # MXFP4 | ||
| eps = 1e-6 | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Store original values before kernel call | ||
| r_original = r.clone() | ||
| expected_residual = x + r_original | ||
|
|
||
| # Call kernel - residual should be modified in-place | ||
| y_fp4, block_scale = add_rmsnorm_fp4quant( | ||
| x, r, weight, eps=eps, block_size=block_size, scale_format="ue8m0" | ||
| ) | ||
|
|
||
| # Verify residual is updated in-place to input + original_residual | ||
| torch.testing.assert_close( | ||
| r, | ||
| expected_residual, | ||
| rtol=0, | ||
| atol=0, | ||
| msg="Residual should be exactly input + original_residual", | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [1, 16, 128]) | ||
| @pytest.mark.parametrize("hidden_size", [16384, 32768]) | ||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| def test_residual_inplace_update_large_hidden(self, batch_size, hidden_size, dtype): | ||
| """Test residual in-place update with large hidden sizes (cluster sync path).""" | ||
| from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant | ||
|
|
||
| torch.manual_seed(42) | ||
| block_size = 16 | ||
| eps = 1e-6 | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Store original values before kernel call | ||
| r_original = r.clone() | ||
| expected_residual = x + r_original | ||
|
|
||
| # Call kernel - residual should be modified in-place | ||
| y_fp4, block_scale = add_rmsnorm_fp4quant( | ||
| x, r, weight, eps=eps, block_size=block_size | ||
| ) | ||
|
|
||
| # Verify residual is updated in-place to input + original_residual | ||
| torch.testing.assert_close( | ||
| r, | ||
| expected_residual, | ||
| rtol=0, | ||
| atol=0, | ||
| msg="Residual should be exactly input + original_residual", | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [16, 128]) | ||
| @pytest.mark.parametrize("hidden_size", [512, 1024]) | ||
| def test_residual_used_for_rmsnorm(self, batch_size, hidden_size): | ||
| """ | ||
| Test that the updated residual (input + original_residual) is used for RMSNorm. | ||
|
|
||
| This verifies the correct sequence of operations: | ||
| 1. residual = input + residual (in-place) | ||
| 2. output = RMSNorm(residual) * weight | ||
| 3. quantize output to FP4 | ||
| """ | ||
| from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant | ||
|
|
||
| torch.manual_seed(42) | ||
| block_size = 16 | ||
| eps = 1e-6 | ||
| dtype = torch.float16 | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Store original residual and compute expected values | ||
| r_original = r.clone() | ||
| expected_h = x + r_original | ||
| expected_rmsnorm = llama_rms_norm(expected_h, weight, eps=eps) | ||
|
|
||
| # Call kernel | ||
| y_fp4, block_scale = add_rmsnorm_fp4quant( | ||
| x, r, weight, eps=eps, block_size=block_size | ||
| ) | ||
|
|
||
| # Verify residual is updated | ||
| torch.testing.assert_close(r, expected_h, rtol=0, atol=0) | ||
|
|
||
| # Verify FP4 output matches RMSNorm of the updated residual | ||
| y_dequant = dequantize_fp4_output(y_fp4, block_scale, block_size) | ||
| assert_close_with_tiered_tolerance( | ||
| y_dequant, | ||
| expected_rmsnorm.float(), | ||
| tight_rtol=0.3, | ||
| tight_atol=0.5, | ||
| loose_rtol=0.5, | ||
| loose_atol=2.0, | ||
| tight_pct=0.99, | ||
| msg="FP4 output should match RMSNorm of updated residual", | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [16, 128]) | ||
| @pytest.mark.parametrize("hidden_size", [512, 1024]) | ||
| def test_residual_inplace_with_preallocated_outputs(self, batch_size, hidden_size): | ||
| """Test residual in-place update when using pre-allocated output tensors.""" | ||
| from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant | ||
|
|
||
| torch.manual_seed(42) | ||
| block_size = 16 | ||
| eps = 1e-6 | ||
| dtype = torch.float16 | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Pre-allocate output tensors | ||
| y_fp4 = torch.empty( | ||
| batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 | ||
| ) | ||
| block_scale = torch.empty( | ||
| batch_size, | ||
| hidden_size // block_size, | ||
| device="cuda", | ||
| dtype=torch.float8_e4m3fn, | ||
| ) | ||
|
|
||
| # Store original residual | ||
| r_original = r.clone() | ||
| expected_residual = x + r_original | ||
|
|
||
| # Call kernel with pre-allocated outputs | ||
| add_rmsnorm_fp4quant( | ||
| x, r, weight, y_fp4, block_scale, eps=eps, block_size=block_size | ||
| ) | ||
|
|
||
| # Verify residual is updated in-place | ||
| torch.testing.assert_close( | ||
| r, | ||
| expected_residual, | ||
| rtol=0, | ||
| atol=0, | ||
| msg="Residual should be updated in-place even with pre-allocated outputs", | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [16, 128]) | ||
| @pytest.mark.parametrize("hidden_size", [512, 1024]) | ||
| def test_residual_inplace_swizzled_layout(self, batch_size, hidden_size): | ||
| """Test residual in-place update with swizzled scale factor layout.""" | ||
| from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant | ||
|
|
||
| torch.manual_seed(42) | ||
| block_size = 16 | ||
| eps = 1e-6 | ||
| dtype = torch.float16 | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Store original residual | ||
| r_original = r.clone() | ||
| expected_residual = x + r_original | ||
|
|
||
| # Call kernel with swizzled layout | ||
| y_fp4, block_scale = add_rmsnorm_fp4quant( | ||
| x, r, weight, eps=eps, block_size=block_size, is_sf_swizzled_layout=True | ||
| ) | ||
|
|
||
| # Verify residual is updated in-place | ||
| torch.testing.assert_close( | ||
| r, | ||
| expected_residual, | ||
| rtol=0, | ||
| atol=0, | ||
| msg="Residual should be updated in-place with swizzled layout", | ||
| ) | ||
|
|
||
| def test_residual_not_aliased_with_input(self): | ||
| """Test that the kernel handles non-aliased input and residual correctly.""" | ||
| from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant | ||
|
|
||
| torch.manual_seed(42) | ||
| batch_size = 64 | ||
| hidden_size = 1024 | ||
| block_size = 16 | ||
| eps = 1e-6 | ||
| dtype = torch.float16 | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Ensure x and r are separate tensors (not views of each other) | ||
| assert x.data_ptr() != r.data_ptr() | ||
|
|
||
| # Store originals | ||
| x_original = x.clone() | ||
| r_original = r.clone() | ||
| expected_residual = x_original + r_original | ||
|
|
||
| # Call kernel | ||
| y_fp4, block_scale = add_rmsnorm_fp4quant( | ||
| x, r, weight, eps=eps, block_size=block_size | ||
| ) | ||
|
|
||
| # Verify x is unchanged | ||
| torch.testing.assert_close( | ||
| x, x_original, rtol=0, atol=0, msg="Input tensor should not be modified" | ||
| ) | ||
|
|
||
| # Verify r is updated | ||
| torch.testing.assert_close( | ||
| r, | ||
| expected_residual, | ||
| rtol=0, | ||
| atol=0, | ||
| msg="Residual should be updated in-place", | ||
| ) | ||
|
|
||
|
|
There was a problem hiding this comment.
Silence Ruff RUF059 by marking unused outputs.
Several tests unpack outputs they don’t use; prefix them with _ to avoid lint failures.
🧹 Example fix (apply similarly across this class)
- y_fp4, block_scale = add_rmsnorm_fp4quant(
+ _y_fp4, _block_scale = add_rmsnorm_fp4quant(
x, r, weight, eps=eps, block_size=block_size
)🧰 Tools
🪛 Ruff (0.14.13)
1340-1340: Unpacked variable y_fp4 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1340-1340: Unpacked variable block_scale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1374-1374: Unpacked variable y_fp4 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1374-1374: Unpacked variable block_scale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1407-1407: Unpacked variable y_fp4 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1407-1407: Unpacked variable block_scale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1440-1440: Unpacked variable y_fp4 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1440-1440: Unpacked variable block_scale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1565-1565: Unpacked variable y_fp4 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1565-1565: Unpacked variable block_scale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1602-1602: Unpacked variable y_fp4 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1602-1602: Unpacked variable block_scale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In `@tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py` around lines 1315 - 1620,
Several test cases unpack the tuple returned by add_rmsnorm_fp4quant but never
use one or both outputs (e.g., in
TestResidualInPlaceUpdate:test_residual_inplace_update_2d,
test_residual_inplace_update_3d, test_residual_inplace_update_mxfp4,
test_residual_inplace_update_large_hidden,
test_residual_inplace_with_preallocated_outputs,
test_residual_inplace_swizzled_layout, test_residual_not_aliased_with_input).
Update those unpackings to mark unused values with a leading underscore (e.g.,
replace "y_fp4, block_scale =" with "_y_fp4, _block_scale =" or "_, block_scale
=" where only one is used) so the returned outputs from add_rmsnorm_fp4quant are
explicitly acknowledged as unused and silence Ruff RUF059.
| @cute_dsl_available | ||
| @blackwell_required | ||
| class TestOutputBothSFLayouts: | ||
| """Tests for output_both_sf_layouts=True which returns both swizzled and unswizzled SFs.""" | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [1, 16, 128, 256]) | ||
| @pytest.mark.parametrize("hidden_size", [256, 512, 1024, 4096]) | ||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| def test_nvfp4_both_sf_layouts_basic(self, batch_size, hidden_size, dtype): | ||
| """ | ||
| Test that output_both_sf_layouts=True returns 3 tensors and both SFs are correct. | ||
| Uses NVFP4 format (block_size=16, E4M3 scales). | ||
| """ | ||
| from flashinfer.cute_dsl import add_rmsnorm_fp4quant | ||
|
|
||
| block_size = 16 | ||
| eps = 1e-6 | ||
| torch.manual_seed(42) | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Clone r since kernel modifies it in-place | ||
| r_clone1 = r.clone() | ||
|
|
||
| # Call with output_both_sf_layouts=True | ||
| result = add_rmsnorm_fp4quant( | ||
| x, | ||
| r, | ||
| weight, | ||
| eps=eps, | ||
| block_size=block_size, | ||
| output_both_sf_layouts=True, | ||
| ) | ||
|
|
||
| # Should return 3 tensors | ||
| assert len(result) == 3, f"Expected 3 tensors, got {len(result)}" | ||
| y_fp4_both, block_scale_swizzled, block_scale_unswizzled = result | ||
|
|
||
| # Verify shapes | ||
| assert y_fp4_both.shape == (batch_size, hidden_size // 2) | ||
| assert block_scale_unswizzled.shape == (batch_size, hidden_size // block_size) | ||
|
|
||
| # Swizzled layout should be 1D | ||
| factor = block_size * 4 | ||
| num_m_tiles = (batch_size + 127) // 128 | ||
| num_k_tiles = (hidden_size + factor - 1) // factor | ||
| expected_swizzled_size = num_m_tiles * num_k_tiles * 32 * 4 * 4 | ||
| assert block_scale_swizzled.shape == (expected_swizzled_size,) | ||
|
|
||
| # Verify dtypes | ||
| assert y_fp4_both.dtype == torch.float4_e2m1fn_x2 | ||
| assert block_scale_swizzled.dtype == torch.float8_e4m3fn | ||
| assert block_scale_unswizzled.dtype == torch.float8_e4m3fn | ||
|
|
||
| # Compare against separate calls with is_sf_swizzled_layout=False | ||
| y_fp4_unswizzled_only, block_scale_ref_unswizzled = add_rmsnorm_fp4quant( | ||
| x, | ||
| r_clone1, | ||
| weight, | ||
| eps=eps, | ||
| block_size=block_size, | ||
| is_sf_swizzled_layout=False, | ||
| ) | ||
|
|
||
| # FP4 values should be identical | ||
| torch.testing.assert_close( | ||
| y_fp4_both.view(torch.uint8), y_fp4_unswizzled_only.view(torch.uint8) | ||
| ) | ||
|
|
||
| # Unswizzled SF should match the reference unswizzled | ||
| torch.testing.assert_close( | ||
| block_scale_unswizzled.view(torch.uint8), | ||
| block_scale_ref_unswizzled.view(torch.uint8), | ||
| ) | ||
|
|
||
| # Verify swizzled SF is correct by unswizzling and comparing | ||
| block_scale_from_swizzled = unswizzle_sf( | ||
| block_scale_swizzled.view(torch.uint8), batch_size, hidden_size, block_size | ||
| ).view(torch.float8_e4m3fn) | ||
| torch.testing.assert_close( | ||
| block_scale_from_swizzled.view(torch.uint8), | ||
| block_scale_ref_unswizzled.view(torch.uint8), | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [1, 16, 128, 256]) | ||
| @pytest.mark.parametrize("hidden_size", [256, 512, 1024, 4096]) | ||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| def test_mxfp4_both_sf_layouts_basic(self, batch_size, hidden_size, dtype): | ||
| """ | ||
| Test that output_both_sf_layouts=True returns 3 tensors for MXFP4 format. | ||
| Uses MXFP4 format (block_size=32, UE8M0 scales). | ||
| """ | ||
| from flashinfer.cute_dsl import add_rmsnorm_fp4quant | ||
|
|
||
| block_size = 32 | ||
| eps = 1e-6 | ||
| torch.manual_seed(42) | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Clone r since kernel modifies it in-place | ||
| r_clone1 = r.clone() | ||
|
|
||
| # Call with output_both_sf_layouts=True | ||
| result = add_rmsnorm_fp4quant( | ||
| x, | ||
| r, | ||
| weight, | ||
| eps=eps, | ||
| block_size=block_size, | ||
| scale_format="ue8m0", | ||
| output_both_sf_layouts=True, | ||
| ) | ||
|
|
||
| # Should return 3 tensors | ||
| assert len(result) == 3, f"Expected 3 tensors, got {len(result)}" | ||
| y_fp4_both, block_scale_swizzled, block_scale_unswizzled = result | ||
|
|
||
| # Verify shapes | ||
| assert y_fp4_both.shape == (batch_size, hidden_size // 2) | ||
| assert block_scale_unswizzled.shape == (batch_size, hidden_size // block_size) | ||
|
|
||
| # Verify dtypes (UE8M0 uses uint8) | ||
| assert y_fp4_both.dtype == torch.float4_e2m1fn_x2 | ||
| assert block_scale_swizzled.dtype == torch.uint8 | ||
| assert block_scale_unswizzled.dtype == torch.uint8 | ||
|
|
||
| # Compare against separate calls with is_sf_swizzled_layout=False | ||
| y_fp4_unswizzled_only, block_scale_ref_unswizzled = add_rmsnorm_fp4quant( | ||
| x, | ||
| r_clone1, | ||
| weight, | ||
| eps=eps, | ||
| block_size=block_size, | ||
| scale_format="ue8m0", | ||
| is_sf_swizzled_layout=False, | ||
| ) | ||
|
|
||
| # FP4 values should be identical | ||
| torch.testing.assert_close( | ||
| y_fp4_both.view(torch.uint8), y_fp4_unswizzled_only.view(torch.uint8) | ||
| ) | ||
|
|
||
| # Unswizzled SF should match the reference unswizzled | ||
| torch.testing.assert_close(block_scale_unswizzled, block_scale_ref_unswizzled) | ||
|
|
||
| # Verify swizzled SF is correct by unswizzling and comparing | ||
| block_scale_from_swizzled = unswizzle_sf( | ||
| block_scale_swizzled, batch_size, hidden_size, block_size | ||
| ) | ||
| torch.testing.assert_close( | ||
| block_scale_from_swizzled, block_scale_ref_unswizzled | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [1, 16, 128]) | ||
| @pytest.mark.parametrize("hidden_size", [512, 1024, 4096]) | ||
| @pytest.mark.parametrize("dtype", [torch.float16]) | ||
| def test_both_sf_layouts_consistency(self, batch_size, hidden_size, dtype): | ||
| """ | ||
| Test that unswizzling the swizzled SF matches the unswizzled SF. | ||
| This verifies internal consistency of the dual output. | ||
| """ | ||
| from flashinfer.cute_dsl import add_rmsnorm_fp4quant | ||
|
|
||
| block_size = 16 | ||
| eps = 1e-6 | ||
| torch.manual_seed(42) | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Call with output_both_sf_layouts=True | ||
| y_fp4, block_scale_swizzled, block_scale_unswizzled = add_rmsnorm_fp4quant( | ||
| x, | ||
| r, | ||
| weight, | ||
| eps=eps, | ||
| block_size=block_size, | ||
| output_both_sf_layouts=True, | ||
| ) | ||
|
|
||
| # Unswizzle the swizzled SF | ||
| block_scale_unswizzled_from_swizzled = unswizzle_sf( | ||
| block_scale_swizzled.view(torch.uint8), batch_size, hidden_size, block_size | ||
| ).view(torch.float8_e4m3fn) | ||
|
|
||
| # Should match the directly returned unswizzled SF | ||
| torch.testing.assert_close( | ||
| block_scale_unswizzled_from_swizzled.view(torch.uint8), | ||
| block_scale_unswizzled.view(torch.uint8), | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [1, 4, 8]) | ||
| @pytest.mark.parametrize("seq_len", [16, 64, 128]) | ||
| @pytest.mark.parametrize("hidden_size", [256, 1024]) | ||
| @pytest.mark.parametrize("dtype", [torch.float16]) | ||
| def test_both_sf_layouts_3d_input(self, batch_size, seq_len, hidden_size, dtype): | ||
| """Test output_both_sf_layouts=True with 3D input tensors.""" | ||
| from flashinfer.cute_dsl import add_rmsnorm_fp4quant | ||
|
|
||
| block_size = 16 | ||
| eps = 1e-6 | ||
| torch.manual_seed(42) | ||
|
|
||
| x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Reference computation before kernel call | ||
| h_ref = x + r | ||
| ref_rmsnorm = llama_rms_norm(h_ref, weight, eps=eps) | ||
|
|
||
| # Call with output_both_sf_layouts=True | ||
| result = add_rmsnorm_fp4quant( | ||
| x, | ||
| r, | ||
| weight, | ||
| eps=eps, | ||
| block_size=block_size, | ||
| output_both_sf_layouts=True, | ||
| ) | ||
|
|
||
| assert len(result) == 3 | ||
| y_fp4, block_scale_swizzled, block_scale_unswizzled = result | ||
|
|
||
| # Verify shapes | ||
| assert y_fp4.shape == (batch_size, seq_len, hidden_size // 2) | ||
| assert block_scale_unswizzled.shape == ( | ||
| batch_size, | ||
| seq_len, | ||
| hidden_size // block_size, | ||
| ) | ||
|
|
||
| # Verify dtypes | ||
| assert y_fp4.dtype == torch.float4_e2m1fn_x2 | ||
| assert block_scale_swizzled.dtype == torch.float8_e4m3fn | ||
| assert block_scale_unswizzled.dtype == torch.float8_e4m3fn | ||
|
|
||
| # Dequantize using unswizzled SF and verify values | ||
| y_dequant = dequantize_fp4_output(y_fp4, block_scale_unswizzled, block_size) | ||
| assert_close_with_tiered_tolerance( | ||
| y_dequant, | ||
| ref_rmsnorm.float(), | ||
| tight_rtol=0.3, | ||
| tight_atol=0.5, | ||
| loose_rtol=0.5, | ||
| loose_atol=2.0, | ||
| tight_pct=0.99, | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [16, 128]) | ||
| @pytest.mark.parametrize("hidden_size", [512, 1024]) | ||
| @pytest.mark.parametrize("dtype", [torch.float16]) | ||
| def test_both_sf_layouts_with_global_scale(self, batch_size, hidden_size, dtype): | ||
| """Test output_both_sf_layouts=True with global_scale applied.""" | ||
| from flashinfer.cute_dsl import add_rmsnorm_fp4quant | ||
|
|
||
| block_size = 16 | ||
| eps = 1e-6 | ||
| torch.manual_seed(42) | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Compute global_scale | ||
| global_scale = compute_global_scale(x, r, weight, eps=eps) | ||
|
|
||
| # Clone r since kernel modifies it in-place | ||
| r_clone = r.clone() | ||
|
|
||
| # Call with output_both_sf_layouts=True and global_scale | ||
| result = add_rmsnorm_fp4quant( | ||
| x, | ||
| r, | ||
| weight, | ||
| global_scale=global_scale, | ||
| eps=eps, | ||
| block_size=block_size, | ||
| output_both_sf_layouts=True, | ||
| ) | ||
|
|
||
| assert len(result) == 3 | ||
| y_fp4_both, block_scale_swizzled, block_scale_unswizzled = result | ||
|
|
||
| # Compare with is_sf_swizzled_layout=False (unswizzled only) | ||
| y_fp4_ref, block_scale_ref = add_rmsnorm_fp4quant( | ||
| x, | ||
| r_clone, | ||
| weight, | ||
| global_scale=global_scale, | ||
| eps=eps, | ||
| block_size=block_size, | ||
| is_sf_swizzled_layout=False, | ||
| ) | ||
|
|
||
| # FP4 values should be identical | ||
| torch.testing.assert_close( | ||
| y_fp4_both.view(torch.uint8), y_fp4_ref.view(torch.uint8) | ||
| ) | ||
|
|
||
| # Unswizzled SF should match | ||
| torch.testing.assert_close( | ||
| block_scale_unswizzled.view(torch.uint8), | ||
| block_scale_ref.view(torch.uint8), | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [16, 128]) | ||
| @pytest.mark.parametrize("hidden_size", [512, 1024]) | ||
| def test_both_sf_layouts_with_preallocated_tensors(self, batch_size, hidden_size): | ||
| """Test output_both_sf_layouts=True with pre-allocated output tensors.""" | ||
| from flashinfer.cute_dsl import add_rmsnorm_fp4quant | ||
|
|
||
| block_size = 16 | ||
| eps = 1e-6 | ||
| dtype = torch.float16 | ||
| torch.manual_seed(42) | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Pre-allocate output tensors | ||
| y_fp4 = torch.empty( | ||
| batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 | ||
| ) | ||
|
|
||
| # Swizzled scale factors | ||
| factor = block_size * 4 | ||
| num_m_tiles = (batch_size + 127) // 128 | ||
| num_k_tiles = (hidden_size + factor - 1) // factor | ||
| swizzled_size = num_m_tiles * num_k_tiles * 32 * 4 * 4 | ||
| block_scale_swizzled = torch.empty( | ||
| swizzled_size, device="cuda", dtype=torch.float8_e4m3fn | ||
| ) | ||
|
|
||
| # Unswizzled scale factors | ||
| block_scale_unswizzled = torch.empty( | ||
| batch_size, | ||
| hidden_size // block_size, | ||
| device="cuda", | ||
| dtype=torch.float8_e4m3fn, | ||
| ) | ||
|
|
||
| # Clone r for comparison | ||
| r_clone = r.clone() | ||
|
|
||
| # Call with pre-allocated tensors | ||
| result = add_rmsnorm_fp4quant( | ||
| x, | ||
| r, | ||
| weight, | ||
| y_fp4=y_fp4, | ||
| block_scale=block_scale_swizzled, | ||
| block_scale_unswizzled=block_scale_unswizzled, | ||
| eps=eps, | ||
| block_size=block_size, | ||
| output_both_sf_layouts=True, | ||
| ) | ||
|
|
||
| assert len(result) == 3 | ||
| y_fp4_out, block_scale_swizzled_out, block_scale_unswizzled_out = result | ||
|
|
||
| # Verify the returned tensors are the same as pre-allocated | ||
| assert y_fp4_out.data_ptr() == y_fp4.data_ptr() | ||
| assert block_scale_swizzled_out.data_ptr() == block_scale_swizzled.data_ptr() | ||
| assert ( | ||
| block_scale_unswizzled_out.data_ptr() == block_scale_unswizzled.data_ptr() | ||
| ) | ||
|
|
||
| # Compare with auto-allocated version | ||
| y_fp4_auto, block_scale_swizzled_auto, block_scale_unswizzled_auto = ( | ||
| add_rmsnorm_fp4quant( | ||
| x, | ||
| r_clone, | ||
| weight, | ||
| eps=eps, | ||
| block_size=block_size, | ||
| output_both_sf_layouts=True, | ||
| ) | ||
| ) | ||
|
|
||
| # FP4 and unswizzled SF should be identical | ||
| torch.testing.assert_close( | ||
| y_fp4.view(torch.uint8), y_fp4_auto.view(torch.uint8) | ||
| ) | ||
| torch.testing.assert_close( | ||
| block_scale_unswizzled.view(torch.uint8), | ||
| block_scale_unswizzled_auto.view(torch.uint8), | ||
| ) | ||
|
|
||
| # For swizzled SF, compare by unswizzling (to avoid padding differences) | ||
| block_scale_from_swizzled = unswizzle_sf( | ||
| block_scale_swizzled.view(torch.uint8), batch_size, hidden_size, block_size | ||
| ).view(torch.float8_e4m3fn) | ||
| block_scale_from_swizzled_auto = unswizzle_sf( | ||
| block_scale_swizzled_auto.view(torch.uint8), | ||
| batch_size, | ||
| hidden_size, | ||
| block_size, | ||
| ).view(torch.float8_e4m3fn) | ||
| torch.testing.assert_close( | ||
| block_scale_from_swizzled.view(torch.uint8), | ||
| block_scale_from_swizzled_auto.view(torch.uint8), | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [1, 16, 128]) | ||
| @pytest.mark.parametrize("hidden_size", [16384, 32768]) | ||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| def test_both_sf_layouts_large_hidden(self, batch_size, hidden_size, dtype): | ||
| """Test output_both_sf_layouts=True with large hidden sizes (cluster sync path).""" | ||
| from flashinfer.cute_dsl import add_rmsnorm_fp4quant | ||
|
|
||
| block_size = 16 | ||
| eps = 1e-6 | ||
| torch.manual_seed(42) | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Sample first few rows for value comparison (full dequant is slow) | ||
| num_check = min(10, batch_size) | ||
| h_ref = x[:num_check] + r[:num_check] | ||
| ref_rmsnorm = llama_rms_norm(h_ref, weight, eps=eps) | ||
|
|
||
| # Clone r for comparison | ||
| r_clone1 = r.clone() | ||
|
|
||
| # Call with output_both_sf_layouts=True | ||
| result = add_rmsnorm_fp4quant( | ||
| x, | ||
| r, | ||
| weight, | ||
| eps=eps, | ||
| block_size=block_size, | ||
| output_both_sf_layouts=True, | ||
| ) | ||
|
|
||
| assert len(result) == 3 | ||
| y_fp4_both, block_scale_swizzled, block_scale_unswizzled = result | ||
|
|
||
| # Compare with separate call using is_sf_swizzled_layout=False | ||
| y_fp4_unswizzled, block_scale_ref_unswizzled = add_rmsnorm_fp4quant( | ||
| x, | ||
| r_clone1, | ||
| weight, | ||
| eps=eps, | ||
| block_size=block_size, | ||
| is_sf_swizzled_layout=False, | ||
| ) | ||
|
|
||
| # FP4 values should be identical | ||
| torch.testing.assert_close( | ||
| y_fp4_both.view(torch.uint8), y_fp4_unswizzled.view(torch.uint8) | ||
| ) | ||
|
|
||
| # Unswizzled scale factors should match | ||
| torch.testing.assert_close( | ||
| block_scale_unswizzled.view(torch.uint8), | ||
| block_scale_ref_unswizzled.view(torch.uint8), | ||
| ) | ||
|
|
||
| # Verify swizzled SF by unswizzling and comparing | ||
| block_scale_from_swizzled = unswizzle_sf( | ||
| block_scale_swizzled.view(torch.uint8), batch_size, hidden_size, block_size | ||
| ).view(torch.float8_e4m3fn) | ||
| torch.testing.assert_close( | ||
| block_scale_from_swizzled.view(torch.uint8), | ||
| block_scale_ref_unswizzled.view(torch.uint8), | ||
| ) | ||
|
|
||
| # Verify dequantized values | ||
| y_dequant = dequantize_fp4_output( | ||
| y_fp4_both[:num_check], block_scale_unswizzled[:num_check], block_size | ||
| ) | ||
| torch.testing.assert_close( | ||
| y_dequant, | ||
| ref_rmsnorm.float(), | ||
| rtol=0.3, | ||
| atol=0.5, | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("batch_size", [16, 128]) | ||
| @pytest.mark.parametrize("hidden_size", [512, 1024]) | ||
| def test_both_sf_layouts_residual_inplace(self, batch_size, hidden_size): | ||
| """Test that residual is updated in-place when output_both_sf_layouts=True.""" | ||
| from flashinfer.cute_dsl import add_rmsnorm_fp4quant | ||
|
|
||
| block_size = 16 | ||
| eps = 1e-6 | ||
| dtype = torch.float16 | ||
| torch.manual_seed(42) | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Store original residual | ||
| r_original = r.clone() | ||
| expected_residual = x + r_original | ||
|
|
||
| # Call kernel | ||
| y_fp4, block_scale_swizzled, block_scale_unswizzled = add_rmsnorm_fp4quant( | ||
| x, r, weight, eps=eps, block_size=block_size, output_both_sf_layouts=True | ||
| ) | ||
|
|
||
| # Verify residual is updated in-place | ||
| torch.testing.assert_close( | ||
| r, | ||
| expected_residual, | ||
| rtol=0, | ||
| atol=0, | ||
| msg="Residual should be updated in-place with output_both_sf_layouts=True", | ||
| ) | ||
|
|
||
| def test_is_sf_swizzled_layout_ignored_when_output_both(self): | ||
| """ | ||
| Test that is_sf_swizzled_layout is effectively ignored when output_both_sf_layouts=True. | ||
| Both True and False values should produce identical results. | ||
| """ | ||
| from flashinfer.cute_dsl import add_rmsnorm_fp4quant | ||
|
|
||
| batch_size = 64 | ||
| hidden_size = 1024 | ||
| block_size = 16 | ||
| eps = 1e-6 | ||
| dtype = torch.float16 | ||
| torch.manual_seed(42) | ||
|
|
||
| x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) | ||
| weight = torch.randn(hidden_size, device="cuda", dtype=dtype) | ||
|
|
||
| # Clone r for both calls | ||
| r_clone = r.clone() | ||
|
|
||
| # Call with is_sf_swizzled_layout=False and output_both_sf_layouts=True | ||
| result1 = add_rmsnorm_fp4quant( | ||
| x, | ||
| r, | ||
| weight, | ||
| eps=eps, | ||
| block_size=block_size, | ||
| is_sf_swizzled_layout=False, | ||
| output_both_sf_layouts=True, | ||
| ) | ||
|
|
||
| # Call with is_sf_swizzled_layout=True and output_both_sf_layouts=True | ||
| result2 = add_rmsnorm_fp4quant( | ||
| x, | ||
| r_clone, | ||
| weight, | ||
| eps=eps, | ||
| block_size=block_size, | ||
| is_sf_swizzled_layout=True, | ||
| output_both_sf_layouts=True, | ||
| ) | ||
|
|
||
| # Both should return 3 tensors | ||
| assert len(result1) == 3 | ||
| assert len(result2) == 3 | ||
|
|
||
| y_fp4_1, swizzled_1, unswizzled_1 = result1 | ||
| y_fp4_2, swizzled_2, unswizzled_2 = result2 | ||
|
|
||
| # FP4 outputs should be identical | ||
| torch.testing.assert_close(y_fp4_1.view(torch.uint8), y_fp4_2.view(torch.uint8)) | ||
|
|
||
| # Unswizzled outputs should be identical | ||
| torch.testing.assert_close( | ||
| unswizzled_1.view(torch.uint8), unswizzled_2.view(torch.uint8) | ||
| ) | ||
|
|
||
| # For swizzled outputs, compare by unswizzling (to avoid padding differences) | ||
| swizzled_1_unswizzled = unswizzle_sf( | ||
| swizzled_1.view(torch.uint8), batch_size, hidden_size, block_size | ||
| ) | ||
| swizzled_2_unswizzled = unswizzle_sf( | ||
| swizzled_2.view(torch.uint8), batch_size, hidden_size, block_size | ||
| ) | ||
| torch.testing.assert_close(swizzled_1_unswizzled, swizzled_2_unswizzled) |
There was a problem hiding this comment.
Unused outputs in dual-SF tests trigger Ruff RUF059.
Mark the unused values with _ to keep lint clean.
🧹 Example fix (apply similarly in this class)
- y_fp4, block_scale_swizzled, block_scale_unswizzled = add_rmsnorm_fp4quant(
+ _y_fp4, block_scale_swizzled, block_scale_unswizzled = add_rmsnorm_fp4quant(
x,
r,
weight,
eps=eps,
block_size=block_size,
output_both_sf_layouts=True,
)🧰 Tools
🪛 Ruff (0.14.13)
1798-1798: Unpacked variable y_fp4 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1909-1909: Unpacked variable block_scale_swizzled is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2129-2129: Unpacked variable y_fp4 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2129-2129: Unpacked variable block_scale_swizzled is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2129-2129: Unpacked variable block_scale_unswizzled is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In `@tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py` around lines 1621 - 2207,
Several tests in the TestOutputBothSFLayouts class unpack the 3-tuple returned
by add_rmsnorm_fp4quant but don't use all elements, which triggers Ruff RUF059;
update those unpackings to mark unused values with _ (or prefix variable names
with an underscore) so the linter ignores them. Locate all uses of
add_rmsnorm_fp4quant inside TestOutputBothSFLayouts (e.g., assignments like
y_fp4_both, block_scale_swizzled, block_scale_unswizzled = result or
result1/result2 unpackings) and replace unused targets with _ (for example
y_fp4, _, _ = result or y_fp4, swizzled, _ = result when only the first two are
used). Ensure preallocated-return checks also use _ for any output tensors not
inspected (e.g., when only y_fp4 is asserted), and keep references to
add_rmsnorm_fp4quant, TestOutputBothSFLayouts, and the specific variable names
to guide the edits.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
2424-2431: Potential issue: In-place residual update may not work for non-contiguous tensors.If
residualis not contiguous,residual.view(B * S, H).contiguous()creates a copy, and the kernel's in-place modification won't propagate back to the original tensor. The docstring promises in-place modification.🐛 Proposed fix to ensure in-place semantics
if is_3d: B, S, H = input.shape input_2d = input.view(B * S, H).contiguous() - residual_2d = residual.view(B * S, H).contiguous() + # Ensure residual is contiguous before view to preserve in-place semantics + if not residual.is_contiguous(): + raise ValueError( + "residual tensor must be contiguous for in-place update. " + "Call residual.contiguous() before passing if needed." + ) + residual_2d = residual.view(B * S, H) else: - input_2d = input - residual_2d = residual + input_2d = input.contiguous() + if not residual.is_contiguous(): + raise ValueError( + "residual tensor must be contiguous for in-place update. " + "Call residual.contiguous() before passing if needed." + ) + residual_2d = residualAlso update line 2537:
tensor_api( input_2d.contiguous(), - residual_2d.contiguous(), + residual_2d, # Already verified contiguous
♻️ Duplicate comments (1)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
1381-1414: Swizzled offset calculation is duplicated across multiple code paths.The swizzled offset calculation logic (lines 1383-1395) is repeated in 8 places throughout the kernel. This was flagged in a previous review. Consider extracting this into a
@cute.jithelper function for maintainability.♻️ Suggested helper function
`@cute.jit` def compute_swizzled_offset( actual_row_idx: Int32, sf_idx: Int32, num_k_tiles: int, k_tile_stride: int, ) -> Int32: """Compute swizzled scale factor offset using 128x4 tile pattern.""" inner_k_idx = sf_idx % Int32(4) inner_m_idx = (actual_row_idx % Int32(128)) // Int32(32) outer_m_idx = actual_row_idx % Int32(32) k_tile_idx = sf_idx // Int32(4) m_tile_idx = actual_row_idx // Int32(128) m_tile_stride = num_k_tiles * k_tile_stride return ( m_tile_idx * m_tile_stride + k_tile_idx * k_tile_stride + outer_m_idx * Int32(16) + inner_m_idx * Int32(4) + inner_k_idx )Then replace all occurrences with:
swizzled_offset = compute_swizzled_offset( actual_row_idx, sf_idx, self.num_k_tiles, self.k_tile_stride )
🧹 Nitpick comments (1)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
2490-2505: Consider lazy allocation forblock_scale_unswizzledto avoid unnecessary memory usage.When
output_both_sf_layouts=False, a full(batch_size, num_sf_blocks_per_row)tensor is allocated but never written to. For large batch sizes, this could be significant.You could allocate a minimal 1-element tensor when
output_both_sf_layouts=Falseto satisfy TVM-FFI validation while avoiding the memory cost:♻️ Proposed optimization
if block_scale_unswizzled is None: - if is_3d: - block_scale_unswizzled = torch.empty( - (B, S, num_sf_blocks_per_row), - dtype=scale_dtype, - device=input.device, - ) - else: - block_scale_unswizzled = torch.empty( - (batch_size, num_sf_blocks_per_row), - dtype=scale_dtype, - device=input.device, - ) + if output_both_sf_layouts: + if is_3d: + block_scale_unswizzled = torch.empty( + (B, S, num_sf_blocks_per_row), + dtype=scale_dtype, + device=input.device, + ) + else: + block_scale_unswizzled = torch.empty( + (batch_size, num_sf_blocks_per_row), + dtype=scale_dtype, + device=input.device, + ) + else: + # Minimal allocation for TVM-FFI validation when not used + block_scale_unswizzled = torch.empty( + (1, 1), dtype=scale_dtype, device=input.device + )Note: This requires verifying that TVM-FFI accepts differently-shaped tensors at runtime vs. compilation time.
|
@yongwww would you mind checking CI errors such as: error during connect: Get "http://%2Fvar%2Frun%2Fdocker.sock/_ping": read unix @->/run/docker.sock: read: connection reset by peer |
I look at the log, the error is due to AWS reclaimed the G5 spot instance due to capacity needs. The Docker errors were symptoms of the instance dying mid-job. We can rerun the failed job as workaround atm. For a long-term solution, I am thinking about adding a retry logic (if the tests failed on a spot instance, then rerun it via on-demand instance) |
|
ci is green now |
|
/bot run |
Thank @yzh119, this is a good suggestion. I converted some of the |
|
/bot run |
3cca6af to
1a4cdbd
Compare
|
/bot stop |
|
The GitLab CI pipeline #42301990 has been cancelled. |
|
/bot run |
📌 Description
Must be merged after #2385
This PR extends the
add_rmsnorm_fp4quantAPI to support outputting both swizzled and unswizzled scale factors simultaneously. This is useful for scenarios where the quantized output needs to be consumed by both GEMMs (experts) and All-to-All without requiring a separate layout conversion pass.When
output_both_sf_layouts=True, the function returns a 3-tuple(y_fp4, block_scale_swizzled, block_scale_unswizzled)instead of the standard 2-tuple. This flag overridesis_sf_swizzled_layoutwhen set.Changes Summary
output_both_sf_layoutsandblock_scale_unswizzledparameters; updated kernel to write both SF layoutsTestOutputBothSFLayoutstest class with 10 test methods covering NVFP4/MXFP4, 2D/3D inputs, auto/pre-allocation, and large hidden sizes--output_both_sf_layoutsflag; adjusted bandwidth calculation to account for 2× SF writes🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes / UX
Tests
✏️ Tip: You can customize this high-level summary in your review settings.