Skip to content

[Refactor] Reduce direct dependency on PyTorch due to its limited type support#1444

Merged
LeiWang1999 merged 5 commits intotile-ai:mainfrom
LeiWang1999:torch_1216
Dec 16, 2025
Merged

[Refactor] Reduce direct dependency on PyTorch due to its limited type support#1444
LeiWang1999 merged 5 commits intotile-ai:mainfrom
LeiWang1999:torch_1216

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Dec 16, 2025

This pull request introduces improved support for FP4 and FP8 types in CUDA code generation, adds a new header for FP4 types, and refactors vector load/store utilities for better generalization and future extensibility. The changes enhance the handling of sub-byte floating-point types, enable larger vector widths, and provide more robust and extensible code for future type additions.

CUDA code generation improvements:

  • Refactored FP4/FP6/FP8 type handling: Renamed type helper functions (e.g., GetFP4Type to GetTileLangFP4Type) and expanded support for more vector widths (2, 4, 8, 16, 32, 64) and new FP8 types (e.g., e8m0fnu). Improved error messages for unsupported types and vector widths. [1] [2] [3] [4]
  • Enhanced buffer load/store logic: Updated the logic for vectorized loads and stores to correctly handle sub-byte types (like FP4) with various lane counts, including new logic for matching ramp patterns and supporting larger vector widths. [1] [2] [3] [4]
  • Improved code generation output: Now includes the correct headers for FP4 types when needed and adds more informative logging for debugging vector store operations. [1] [2] [3] [4] [5]

CUDA utility and type definitions:

  • Added a new header cuda_fp4.h: Defines new FP4 types and vector structures (e.g., fp4_e2_2_t, fp4_e2_4_t, ..., fp4_e2_64_t) and provides utility functions for packing/unpacking these types, mirroring the FP8 handling.
  • Extended FP8 support: Added new FP8 type definitions (e.g., fp8_e8_t, fp8_e8_2_t, ..., fp8_e8_32_t) and their associated pack/unpack logic in cuda_fp8.h. [1] [2]

Vector load/store utility refactoring:

  • Generalized 256-bit load/store functions: Refactored copy_sm100.h to provide generic 256-bit load/store templates for arbitrary types, including specializations for FP8 and new support for FP4 types. Improved overloads to handle both const and non-const references. [1] [2] [3]

These changes collectively make the codebase more robust and extensible for emerging low-precision floating-point types and vectorized operations on CUDA.

Summary by CodeRabbit

  • New Features

    • FP4 CUDA support added (packed vector types and 2/4/8/16/32/64-lane helpers) and an additional FP8 variant.
    • 256-bit generic load/store utilities for wider vector transfers.
  • Refactor

    • Unified low-precision type handling and vectorization logic for FP4/FP6/FP8, improving mixed-type load/store and ramp/vector decisions.
  • Bug Fixes / Compatibility

    • Dtype handling harmonized across runtimes/adapters; FP4 storage accepted where compatible.
  • Tests

    • Copy tests expanded for FP4/FP8; one matmul cache test removed.

✏️ Tip: You can customize this high-level summary in your review settings.

… torch_dtype conversion method

- Changed dtype in KernelParam from torch.dtype to tvm.DataType to support a wider range of data types and prevent information loss during conversions.
- Added a new method, torch_dtype, to convert tvm.DataType back to torch.dtype for tensor creation.
- Updated various adapters to utilize the new torch_dtype method for parameter type conversion during initialization.
… FP8 types

- Renamed functions for clarity: GetFP8Type, GetFP6Type, and GetFP4Type are now GetTileLangFP8Type, GetTileLangFP6Type, and GetTileLangFP4Type respectively.
- Enhanced FP4 type handling to support additional lane sizes (2, 4, 8, 16, 32, 64).
- Updated CUDA code generation to include new FP8 and FP4 types, ensuring proper type handling in PrintType and related functions.
- Introduced new structures for FP8 types in cuda_fp8.h to facilitate better memory management and type packing.
- Added methods in KernelParam and tensor utilities to recognize and handle float4 types, improving compatibility with PyTorch.
- Enhanced logging for debugging purposes in various CUDA functions to track type handling and memory operations more effectively.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 16, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Refactored FP8/FP6/FP4 type helpers to TileLang-prefixed names; added FP4 CUDA types and 256-bit load/store templates; extended CUDA codegen for sub-byte (FP4/FP8) vectorization and BufferStore handling; adjusted runtime dtype acceptance for FP4 storage; switched KernelParam.dtype to tvm.DataType with torch_dtype()/is_float4() accessors.

