-
Notifications
You must be signed in to change notification settings - Fork 333
[Enhancement] Enhance Cast operations Vectorization #1156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughExpanded CUDA CastNode codegen to add explicit 8-lane vector conversion paths for multiple type pairs and removed emission of device runtime asserts; added deterministic tie-break and broader vectorization trigger in layout inference; added a parallel vectorized cast test kernel and adjusted tensor shape annotations. Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test runner
participant JIT as @tilelang.jit
participant KernelA as kernel (vectorized_cast_kernel)
participant KernelB as kernel_parallel (parallel_vectorized_cast_kernel)
participant Device as Device runtime
Test->>JIT: compile KernelA
Test->>JIT: compile KernelB
JIT-->>Test: code_A, binary_A
JIT-->>Test: code_B, binary_B
Test->>Device: launch binary_A with A -> B
Device-->>Test: B_out
Test->>Device: launch binary_B with A -> C
Device-->>Test: C_out
Test->>Test: compare B_out == expected && C_out == expected
Note over Test,Device: assert both code_A and code_B contain vectorization checks
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/layout_inference.cc (1)
645-671: Fix typo in fragment buffer scope checks ("local.framgent" → "local.fragment").The typo disables the intended ICHECK for fragment buffer layouts. Fix all 3 occurrences in the documentation and code condition:
* @brief Visit and mutate a Block node to attach inferred layout information. * * Converts the visited Block via the base visitor, asserts that every buffer - * allocated with scope "local.framgent" has an inferred layout in + * allocated with scope "local.fragment" has an inferred layout in * result_.layout_map, and attaches result_.layout_map to the Block's * annotations under attr::kLayoutMap. * - * If any "local.framgent" buffer lacks an entry in result_.layout_map an + * If any "local.fragment" buffer lacks an entry in result_.layout_map an * ICHECK will fail with the offending buffer printed. * * @return Stmt The (possibly modified) Block statement with the layout-map * annotation set. */ Stmt VisitStmt_(const BlockNode *op) final { Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op)); for (auto buffer : block->alloc_buffers) { - if (buffer.scope() == "local.framgent") { + if (buffer.scope() == "local.fragment") { ICHECK(result_.layout_map.count(buffer)) << "Cannot inference fragment layout for " << buffer; } }
🧹 Nitpick comments (5)
src/transform/layout_inference.cc (2)
600-603: Deterministic tie-breaker: good addition; consider documenting.Choosing the smaller attempt_infer_root on equal register usage makes selection stable. Please add a brief comment noting this deterministic tie-break. Optionally sort members before iteration for readability, though not required.
- // Update the best plan if this one uses fewer registers + // Update the best plan if this one uses fewer registers. + // Tie-break deterministically by smaller attempt_infer_root. if (reg_num < min_reg_num || (reg_num == min_reg_num && attempt_infer_root < min_reg_num_infer_root)) {
792-804: Cast-triggered vectorization: broaden detection beyond top-level stores.Current check only catches a CastNode as the direct BufferStore value; nested casts (e.g., within Select/Min) are missed. Scan for any CastNode in the loop body instead.
- bool has_cast_operations = false; - PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { - if (const auto* store = obj.as<BufferStoreNode>()) { - // Check if this is a non-reducer store with Cast operation - if (store->value.as<CastNode>()) { - has_cast_operations = true; - } - } - }); + bool has_cast_operations = false; + PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { + if (!has_cast_operations && obj.as<CastNode>()) { + has_cast_operations = true; + } + });Add/adjust a unit test with an expression like B[i] = cast(f(min(A[i], 1.0))) to ensure vectorization still triggers.
src/target/codegen_cuda.cc (1)
923-937: 8‑lane vectorized cast paths: solid, mirror existing 2/4‑lane logic; add tests.Implementations for half8↔float8, bf16x8↔float8, and fp32x8→fp8x8 look correct and consistent with 2/4‑lane patterns. Please:
- Add tests that exercise lanes=8 for these paths to lock behavior.
- Optional: assert or comment on alignment/aliasing assumptions for the float2/half2 and bfloat162 pointer casts.
Add cases like:
- run_vectorized_cast("float16","float32","__half22float2",8)
- run_vectorized_cast("float32","float16","__float22half2_rn",8)
- run_vectorized_cast("bfloat16","float32","__bfloat1622float2",8)
- run_vectorized_cast("float32","bfloat16","__float22bfloat162_rn",8)
- run_vectorized_cast("float32","float8_e4m3","__nv_cvt_float2_to_fp8x2",8)
- run_vectorized_cast("float32","float8_e5m2","__nv_cvt_float2_to_fp8x2",8)
Consider adding a brief comment near each 8-lane branch noting aliasing relies on NVCC’s permissiveness, mirroring the 4‑lane patterns.
Also applies to: 959-973, 1001-1019, 1041-1055, 1091-1117
testing/python/language/test_tilelang_language_vectorized_cast.py (2)
29-48: Parallel kernel OK; consider covering true local→local as well.This exercises local.fragment↔local.fragment. To fully validate the “Local‑to‑Local” vectorization fix, add a variant using scope="local".
+@tilelang.jit +def parallel_vectorized_cast_kernel_local(M: int, dtype_A: str, dtype_B: str): + assert M % 256 == 0 + @T.prim_func + def main( + A: T.Tensor[(M,), dtype_A], + B: T.Tensor[(M,), dtype_B], + ): + with T.Kernel(1, threads=128): + A_local = T.alloc_buffer((M,), dtype_A, scope="local") + B_local = T.alloc_buffer((M,), dtype_B, scope="local") + T.copy(A, A_local) + for i in T.Parallel(M): + B_local[i] = A_local[i] + T.copy(B_local, B) + return mainCall this kernel in run_vectorized_cast and assert parity with the others.
59-77: Add lanes=8 cases to lock new 256‑bit vectorization paths.Extend tests to cover the new 8‑lane branches and check the same vectorization strings.
def test_vectorized_cast(): # fp32 -> fp16 run_vectorized_cast("float32", "float16", "__float22half2_rn", 2) run_vectorized_cast("float32", "float16", "__float22half2_rn", 4) + run_vectorized_cast("float32", "float16", "__float22half2_rn", 8) # fp16 -> fp32 run_vectorized_cast("float16", "float32", "__half22float2", 2) run_vectorized_cast("float16", "float32", "__half22float2", 4) + run_vectorized_cast("float16", "float32", "__half22float2", 8) # fp32 -> fp8_e4m3 run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2) run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4) + run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 8) # fp32 -> fp8_e5m2 run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2) run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4) + run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 8) # fp32 -> bf16 run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 2) run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 4) + run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 8) # bf16 -> fp32 run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 2) run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 4) + run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 8)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/target/codegen_cuda.cc(5 hunks)src/transform/layout_inference.cc(2 hunks)testing/python/language/test_tilelang_language_vectorized_cast.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (4)
tilelang/language/copy.py (1)
copy(11-87)tilelang/jit/__init__.py (1)
jit(233-306)tilelang/language/allocate.py (1)
alloc_fragment(59-70)tilelang/language/parallel.py (1)
Parallel(9-29)
🪛 Ruff (0.14.2)
testing/python/language/test_tilelang_language_vectorized_cast.py
20-20: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
21-21: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
35-35: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
36-36: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
| A: T.Tensor[(M,), dtype_A], # noqa: F821 | ||
| B: T.Tensor[(M,), dtype_B], # noqa: F821 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused noqa F821 directives.
They’re unnecessary with (M,) annotations and trip Ruff.
- A: T.Tensor[(M,), dtype_A], # noqa: F821
- B: T.Tensor[(M,), dtype_B], # noqa: F821
+ A: T.Tensor[(M,), dtype_A],
+ B: T.Tensor[(M,), dtype_B],Apply the same change in the parallel kernel.
Also applies to: 35-36
🧰 Tools
🪛 Ruff (0.14.2)
20-20: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
21-21: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_vectorized_cast.py around
lines 20-21 (and also lines 35-36), remove the redundant "# noqa: F821"
annotations on the T.Tensor[(M,), dtype_...] lines — they are unnecessary with
the (M,) annotations and duplicate Ruff coverage; do the same edit in the
parallel kernel test file equivalent locations as noted.
* Enhance Cast vectorized * Add Parallel vectorized cast test * code lint * merge newest commit
Changes:
Summary by CodeRabbit
Improvements
Tests