[CUDA] Enhance Broadcast Codegen for Symbolic Value#1669
[CUDA] Enhance Broadcast Codegen for Symbolic Value#1669LeiWang1999 merged 5 commits intotile-ai:mainfrom
Conversation
…e-time constant folding and runtime broadcasting for various lane configurations. Improved handling for 4-bit and 8-bit integer types, ensuring correct replication and type casting in output expressions. This update increases performance and correctness in CUDA kernel generation.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughRefactors CUDA codegen broadcast handling to centralize constant-pointer retrieval and constant-folding paths for 8-bit and 4-bit broadcasts, adds a diagnostic message when a BroadcastNode value is not constant, and adds a CUDA-specific int8 vectorize broadcast test and its parametrized runner. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧹 Recent nitpick comments
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧠 Learnings (1)📓 Common learnings🔇 Additional comments (1)
✏️ Tip: You can disable this entire section by setting Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@testing/python/language/test_tilelang_language_vectorize.py`:
- Around line 151-159: The test function vectorize_broadcast_int8 is missing the
inner `@T.prim_func` decorator and the final return statement for the generated
prim func; update the definition so that the inner function that contains the
Kernel is annotated with `@T.prim_func` (matching the pattern used by
vectorize_test_all_dtypes) and ensure the outer `@tilelang.jit-decorated` function
returns that prim func object at the end.
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_vectorize.py (1)
161-165: Test only validates compilation, not runtime correctness.The test calls
.compile()but doesn't actually execute the kernel or verify that the broadcast produces correct results. While this may be intentional to test code generation (similar to the retrieved learning about pattern checking), consider adding a comment explaining the test's purpose, or adding a simple runtime check if feasible.Based on learnings, pattern-based testing for transformations is acceptable when the goal is to validate code generation paths rather than numerical correctness.
📝 Suggested improvement for clarity
`@tilelang.testing.requires_cuda` `@pytest.mark.parametrize`("vec_num", [4, 32]) def test_vectorize_broadcast_int8(vec_num): - """Test broadcasting a non-constant int8 value to a vectorized store.""" - vectorize_broadcast_int8.compile(vec_num=vec_num) + """Test that non-constant int8 broadcast generates valid CUDA code. + + This test validates the codegen path for runtime int8 broadcast + (not compile-time constant folding) by ensuring compilation succeeds. + """ + kernel = vectorize_broadcast_int8(vec_num) + # Optionally verify the generated code contains the expected broadcast pattern + code = kernel.get_kernel_source() + assert "0x01010101" in code, "Expected runtime byte replication pattern"
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/target/codegen_cuda.cctesting/python/language/test_tilelang_language_vectorize.py
🧰 Additional context used
🧠 Learnings (5)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
testing/python/language/test_tilelang_language_vectorize.py
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/language/test_tilelang_language_vectorize.py
📚 Learning: 2025-11-03T06:24:11.411Z
Learnt from: Rachmanino
Repo: tile-ai/tilelang PR: 1175
File: src/op/math.cc:44-52
Timestamp: 2025-11-03T06:24:11.411Z
Learning: In tilelang's `src/op/math.cc`, the `infinity_op` function uses `std::numeric_limits<float>::infinity()` as a placeholder for all float types (including float64 and bfloat16). The codegen layer (PrintConst:Inf) handles the correct infinity value based on the dtype field of the FloatImm node, so the specific C++ template argument doesn't matter.
Applied to files:
src/target/codegen_cuda.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
src/target/codegen_cuda.cc
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_vectorize.py (3)
tilelang/language/kernel.py (1)
threads(215-219)tilelang/language/allocate.py (1)
alloc_local(57-68)testing/python/language/test_tilelang_language_alloc.py (1)
alloc_var(5-22)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (3)
src/target/codegen_cuda.cc (3)
3199-3217: LGTM - Compile-time constant folding and runtime broadcast for int8x4.The implementation correctly:
- Folds compile-time constants by replicating the byte value across all 4 positions
- Uses runtime multiplication by
0x01010101to broadcast a single byte to all 4 positions- Handles both signed and unsigned variants appropriately
3222-3248: Runtime lambda captures by reference but may evaluatevalmultiple times.The lambda approach is reasonable for complex construction, but using
[&]()withval(which isPrintExpr(op->value)) already evaluated as a string should be safe here. The pattern correctly builds a 64-bit value from the byte and replicates it across all 4 components of the longlong4.
3316-3389: 4-bit integer broadcast implementation looks correct.The constant folding and runtime broadcast logic for 4-bit integers correctly handles:
- 4 lanes: 16-bit result with nibble replication via
0x1111- 8 lanes: 32-bit result with nibble replication via
0x11111111- 16/32 lanes: Multiple 32-bit values via lambda construction
One minor observation: the lambda for lanes 16/32 at line 3388 has a semicolon inside the parentheses (
); }())), which is syntactically valid but slightly unusual. The generated CUDA code will work correctly.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
Enhance CUDA code generation for BroadcastNode by implementing compil…e-time constant folding and runtime broadcasting for various lane configurations. Improved handling for 4-bit and 8-bit integer types, ensuring correct replication and type casting in output expressions. This update increases performance and correctness in CUDA kernel generation.
Summary by CodeRabbit
Improvements
Tests
✏️ Tip: You can customize this high-level summary in your review settings.