Changes

Cohort / File(s) Summary
CUDA Codegen Core
src/target/codegen_cuda.cc, src/target/codegen_cuda.h
Renamed FP helper functions to TileLang-prefixed variants; updated PrintType and vector element load/store to support FP4/FP8/FP16/bfloat16 mixtures and lanes up to 64; added ramp-based lane calculations and extra safety checks; added VisitStmt_(const BufferStoreNode *op) override; include cuda_fp4.h when enable_fp4_.
CUDA Type Templates
src/tl_templates/cuda/copy_sm100.h, src/tl_templates/cuda/cuda_fp4.h, src/tl_templates/cuda/cuda_fp8.h
Added generic 256-bit ld/st templates and replaced FP8-specific specializations in copy_sm100.h; added full FP4 wrapper types, vector packs, and factory helpers in cuda_fp4.h; added fp8_e8 types, hierarchical packing structs, assignment from ulonglong4, and make_fp8_e8_N_t helpers in cuda_fp8.h.
Runtime dtype binding
src/transform/arg_binder.cc
Extended BindDLTensor dtype-compatibility checks to treat FP4 storage as compatible with int8-like representations (accept float4 as compatible alternative) to avoid dtype-mismatch errors for sub-byte storage.
Vectorization constraints
src/transform/loop_vectorize.cc
Tightened vector width calculation in UpdateVectorSize to divide by buffer->dtype.bits() * buffer->dtype.lanes(), accounting for lane count for sub-byte types.
Kernel param / adapters
tilelang/engine/param.py, tilelang/jit/adapter/ctypes/adapter.py, tilelang/jit/adapter/cython/cython_wrapper.pyx, tilelang/jit/adapter/nvrtc/adapter.py, tilelang/jit/adapter/tvm_ffi.py
Changed KernelParam.dtype from torch.dtype to tvm.DataType; added is_float4() and torch_dtype() methods; adapters and wrappers now call param.torch_dtype() when constructing cached param dtype lists.
Tensor utilities & supplies
tilelang/utils/tensor.py
map_torch_type accepts non-string inputs and recognizes FP8/FP4 variants (maps float4 fallback to storage int8); get_tensor_supply uses param.torch_dtype(); integer supply paths include is_float4-aware generation ranges.
Tests & examples
testing/python/language/test_tilelang_language_copy.py, testing/python/cache/test_tilelang_cache_matmul.py
Expanded copy tests to cover FP8 and FP4 (new test helpers and cases, dtype parameters split into src/dst); removed an older cache matmul test file.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Areas requiring extra attention:

  • src/target/codegen_cuda.cc: ramp lane arithmetic, PrintVecElemLoad/Store logic for mixed sub-byte types, and buffer-ref packing/division factors.
  • src/tl_templates/cuda/{cuda_fp4.h,cuda_fp8.h,copy_sm100.h}: struct alignments, assignment-from-ulonglong4 semantics, and correctness of generic 256-bit templates.
  • src/transform/arg_binder.cc: runtime dtype-compatibility changes for FP4 storage — confirm no unintended acceptances.
  • KernelParam / adapters: tvm.DataType ↔ torch.dtype conversions (torch_dtype()) used consistently across adapters.

Possibly related PRs

Poem

