Skip to content

Conversation

@LJC00118
Copy link
Collaborator

@LJC00118 LJC00118 commented Oct 29, 2025

Changes:

  • Fixed T.parallel Vectorization from Local to Local
  • Support 256-bits Vectorization for Cast operations

Summary by CodeRabbit

  • Improvements

    • Expanded CUDA vector conversions to support 8-lane operations across more numeric type pairs.
    • Added a deterministic tie-break in layout inference and broadened vectorization detection for additional optimization cases.
  • Tests

    • Added a parallel vectorized cast kernel test and strengthened validation to compare multiple vectorized implementations.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 29, 2025

Walkthrough

Expanded CUDA CastNode codegen to add explicit 8-lane vector conversion paths for multiple type pairs and removed emission of device runtime asserts; added deterministic tie-break and broader vectorization trigger in layout inference; added a parallel vectorized cast test kernel and adjusted tensor shape annotations.

Changes

Cohort / File(s) Summary
CUDA codegen vectorization
src/target/codegen_cuda.cc
Added explicit 8-lane vectorized CastNode handling for multiple from/to type pairs by composing per-lane unpack/store sequences (reusing 2-/4-lane patterns); retained existing 2- and 4-lane paths. Removed emission of device_assert / device_assert_with_msg in EvaluateNode/codegen end paths.
Layout inference adjustments
src/transform/layout_inference.cc
Add deterministic tie-break in BufferUseDefCollector preferring smaller attempt_infer_root when register counts tie; expanded vectorization trigger to also enable when non-local casts exist (if no reducer).
Vectorized cast tests & API tweak
testing/python/language/test_tilelang_language_vectorized_cast.py
Added parallel_vectorized_cast_kernel (a second @tilelang.jit kernel); changed tensor shape annotations from (M) to (M,) for kernels; updated run_vectorized_cast to compile/run both kernels, allocate an extra output, and assert both outputs and both kernel sources contain expected vectorization checks.

Sequence Diagram(s)

sequenceDiagram
  participant Test as Test runner
  participant JIT as @tilelang.jit
  participant KernelA as kernel (vectorized_cast_kernel)
  participant KernelB as kernel_parallel (parallel_vectorized_cast_kernel)
  participant Device as Device runtime

  Test->>JIT: compile KernelA
  Test->>JIT: compile KernelB
  JIT-->>Test: code_A, binary_A
  JIT-->>Test: code_B, binary_B
  Test->>Device: launch binary_A with A -> B
  Device-->>Test: B_out
  Test->>Device: launch binary_B with A -> C
  Device-->>Test: C_out
  Test->>Test: compare B_out == expected && C_out == expected
  Note over Test,Device: assert both code_A and code_B contain vectorization checks
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Focus review on src/target/codegen_cuda.cc for correctness of 8-lane unpack/store sequences across all type pairs and potential alignment/reinterpret-cast issues.
  • Verify removal of device_assert* emissions is intentional and covered by tests.
  • Check determinism and side effects of tie-break logic in src/transform/layout_inference.cc.
  • Confirm test updates in testing/python/... correctly validate both kernels and that shape annotation changes don't break test harnesses.

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • xwhzz
  • tzj-fxz

Poem

🐰 I hopped through lanes of eight and four,
I stitched the casts and checked the core,
I nudged the roots to pick a side,
I ran two kernels, side-by-side,
Vector hops — a CUDA score!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "[Enhancement] Enhance Cast operations Vectorization" directly aligns with the main objective and changes in the PR. The changes across all three files focus on enhancing vectorization for cast operations: adding 8-lane vectorized cast handling in CUDA codegen, expanding vectorization triggers for non-local casts in layout inference, and adding comprehensive tests for vectorized cast kernels. The title is concise, specific to cast operations vectorization, and avoids vague terminology. A teammate reviewing the repository history would immediately understand that this PR improves vectorization support for type conversion operations.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ca8497a and 8902bbf.

