Skip to content

[CUDA] Enhance Broadcast Codegen for Symbolic Value#1669

Merged
LeiWang1999 merged 5 commits intotile-ai:mainfrom
LeiWang1999:cuda_0114
Jan 14, 2026
Merged

[CUDA] Enhance Broadcast Codegen for Symbolic Value#1669
LeiWang1999 merged 5 commits intotile-ai:mainfrom
LeiWang1999:cuda_0114

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Jan 14, 2026

Enhance CUDA code generation for BroadcastNode by implementing compil…e-time constant folding and runtime broadcasting for various lane configurations. Improved handling for 4-bit and 8-bit integer types, ensuring correct replication and type casting in output expressions. This update increases performance and correctness in CUDA kernel generation.

Summary by CodeRabbit

  • Improvements

    • Refined CUDA int8/4-bit broadcast handling to centralize constant-path logic, preserving prior behavior while improving consistency and efficiency.
    • Increased diagnostic verbosity for non-constant broadcast values to make failures clearer.
  • Tests

    • Added CUDA-specific vectorized broadcast test coverage exercising the new broadcast paths.

✏️ Tip: You can customize this high-level summary in your review settings.

…e-time constant folding and runtime broadcasting for various lane configurations. Improved handling for 4-bit and 8-bit integer types, ensuring correct replication and type casting in output expressions. This update increases performance and correctness in CUDA kernel generation.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 14, 2026

📝 Walkthrough

Walkthrough

Refactors CUDA codegen broadcast handling to centralize constant-pointer retrieval and constant-folding paths for 8-bit and 4-bit broadcasts, adds a diagnostic message when a BroadcastNode value is not constant, and adds a CUDA-specific int8 vectorize broadcast test and its parametrized runner.

Changes

Cohort / File(s) Summary
CUDA codegen change
src/target/codegen_cuda.cc
Centralizes retrieval of the constant pointer p before lane checks; preserves and routes lanes == 4 and lanes == 32 through constant-folding paths using the pre-check; adds an ICHECK diagnostic message for non-constant BroadcastNode values in 4-bit integer handling; minor redundant re-acquire of p inside the 32-lane branch.
New tests
testing/python/language/test_tilelang_language_vectorize.py
Adds vectorize_broadcast_int8(vec_num) (tilelang.jit) that performs a vectorized int8 broadcast and a parametrized test_vectorize_broadcast_int8(vec_num) decorated for CUDA to compile and exercise the kernel.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Poem

🐇 I checked the bytes before lanes take flight,
Folded constants snugly, tucked in tight.
If values wander, a message I send,
Tests hop in CUDA to dance and defend. 🥕✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[CUDA] Enhance Broadcast Codegen for Symbolic Value' accurately captures the main focus of the changes—improving CUDA broadcast code generation with enhanced handling for symbolic/non-constant values.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

🧹 Recent nitpick comments
src/target/codegen_cuda.cc (1)

3196-3224: Redundant variable re-declaration and checks inside guarded block.

Within the if (p) block (line 3197), both branches contain redundant code:

  • Line 3200: ICHECK(p) is unnecessary since p is guaranteed non-null by the enclosing if.
  • Lines 3211-3212: Re-declaring const int64_t *p = as_const_int(op->value) shadows the outer p and is redundant since we're already inside the if (p) block.