🐰
I nibble bits both small and spry,
FP4 and FP8 now leap and fly.
Packets, lanes, a tidy store—
256-bit dreams and one-byte more.
Hooray — the tiles hop, codegen winks an eye!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.56% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately captures the main refactoring objective: reducing PyTorch dependency by supporting low-precision types that PyTorch has limited support for.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@LeiWang1999
Copy link
Member Author

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines 742 to 745
} else if (t.is_float4_e2m1fn()) {
os << "([](__nv_fp4_storage_t v) { __nv_fp4_e2m1 t; t.__x = v; return t; "
"})(("
<< vec << ".__x >> " << i * 4 << ") & 0xF)";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Update FP4 vector element access for new struct types

The FP4 type printer now emits custom structs such as fp4_e2_2_t/fp4_e2_4_t from cuda_fp4.h, whose members are named x, y, etc., but PrintVecElemLoad still extracts elements via vec.__x and bit‑shifts as if the value were a packed __nv_fp4*_e2m1 storage object. With the new struct definitions lacking a __x field, any codegen path that needs to read individual lanes from an FP4 vector will fail to compile (e.g., vector buffer loads/stores of float4 types), blocking FP4 kernels. The element access logic needs to be updated to match the new struct layout.

Useful? React with 👍 / 👎.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/target/codegen_cuda.cc (1)

145-174: Update error message to reflect actual supported widths.

The code supports vector widths 2, 4, 8, and 16 (lines 151-158), but the error message at line 160-161 claims "Only support scalar and vector types of width (2, 4) for FP6".

-    LOG(FATAL)
-        << "Only support scalar and vector types of width (2, 4) for FP6";
+    LOG(FATAL)
+        << "Only support scalar and vector types of width (2, 4, 8, 16) for FP6";
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 869f021 and 8fc0268.

📒 Files selected for processing (11)
  • src/target/codegen_cuda.cc (11 hunks)
  • src/target/codegen_cuda.h (1 hunks)
  • src/tl_templates/cuda/copy_sm100.h (3 hunks)
  • src/tl_templates/cuda/cuda_fp4.h (1 hunks)
  • src/tl_templates/cuda/cuda_fp8.h (3 hunks)
  • tilelang/engine/param.py (5 hunks)
  • tilelang/jit/adapter/ctypes/adapter.py (2 hunks)
  • tilelang/jit/adapter/cython/cython_wrapper.pyx (1 hunks)
  • tilelang/jit/adapter/nvrtc/adapter.py (2 hunks)
  • tilelang/jit/adapter/tvm_ffi.py (1 hunks)
  • tilelang/utils/tensor.py (4 hunks)
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • src/target/codegen_cuda.h
  • src/target/codegen_cuda.cc
📚 Learning: 2025-11-03T06:24:11.411Z
Learnt from: Rachmanino
Repo: tile-ai/tilelang PR: 1175
File: src/op/math.cc:44-52
Timestamp: 2025-11-03T06:24:11.411Z
Learning: In tilelang's `src/op/math.cc`, the `infinity_op` function uses `std::numeric_limits<float>::infinity()` as a placeholder for all float types (including float64 and bfloat16). The codegen layer (PrintConst:Inf) handles the correct infinity value based on the dtype field of the FloatImm node, so the specific C++ template argument doesn't matter.

Applied to files:

  • src/target/codegen_cuda.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/tl_templates/cuda/cuda_fp4.h
🧬 Code graph analysis (5)
src/target/codegen_cuda.h (1)
src/transform/warp_specialized_rewriter.cc (10)
  • op (38-43)
  • op (38-38)
  • op (73-83)
  • op (73-73)
  • op (85-93)
  • op (85-85)
  • op (95-100)
  • op (95-95)
  • op (102-110)
  • op (102-102)
tilelang/jit/adapter/ctypes/adapter.py (1)
tilelang/engine/param.py (1)
  • torch_dtype (127-141)
tilelang/engine/param.py (2)
src/tl_templates/cuda/common.h (1)
  • DataType (214-259)
tilelang/utils/tensor.py (1)
  • map_torch_type (35-51)
tilelang/utils/tensor.py (1)
tilelang/engine/param.py (4)
  • torch_dtype (127-141)
  • is_float4 (103-113)
  • is_unsigned (79-89)
  • is_float8 (91-101)
tilelang/jit/adapter/tvm_ffi.py (1)
tilelang/engine/param.py (1)
  • torch_dtype (127-141)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (34)
src/target/codegen_cuda.h (1)

60-60: LGTM!

The new VisitStmt_ override for BufferStoreNode follows the established visitor pattern in this class and is consistent with the existing VisitExpr_ override for BufferLoadNode on line 59.

tilelang/jit/adapter/ctypes/adapter.py (2)

79-80: LGTM!

The change correctly uses param.torch_dtype() to convert from TVM's DataType to PyTorch dtype, aligning with the updated KernelParam API.


143-144: LGTM!

Consistent with the __init__ method change, ensuring both initialization paths use the same dtype conversion.

tilelang/jit/adapter/tvm_ffi.py (1)

138-139: LGTM!

The change correctly converts TVM DataType to PyTorch dtype for tensor creation, consistent with other adapter implementations.

tilelang/jit/adapter/nvrtc/adapter.py (2)

55-56: LGTM!

Consistent with the dtype handling refactor across adapters.


122-123: LGTM!

The from_database classmethod is updated consistently with the __init__ method.

tilelang/jit/adapter/cython/cython_wrapper.pyx (1)

35-36: LGTM!

The Cython wrapper is updated consistently with the Python adapters to use param.torch_dtype().

tilelang/utils/tensor.py (4)

47-49: LGTM!

Using torch.uint8 as storage for float4 is appropriate since PyTorch lacks native float4 support. This aligns with the CUDA-side FP4 handling in the PR.


59-60: LGTM!

Consistent with the dtype handling refactor across the codebase.


81-88: LGTM!

The float4 tensor generation using randint(low=0, high=16, ...) correctly produces 4-bit values (0-15) stored in uint8. The logic placement after is_float8 check ensures proper type priority.


105-112: LGTM!

Consistent with the Auto supply type handling above.

tilelang/engine/param.py (5)

19-23: LGTM!

Good documentation explaining the rationale for using tvm.DataType directly to preserve full type information including specialized types like float8 and float4.


40-42: LGTM!

Directly using buffer.dtype preserves the TVM DataType without premature conversion.


65-67: LGTM!

Consistent with from_buffer - uses var.dtype directly.


103-113: LGTM!

The is_float4() method follows the same pattern as is_float8() and other type-checking methods.


127-141: LGTM!

The torch_dtype() method provides a clean API for converting TVM DataType to PyTorch dtype on-demand, enabling consistent tensor creation across all adapters.

src/tl_templates/cuda/copy_sm100.h (4)

8-24: LGTM!

The 256-bit load implementations for longlong4 and ulonglong4 use correct PTX v4 load instructions.


26-34: LGTM!

The generic template load enables 256-bit loads for FP8/FP4 types by returning the data as ulonglong4.


36-59: LGTM!

The 256-bit store overloads cover both signed and unsigned 64-bit vector types, with both const and non-const reference variants. The comment explaining the need for const &val is helpful.


61-77: Pointer aliasing pattern is acceptable for CUDA device code.

The cast on line 72 (ulonglong4 &val_u64 = *((ulonglong4 *)&val)) assumes T has the same 256-bit layout as ulonglong4. The code assertion at the call site (ICHECK_EQ(t.bits() * t.lanes(), 256)) enforces this constraint. FP8 types passed to this template (e.g., fp8_e4_32_t, fp8_e5_32_t) are 32-byte structures with matching layout and use the same pointer aliasing pattern in their own operator= implementations, confirming this is the established design pattern in the codebase.

src/tl_templates/cuda/cuda_fp8.h (3)

81-114: LGTM!

The fp8_e8 struct definitions correctly mirror the existing fp8_e4 and fp8_e5 patterns with appropriate alignment and hierarchical composition.


234-291: LGTM!

The make_fp8_e8_* factory functions are correctly implemented following the established pattern for fp8_e4 and fp8_e5 types.


9-9: No action required. The __nv_fp8_e8m0 type is a valid, documented CUDA type available in cuda_fp8.h (supported in CUDA 11.8+). The type alias is correct and properly supported.

src/tl_templates/cuda/cuda_fp4.h (5)

5-7: LGTM!

The CUDA architecture guard (__CUDA_ARCH__ >= 800) appropriately restricts FP4 support to Ampere (SM 8.0) and newer architectures.


8-18: LGTM!

Type aliases correctly map to NVIDIA FP4 types, with fallback to array-based structs for wider vector types (x8, x16) where native CUDA types aren't available.


55-58: LGTM, but depends on fixing fp4_e2_32_t.

The fp4_e2_64_t struct definition is correct with proper size and alignment (32 bytes). However, if code relies on the broken fp4_e2_32_t::operator=, issues will cascade to this type.


61-117: LGTM!

The make_fp4_e2_* factory functions correctly pack FP4 values into composite types following established patterns. These functions are independent of the operator= issue.


119-119: LGTM!

The #endif correctly closes the CUDA architecture guard.

src/target/codegen_cuda.cc (6)

110-143: LGTM!

The rename to GetTileLangFP8Type and addition of float8_e8m0fnu support correctly extends FP8 type handling. The updated vector suffix convention (_2, _4, _8, _16, _32) aligns with the new header definitions.


176-208: LGTM!

The GetTileLangFP4Type function correctly generates type names matching the cuda_fp4.h definitions, with proper support for lanes 2, 4, 8, 16, 32, and 64.


288-290: LGTM!

The conditional inclusion of cuda_fp4.h correctly follows the established pattern for FP8 support.


742-748: LGTM!

The FP4 element load correctly extracts individual 4-bit values from packed storage using bit manipulation and constructs proper __nv_fp4_e2m1 instances.


2918-2922: LGTM!

The ramp_lanes adjustment correctly handles sub-byte types where multiple elements are packed into a single byte, ensuring proper vectorization pattern matching.


2968-3029: LGTM!

The new VisitStmt_(BufferStoreNode) implementation correctly handles scalar, vectorized, and element-wise stores with proper sub-byte type support. The ramp pattern matching aligns with the BufferLoadNode visitor.

Comment on lines +46 to +52
TL_DEVICE fp4_e2_32_t &operator=(const ulonglong4 &rhs) {
x.x = *(fp4_e2_8_t *)&rhs.x;
x.y = *(fp4_e2_8_t *)&rhs.y;
y.x = *(fp4_e2_8_t *)&rhs.z;
y.y = *(fp4_e2_8_t *)&rhs.w;
return *this;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Incorrect operator= signature causes size mismatch.

The fp4_e2_32_t::operator= takes ulonglong4 (32 bytes) but the struct is only 16 bytes. Each cast *(fp4_e2_8_t *)&rhs.x reinterprets 8 bytes as 4 bytes, discarding half the data.

The correct signature should use ulonglong2:

-  TL_DEVICE fp4_e2_32_t &operator=(const ulonglong4 &rhs) {
-    x.x = *(fp4_e2_8_t *)&rhs.x;
-    x.y = *(fp4_e2_8_t *)&rhs.y;
-    y.x = *(fp4_e2_8_t *)&rhs.z;
-    y.y = *(fp4_e2_8_t *)&rhs.w;
+  TL_DEVICE fp4_e2_32_t &operator=(const ulonglong2 &rhs) {
+    x = *(fp4_e2_16_t *)&rhs.x;
+    y = *(fp4_e2_16_t *)&rhs.y;
     return *this;
   }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
TL_DEVICE fp4_e2_32_t &operator=(const ulonglong4 &rhs) {
x.x = *(fp4_e2_8_t *)&rhs.x;
x.y = *(fp4_e2_8_t *)&rhs.y;
y.x = *(fp4_e2_8_t *)&rhs.z;
y.y = *(fp4_e2_8_t *)&rhs.w;
return *this;
}
TL_DEVICE fp4_e2_32_t &operator=(const ulonglong2 &rhs) {
x = *(fp4_e2_16_t *)&rhs.x;
y = *(fp4_e2_16_t *)&rhs.y;
return *this;
}
🤖 Prompt for AI Agents
In src/tl_templates/cuda/cuda_fp4.h around lines 46 to 52, the operator=
currently accepts a ulonglong4 (32 bytes) while the struct is 16 bytes and
reinterprets 8-byte members as 4-byte types, discarding data; change the
signature to accept ulonglong2 and copy only the two 8-byte components (rhs.x
and rhs.y), casting each to the correct fp4_e2_8_t pointer/assignment so sizes
align and remove any references to rhs.z or rhs.w.

LeiWang1999 and others added 2 commits December 16, 2025 11:18
…elete obsolete matrix multiplication test file.
- Enhanced PrintVecElemLoad and PrintVecElemStore functions to handle new FP4 types.
- Updated arg_binder to allow float4 to match int8 at runtime, improving compatibility with PyTorch.
- Modified loop_vectorize to account for buffer dtype lanes in vectorization calculations.
- Refactored tensor type mapping to support new float4 and float8 types, ensuring correct type handling in tensor operations.
- Added tests for FP4 and FP8 copy operations to validate functionality and integration with existing workflows.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
src/tl_templates/cuda/cuda_fp4.h (1)

58-96: Verify struct sizes and operator= correctness.

A previous review flagged a potential size mismatch in fp4_e2_32_t::operator=(const ulonglong4&). However, the sizes appear correct upon closer analysis:

  • fp4_e2_32_t = 2 × fp4_e2_16_t = 2 × 16 = 32 bytes
  • ulonglong4 = 4 × 8 = 32 bytes ✓

The reinterpret casts also appear valid:

  • rhs.x/y/z/w are each 8 bytes (unsigned long long)
  • fp4_e2_8_t = 2 × fp4_e2_4_t = 2 × 4 = 8 bytes ✓
#!/bin/bash
# Verify struct sizes match expectations by checking member sizes
# Each fp4_e2_t contains __nv_fp4_storage_t which should be 1 byte

# Search for __nv_fp4_storage_t definition to confirm size
rg -n "__nv_fp4_storage_t" --type-add 'cuda:*.h' --type-add 'cuda:*.cuh' -g '!build/**'
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_copy.py (1)

6-7: Remove debug print statement.

This print(torch.__version__) appears to be a debugging leftover. Consider removing it to avoid cluttering test output.

-print(torch.__version__)
-
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8fc0268 and 30f7ecf.

📒 Files selected for processing (7)
  • src/target/codegen_cuda.cc (11 hunks)
  • src/tl_templates/cuda/cuda_fp4.h (1 hunks)
  • src/transform/arg_binder.cc (1 hunks)
  • src/transform/loop_vectorize.cc (1 hunks)
  • testing/python/cache/test_tilelang_cache_matmul.py (0 hunks)
  • testing/python/language/test_tilelang_language_copy.py (2 hunks)
  • tilelang/utils/tensor.py (5 hunks)
💤 Files with no reviewable changes (1)
  • testing/python/cache/test_tilelang_cache_matmul.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tilelang/utils/tensor.py
🧰 Additional context used
🧠 Learnings (4)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/transform/loop_vectorize.cc
  • src/target/codegen_cuda.cc
  • src/tl_templates/cuda/cuda_fp4.h
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • src/transform/loop_vectorize.cc
  • src/tl_templates/cuda/cuda_fp4.h
📚 Learning: 2025-11-03T06:24:11.411Z
Learnt from: Rachmanino
Repo: tile-ai/tilelang PR: 1175
File: src/op/math.cc:44-52
Timestamp: 2025-11-03T06:24:11.411Z
Learning: In tilelang's `src/op/math.cc`, the `infinity_op` function uses `std::numeric_limits<float>::infinity()` as a placeholder for all float types (including float64 and bfloat16). The codegen layer (PrintConst:Inf) handles the correct infinity value based on the dtype field of the FloatImm node, so the specific C++ template argument doesn't matter.

Applied to files:

  • src/target/codegen_cuda.cc
  • src/tl_templates/cuda/cuda_fp4.h
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • src/target/codegen_cuda.cc
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_copy.py (2)
tilelang/language/copy.py (1)
  • copy (14-95)
tilelang/testing/__init__.py (1)
  • requires_cuda_compute_version_ge (104-105)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (13)
src/transform/loop_vectorize.cc (1)

195-196: LGTM: Correctly accounts for vector lanes in the vectorization constraint.

The updated calculation properly multiplies buffer->dtype.bits() by buffer->dtype.lanes() to compute the total bits per buffer element. This ensures that for vector dtypes (e.g., float32x4), the vectorization constraint reflects the full element width rather than treating each lane as a separate element. For scalar types (lanes=1), behavior is unchanged.

src/transform/arg_binder.cc (1)

446-460: LGTM! FP4-to-int8 compatibility logic is consistent with existing patterns.

The implementation correctly mirrors the bool-to-int8 compatibility pattern above. The logic allows FP4 types to accept int8 storage (since PyTorch uses int8 as storage for FP4 packed data), and the guard at line 460 correctly skips dtype mismatch errors for FP4 types.

Minor: The comment on lines 451-452 has a grammatical hiccup ("Accept int8 with same lanes as the fp4 type" reads as incomplete). Consider rewording to clarify intent, e.g., "Accept int8 storage with matching lane count."

testing/python/language/test_tilelang_language_copy.py (2)

11-22: LGTM! Clean refactor to use T.copy.

The signature change to support separate src_dtype and dst_dtype enables more flexible type testing, and using T.copy instead of a manual element-wise loop is cleaner and aligns with the copy semantics being tested.


165-185: LGTM! FP4 test correctly uses int8 as storage.

The test correctly uses torch.int8 as storage for FP4 data, which aligns with the runtime compatibility changes in arg_binder.cc. The mixed dtype tests (FP4 → float16/bfloat16) exercise the type conversion paths.

src/target/codegen_cuda.cc (6)

110-143: LGTM! FP8 type helper extended with e8m0fnu support.

The rename to GetTileLangFP8Type and addition of is_float8_e8m0fnu() handling extends FP8 support correctly. Vector width support is comprehensive (2, 4, 8, 16, 32).


176-208: LGTM! FP4 type helper supports extended vector widths.

The implementation correctly extends FP4 support to include 32 and 64-lane vector types, matching the struct definitions in cuda_fp4.h. The naming convention (fp4 + suffix + vec + _t) is consistent.


288-290: LGTM! FP4 header inclusion wired correctly.

The enable_fp4_ flag properly gates the inclusion of the FP4 header, following the same pattern as FP8.


724-739: LGTM! FP4 vector element load access follows FP8 pattern.

The hierarchical access pattern (vec.x, vec.x.y, etc.) correctly mirrors the FP8 implementation and aligns with the struct layout in cuda_fp4.h.


843-858: LGTM! FP4 vector element store mirrors load pattern.

The store implementation correctly mirrors the load implementation, maintaining consistency in element access.


2944-3055: LGTM! BufferStoreNode visitor handles sub-byte vectorization.

The implementation correctly handles:

  1. Scalar stores when value_dtype.lanes() == element_dtype.lanes()
  2. Ramp-based vector stores with adjusted lane counts for sub-byte types
  3. Fallback element-wise stores for non-ramp indices

The ramp_lanes calculation at lines 3014-3016 correctly adjusts for sub-byte types with multiple element lanes, matching the same logic in VisitExpr_ for BufferLoadNode.

src/tl_templates/cuda/cuda_fp4.h (3)

1-6: LGTM! Header guard and CUDA architecture check.

The header correctly guards for __CUDA_ARCH__ >= 800 which is required for FP4 support.


9-45: LGTM! fp4_e2_t wrapper provides clean conversion interface.

The wrapper around __nv_fp4_e2m1 with implicit conversions to/from float, half, and storage types follows good CUDA type wrapper patterns.


98-155: LGTM! Factory functions follow consistent patterns.

The make_fp4_e2_N_t functions correctly compose larger types from smaller ones, following the hierarchical struct layout.

Comment on lines 134 to 136
} else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz() ||
type.is_float8_e5m2()) {
stream << "fp8_e5" << vec << "_t";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix duplicate condition check.

Line 135 checks type.is_float8_e5m2() twice (before and after the ||). The second occurrence should likely be removed or replaced with a different check.

-  } else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz() ||
-             type.is_float8_e5m2()) {
+  } else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz()) {
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
} else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz() ||
type.is_float8_e5m2()) {
stream << "fp8_e5" << vec << "_t";
} else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz()) {
stream << "fp8_e5" << vec << "_t";
🤖 Prompt for AI Agents
In src/target/codegen_cuda.cc around lines 134 to 136, there's a duplicated
condition type.is_float8_e5m2() in the OR chain; replace the second occurrence
with the correct float8 e4m3 check (e.g., type.is_float8_e4m3() or
type.is_float8_e4m3fnuz() as appropriate for the intended variant) so the
condition reads for the distinct float8 variants rather than repeating the same
check.

Comment on lines +146 to +163
def run_tilelang_copy_fp8_e8m0(M=1024, N=1024, block_M=128, block_N=128, src_dtype="float8_e8m0fnu", dst_dtype="float8_e8m0fnu"):
program = tilelang_copy(M, N, block_M, block_N, src_dtype=src_dtype, dst_dtype=dst_dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
)
source = kernel.get_kernel_source()
assert "fp8_e8_t" in source
dummy_input = torch.randint(0, 100, (M, N), device="cuda", dtype=torch.int8).view(torch.float8_e8m0fnu)
output = kernel(dummy_input)
assert output is not None


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(10, 0)
def test_tilelang_copy_fp8_e8m0():
run_tilelang_copy_fp8_e8m0(src_dtype="float8_e8m0fnu", dst_dtype="float8_e8m0fnu")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Look for the test file and surrounding context
fd test_tilelang_language_copy.py

Repository: tile-ai/tilelang

Length of output: 114


🏁 Script executed:

# Search for PyTorch version checking patterns in the codebase
rg -i "torch.__version__|requires_pytorch|pytorch.*version" --type py -A 2 -B 2 | head -100

Repository: tile-ai/tilelang

Length of output: 1357


🏁 Script executed:

# Look for tilelang.testing module to understand available decorators/utilities
fd -e py "testing" | grep -E "tilelang.*testing" | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Read the test file around lines 146-163
sed -n '130,170p' testing/python/language/test_tilelang_language_copy.py

Repository: tile-ai/tilelang

Length of output: 1717


🏁 Script executed:

# Search for PyTorch version guards in test files
rg "torch\.__version__|sys\.version|packaging\.version|requires_pytorch" testing/ -A 2 -B 2

Repository: tile-ai/tilelang

Length of output: 386


🏁 Script executed:

# Look at the tilelang.testing module
fd testing/__init__.py -o find . -name "testing.py" | head -5

Repository: tile-ai/tilelang

Length of output: 187


🏁 Script executed:

# Search for requires_cuda to understand available test decorators
rg "requires_cuda|requires_torch" tilelang/ --type py -A 3 | head -50

Repository: tile-ai/tilelang

Length of output: 2959


🏁 Script executed:

# Search for how other tests handle PyTorch version dependencies
rg "requires_package|torch\.__version__|packaging" testing/ --type py -B 2 -A 2 | head -50

Repository: tile-ai/tilelang

Length of output: 820


🏁 Script executed:

# Check if requires_package is used for torch version checking anywhere
rg "requires_package.*torch" . --type py -B 1 -A 2

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look at the full tilelang/testing/__init__.py to see available utilities
cat tilelang/testing/__init__.py

Repository: tile-ai/tilelang

Length of output: 3838


Add PyTorch version guard for torch.float8_e8m0fnu dtype support.

The torch.float8_e8m0fnu dtype was added in PyTorch 2.7.0, but this test lacks a version guard. On PyTorch versions < 2.7.0, the test will fail when attempting to access this dtype. Add @tilelang.testing.requires_package("torch", "2.7.0") or implement a version check decorator to skip the test gracefully on incompatible PyTorch versions.

🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_copy.py around lines 146-163,
the test uses torch.float8_e8m0fnu which was introduced in PyTorch 2.7.0 so add
a PyTorch version guard to skip on older versions: annotate the test function
(and the helper if it may be called directly) with
@tilelang.testing.requires_package("torch", "2.7.0") or perform an equivalent
runtime version check that skips the test when torch.__version__ < "2.7.0".

@LeiWang1999
Copy link
Member Author

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +1 to +6
#pragma once

#include "common.h"

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
#include <cuda_fp4.h>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Float4 types hidden from host compilation

The new cuda_fp4.h wraps every fp4 definition inside #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) (lines 1‑6). NVCC compiles each CU file twice: a host pass where __CUDA_ARCH__ is not defined and a device pass where it is. Because this header becomes empty in the host pass, any generated kernel that includes it (enabled by enable_fp4_ in codegen_cuda.cc) will have undefined parameter types on the host side, causing compilation of fp4 kernels to fail even on supported GPUs. The guard should not remove the type definitions from the host compilation path.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants