[Enhancement] Refactor CUDA vectorized cast generation and remove unsupported FP8 type#1474
[Enhancement] Refactor CUDA vectorized cast generation and remove unsupported FP8 type#1474LeiWang1999 merged 6 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughCentralizes CUDA cast vectorization: adds IsCudaVectorizableFP8/IsCudaVectorizableCast predicates, introduces a PrintVectorizedCast emitter, refactors fp16/bf16/fp8/fp4 ↔ float32 cast emission to use vectorized paths while preserving scalar fallbacks, and updates related tests and dtype mappings. (32 words) Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (5)
🚧 Files skipped from review as they are similar to previous changes (1)
🧰 Additional context used🧠 Learnings (1)📚 Learning: 2025-11-03T06:24:11.411ZApplied to files:
🧬 Code graph analysis (2)testing/python/language/test_tilelang_language_vectorized_cast.py (3)
src/target/codegen_cuda.cc (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
🔇 Additional comments (8)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/target/codegen_cuda.cc (1)
975-995: Excellent refactoring with thegenerate_vector_conversionhelper.The lambda function consolidates repetitive per-lane vector conversion logic into a single reusable helper, significantly improving code maintainability and reducing duplication. The parameterization is well-designed to handle different conversion scenarios.
For future consideration, you might extract this lambda as a class method to improve testability and potentially reuse it across other conversion contexts, though the current implementation is acceptable.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/target/codegen_cuda.cc(3 hunks)src/target/utils.cc(1 hunks)src/target/utils.h(1 hunks)src/transform/layout_inference.cc(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/layout_inference.cc (2)
src/target/utils.cc (4)
IsCudaVectorizableCast(142-167)IsCudaVectorizableCast(142-142)TargetIsCuda(14-16)TargetIsCuda(14-14)tilelang/language/kernel.py (1)
Current(134-140)
src/target/codegen_cuda.cc (1)
src/target/utils.cc (2)
IsCudaVectorizableFP8(137-140)IsCudaVectorizableFP8(137-137)
src/target/utils.h (1)
src/target/utils.cc (4)
IsCudaVectorizableFP8(137-140)IsCudaVectorizableFP8(137-137)IsCudaVectorizableCast(142-167)IsCudaVectorizableCast(142-142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (6)
src/target/utils.h (1)
33-34: LGTM! Clean function declarations.The new helper functions centralize FP8 vectorization and cast validation logic, improving code maintainability.
src/target/utils.cc (1)
137-167: LGTM! Well-structured helper functions.The implementations correctly:
- Identify vectorizable FP8 types (E4M3, E4M3FN, E5M2) excluding unsupported variants
- Cover all bidirectional cast combinations for float16, bfloat16, float32, and FP8 types
- Provide a centralized validation point for CUDA vectorizable casts
src/target/codegen_cuda.cc (3)
18-18: LGTM! Correct include for new utilities.
132-134: Good consolidation of FP8 type handling.Merging
float8_e4m3andfloat8_e4m3fnbranches simplifies the code while maintaining correct type mapping. The removal of thefloat8_e4m3fnuzbranch aligns with the PR objective to remove unsupported FP8 types.
1036-1064: Verify implementation of custom__tl_cvt_fp8x2_to_float2function and type suffix handling.The CUDA intrinsic
__nv_cvt_float2_to_fp8x2and constant__NV_SATFINITEare documented in CUDA SDK 13.1+. However:
__tl_cvt_fp8x2_to_float2is not a standard CUDA intrinsic and must be verified to exist and match expected signatures.__NV_E4M3and__NV_E5M2are not standard CUDA constants. Confirm these are properly defined type suffixes for this codebase and that the string-based parameter passing aligns with the underlying CUDA intrinsic requirements.- Ensure FP8 code paths are guarded by appropriate compile-time checks for CUDA architecture >= 890.
src/transform/layout_inference.cc (1)
1180-1183: Good refactoring to centralize cast validation logic.The change replaces explicit per-field type checks with a call to
IsCudaVectorizableCast, improving maintainability.Target::Current()is properly defined and correctly returns the current target context, as verified by consistent usage throughout the codebase.
There was a problem hiding this comment.
Actionable comments posted: 2
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/target/codegen_cuda.cctesting/python/debug/test_tilelang_debug_print.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/debug/test_tilelang_debug_print.py
🧬 Code graph analysis (2)
src/target/codegen_cuda.cc (1)
src/target/utils.cc (2)
IsCudaVectorizableFP8(137-140)IsCudaVectorizableFP8(137-137)
testing/python/debug/test_tilelang_debug_print.py (2)
tilelang/jit/__init__.py (2)
compile(49-115)compile(347-373)tilelang/language/v2/dtypes.py (3)
float8_e4m3fn(332-332)float8_e5m2(346-346)float8_e4m3fnuz(339-339)
🔇 Additional comments (4)
testing/python/debug/test_tilelang_debug_print.py (1)
35-44: Well-structured FP8 test coverage.The separation of CUDA and ROCm FP8 tests with appropriate hardware guards is excellent:
- NVIDIA formats (
float8_e4m3fn,float8_e5m2) are correctly isolated under@requires_cuda- AMD formats (
float8_e4m3fnuz,float8_e5m2fnuz) are correctly isolated under@requires_rocmThis approach ensures tests only run when the required hardware is available and properly exercises the FP8 vectorization paths introduced in the PR.
src/target/codegen_cuda.cc (3)
18-18: LGTM! Clean dependency addition.The new include provides the
IsCudaVectorizableFP8andIsCudaVectorizableCastpredicates used for FP8 vectorization checks later in the file.
973-1033: Excellent refactoring to eliminate duplication!The introduction of the
PrintVectorizedCastlambda effectively consolidates the vectorized cast pattern across float16, bfloat16, and FP8 conversions. The implementation correctly:
- Computes
num_chunks = lanes / 2matching the 2-element intrinsics- Uses
reinterpret_castappropriately for bfloat16 (__nv_bfloat162) types- Maintains consistent lane count checks (2, 4, 8) across all conversions
- Preserves the element-wise fallback at line 1067 for other cases
1036-1064: FP8 vectorized conversions are correct.The FP8 conversion paths properly:
- Gate vectorization with
IsCudaVectorizableFP8checks- Consolidate
float8_e4m3andfloat8_e4m3fnto the same__NV_E4M3constant (consistent with line 132 and justified by their identical E4M3 binary representation)- Use
reinterpret_castfor__nv_fp8x2_storage_tstorage type- Pass saturation mode (
__NV_SATFINITE) and type constants to CUDA intrinsicsThe consolidation is correct: both e4m3 and e4m3fn share the same binary layout (1 sign, 4 exponent, 3 fraction bits) and are treated as compatible by NVIDIA's CUDA ecosystem and ONNX standards.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/language/v2/dtypes.py (1)
189-196: Fix assertion to match the correct torch attribute name with underscore.Line 190-192 asserts for
torch.float4_e2m1fnx2(no underscore) but line 193 returnstorch.float4_e2m1fn_x2(with underscore). Update the assertion to check for the correct attribute name:torch.float4_e2m1fn_x2.
♻️ Duplicate comments (1)
src/target/codegen_cuda.cc (1)
132-134: Acknowledged: float8_e4m3fnuz removal is an intentional CUDA backend change.Per the past review comment, the removal of
float8_e4m3fnuzsupport from CUDA codegen is backend-specific. HIP/AMD codegen still supports fnuz. Ensure migration documentation or release notes reflect this change for CUDA users who may be using fnuz types.
🧹 Nitpick comments (2)
src/tl_templates/cuda/cuda_fp4.h (1)
197-203: Consider potential precision implications of double conversion.The float2 → fp4x2 conversion goes through half2 as an intermediate step. While this matches the reverse path (fp4x2 → half2 → float2), the double conversion (float → half → fp4) may introduce different rounding behavior compared to a direct float → fp4 conversion. This is likely acceptable for FP4's limited precision, but worth documenting if precision guarantees matter.
testing/python/language/test_tilelang_language_vectorized_cast.py (1)
97-100: Verify FP4 test coverage for higher lane counts.The new FP4 test cases only use
lanes=2, while other dtype conversions test bothlanes=2andlanes=4. Consider addinglanes=4test cases for FP4 conversions to ensure the vectorized path works correctly for larger vector widths.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
src/target/codegen_cuda.ccsrc/target/utils.ccsrc/tl_templates/cuda/cuda_fp4.htesting/python/language/test_tilelang_language_vectorized_cast.pytilelang/language/v2/dtypes.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-03T06:24:11.411Z
Learnt from: Rachmanino
Repo: tile-ai/tilelang PR: 1175
File: src/op/math.cc:44-52
Timestamp: 2025-11-03T06:24:11.411Z
Learning: In tilelang's `src/op/math.cc`, the `infinity_op` function uses `std::numeric_limits<float>::infinity()` as a placeholder for all float types (including float64 and bfloat16). The codegen layer (PrintConst:Inf) handles the correct infinity value based on the dtype field of the FloatImm node, so the specific C++ template argument doesn't matter.
Applied to files:
src/target/codegen_cuda.cc
🧬 Code graph analysis (3)
src/tl_templates/cuda/cuda_fp4.h (2)
src/tl_templates/cuda/common.h (1)
uint32_t(152-154)src/tl_templates/cuda/cuda_fp8.h (1)
float2(294-302)
src/target/codegen_cuda.cc (1)
src/target/utils.cc (2)
IsCudaVectorizableFP8(137-140)IsCudaVectorizableFP8(137-137)
testing/python/language/test_tilelang_language_vectorized_cast.py (1)
tilelang/language/v2/dtypes.py (5)
dtype(14-15)float4_e2m1fn(385-385)float32(295-295)as_torch(15-15)float16(294-294)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (5)
src/target/utils.cc (1)
137-176: LGTM! Well-structured vectorization predicates.The new
IsCudaVectorizableFP8andIsCudaVectorizableCasthelper functions provide clean, centralized logic for determining vectorizable cast combinations. The symmetric handling of bidirectional casts (e.g., float16 ↔ float32) is consistent and easy to maintain.src/tl_templates/cuda/cuda_fp4.h (1)
157-171: LGTM! Clean PTX-based FP4x2 conversion implementation.The conversion uses appropriate PTX instructions (
cvt.rn.f16x2.e2m1x2) and follows the established pattern from the FP8 conversion utilities incuda_fp8.h. The intermediate register handling is correctly managed.src/target/codegen_cuda.cc (2)
975-995: LGTM! Clean helper reduces duplication significantly.The
PrintVectorizedCastlambda centralizes the vectorized cast emission logic, handling both reinterpret_cast and C-style casts appropriately. The parameterization (cast_func, src_type, dst_type, extra_args, reinterpret flags) provides good flexibility for the various conversion paths.
1064-1102: LGTM! Complete FP4 vectorized cast coverage.The new FP4 ↔ float16/float32 conversion paths are comprehensive, covering all four directions. The code correctly uses the
__tl_cvt_*functions defined incuda_fp4.h.testing/python/language/test_tilelang_language_vectorized_cast.py (1)
43-62: LGTM! Cleaner API with dtype objects.The refactoring from string-based dtype specification to direct dtype objects is a good improvement. The early return for
float4_e2m1fncorrectly handles the case where torch doesn't support the type directly, allowing the vectorization code path to still be verified via source inspection.
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| assert hasattr(torch, "float4_e2m1fnx2"), ( | ||
| "torch.float4_e2m1fnx2 is not supported in this version of torch. Please upgrade torch >= 2.8.0" | ||
| ) | ||
| return torch.float4_e2m1fnx2 | ||
| return torch.float4_e2m1fn_x2 | ||
| elif dtype_str == "float4_e2m1fn": | ||
| logger.info("torch doesn't support float4_e2m1fn, using float4_e2m1fnx2 as storage dtype.") | ||
| return torch.float4_e2m1fn_x2 |
There was a problem hiding this comment.
Return the correct torch dtype attribute name
The new float4 mapping checks hasattr(torch, "float4_e2m1fnx2") but then returns torch.float4_e2m1fn_x2 (note the extra underscore). When torch supports the dtype, this will still raise AttributeError at runtime because the returned attribute doesn’t exist, and it affects both the float4_e2m1fnx2 path and the fallback for float4_e2m1fn. This breaks any code paths that request those dtypes on supported torch versions.
Useful? React with 👍 / 👎.
- Implemented conversions between float4_e2m1fn and float32, half2, and float2 in utils.cc and cuda_fp4.h. - Updated test_tilelang_language_vectorized_cast.py to validate new conversions and ensure correctness. - Enhanced dtype conversion in dtypes.py to handle float4_e2m1fn appropriately, logging a warning for unsupported types in PyTorch.
- Added tests for vectorized casting of float8 and float4 data types, ensuring compatibility with CUDA compute versions. - Refactored existing test functions to improve clarity and organization, separating tests for different data types. - Updated parameterization to include additional test cases for new conversions.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (1)
65-67: Remove duplicate assignment ofA.Lines 65 and 67 both assign
A = A_float.to(src_dtype.as_torch()). The first assignment is redundant.🔎 Proposed fix
A_float = torch.randn(M, dtype=torch.float32, device="cuda") - A = A_float.to(src_dtype.as_torch()) - A = A_float.to(src_dtype.as_torch()) B = torch.zeros(M, dtype=dst_dtype.as_torch(), device="cuda") C = torch.zeros(M, dtype=dst_dtype.as_torch(), device="cuda")
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (1)
56-59: Consider printing both kernel sources for debugging consistency.While both
codeandcode_parallelare checked for vectorization, onlycodeis printed. For debugging failures in the parallel kernel path, printing both would be helpful.🔎 Proposed enhancement
code = kernel.get_kernel_source() code_parallel = kernel_parallel.get_kernel_source() print(code) + print(code_parallel) assert check_str in code and check_str in code_parallel, f"Cast {src_dtype} to {dst_dtype} with {lanes=} is not vectorized!"
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/language/test_tilelang_language_vectorized_cast.py
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (2)
tilelang/language/v2/dtypes.py (6)
dtype(14-15)float4_e2m1fn(385-385)as_torch(15-15)float8_e4m3fn(336-336)float8_e5m2(350-350)float16(294-294)tilelang/testing/__init__.py (1)
requires_cuda_compute_version_ge(104-105)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (4)
testing/python/language/test_tilelang_language_vectorized_cast.py (4)
43-50: LGTM! Type-safe dtype objects improve API clarity.The refactor from string-based to dtype object parameters enhances type safety and aligns with the broader dtype-based invocation path introduced in this PR.
68-69: LGTM! Tensor creation and assertions correctly use dtype API.The tensor allocations and correctness checks properly use the
dtype.as_torch()method, maintaining consistency with the refactored dtype-based approach.Also applies to: 74-75
78-127: LGTM! Well-organized test structure with appropriate version guards.The three-tier test organization clearly separates:
- General vectorized casts (FP16, BF16, FP8, FP32)
- FP8-to-FP32 conversions requiring CUDA compute 8.9+
- FP4 conversions requiring CUDA compute 10.0+
The version guards ensure tests run only on compatible hardware, and parametrization provides comprehensive coverage.
61-62: FP4 runtime execution is intentionally skipped until ml_dtypes support is available.The early return prevents runtime kernel execution for
float4_e2m1fnbecause ml_dtypes does not yet have thefloat4_e2m1fnattribute, which is needed to convert PyTorch tensors to FP4 dtype for testing. Codegen validation (confirming vectorized intrinsics are generated correctly) still runs; runtime correctness cannot be validated until the ml_dtypes library adds FP4 support.
Summary by CodeRabbit
Refactor
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.