♻️ Suggested cleanup
   if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8) {
     const int64_t *p = as_const_int(op->value);
     if (p) {
       if (lanes == 4) {
         // make_int8x4
-        ICHECK(p);
         int64_t v = *p & 0xFF;
         v = (v << 24) | (v << 16) | (v << 8) | v;
         if (op->dtype.is_uint()) {
           os << "(uint)" << v;
         } else {
           os << "(int)" << v;
         }
         return;
       } else if (lanes == 32) {
         // make_int8x32
-        const int64_t *p = as_const_int(op->value);
-        ICHECK(p);
         int64_t v = *p & 0xFF;
         v = (v << 24) | (v << 16) | (v << 8) | v;

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3e707a6 and 6cc682a.

📒 Files selected for processing (1)
  • src/target/codegen_cuda.cc
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
🔇 Additional comments (1)
src/target/codegen_cuda.cc (1)

3288-3290: Good addition of diagnostic message for non-constant 4-bit broadcasts.

The added diagnostic message improves debuggability by clearly indicating when a 4-bit BroadcastNode value fails the constant requirement.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@testing/python/language/test_tilelang_language_vectorize.py`:
- Around line 151-159: The test function vectorize_broadcast_int8 is missing the
inner `@T.prim_func` decorator and the final return statement for the generated
prim func; update the definition so that the inner function that contains the
Kernel is annotated with `@T.prim_func` (matching the pattern used by
vectorize_test_all_dtypes) and ensure the outer `@tilelang.jit-decorated` function
returns that prim func object at the end.
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_vectorize.py (1)

161-165: Test only validates compilation, not runtime correctness.

The test calls .compile() but doesn't actually execute the kernel or verify that the broadcast produces correct results. While this may be intentional to test code generation (similar to the retrieved learning about pattern checking), consider adding a comment explaining the test's purpose, or adding a simple runtime check if feasible.

Based on learnings, pattern-based testing for transformations is acceptable when the goal is to validate code generation paths rather than numerical correctness.

📝 Suggested improvement for clarity
 `@tilelang.testing.requires_cuda`
 `@pytest.mark.parametrize`("vec_num", [4, 32])
 def test_vectorize_broadcast_int8(vec_num):
-    """Test broadcasting a non-constant int8 value to a vectorized store."""
-    vectorize_broadcast_int8.compile(vec_num=vec_num)
+    """Test that non-constant int8 broadcast generates valid CUDA code.
+    
+    This test validates the codegen path for runtime int8 broadcast
+    (not compile-time constant folding) by ensuring compilation succeeds.
+    """
+    kernel = vectorize_broadcast_int8(vec_num)
+    # Optionally verify the generated code contains the expected broadcast pattern
+    code = kernel.get_kernel_source()
+    assert "0x01010101" in code, "Expected runtime byte replication pattern"
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 732971a and 3e707a6.

📒 Files selected for processing (2)
  • src/target/codegen_cuda.cc
  • testing/python/language/test_tilelang_language_vectorize.py
🧰 Additional context used
🧠 Learnings (5)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • testing/python/language/test_tilelang_language_vectorize.py
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/language/test_tilelang_language_vectorize.py
📚 Learning: 2025-11-03T06:24:11.411Z
Learnt from: Rachmanino
Repo: tile-ai/tilelang PR: 1175
File: src/op/math.cc:44-52
Timestamp: 2025-11-03T06:24:11.411Z
Learning: In tilelang's `src/op/math.cc`, the `infinity_op` function uses `std::numeric_limits<float>::infinity()` as a placeholder for all float types (including float64 and bfloat16). The codegen layer (PrintConst:Inf) handles the correct infinity value based on the dtype field of the FloatImm node, so the specific C++ template argument doesn't matter.

Applied to files:

  • src/target/codegen_cuda.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/target/codegen_cuda.cc
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_vectorize.py (3)
tilelang/language/kernel.py (1)
  • threads (215-219)
tilelang/language/allocate.py (1)
  • alloc_local (57-68)
testing/python/language/test_tilelang_language_alloc.py (1)
  • alloc_var (5-22)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (3)
src/target/codegen_cuda.cc (3)

3199-3217: LGTM - Compile-time constant folding and runtime broadcast for int8x4.

The implementation correctly:

  • Folds compile-time constants by replicating the byte value across all 4 positions
  • Uses runtime multiplication by 0x01010101 to broadcast a single byte to all 4 positions
  • Handles both signed and unsigned variants appropriately

3222-3248: Runtime lambda captures by reference but may evaluate val multiple times.

The lambda approach is reasonable for complex construction, but using [&]() with val (which is PrintExpr(op->value)) already evaluated as a string should be safe here. The pattern correctly builds a 64-bit value from the byte and replicates it across all 4 components of the longlong4.


3316-3389: 4-bit integer broadcast implementation looks correct.

The constant folding and runtime broadcast logic for 4-bit integers correctly handles:

  • 4 lanes: 16-bit result with nibble replication via 0x1111
  • 8 lanes: 32-bit result with nibble replication via 0x11111111
  • 16/32 lanes: Multiple 32-bit values via lambda construction

One minor observation: the lambda for lanes 16/32 at line 3388 has a semicolon inside the parentheses (); }())), which is syntactically valid but slightly unusual. The generated CUDA code will work correctly.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

@LeiWang1999 LeiWang1999 merged commit 2d8d367 into tile-ai:main Jan 14, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant