Skip to content

[Enhancement] Refactor CUDA vectorized cast generation and remove unsupported FP8 type#1474

Merged
LeiWang1999 merged 6 commits intotile-ai:mainfrom
LJC00118:qwq-1
Dec 22, 2025
Merged

[Enhancement] Refactor CUDA vectorized cast generation and remove unsupported FP8 type#1474
LeiWang1999 merged 6 commits intotile-ai:mainfrom
LJC00118:qwq-1

Conversation

@LJC00118
Copy link
Collaborator

@LJC00118 LJC00118 commented Dec 19, 2025

Summary by CodeRabbit

  • Refactor

    • Consolidated and vectorized GPU cast emission and detection across many numeric types, reducing duplication and improving consistency for vectorized casts.
  • New Features

    • Added compact FP4x2 conversion utilities and broader FP8/FP4 vectorized conversion support.
    • Improved runtime dtype handling and logging for float4/float8 storage variants.
  • Tests

    • Added explicit CUDA/ROCm FP8/FP4 test coverage and simplified test compilation invocation.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 19, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Centralizes CUDA cast vectorization: adds IsCudaVectorizableFP8/IsCudaVectorizableCast predicates, introduces a PrintVectorizedCast emitter, refactors fp16/bf16/fp8/fp4 ↔ float32 cast emission to use vectorized paths while preserving scalar fallbacks, and updates related tests and dtype mappings. (32 words)

Changes

Cohort / File(s) Summary
Vectorization helpers
src/target/utils.h, src/target/utils.cc
Added IsCudaVectorizableFP8(DataType) and IsCudaVectorizableCast(DataType, DataType) predicates to detect CUDA-vectorizable FP8 types and cast pairs (fp16, bf16, fp8, fp4 rules).
CUDA codegen refactor
src/target/codegen_cuda.cc
Added #include "utils.h"; narrowed FP8 dispatch in GetTileLangFP8Type; added PrintVectorizedCast helper and replaced many per-lane CastNode emissions with vectorized emission calls for fp16/bf16/fp8/fp4 ↔ float32 (lanes 2/4/8) with type-suffix logic; retained scalar fallbacks.
Layout inference
src/transform/layout_inference.cc
Replaced ad-hoc per-field cast checks with IsCudaVectorizableCast(from_ty, target_ty) and minor variable renames for clarity when detecting cast operations.
CUDA FP4 helpers
src/tl_templates/cuda/cuda_fp4.h
Added fp4x2 conversion helpers: __tl_cvt_fp4x2_to_half2, __tl_cvt_fp4x2_to_float2, __tl_cvt_half2_to_fp4x2, __tl_cvt_float2_to_fp4x2.
Tests / Python harness
testing/python/debug/test_tilelang_debug_print.py, testing/python/language/test_tilelang_language_vectorized_cast.py
Simplified tilelang.compile call in one test; converted vectorized-cast tests to use T.dtype objects (removed string map), updated data conversions/assertions, added CUDA/ROCm-guarded FP8/FP4 test variants, and added early-skip for float4_e2m1fn.
Dtype mapping
tilelang/language/v2/dtypes.py
Adjusted Torch dtype mapping names for float8/float4 storage and added an informative log branch for float4_e2m1fn mapping to torch.float4_e2m1fn_x2.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Files needing extra attention:
    • src/target/codegen_cuda.cc — lane-width gating, FP8 type_suffix selection, vectorized vs scalar correctness.
    • src/target/utils.cc/.h — completeness and symmetry of IsCudaVectorizableCast rules.
    • src/tl_templates/cuda/cuda_fp4.h — PTX conversion correctness, packing/unpacking, and endianness.
    • Tests — dtype-to-torch mapping changes and CUDA/ROCm gating.

Possibly related PRs

Suggested reviewers

  • xwhzz

Poem

🐇
I nibble bits and hop the lanes,
I braid the casts through vector chains.
FP8, fp4, half and float,
I pack, I cast — then cheer and gloat. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.77% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and clearly summarizes the main change: refactoring CUDA vectorized cast generation and removing an unsupported FP8 type, which aligns with the substantive changes across multiple files.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f96a988 and e68e073.