📒 Files selected for processing (1)
  • src/target/codegen_cuda.cc (5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/target/codegen_cuda.cc
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/transform/layout_inference.cc (1)

645-671: Fix typo in fragment buffer scope checks ("local.framgent" → "local.fragment").

The typo disables the intended ICHECK for fragment buffer layouts. Fix all 3 occurrences in the documentation and code condition:

   * @brief Visit and mutate a Block node to attach inferred layout information.
   *
   * Converts the visited Block via the base visitor, asserts that every buffer
-  * allocated with scope "local.framgent" has an inferred layout in
+  * allocated with scope "local.fragment" has an inferred layout in
   * result_.layout_map, and attaches result_.layout_map to the Block's
   * annotations under attr::kLayoutMap.
   *
-  * If any "local.framgent" buffer lacks an entry in result_.layout_map an
+  * If any "local.fragment" buffer lacks an entry in result_.layout_map an
   * ICHECK will fail with the offending buffer printed.
   *
   * @return Stmt The (possibly modified) Block statement with the layout-map
   * annotation set.
   */
  Stmt VisitStmt_(const BlockNode *op) final {
    Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));

    for (auto buffer : block->alloc_buffers) {
-      if (buffer.scope() == "local.framgent") {
+      if (buffer.scope() == "local.fragment") {
         ICHECK(result_.layout_map.count(buffer))
             << "Cannot inference fragment layout for " << buffer;
       }
     }
🧹 Nitpick comments (5)
src/transform/layout_inference.cc (2)

600-603: Deterministic tie-breaker: good addition; consider documenting.

Choosing the smaller attempt_infer_root on equal register usage makes selection stable. Please add a brief comment noting this deterministic tie-break. Optionally sort members before iteration for readability, though not required.

-      // Update the best plan if this one uses fewer registers
+      // Update the best plan if this one uses fewer registers.
+      // Tie-break deterministically by smaller attempt_infer_root.
       if (reg_num < min_reg_num ||
           (reg_num == min_reg_num &&
            attempt_infer_root < min_reg_num_infer_root)) {

792-804: Cast-triggered vectorization: broaden detection beyond top-level stores.

Current check only catches a CastNode as the direct BufferStore value; nested casts (e.g., within Select/Min) are missed. Scan for any CastNode in the loop body instead.

-      bool has_cast_operations = false;
-      PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
-        if (const auto* store = obj.as<BufferStoreNode>()) {
-          // Check if this is a non-reducer store with Cast operation
-          if (store->value.as<CastNode>()) {
-            has_cast_operations = true;
-          }
-        }
-      });
+      bool has_cast_operations = false;
+      PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
+        if (!has_cast_operations && obj.as<CastNode>()) {
+          has_cast_operations = true;
+        }
+      });

Add/adjust a unit test with an expression like B[i] = cast(f(min(A[i], 1.0))) to ensure vectorization still triggers.

src/target/codegen_cuda.cc (1)

923-937: 8‑lane vectorized cast paths: solid, mirror existing 2/4‑lane logic; add tests.

Implementations for half8↔float8, bf16x8↔float8, and fp32x8→fp8x8 look correct and consistent with 2/4‑lane patterns. Please:

  • Add tests that exercise lanes=8 for these paths to lock behavior.
  • Optional: assert or comment on alignment/aliasing assumptions for the float2/half2 and bfloat162 pointer casts.

Add cases like:

  • run_vectorized_cast("float16","float32","__half22float2",8)
  • run_vectorized_cast("float32","float16","__float22half2_rn",8)
  • run_vectorized_cast("bfloat16","float32","__bfloat1622float2",8)
  • run_vectorized_cast("float32","bfloat16","__float22bfloat162_rn",8)
  • run_vectorized_cast("float32","float8_e4m3","__nv_cvt_float2_to_fp8x2",8)
  • run_vectorized_cast("float32","float8_e5m2","__nv_cvt_float2_to_fp8x2",8)

