From 9059222dab0b32381de9cf04660dfc37a7001abe Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Dec 2025 17:29:31 +0800 Subject: [PATCH 1/5] fp4 related update, require_cu13 --- src/target/codegen_cuda.cc | 44 ++++++ src/tl_templates/cuda/cuda_fp4.h | 135 +++++++++++++----- .../test_tilelang_language_vectorized_cast.py | 11 +- 3 files changed, 157 insertions(+), 33 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 517e12094..88e499b24 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1101,6 +1101,50 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { } } + // Handle conversion from double to float4 (E2M1) + if (from_ty.is_float64() && target_ty.is_float4_e2m1fn()) { + // Use __tl_cvt_double2_to_fp4x2 for vectorized conversion (double2 -> + // fp4x2) + if (lanes == 2 || lanes == 4 || lanes == 8) { + PrintVectorizedCast("__tl_cvt_double2_to_fp4x2", "double2", "uint8_t", "", + false, true); + return; + } + } + + // Handle conversion from float4 (E2M1) to double + if (from_ty.is_float4_e2m1fn() && target_ty.is_float64()) { + // Use __tl_cvt_fp4x2_to_double2 for vectorized conversion (fp4x2 -> + // double2) + if (lanes == 2 || lanes == 4 || lanes == 8) { + PrintVectorizedCast("__tl_cvt_fp4x2_to_double2", "uint8_t", "double2", "", + true, false); + return; + } + } + + // Handle conversion from bfloat16 to float4 (E2M1) + if (from_ty.is_bfloat16() && target_ty.is_float4_e2m1fn()) { + // Use __tl_cvt_bfloat162_to_fp4x2 for vectorized conversion (bfloat162 -> + // fp4x2) + if (lanes == 2 || lanes == 4 || lanes == 8) { + PrintVectorizedCast("__tl_cvt_bfloat162_to_fp4x2", "__nv_bfloat162", + "uint8_t", "", false, true); + return; + } + } + + // Handle conversion from float4 (E2M1) to bfloat16 + if (from_ty.is_float4_e2m1fn() && target_ty.is_bfloat16()) { + // Use __tl_cvt_fp4x2_to_bfloat162 for vectorized conversion (fp4x2 -> + // bfloat162) + if (lanes == 2 || lanes == 4 || lanes == 8) { + PrintVectorizedCast("__tl_cvt_fp4x2_to_bfloat162", "uint8_t", + "__nv_bfloat162", "", true, false); + return; + } + } + // Fallback: elementwise cast for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { std::ostringstream val; diff --git a/src/tl_templates/cuda/cuda_fp4.h b/src/tl_templates/cuda/cuda_fp4.h index b76246442..a3684c69f 100644 --- a/src/tl_templates/cuda/cuda_fp4.h +++ b/src/tl_templates/cuda/cuda_fp4.h @@ -154,24 +154,50 @@ TL_DEVICE fp4_e2_32_t make_fp4_e2_32_t( return result; } +// ============================================================================ +// FP4 <-> Half Precision Conversions +// ============================================================================ +// https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__FP4__MISC.html + +// fp4_e2m1 -> half +TL_DEVICE __half __tl_cvt_fp4_to_half(const __nv_fp4_storage_t src) { + __half_raw raw = __nv_cvt_fp4_to_halfraw(src, __NV_E2M1); + __half result; + result = *reinterpret_cast<__half *>(&raw); + return result; +} + // fp4_e2m1x2 (1 byte) -> half2 -// Uses PTX cvt.rn.f16x2.e2m1x2 instruction -TL_DEVICE half2 __tl_cvt_fp4x2_to_half2(const uint8_t src) { - half2 out; - uint32_t *out_ptr = reinterpret_cast(&out); - uint16_t src_packed = static_cast(src); - asm volatile("{\n" - ".reg .b8 byte0, byte1;\n" - "mov.b16 {byte0, byte1}, %1;\n" - "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" - "}\n" - : "=r"(*out_ptr) - : "h"(src_packed)); - return out; +TL_DEVICE half2 __tl_cvt_fp4x2_to_half2(const __nv_fp4x2_storage_t src) { + __half2_raw raw = __nv_cvt_fp4x2_to_halfraw2(src, __NV_E2M1); + half2 result; + result = *reinterpret_cast(&raw); + return result; +} + +// half -> fp4_e2m1 +TL_DEVICE __nv_fp4_storage_t __tl_cvt_half_to_fp4(const __half src) { + __half_raw raw = *reinterpret_cast(&src); + return __nv_cvt_halfraw_to_fp4(raw, __NV_SATFINITE, __NV_E2M1); +} + +// half2 -> fp4_e2m1x2 (1 byte) +TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_half2_to_fp4x2(const half2 src) { + __half2_raw raw = *reinterpret_cast(&src); + return __nv_cvt_halfraw2_to_fp4x2(raw, __NV_SATFINITE, __NV_E2M1); +} + +// ============================================================================ +// FP4 <-> Float Conversions +// ============================================================================ + +// fp4_e2m1 -> float +TL_DEVICE float __tl_cvt_fp4_to_float(const __nv_fp4_storage_t src) { + return __half2float(__tl_cvt_fp4_to_half(src)); } // fp4_e2m1x2 (1 byte) -> float2 -TL_DEVICE float2 __tl_cvt_fp4x2_to_float2(const uint8_t src) { +TL_DEVICE float2 __tl_cvt_fp4x2_to_float2(const __nv_fp4x2_storage_t src) { half2 tmp = __tl_cvt_fp4x2_to_half2(src); float2 result; result.x = __half2float(tmp.x); @@ -179,27 +205,72 @@ TL_DEVICE float2 __tl_cvt_fp4x2_to_float2(const uint8_t src) { return result; } -// half2 -> fp4_e2m1x2 (1 byte) -// Uses PTX cvt.rn.satfinite.e2m1x2.f16x2 instruction -TL_DEVICE uint8_t __tl_cvt_half2_to_fp4x2(const half2 src) { - uint16_t out; - uint32_t const *src_ptr = reinterpret_cast(&src); - asm volatile("{\n" - ".reg .b8 result_byte;\n" - "cvt.rn.satfinite.e2m1x2.f16x2 result_byte, %1;\n" - "mov.b16 %0, {result_byte, 0};\n" - "}\n" - : "=h"(out) - : "r"(*src_ptr)); - return static_cast(out); +// float -> fp4_e2m1 +TL_DEVICE __nv_fp4_storage_t __tl_cvt_float_to_fp4(const float src) { + return __nv_cvt_float_to_fp4(src, __NV_SATFINITE, __NV_E2M1); } // float2 -> fp4_e2m1x2 (1 byte) -TL_DEVICE uint8_t __tl_cvt_float2_to_fp4x2(const float2 src) { - half2 tmp; - tmp.x = __float2half(src.x); - tmp.y = __float2half(src.y); - return __tl_cvt_half2_to_fp4x2(tmp); +TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_float2_to_fp4x2(const float2 src) { + return __nv_cvt_float2_to_fp4x2(src, __NV_SATFINITE, __NV_E2M1); +} + +// ============================================================================ +// FP4 <-> Double Conversions +// ============================================================================ + +// fp4_e2m1 -> double +TL_DEVICE double __tl_cvt_fp4_to_double(const __nv_fp4_storage_t src) { + return static_cast(__tl_cvt_fp4_to_float(src)); +} + +// fp4_e2m1x2 -> double2 +TL_DEVICE double2 __tl_cvt_fp4x2_to_double2(const __nv_fp4x2_storage_t src) { + float2 tmp = __tl_cvt_fp4x2_to_float2(src); + double2 result; + result.x = static_cast(tmp.x); + result.y = static_cast(tmp.y); + return result; +} + +// double -> fp4_e2m1 +TL_DEVICE __nv_fp4_storage_t __tl_cvt_double_to_fp4(const double src) { + return __nv_cvt_double_to_fp4(src, __NV_SATFINITE, __NV_E2M1); +} + +// double2 -> fp4_e2m1x2 +TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_double2_to_fp4x2(const double2 src) { + return __nv_cvt_double2_to_fp4x2(src, __NV_SATFINITE, __NV_E2M1); +} + +// ============================================================================ +// FP4 <-> BFloat16 Conversions +// ============================================================================ + +// fp4_e2m1 -> bfloat16 +TL_DEVICE __nv_bfloat16 __tl_cvt_fp4_to_bfloat16(const __nv_fp4_storage_t src) { + return __float2bfloat16(__tl_cvt_fp4_to_float(src)); +} + +// fp4_e2m1x2 -> bfloat162 +TL_DEVICE __nv_bfloat162 +__tl_cvt_fp4x2_to_bfloat162(const __nv_fp4x2_storage_t src) { + float2 tmp = __tl_cvt_fp4x2_to_float2(src); + return __floats2bfloat162_rn(tmp.x, tmp.y); +} + +// bfloat16 -> fp4_e2m1 +TL_DEVICE __nv_fp4_storage_t +__tl_cvt_bfloat16_to_fp4(const __nv_bfloat16 src) { + __nv_bfloat16_raw raw = *reinterpret_cast(&src); + return __nv_cvt_bfloat16raw_to_fp4(raw, __NV_SATFINITE, __NV_E2M1); +} + +// bfloat162 -> fp4_e2m1x2 +TL_DEVICE __nv_fp4x2_storage_t +__tl_cvt_bfloat162_to_fp4x2(const __nv_bfloat162 src) { + __nv_bfloat162_raw raw = *reinterpret_cast(&src); + return __nv_cvt_bfloat16raw2_to_fp4x2(raw, __NV_SATFINITE, __NV_E2M1); } #endif diff --git a/testing/python/language/test_tilelang_language_vectorized_cast.py b/testing/python/language/test_tilelang_language_vectorized_cast.py index 991b2a8eb..2710ee685 100644 --- a/testing/python/language/test_tilelang_language_vectorized_cast.py +++ b/testing/python/language/test_tilelang_language_vectorized_cast.py @@ -116,10 +116,18 @@ def test_vectorized_cast_fp8(src_dtype, dst_dtype, check_str, lanes): @pytest.mark.parametrize( "src_dtype, dst_dtype, check_str, lanes", [ + # FP4 <-> Half (T.float4_e2m1fn, T.float16, "__tl_cvt_fp4x2_to_half2", 2), (T.float16, T.float4_e2m1fn, "__tl_cvt_half2_to_fp4x2", 2), + # FP4 <-> Float (T.float4_e2m1fn, T.float32, "__tl_cvt_fp4x2_to_float2", 2), (T.float32, T.float4_e2m1fn, "__tl_cvt_float2_to_fp4x2", 2), + # FP4 <-> Double + (T.float4_e2m1fn, T.float64, "__tl_cvt_fp4x2_to_double2", 2), + (T.float64, T.float4_e2m1fn, "__tl_cvt_double2_to_fp4x2", 2), + # FP4 <-> BFloat16 + (T.float4_e2m1fn, T.bfloat16, "__tl_cvt_fp4x2_to_bfloat162", 2), + (T.bfloat16, T.float4_e2m1fn, "__tl_cvt_bfloat162_to_fp4x2", 2), ], ) def test_vectorized_cast_fp4(src_dtype, dst_dtype, check_str, lanes): @@ -127,4 +135,5 @@ def test_vectorized_cast_fp4(src_dtype, dst_dtype, check_str, lanes): if __name__ == "__main__": - tilelang.testing.main() + # tilelang.testing.main() + test_vectorized_cast_fp4(T.float4_e2m1fn, T.float32, "__tl_cvt_fp4x2_to_float2", 2) From 6a778a53ae43ff7645dac66cc88be5d907ec4e45 Mon Sep 17 00:00:00 2001 From: Zhiwen Mo Date: Wed, 24 Dec 2025 13:57:25 +0800 Subject: [PATCH 2/5] Enhance CUDA type conversion handling and optimize dtype management - Updated CUDA vectorized cast functions to ensure proper handling of float16, float32, bfloat16, and float8 conversions, adding checks for bit sizes. - Refactored dtype conversion logic in `cuda_fp4.h` to utilize `cudaRoundZero` for improved accuracy in floating-point conversions. - Introduced a new method in `KernelParam` to convert TVM DataType to TileLang dtype. - Adjusted argument binding logic in `arg_binder.cc` to allow for better subtype matching based on total bit counts. - Enhanced dtype handling in `dtypes.py` to accommodate new float4_e2m1fn types and ensure compatibility with PyTorch. This update aims to improve type safety and conversion accuracy across the codebase. --- src/target/codegen_cuda.cc | 17 ++++++---- src/tl_templates/cuda/cuda_fp4.h | 19 +++++------ src/transform/arg_binder.cc | 57 ++++++++++++++++++++------------ tilelang/engine/param.py | 9 +++++ tilelang/jit/adapter/tvm_ffi.py | 6 ++++ tilelang/language/v2/dtypes.py | 9 +++-- 6 files changed, 77 insertions(+), 40 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 88e499b24..013e151f5 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -995,7 +995,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { }; // Handle conversion from float16 to float32 - if (from_ty.is_float16() && target_ty.is_float()) { + if (from_ty.is_float16() && target_ty.is_float() && target_ty.bits() == 32) { // Use __half22float2 for vectorized conversion (half2 -> float2) if (lanes == 2 || lanes == 4 || lanes == 8) { PrintVectorizedCast("__half22float2", "half2", "float2"); @@ -1004,7 +1004,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { } // Handle conversion from float32 to float16 - if (from_ty.is_float() && target_ty.is_float16()) { + if (from_ty.is_float() && from_ty.bits() == 32 && target_ty.is_float16()) { // Use __float22half2_rn for vectorized conversion (float2 -> half2) if (lanes == 2 || lanes == 4 || lanes == 8) { PrintVectorizedCast("__float22half2_rn", "float2", "half2"); @@ -1013,7 +1013,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { } // Handle conversion from bfloat16 to float32 - if (from_ty.is_bfloat16() && target_ty.is_float()) { + if (from_ty.is_bfloat16() && target_ty.is_float() && target_ty.bits() == 32) { // Use __bfloat1622float2 for vectorized conversion (bfloat162 -> float2) if (lanes == 2 || lanes == 4 || lanes == 8) { PrintVectorizedCast("__bfloat1622float2", "__nv_bfloat162", "float2", "", @@ -1023,7 +1023,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { } // Handle conversion from float32 to bfloat16 - if (from_ty.is_float() && target_ty.is_bfloat16()) { + if (from_ty.is_float() && from_ty.bits() == 32 && target_ty.is_bfloat16()) { // Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162) if (lanes == 2 || lanes == 4 || lanes == 8) { PrintVectorizedCast("__float22bfloat162_rn", "float2", "__nv_bfloat162", @@ -1033,7 +1033,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { } // Handle conversion from float32 to float8 (E4M3/E5M2) - if (from_ty.is_float() && tl::IsCudaVectorizableFP8(target_ty)) { + if (from_ty.is_float() && from_ty.bits() == 32 && + tl::IsCudaVectorizableFP8(target_ty)) { bool target_type_is_e4m3 = target_ty.is_float8_e4m3() || target_ty.is_float8_e4m3fn(); std::string type_suffix = target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2"; @@ -1102,7 +1103,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { } // Handle conversion from double to float4 (E2M1) - if (from_ty.is_float64() && target_ty.is_float4_e2m1fn()) { + if (from_ty.is_float() && from_ty.bits() == 64 && + target_ty.is_float4_e2m1fn()) { // Use __tl_cvt_double2_to_fp4x2 for vectorized conversion (double2 -> // fp4x2) if (lanes == 2 || lanes == 4 || lanes == 8) { @@ -1113,7 +1115,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { } // Handle conversion from float4 (E2M1) to double - if (from_ty.is_float4_e2m1fn() && target_ty.is_float64()) { + if (from_ty.is_float4_e2m1fn() && target_ty.is_float() && + target_ty.bits() == 64) { // Use __tl_cvt_fp4x2_to_double2 for vectorized conversion (fp4x2 -> // double2) if (lanes == 2 || lanes == 4 || lanes == 8) { diff --git a/src/tl_templates/cuda/cuda_fp4.h b/src/tl_templates/cuda/cuda_fp4.h index a3684c69f..22cc0460c 100644 --- a/src/tl_templates/cuda/cuda_fp4.h +++ b/src/tl_templates/cuda/cuda_fp4.h @@ -178,13 +178,13 @@ TL_DEVICE half2 __tl_cvt_fp4x2_to_half2(const __nv_fp4x2_storage_t src) { // half -> fp4_e2m1 TL_DEVICE __nv_fp4_storage_t __tl_cvt_half_to_fp4(const __half src) { __half_raw raw = *reinterpret_cast(&src); - return __nv_cvt_halfraw_to_fp4(raw, __NV_SATFINITE, __NV_E2M1); + return __nv_cvt_halfraw_to_fp4(raw, __NV_E2M1, cudaRoundZero); } // half2 -> fp4_e2m1x2 (1 byte) TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_half2_to_fp4x2(const half2 src) { __half2_raw raw = *reinterpret_cast(&src); - return __nv_cvt_halfraw2_to_fp4x2(raw, __NV_SATFINITE, __NV_E2M1); + return __nv_cvt_halfraw2_to_fp4x2(raw, __NV_E2M1, cudaRoundZero); } // ============================================================================ @@ -207,12 +207,12 @@ TL_DEVICE float2 __tl_cvt_fp4x2_to_float2(const __nv_fp4x2_storage_t src) { // float -> fp4_e2m1 TL_DEVICE __nv_fp4_storage_t __tl_cvt_float_to_fp4(const float src) { - return __nv_cvt_float_to_fp4(src, __NV_SATFINITE, __NV_E2M1); + return __nv_cvt_float_to_fp4(src, __NV_E2M1, cudaRoundZero); } // float2 -> fp4_e2m1x2 (1 byte) TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_float2_to_fp4x2(const float2 src) { - return __nv_cvt_float2_to_fp4x2(src, __NV_SATFINITE, __NV_E2M1); + return __nv_cvt_float2_to_fp4x2(src, __NV_E2M1, cudaRoundZero); } // ============================================================================ @@ -235,12 +235,12 @@ TL_DEVICE double2 __tl_cvt_fp4x2_to_double2(const __nv_fp4x2_storage_t src) { // double -> fp4_e2m1 TL_DEVICE __nv_fp4_storage_t __tl_cvt_double_to_fp4(const double src) { - return __nv_cvt_double_to_fp4(src, __NV_SATFINITE, __NV_E2M1); + return __nv_cvt_double_to_fp4(src, __NV_E2M1, cudaRoundZero); } // double2 -> fp4_e2m1x2 TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_double2_to_fp4x2(const double2 src) { - return __nv_cvt_double2_to_fp4x2(src, __NV_SATFINITE, __NV_E2M1); + return __nv_cvt_double2_to_fp4x2(src, __NV_E2M1, cudaRoundZero); } // ============================================================================ @@ -260,17 +260,16 @@ __tl_cvt_fp4x2_to_bfloat162(const __nv_fp4x2_storage_t src) { } // bfloat16 -> fp4_e2m1 -TL_DEVICE __nv_fp4_storage_t -__tl_cvt_bfloat16_to_fp4(const __nv_bfloat16 src) { +TL_DEVICE __nv_fp4_storage_t __tl_cvt_bfloat16_to_fp4(const __nv_bfloat16 src) { __nv_bfloat16_raw raw = *reinterpret_cast(&src); - return __nv_cvt_bfloat16raw_to_fp4(raw, __NV_SATFINITE, __NV_E2M1); + return __nv_cvt_bfloat16raw_to_fp4(raw, __NV_E2M1, cudaRoundZero); } // bfloat162 -> fp4_e2m1x2 TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_bfloat162_to_fp4x2(const __nv_bfloat162 src) { __nv_bfloat162_raw raw = *reinterpret_cast(&src); - return __nv_cvt_bfloat16raw2_to_fp4x2(raw, __NV_SATFINITE, __NV_E2M1); + return __nv_cvt_bfloat16raw2_to_fp4x2(raw, __NV_E2M1, cudaRoundZero); } #endif diff --git a/src/transform/arg_binder.cc b/src/transform/arg_binder.cc index 294c9f6bc..badcfa620 100644 --- a/src/transform/arg_binder.cc +++ b/src/transform/arg_binder.cc @@ -333,9 +333,7 @@ void ArgBinder::BindDLTensors( // Scan buffer shape for symbolic variables for (size_t k = 0; k < buffer->shape.size(); ++k) { - if (buffer->dtype == DataType::Int(4) || - buffer->dtype == DataType::UInt(4) || - buffer->dtype == DataType::Int(1)) { + if (buffer->dtype.bits() < 8) { break; } @@ -524,21 +522,40 @@ void ArgBinder::BindDLTensors( cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok; } - // Allow float4 to match int8 at runtime (PyTorch uses int8 as storage for - // FP4). - if (buffer->dtype.is_float4()) { - PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt); - PrimExpr bits8 = IntImm(DataType::UInt(8), 8); - // For FP4, we pack 2 elements per byte, but we still use same lanes at - // storage level Accept int8 with same lanes as the fp4 type - PrimExpr fp4_lanes_ok = (v_type_lanes == expect_lanes); - PrimExpr int8_ok = - (v_type_code == code_int && v_type_bits == bits8 && fp4_lanes_ok); - cond = cond || int8_ok; + // Allow with bits < 8 to match any type with the same total bit count at + // runtime (PyTorch uses int8 as storage for FP4). + bool data_is_subtype = buffer->dtype.bits() < 8; + if (data_is_subtype) { + // Get the pre-created shape buffer for reading runtime shape + Buffer buf_shape = shape_buffer_map[arg_name]; + + // Calculate expected total bits using compile-time buffer->shape + PrimExpr expect_total_bits = + cast(DataType::UInt(64), expect_bits) * + cast(DataType::UInt(64), expect_lanes) * + cast(DataType::UInt(64), + buffer->shape.size() == 0 + ? make_const(DataType::UInt(64), 1) + : foldl([](PrimExpr a, PrimExpr b, Span) { return a * b; }, + make_const(DataType::UInt(64), 1), buffer->shape)); + + // Calculate actual total bits using runtime shape from DLTensor + PrimExpr actual_total_bits = cast(DataType::UInt(64), v_type_bits) * + cast(DataType::UInt(64), v_type_lanes); + for (size_t k = 0; k < buffer->shape.size(); ++k) { + PrimExpr dim_val = + cast(DataType::UInt(64), + BufferLoad(buf_shape, + {IntImm(DataType::Int(32), static_cast(k))})); + actual_total_bits = actual_total_bits * dim_val; + } + + PrimExpr bits_match = (actual_total_bits == expect_total_bits); + BinderAddAssert(&analyzer_, bits_match, + arg_name + " is a subtype, but total bits mismatch", + &asserts_, is_null); } - if (!(buffer->dtype == DataType::Int(1) || - buffer->dtype == DataType::Int(4) || - buffer->dtype == DataType::UInt(4) || buffer->dtype.is_float4())) { + if (!data_is_subtype) { // Build FFI packed call to __tvm_error_dtype_mismatch when mismatch // occurs. Only issue the call when handle is non-NULL and cond is false. ffi::Array packed_args; @@ -578,9 +595,7 @@ void ArgBinder::BindDLTensors( for (size_t k = 0; k < buffer->shape.size(); ++k) { // These packed-bit dtype shapes were not bound in the original // implementation, so we just use them as is. - if (buffer->dtype == DataType::Int(4) || - buffer->dtype == DataType::UInt(4) || - buffer->dtype == DataType::Int(1)) { + if (data_is_subtype) { break; } @@ -925,4 +940,4 @@ void ArgBinder::BindDLTensors( } } // namespace tl -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/tilelang/engine/param.py b/tilelang/engine/param.py index fe023f83f..98ef6f0e1 100644 --- a/tilelang/engine/param.py +++ b/tilelang/engine/param.py @@ -140,6 +140,15 @@ def torch_dtype(self) -> torch.dtype: """ return T.dtype(self.dtype).as_torch() + def tilelang_dtype(self) -> T.dtype: + """ + Converts the TVM DataType to TileLang dtype. + + Returns: + T.dtype: Corresponding TileLang dtype + """ + return T.dtype(self.dtype) + @dataclass class CompiledArtifact: diff --git a/tilelang/jit/adapter/tvm_ffi.py b/tilelang/jit/adapter/tvm_ffi.py index fdba92c21..cd473efad 100644 --- a/tilelang/jit/adapter/tvm_ffi.py +++ b/tilelang/jit/adapter/tvm_ffi.py @@ -19,6 +19,7 @@ from tilelang.jit.adapter.base import BaseKernelAdapter from tilelang.utils.language import retrieve_func_from_module from tilelang.engine.param import KernelParam +from tilelang.language.v2.dtypes import dtype class TVMFFIKernelAdapter(BaseKernelAdapter): @@ -149,6 +150,11 @@ def _convert_torch_func(self) -> Callable[..., Any]: native_shape.append(dim) # Keep tir.Var for dynamic dimensions else: native_shape.append(dim) + tl_dtype = param.dtype + if tl_dtype.bits < 8: + stroage_dtype: dtype = dtype(param.torch_dtype()) + # last dim divide by bits to get the actual shape + native_shape[-1] = native_shape[-1] * tl_dtype.bits * tl_dtype.lanes // (stroage_dtype.bits * stroage_dtype.lanes) param_shapes.append(native_shape) if self.executable is None: diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index 1649da6e4..a29c57ff9 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -79,7 +79,12 @@ def as_torch(self) -> torch.dtype: ... ] for dtype_name_tuple in _extended_torch_dtypes: dtype_name = dtype_name_tuple[0] - torch_dtype = getattr(torch, dtype_name, None) + torch_dtype = None + if dtype_name == "float4_e2m1fnx2": + torch_dtype = getattr(torch, "float4_e2m1fn_x2", None) + else: + torch_dtype = getattr(torch, dtype_name, None) + if torch_dtype is not None: _TORCH_DTYPE_TO_STR[torch_dtype] = dtype_name @@ -193,7 +198,7 @@ def __dtype_as_torch__(self: dtype) -> torch.dtype: return torch.float4_e2m1fn_x2 elif dtype_str == "float4_e2m1fn": logger.info("torch doesn't support float4_e2m1fn, using float4_e2m1fnx2 as storage dtype.") - return torch.float4_e2m1fn_x2 + return torch.float4_e2m1fn_x2 if hasattr(torch, "float4_e2m1fn_x2") else torch.int8 elif dtype_str in _STR_TO_TORCH_DTYPE: return _STR_TO_TORCH_DTYPE[dtype_str] From 49bc41d18d4c75f11ad0ed555044835ae9f73663 Mon Sep 17 00:00:00 2001 From: Zhiwen Mo Date: Wed, 24 Dec 2025 14:09:48 +0800 Subject: [PATCH 3/5] lint fix --- src/transform/arg_binder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transform/arg_binder.cc b/src/transform/arg_binder.cc index badcfa620..4dc92f9d8 100644 --- a/src/transform/arg_binder.cc +++ b/src/transform/arg_binder.cc @@ -534,7 +534,7 @@ void ArgBinder::BindDLTensors( cast(DataType::UInt(64), expect_bits) * cast(DataType::UInt(64), expect_lanes) * cast(DataType::UInt(64), - buffer->shape.size() == 0 + buffer->shape.empty() ? make_const(DataType::UInt(64), 1) : foldl([](PrimExpr a, PrimExpr b, Span) { return a * b; }, make_const(DataType::UInt(64), 1), buffer->shape)); From d7efd1d605ce06aca3bb01e9e5272e111dfb918f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 24 Dec 2025 14:41:15 +0800 Subject: [PATCH 4/5] lint fix --- .../python/language/test_tilelang_language_vectorized_cast.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/testing/python/language/test_tilelang_language_vectorized_cast.py b/testing/python/language/test_tilelang_language_vectorized_cast.py index 2710ee685..33d40e679 100644 --- a/testing/python/language/test_tilelang_language_vectorized_cast.py +++ b/testing/python/language/test_tilelang_language_vectorized_cast.py @@ -135,5 +135,4 @@ def test_vectorized_cast_fp4(src_dtype, dst_dtype, check_str, lanes): if __name__ == "__main__": - # tilelang.testing.main() - test_vectorized_cast_fp4(T.float4_e2m1fn, T.float32, "__tl_cvt_fp4x2_to_float2", 2) + tilelang.testing.main() From d25be3282e276fc97e1426953ae148d0826bc776 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 24 Dec 2025 14:46:06 +0800 Subject: [PATCH 5/5] typo fix --- examples/flash_decoding/example_gqa_decode_varlen_logits.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits.py b/examples/flash_decoding/example_gqa_decode_varlen_logits.py index 8f26a59c3..8e2db8727 100644 --- a/examples/flash_decoding/example_gqa_decode_varlen_logits.py +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits.py @@ -197,7 +197,6 @@ def get_configs(): return configs -@autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[-2, -1]) def flashattn( batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128