📒 Files selected for processing (5)
  • src/target/codegen_cuda.cc
  • src/target/utils.cc
  • src/tl_templates/cuda/cuda_fp4.h
  • testing/python/language/test_tilelang_language_vectorized_cast.py
  • tilelang/language/v2/dtypes.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/tl_templates/cuda/cuda_fp4.h
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-03T06:24:11.411Z
Learnt from: Rachmanino
Repo: tile-ai/tilelang PR: 1175
File: src/op/math.cc:44-52
Timestamp: 2025-11-03T06:24:11.411Z
Learning: In tilelang's `src/op/math.cc`, the `infinity_op` function uses `std::numeric_limits<float>::infinity()` as a placeholder for all float types (including float64 and bfloat16). The codegen layer (PrintConst:Inf) handles the correct infinity value based on the dtype field of the FloatImm node, so the specific C++ template argument doesn't matter.

Applied to files:

  • src/target/codegen_cuda.cc
🧬 Code graph analysis (2)
testing/python/language/test_tilelang_language_vectorized_cast.py (3)
tilelang/language/v2/dtypes.py (7)
  • dtype (14-15)
  • float4_e2m1fn (385-385)
  • float32 (295-295)
  • as_torch (15-15)
  • float8_e4m3fn (336-336)
  • float8_e5m2 (350-350)
  • float16 (294-294)
examples/bitnet-1.58b/eval_utils.py (1)
  • device (97-98)
tilelang/testing/__init__.py (1)
  • requires_cuda_compute_version_ge (104-105)
src/target/codegen_cuda.cc (1)
src/target/utils.cc (2)
  • IsCudaVectorizableFP8 (137-140)
  • IsCudaVectorizableFP8 (137-137)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (8)
testing/python/language/test_tilelang_language_vectorized_cast.py (2)

61-62: Early return prevents FP4 testing.

Lines 61-62 skip execution for float4_e2m1fn types. This prevents validation of FP4 conversions in tests, though the kernel source is still checked (Line 59). Confirm this is intentional and related to torch dtype availability issues flagged in dtypes.py.


95-126: LGTM: Well-organized test structure.

The separation of test variants (standard, FP8, FP4) with appropriate CUDA compute version gates is clear and maintainable. FP8 tests require CUDA 8.9+ and FP4 tests require CUDA 10.0+, which aligns with hardware capabilities.

src/target/utils.cc (2)

137-140: LGTM: Correct FP8 type coverage for CUDA.

The function correctly identifies FP8 types supported for CUDA vectorized operations: e4m3, e4m3fn, and e5m2. The intentional exclusion of e4m3fnuz (AMD/HIP variant) is consistent with CUDA-specific vectorization paths in codegen_cuda.cc.


142-176: LGTM: Comprehensive vectorizable cast coverage.

The function correctly identifies all CUDA-vectorizable cast pairs:

  • Standard precision conversions (float16, bfloat16, float32)
  • FP8 conversions (delegated to IsCudaVectorizableFP8)
  • FP4 conversions (float4_e2m1fn ↔ float32)

Logic is symmetric for bidirectional casts and aligns with hardware capabilities.

src/target/codegen_cuda.cc (3)

975-995: LGTM: Well-designed vectorization helper.

The PrintVectorizedCast lambda provides a clean abstraction for emitting vectorized cast operations. It correctly handles:

  • Reinterpret vs. regular casts via bool flags
  • Extra arguments for intrinsic-specific parameters
  • Chunked processing (pairs of lanes)

This significantly reduces code duplication across multiple cast paths.


997-1103: LGTM: Comprehensive vectorized cast coverage.

The vectorized cast implementations correctly handle all supported dtype pairs:

  • Standard precision: float16, bfloat16, float32 (2/4/8 lanes)
  • FP8: e4m3/e4m3fn/e5m2 ↔ float32 with type_suffix logic
  • FP4: float4_e2m1fn ↔ float16/float32 with uint8_t storage

All paths use the PrintVectorizedCast helper consistently and check for 2/4/8 lane support.


1104-1116: LGTM: Proper fallback for non-vectorizable casts.