Consider adding a brief comment near each 8-lane branch noting aliasing relies on NVCC’s permissiveness, mirroring the 4‑lane patterns.

Also applies to: 959-973, 1001-1019, 1041-1055, 1091-1117

testing/python/language/test_tilelang_language_vectorized_cast.py (2)

29-48: Parallel kernel OK; consider covering true local→local as well.

This exercises local.fragment↔local.fragment. To fully validate the “Local‑to‑Local” vectorization fix, add a variant using scope="local".

+@tilelang.jit
+def parallel_vectorized_cast_kernel_local(M: int, dtype_A: str, dtype_B: str):
+    assert M % 256 == 0
+    @T.prim_func
+    def main(
+        A: T.Tensor[(M,), dtype_A],
+        B: T.Tensor[(M,), dtype_B],
+    ):
+        with T.Kernel(1, threads=128):
+            A_local = T.alloc_buffer((M,), dtype_A, scope="local")
+            B_local = T.alloc_buffer((M,), dtype_B, scope="local")
+            T.copy(A, A_local)
+            for i in T.Parallel(M):
+                B_local[i] = A_local[i]
+            T.copy(B_local, B)
+    return main

Call this kernel in run_vectorized_cast and assert parity with the others.


59-77: Add lanes=8 cases to lock new 256‑bit vectorization paths.

Extend tests to cover the new 8‑lane branches and check the same vectorization strings.

 def test_vectorized_cast():
   # fp32 -> fp16
   run_vectorized_cast("float32", "float16", "__float22half2_rn", 2)
   run_vectorized_cast("float32", "float16", "__float22half2_rn", 4)
+  run_vectorized_cast("float32", "float16", "__float22half2_rn", 8)
   # fp16 -> fp32
   run_vectorized_cast("float16", "float32", "__half22float2", 2)
   run_vectorized_cast("float16", "float32", "__half22float2", 4)
+  run_vectorized_cast("float16", "float32", "__half22float2", 8)
   # fp32 -> fp8_e4m3
   run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2)
   run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4)
+  run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 8)
   # fp32 -> fp8_e5m2
   run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2)
   run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4)
+  run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 8)
   # fp32 -> bf16
   run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 2)
   run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 4)
+  run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 8)
   # bf16 -> fp32
   run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 2)
   run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 4)
+  run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 8)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 198f22b and 2758fb4.

📒 Files selected for processing (3)
  • src/target/codegen_cuda.cc (5 hunks)
  • src/transform/layout_inference.cc (2 hunks)
  • testing/python/language/test_tilelang_language_vectorized_cast.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (4)
tilelang/language/copy.py (1)
  • copy (11-87)
tilelang/jit/__init__.py (1)
  • jit (233-306)
tilelang/language/allocate.py (1)
  • alloc_fragment (59-70)
tilelang/language/parallel.py (1)
  • Parallel (9-29)
🪛 Ruff (0.14.2)
testing/python/language/test_tilelang_language_vectorized_cast.py

20-20: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)


21-21: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)


35-35: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)


36-36: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)

Comment on lines +20 to +21
A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M,), dtype_B], # noqa: F821
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused noqa F821 directives.

They’re unnecessary with (M,) annotations and trip Ruff.

-            A: T.Tensor[(M,), dtype_A],  # noqa: F821
-            B: T.Tensor[(M,), dtype_B],  # noqa: F821
+            A: T.Tensor[(M,), dtype_A],
+            B: T.Tensor[(M,), dtype_B],

Apply the same change in the parallel kernel.

Also applies to: 35-36

🧰 Tools
🪛 Ruff (0.14.2)

20-20: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)


21-21: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)

🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_vectorized_cast.py around
lines 20-21 (and also lines 35-36), remove the redundant "# noqa: F821"
annotations on the T.Tensor[(M,), dtype_...] lines — they are unnecessary with
the (M,) annotations and duplicate Ruff coverage; do the same edit in the
parallel kernel test file equivalent locations as noted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants