[Feature] Support E8M0 related type conversion and vectorized cast#1731
[Feature] Support E8M0 related type conversion and vectorized cast#1731LeiWang1999 merged 3 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThis PR adds vectorized cast support for CUDA's FP8 E8M0 format, enabling efficient conversions between E8M0 and BFloat16, as well as conversions from float/double types. The implementation spans code generation, utility checks, FP8 template functions, and corresponding test coverage. Changes
Sequence DiagramsequenceDiagram
participant Frontend as TileLang Frontend
participant Codegen as CodeGenTileLangCUDA
participant Utils as Vectorization Utils
participant Templates as CUDA FP8 Templates
participant GPU as GPU Execution
Frontend->>Codegen: VisitExpr_(CastNode)
activate Codegen
Codegen->>Utils: IsCudaVectorizableCast(src, dst)
activate Utils
Utils->>Utils: Check if E8M0↔BFloat16<br/>or float/double→E8M0
Utils-->>Codegen: vectorizable = true
deactivate Utils
alt Vectorized Path
Codegen->>Codegen: Generate vectorized<br/>call to template function
Codegen-->>Templates: e.g., __tl_cvt_e8m0x2_to_bfloat162()
activate Templates
Templates->>Templates: Reinterpret cast +<br/>NVIDIA intrinsics
Templates-->>GPU: CUDA kernel code
deactivate Templates
else Non-Vectorized Path
Codegen-->>GPU: Scalar cast code
end
deactivate Codegen
GPU->>GPU: Execute cast conversion
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In `@src/target/codegen_cuda.cc`:
- Around line 1277-1299: Update the misleading comments that reference
__nv_cvt_* wrappers to the TileLang wrapper names actually used: change the
comment mentioning __nv_cvt_float2_to_e8m0x2 to __tl_cvt_float2_to_e8m0x2 and
the one mentioning __nv_cvt_double2_to_e8m0x2 to __tl_cvt_double2_to_e8m0x2 so
they match the call sites that invoke
PrintVectorizedCast("__tl_cvt_float2_to_e8m0x2", "float2",
"__nv_fp8x2_storage_t", "", false, true) and
PrintVectorizedCast("__tl_cvt_double2_to_e8m0x2", "double2",
"__nv_fp8x2_storage_t", "", false, true).
In `@src/target/utils.cc`:
- Around line 185-191: The comments above the two conversion checks are
incorrect: they mention "E4M3/E5M2" but the code tests
target_ty.is_float8_e8m0fnu(); update the comment text to accurately describe
the checked target (e.g., mention float8 E8M0FNU/float8_e8m0fnu) or otherwise
make the comment consistent with the condition using from_ty.is_bfloat16(),
from_ty.is_float(), and target_ty.is_float8_e8m0fnu(); ensure both comment lines
reference the correct float8 format and keep wording consistent with the
is_float8_e8m0fnu() predicate.
In `@src/tl_templates/cuda/cuda_fp8.h`:
- Line 319: The preprocessor guard using "defined(TL_HAS_FP8_E8M0)" is incorrect
because TL_HAS_FP8_E8M0 is defined as 0 or 1; update the check around the E8M0
code (the `#if` that currently reads "#if defined(TL_HAS_FP8_E8M0)") to test the
macro's value instead (e.g., "#if TL_HAS_FP8_E8M0" or "#if TL_HAS_FP8_E8M0 ==
1") so the E8M0 blocks in cuda_fp8.h are only compiled when TL_HAS_FP8_E8M0 is
set to 1 (CUDA >= 12.6).
In `@testing/python/language/test_tilelang_language_vectorized_cast.py`:
- Around line 114-118: Update the inaccurate comment "E8M0 <-> FP16" to reflect
that the conversions in the test tuples use bfloat16 (BF16), not float16 (FP16);
locate the test entries that include T.float8_e8m0fnu and T.bfloat16 (the tuples
with "__tl_cvt_e8m0x2_to_bfloat162", "__tl_cvt_bfloat162_to_e8m0x2",
"__tl_cvt_float2_to_e8m0x2", "__tl_cvt_double2_to_e8m0x2") and change the
comment to something like "E8M0 <-> BF16" or otherwise mention bfloat16.
🧹 Nitpick comments (1)
examples/gemm/example_gemm_autotune.py (1)
222-222: Consider making the kernel source print conditional or removing it.This print statement will output the full kernel source code on every run, which may produce excessive output. If this is intended for debugging, consider guarding it with a verbose flag or removing it for cleaner example output.
♻️ Suggested alternatives
Option 1: Remove if unintended:
- print(kernel.get_kernel_source())Option 2: Make it conditional:
+ if os.environ.get("TILELANG_DEBUG"): + print(kernel.get_kernel_source()) - print(kernel.get_kernel_source())
Address #1710
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.