The element-wise fallback (lines 1104-1113) ensures correctness for unsupported vector widths or dtype combinations. This maintains backward compatibility while enabling vectorized paths where available.

tilelang/language/v2/dtypes.py (1)

194-196: Verify the log message text and attribute name match.

Line 195's log message mentions "float4_e2m1fnx2" but should reference "float4_e2m1fn_x2" (with underscore) to match the actual dtype attribute being returned on line 196. While torch.float4_e2m1fn_x2 is a valid PyTorch dtype, the log message contains inconsistent naming that could confuse users about the correct attribute name. Update the message to use the correct underscore-separated name: "float4_e2m1fn_x2".


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/target/codegen_cuda.cc (1)

975-995: Excellent refactoring with the generate_vector_conversion helper.

The lambda function consolidates repetitive per-lane vector conversion logic into a single reusable helper, significantly improving code maintainability and reducing duplication. The parameterization is well-designed to handle different conversion scenarios.

For future consideration, you might extract this lambda as a class method to improve testability and potentially reuse it across other conversion contexts, though the current implementation is acceptable.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3516f1e and a60139f.

📒 Files selected for processing (4)
  • src/target/codegen_cuda.cc (3 hunks)
  • src/target/utils.cc (1 hunks)
  • src/target/utils.h (1 hunks)
  • src/transform/layout_inference.cc (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/layout_inference.cc (2)
src/target/utils.cc (4)
  • IsCudaVectorizableCast (142-167)
  • IsCudaVectorizableCast (142-142)
  • TargetIsCuda (14-16)
  • TargetIsCuda (14-14)
tilelang/language/kernel.py (1)
  • Current (134-140)
src/target/codegen_cuda.cc (1)
src/target/utils.cc (2)
  • IsCudaVectorizableFP8 (137-140)
  • IsCudaVectorizableFP8 (137-137)
src/target/utils.h (1)
src/target/utils.cc (4)
  • IsCudaVectorizableFP8 (137-140)
  • IsCudaVectorizableFP8 (137-137)
  • IsCudaVectorizableCast (142-167)
  • IsCudaVectorizableCast (142-142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (6)
src/target/utils.h (1)

33-34: LGTM! Clean function declarations.

The new helper functions centralize FP8 vectorization and cast validation logic, improving code maintainability.

src/target/utils.cc (1)

137-167: LGTM! Well-structured helper functions.

The implementations correctly:

  • Identify vectorizable FP8 types (E4M3, E4M3FN, E5M2) excluding unsupported variants
  • Cover all bidirectional cast combinations for float16, bfloat16, float32, and FP8 types
  • Provide a centralized validation point for CUDA vectorizable casts
src/target/codegen_cuda.cc (3)

18-18: LGTM! Correct include for new utilities.


132-134: Good consolidation of FP8 type handling.

Merging float8_e4m3 and float8_e4m3fn branches simplifies the code while maintaining correct type mapping. The removal of the float8_e4m3fnuz branch aligns with the PR objective to remove unsupported FP8 types.


1036-1064: Verify implementation of custom __tl_cvt_fp8x2_to_float2 function and type suffix handling.

The CUDA intrinsic __nv_cvt_float2_to_fp8x2 and constant __NV_SATFINITE are documented in CUDA SDK 13.1+. However:

  • __tl_cvt_fp8x2_to_float2 is not a standard CUDA intrinsic and must be verified to exist and match expected signatures.
  • __NV_E4M3 and __NV_E5M2 are not standard CUDA constants. Confirm these are properly defined type suffixes for this codebase and that the string-based parameter passing aligns with the underlying CUDA intrinsic requirements.
  • Ensure FP8 code paths are guarded by appropriate compile-time checks for CUDA architecture >= 890.
src/transform/layout_inference.cc (1)

1180-1183: Good refactoring to centralize cast validation logic.

The change replaces explicit per-field type checks with a call to IsCudaVectorizableCast, improving maintainability. Target::Current() is properly defined and correctly returns the current target context, as verified by consistent usage throughout the codebase.

@LJC00118 LJC00118 changed the title Refactor CUDA vectorized cast generation and remove unsupported FP8 type [Enhancement] Refactor CUDA vectorized cast generation and remove unsupported FP8 type Dec 19, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a60139f and 0e264c9.

📒 Files selected for processing (2)
  • src/target/codegen_cuda.cc
  • testing/python/debug/test_tilelang_debug_print.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/debug/test_tilelang_debug_print.py
🧬 Code graph analysis (2)
src/target/codegen_cuda.cc (1)
src/target/utils.cc (2)
  • IsCudaVectorizableFP8 (137-140)
  • IsCudaVectorizableFP8 (137-137)
testing/python/debug/test_tilelang_debug_print.py (2)
tilelang/jit/__init__.py (2)
  • compile (49-115)
  • compile (347-373)
tilelang/language/v2/dtypes.py (3)
  • float8_e4m3fn (332-332)
  • float8_e5m2 (346-346)
  • float8_e4m3fnuz (339-339)
🔇 Additional comments (4)
testing/python/debug/test_tilelang_debug_print.py (1)

35-44: Well-structured FP8 test coverage.

The separation of CUDA and ROCm FP8 tests with appropriate hardware guards is excellent:

  • NVIDIA formats (float8_e4m3fn, float8_e5m2) are correctly isolated under @requires_cuda
  • AMD formats (float8_e4m3fnuz, float8_e5m2fnuz) are correctly isolated under @requires_rocm

This approach ensures tests only run when the required hardware is available and properly exercises the FP8 vectorization paths introduced in the PR.

src/target/codegen_cuda.cc (3)

18-18: LGTM! Clean dependency addition.

The new include provides the IsCudaVectorizableFP8 and IsCudaVectorizableCast predicates used for FP8 vectorization checks later in the file.


973-1033: Excellent refactoring to eliminate duplication!

The introduction of the PrintVectorizedCast lambda effectively consolidates the vectorized cast pattern across float16, bfloat16, and FP8 conversions. The implementation correctly:

  • Computes num_chunks = lanes / 2 matching the 2-element intrinsics
  • Uses reinterpret_cast appropriately for bfloat16 (__nv_bfloat162) types
  • Maintains consistent lane count checks (2, 4, 8) across all conversions
  • Preserves the element-wise fallback at line 1067 for other cases

1036-1064: FP8 vectorized conversions are correct.

The FP8 conversion paths properly:

  • Gate vectorization with IsCudaVectorizableFP8 checks
  • Consolidate float8_e4m3 and float8_e4m3fn to the same __NV_E4M3 constant (consistent with line 132 and justified by their identical E4M3 binary representation)
  • Use reinterpret_cast for __nv_fp8x2_storage_t storage type
  • Pass saturation mode (__NV_SATFINITE) and type constants to CUDA intrinsics

The consolidation is correct: both e4m3 and e4m3fn share the same binary layout (1 sign, 4 exponent, 3 fraction bits) and are treated as compatible by NVIDIA's CUDA ecosystem and ONNX standards.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/language/v2/dtypes.py (1)

189-196: Fix assertion to match the correct torch attribute name with underscore.

Line 190-192 asserts for torch.float4_e2m1fnx2 (no underscore) but line 193 returns torch.float4_e2m1fn_x2 (with underscore). Update the assertion to check for the correct attribute name: torch.float4_e2m1fn_x2.

♻️ Duplicate comments (1)
src/target/codegen_cuda.cc (1)

132-134: Acknowledged: float8_e4m3fnuz removal is an intentional CUDA backend change.

Per the past review comment, the removal of float8_e4m3fnuz support from CUDA codegen is backend-specific. HIP/AMD codegen still supports fnuz. Ensure migration documentation or release notes reflect this change for CUDA users who may be using fnuz types.

🧹 Nitpick comments (2)
src/tl_templates/cuda/cuda_fp4.h (1)

197-203: Consider potential precision implications of double conversion.

The float2 → fp4x2 conversion goes through half2 as an intermediate step. While this matches the reverse path (fp4x2 → half2 → float2), the double conversion (float → half → fp4) may introduce different rounding behavior compared to a direct float → fp4 conversion. This is likely acceptable for FP4's limited precision, but worth documenting if precision guarantees matter.

testing/python/language/test_tilelang_language_vectorized_cast.py (1)

97-100: Verify FP4 test coverage for higher lane counts.

The new FP4 test cases only use lanes=2, while other dtype conversions test both lanes=2 and lanes=4. Consider adding lanes=4 test cases for FP4 conversions to ensure the vectorized path works correctly for larger vector widths.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0e264c9 and 91da49d.

📒 Files selected for processing (5)
  • src/target/codegen_cuda.cc
  • src/target/utils.cc
  • src/tl_templates/cuda/cuda_fp4.h
  • testing/python/language/test_tilelang_language_vectorized_cast.py
  • tilelang/language/v2/dtypes.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-03T06:24:11.411Z
Learnt from: Rachmanino
Repo: tile-ai/tilelang PR: 1175
File: src/op/math.cc:44-52
Timestamp: 2025-11-03T06:24:11.411Z
Learning: In tilelang's `src/op/math.cc`, the `infinity_op` function uses `std::numeric_limits<float>::infinity()` as a placeholder for all float types (including float64 and bfloat16). The codegen layer (PrintConst:Inf) handles the correct infinity value based on the dtype field of the FloatImm node, so the specific C++ template argument doesn't matter.

Applied to files:

  • src/target/codegen_cuda.cc
🧬 Code graph analysis (3)
src/tl_templates/cuda/cuda_fp4.h (2)
src/tl_templates/cuda/common.h (1)
  • uint32_t (152-154)
src/tl_templates/cuda/cuda_fp8.h (1)
  • float2 (294-302)
src/target/codegen_cuda.cc (1)
src/target/utils.cc (2)
  • IsCudaVectorizableFP8 (137-140)
  • IsCudaVectorizableFP8 (137-137)
testing/python/language/test_tilelang_language_vectorized_cast.py (1)
tilelang/language/v2/dtypes.py (5)
  • dtype (14-15)
  • float4_e2m1fn (385-385)
  • float32 (295-295)
  • as_torch (15-15)
  • float16 (294-294)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (5)
src/target/utils.cc (1)

137-176: LGTM! Well-structured vectorization predicates.

The new IsCudaVectorizableFP8 and IsCudaVectorizableCast helper functions provide clean, centralized logic for determining vectorizable cast combinations. The symmetric handling of bidirectional casts (e.g., float16 ↔ float32) is consistent and easy to maintain.

src/tl_templates/cuda/cuda_fp4.h (1)

157-171: LGTM! Clean PTX-based FP4x2 conversion implementation.

The conversion uses appropriate PTX instructions (cvt.rn.f16x2.e2m1x2) and follows the established pattern from the FP8 conversion utilities in cuda_fp8.h. The intermediate register handling is correctly managed.

src/target/codegen_cuda.cc (2)

975-995: LGTM! Clean helper reduces duplication significantly.

The PrintVectorizedCast lambda centralizes the vectorized cast emission logic, handling both reinterpret_cast and C-style casts appropriately. The parameterization (cast_func, src_type, dst_type, extra_args, reinterpret flags) provides good flexibility for the various conversion paths.


1064-1102: LGTM! Complete FP4 vectorized cast coverage.

The new FP4 ↔ float16/float32 conversion paths are comprehensive, covering all four directions. The code correctly uses the __tl_cvt_* functions defined in cuda_fp4.h.

testing/python/language/test_tilelang_language_vectorized_cast.py (1)

43-62: LGTM! Cleaner API with dtype objects.

The refactoring from string-based dtype specification to direct dtype objects is a good improvement. The early return for float4_e2m1fn correctly handles the case where torch doesn't support the type directly, allowing the vectorization code path to still be verified via source inspection.

@LeiWang1999
Copy link
Member

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines 190 to +196
assert hasattr(torch, "float4_e2m1fnx2"), (
"torch.float4_e2m1fnx2 is not supported in this version of torch. Please upgrade torch >= 2.8.0"
)
return torch.float4_e2m1fnx2
return torch.float4_e2m1fn_x2
elif dtype_str == "float4_e2m1fn":
logger.info("torch doesn't support float4_e2m1fn, using float4_e2m1fnx2 as storage dtype.")
return torch.float4_e2m1fn_x2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Return the correct torch dtype attribute name

The new float4 mapping checks hasattr(torch, "float4_e2m1fnx2") but then returns torch.float4_e2m1fn_x2 (note the extra underscore). When torch supports the dtype, this will still raise AttributeError at runtime because the returned attribute doesn’t exist, and it affects both the float4_e2m1fnx2 path and the fallback for float4_e2m1fn. This breaks any code paths that request those dtypes on supported torch versions.

Useful? React with 👍 / 👎.

- Implemented conversions between float4_e2m1fn and float32, half2, and float2 in utils.cc and cuda_fp4.h.
- Updated test_tilelang_language_vectorized_cast.py to validate new conversions and ensure correctness.
- Enhanced dtype conversion in dtypes.py to handle float4_e2m1fn appropriately, logging a warning for unsupported types in PyTorch.
- Added tests for vectorized casting of float8 and float4 data types, ensuring compatibility with CUDA compute versions.
- Refactored existing test functions to improve clarity and organization, separating tests for different data types.
- Updated parameterization to include additional test cases for new conversions.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (1)

65-67: Remove duplicate assignment of A.

Lines 65 and 67 both assign A = A_float.to(src_dtype.as_torch()). The first assignment is redundant.

🔎 Proposed fix
     A_float = torch.randn(M, dtype=torch.float32, device="cuda")
-    A = A_float.to(src_dtype.as_torch())
-
     A = A_float.to(src_dtype.as_torch())
     B = torch.zeros(M, dtype=dst_dtype.as_torch(), device="cuda")
     C = torch.zeros(M, dtype=dst_dtype.as_torch(), device="cuda")
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (1)

56-59: Consider printing both kernel sources for debugging consistency.

While both code and code_parallel are checked for vectorization, only code is printed. For debugging failures in the parallel kernel path, printing both would be helpful.

🔎 Proposed enhancement
     code = kernel.get_kernel_source()
     code_parallel = kernel_parallel.get_kernel_source()
     print(code)
+    print(code_parallel)
     assert check_str in code and check_str in code_parallel, f"Cast {src_dtype} to {dst_dtype} with {lanes=} is not vectorized!"
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 91da49d and f96a988.

📒 Files selected for processing (1)
  • testing/python/language/test_tilelang_language_vectorized_cast.py
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (2)
tilelang/language/v2/dtypes.py (6)
  • dtype (14-15)
  • float4_e2m1fn (385-385)
  • as_torch (15-15)
  • float8_e4m3fn (336-336)
  • float8_e5m2 (350-350)
  • float16 (294-294)
tilelang/testing/__init__.py (1)
  • requires_cuda_compute_version_ge (104-105)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (4)
testing/python/language/test_tilelang_language_vectorized_cast.py (4)

43-50: LGTM! Type-safe dtype objects improve API clarity.

The refactor from string-based to dtype object parameters enhances type safety and aligns with the broader dtype-based invocation path introduced in this PR.


68-69: LGTM! Tensor creation and assertions correctly use dtype API.

The tensor allocations and correctness checks properly use the dtype.as_torch() method, maintaining consistency with the refactored dtype-based approach.

Also applies to: 74-75


78-127: LGTM! Well-organized test structure with appropriate version guards.

The three-tier test organization clearly separates:

  • General vectorized casts (FP16, BF16, FP8, FP32)
  • FP8-to-FP32 conversions requiring CUDA compute 8.9+
  • FP4 conversions requiring CUDA compute 10.0+

The version guards ensure tests run only on compatible hardware, and parametrization provides comprehensive coverage.


61-62: FP4 runtime execution is intentionally skipped until ml_dtypes support is available.

The early return prevents runtime kernel execution for float4_e2m1fn because ml_dtypes does not yet have the float4_e2m1fn attribute, which is needed to convert PyTorch tensors to FP4 dtype for testing. Codegen validation (confirming vectorized intrinsics are generated correctly) still runs; runtime correctness cannot be validated until the ml_dtypes library adds FP4 support.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants