From 9923cd8bf451ed0af4311b0ce9b34fce903b3f03 Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Tue, 21 Oct 2025 15:36:49 +0800 Subject: [PATCH 01/17] [BugFix] Correct direct copy from bf16 to fp8 --- src/op/copy.cc | 7 ++++- .../python/issue/test_tilelang_issue_1046.py | 28 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) create mode 100644 testing/python/issue/test_tilelang_issue_1046.py diff --git a/src/op/copy.cc b/src/op/copy.cc index a16d09dad..c5b39619e 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -324,8 +324,13 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1); PrimExpr value = BufferLoad(src, src_indices); - if (src->dtype != dst->dtype) + if (src->dtype != dst->dtype) { + // If dst is fp8 and src is bf16, first cast dst to fp32. + if (src->dtype.is_bfloat16() && dst->dtype.is_float8_e4m3()) { + value = Cast(DataType::Float(32), value); + } value = Cast(dst->dtype, value); + } if (src_predicate.defined()) value = if_then_else(src_predicate, value, make_zero(dst->dtype)); diff --git a/testing/python/issue/test_tilelang_issue_1046.py b/testing/python/issue/test_tilelang_issue_1046.py new file mode 100644 index 000000000..26a30bb1b --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_1046.py @@ -0,0 +1,28 @@ +import tilelang +import tilelang.language as T + +tilelang.disable_cache() + +FP8 = "float8_e4m3" +BF16 = "bfloat16" + + +@tilelang.jit +def test_kernel(N, in_dtype=BF16, out_dtype=FP8): + M = T.symbolic("M") + blk_m = 128 + group_size = 128 + + @T.prim_func + def test_kernel_(X: T.Tensor[(M, N), in_dtype], Y: T.Tensor[(M, N), out_dtype]): + with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (pid_m, pid_n): + x_shared = T.alloc_shared((blk_m, group_size), in_dtype) + T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared) + T.copy(x_shared, Y[pid_m * blk_m, pid_n * group_size]) + + return test_kernel_ + + +kernel = test_kernel(128) + +print(kernel.get_kernel_source()) From cb907dd6092f3fa98eef30642b000ee91b83cc26 Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Tue, 21 Oct 2025 15:39:16 +0800 Subject: [PATCH 02/17] fix lint --- testing/python/issue/test_tilelang_issue_1046.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/python/issue/test_tilelang_issue_1046.py b/testing/python/issue/test_tilelang_issue_1046.py index 26a30bb1b..b7f15d2da 100644 --- a/testing/python/issue/test_tilelang_issue_1046.py +++ b/testing/python/issue/test_tilelang_issue_1046.py @@ -9,7 +9,7 @@ @tilelang.jit def test_kernel(N, in_dtype=BF16, out_dtype=FP8): - M = T.symbolic("M") + M = T.dynamic("M") blk_m = 128 group_size = 128 From 8281b05490b91b587b5378e57d41b71c6f4e9bce Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Wed, 22 Oct 2025 02:42:00 +0800 Subject: [PATCH 03/17] implement overloaded cast codegen for type conversion --- src/op/copy.cc | 4 ---- src/target/codegen_cuda.cc | 5 ++++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index c5b39619e..fbebd2d81 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -325,10 +325,6 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { PrimExpr value = BufferLoad(src, src_indices); if (src->dtype != dst->dtype) { - // If dst is fp8 and src is bf16, first cast dst to fp32. - if (src->dtype.is_bfloat16() && dst->dtype.is_float8_e4m3()) { - value = Cast(DataType::Float(32), value); - } value = Cast(dst->dtype, value); } if (src_predicate.defined()) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index d06e7170d..0fb9a7c60 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -953,12 +953,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { } } + const char *convert_part = + (from_ty.is_bfloat16() || target_ty.is_float8_e4m3()) ? ")(half)(" : ")("; + // Fallback: elementwise cast for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { std::ostringstream val; val << "("; PrintType(target_ty.element_of(), val); - val << ")("; + val << convert_part; PrintVecElemLoad(src, from_ty, i, val); val << ")"; PrintVecElemStore(sret, target_ty, i, val.str()); From 3051cf4a0bba6a988990ef9c8fc4ceac839ed019 Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Wed, 22 Oct 2025 11:25:12 +0800 Subject: [PATCH 04/17] fix lint --- src/op/copy.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index fbebd2d81..a16d09dad 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -324,9 +324,8 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1); PrimExpr value = BufferLoad(src, src_indices); - if (src->dtype != dst->dtype) { + if (src->dtype != dst->dtype) value = Cast(dst->dtype, value); - } if (src_predicate.defined()) value = if_then_else(src_predicate, value, make_zero(dst->dtype)); From 37804b3750fb796a4d44786d8d33165bcc2e41a9 Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Wed, 22 Oct 2025 12:17:30 +0800 Subject: [PATCH 05/17] remove test --- .../python/issue/test_tilelang_issue_1046.py | 28 ------------------- 1 file changed, 28 deletions(-) delete mode 100644 testing/python/issue/test_tilelang_issue_1046.py diff --git a/testing/python/issue/test_tilelang_issue_1046.py b/testing/python/issue/test_tilelang_issue_1046.py deleted file mode 100644 index b7f15d2da..000000000 --- a/testing/python/issue/test_tilelang_issue_1046.py +++ /dev/null @@ -1,28 +0,0 @@ -import tilelang -import tilelang.language as T - -tilelang.disable_cache() - -FP8 = "float8_e4m3" -BF16 = "bfloat16" - - -@tilelang.jit -def test_kernel(N, in_dtype=BF16, out_dtype=FP8): - M = T.dynamic("M") - blk_m = 128 - group_size = 128 - - @T.prim_func - def test_kernel_(X: T.Tensor[(M, N), in_dtype], Y: T.Tensor[(M, N), out_dtype]): - with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (pid_m, pid_n): - x_shared = T.alloc_shared((blk_m, group_size), in_dtype) - T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared) - T.copy(x_shared, Y[pid_m * blk_m, pid_n * group_size]) - - return test_kernel_ - - -kernel = test_kernel(128) - -print(kernel.get_kernel_source()) From 999e74e744a525e0d1e02e33b9cea3ba497853e4 Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Wed, 22 Oct 2025 14:10:06 +0800 Subject: [PATCH 06/17] fix lint --- src/target/codegen_cuda.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 0fb9a7c60..296ca5655 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -954,7 +954,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { } const char *convert_part = - (from_ty.is_bfloat16() || target_ty.is_float8_e4m3()) ? ")(half)(" : ")("; + (from_ty.is_bfloat16() && + (target_ty.is_float8_e4m3() || target_ty.is_float8_e5m2())) + ? ")(half)(" + : ")("; // Fallback: elementwise cast for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { From 900ae67a24f7caaf1a48ba78cf26de16f143b071 Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Wed, 22 Oct 2025 20:20:54 +0800 Subject: [PATCH 07/17] trigger CI From 5c251475baf7453a30045a9710edc984ad68a4fb Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Thu, 23 Oct 2025 13:33:55 +0800 Subject: [PATCH 08/17] Overload fp8 for implicit conversion --- src/target/codegen_cuda.cc | 9 ++------- src/tl_templates/cuda/common.h | 24 ++++++++++++++++++++++++ src/tl_templates/cuda/cuda_fp8.h | 5 +++-- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 296ca5655..e1950269c 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -953,23 +953,18 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { } } - const char *convert_part = - (from_ty.is_bfloat16() && - (target_ty.is_float8_e4m3() || target_ty.is_float8_e5m2())) - ? ")(half)(" - : ")("; - // Fallback: elementwise cast for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { std::ostringstream val; val << "("; PrintType(target_ty.element_of(), val); - val << convert_part; + val << ")("; PrintVecElemLoad(src, from_ty, i, val); val << ")"; PrintVecElemStore(sret, target_ty, i, val.str()); } + if (used_bf16_op) { stream << "#endif\n"; } diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index 34a30821b..cfa65ba0e 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -10,6 +10,9 @@ #include #include +#include +#include + using cutlass::bfloat16_t; using cutlass::half_t; using cutlass::tfloat32_t; @@ -318,6 +321,27 @@ TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor, descriptor.reg32_[0] += (offset >> 4); } +// and add the desired implicit conversion from bfloat16_t. +struct float_e4m3_t : public cutlass::float_e4m3_t { + using cutlass::float_e4m3_t::float_e4m3_t; + CUTLASS_HOST_DEVICE + float_e4m3_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(__nv_bfloat16 x) : float_e4m3_t(static_cast(x)) { + } +}; + +struct float_e5m2_t : public cutlass::float_e5m2_t { + using cutlass::float_e5m2_t::float_e5m2_t; + CUTLASS_HOST_DEVICE + float_e5m2_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(__nv_bfloat16 x) : float_e5m2_t(static_cast(x)) { + } +}; + } // namespace tl namespace cutlass { diff --git a/src/tl_templates/cuda/cuda_fp8.h b/src/tl_templates/cuda/cuda_fp8.h index 8d2165822..b161f3a7f 100644 --- a/src/tl_templates/cuda/cuda_fp8.h +++ b/src/tl_templates/cuda/cuda_fp8.h @@ -2,9 +2,10 @@ #include #include +#include "common.h" -using fp8_e4_t = cute::float_e4m3_t; -using fp8_e5_t = cute::float_e5m2_t; +using fp8_e4_t = tl::float_e4m3_t; +using fp8_e5_t = tl::float_e5m2_t; struct __CUDA_ALIGN__(2) fp8_e4_2_t { fp8_e4_t x; From c49edeeb7318a80e17388dd12d51bd31f15f4981 Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Thu, 23 Oct 2025 13:35:44 +0800 Subject: [PATCH 09/17] format --- src/target/codegen_cuda.cc | 1 - src/tl_templates/cuda/common.h | 10 +++++----- src/tl_templates/cuda/cuda_fp8.h | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index e1950269c..d06e7170d 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -964,7 +964,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { PrintVecElemStore(sret, target_ty, i, val.str()); } - if (used_bf16_op) { stream << "#endif\n"; } diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index cfa65ba0e..d47b59ef8 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -10,8 +10,8 @@ #include #include -#include #include +#include using cutlass::bfloat16_t; using cutlass::half_t; @@ -328,8 +328,8 @@ struct float_e4m3_t : public cutlass::float_e4m3_t { float_e4m3_t() = default; CUTLASS_HOST_DEVICE - explicit float_e4m3_t(__nv_bfloat16 x) : float_e4m3_t(static_cast(x)) { - } + explicit float_e4m3_t(__nv_bfloat16 x) + : float_e4m3_t(static_cast(x)) {} }; struct float_e5m2_t : public cutlass::float_e5m2_t { @@ -338,8 +338,8 @@ struct float_e5m2_t : public cutlass::float_e5m2_t { float_e5m2_t() = default; CUTLASS_HOST_DEVICE - explicit float_e5m2_t(__nv_bfloat16 x) : float_e5m2_t(static_cast(x)) { - } + explicit float_e5m2_t(__nv_bfloat16 x) + : float_e5m2_t(static_cast(x)) {} }; } // namespace tl diff --git a/src/tl_templates/cuda/cuda_fp8.h b/src/tl_templates/cuda/cuda_fp8.h index b161f3a7f..2efb8f111 100644 --- a/src/tl_templates/cuda/cuda_fp8.h +++ b/src/tl_templates/cuda/cuda_fp8.h @@ -1,8 +1,8 @@ #pragma once +#include "common.h" #include #include -#include "common.h" using fp8_e4_t = tl::float_e4m3_t; using fp8_e5_t = tl::float_e5m2_t; From 0aad65168ece3eba65f0c7455f387d3aef8d6aa8 Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Thu, 23 Oct 2025 13:50:39 +0800 Subject: [PATCH 10/17] new format --- tilelang/language/allocate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 55e1fdfd5..de86f270a 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -1,3 +1,4 @@ +from __future__ import annotations """Memory allocation utilities for Tile-AI programs. This module provides a set of functions for allocating different types of memory buffers @@ -67,7 +68,7 @@ def alloc_fragment(shape, dtype, scope="local.fragment"): return T.alloc_buffer(shape, dtype, scope=scope) -def alloc_var(dtype, *args, scope="local.var", init: Union[PrimExpr] = None): +def alloc_var(dtype, *args, scope="local.var", init: Union[PrimExpr] = None): # noqa: UP007 """Allocate a single-element variable buffer. Args: From e448754145105e95d52fad27db96c730ba9284ee Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Thu, 23 Oct 2025 15:14:14 +0800 Subject: [PATCH 11/17] fix: Reinterpret types to cute types in GEMM --- src/tl_templates/cuda/common.h | 8 ++++---- src/tl_templates/cuda/gemm_mma.h | 10 ++++++++-- src/tl_templates/cuda/gemm_sm90.h | 13 +++++++++---- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index 8c6e00cf7..9ed0b26f9 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -322,8 +322,8 @@ TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor, } // and add the desired implicit conversion from bfloat16_t. -struct float_e4m3_t : public cutlass::float_e4m3_t { - using cutlass::float_e4m3_t::float_e4m3_t; +struct float_e4m3_t : public cute::float_e4m3_t { + using cute::float_e4m3_t::float_e4m3_t; CUTLASS_HOST_DEVICE float_e4m3_t() = default; @@ -332,8 +332,8 @@ struct float_e4m3_t : public cutlass::float_e4m3_t { : float_e4m3_t(static_cast(x)) {} }; -struct float_e5m2_t : public cutlass::float_e5m2_t { - using cutlass::float_e5m2_t::float_e5m2_t; +struct float_e5m2_t : public cute::float_e5m2_t { + using cute::float_e5m2_t::float_e5m2_t; CUTLASS_HOST_DEVICE float_e5m2_t() = default; diff --git a/src/tl_templates/cuda/gemm_mma.h b/src/tl_templates/cuda/gemm_mma.h index 9462514f8..4445688ad 100644 --- a/src/tl_templates/cuda/gemm_mma.h +++ b/src/tl_templates/cuda/gemm_mma.h @@ -257,18 +257,24 @@ struct OperandTraits<64, N, K, false, num_warp_n, leading_dim, using Copy = DefaultCopy; }; +template struct to_cute_type {using type = T;}; +template<> struct to_cute_type {using type = cute::float_e4m3_t;}; +template<> struct to_cute_type {using type = cute::float_e5m2_t;}; + template class GemmTensorOp { public: + using A_type_cute = typename to_cute_type::type; + using B_type_cute = typename to_cute_type::type; using A_type = - typename std::conditional::value, + typename std::conditional::value, tfloat32_t, A_type_raw>::type; using B_type = typename std::conditional::value, - tfloat32_t, A_type_raw>::type; + tfloat32_t, B_type_cute>::type; using C_type = C_type_raw; using Instruction = diff --git a/src/tl_templates/cuda/gemm_sm90.h b/src/tl_templates/cuda/gemm_sm90.h index 1aa3ecff9..a270559ee 100644 --- a/src/tl_templates/cuda/gemm_sm90.h +++ b/src/tl_templates/cuda/gemm_sm90.h @@ -15,16 +15,21 @@ using namespace SM90; namespace tl_wgmma { using namespace cutlass::gemm::collective::detail; // ss_smem_selector +template struct to_cute_type {using type = T;}; +template<> struct to_cute_type {using type = cute::float_e4m3_t;}; +template<> struct to_cute_type {using type = cute::float_e5m2_t;}; template class GemmTensorOp { public: - using A_type = conditional_t::value, - tfloat32_t, A_type_raw>; - using B_type = conditional_t::value, - tfloat32_t, B_type_raw>; + using A_type_cute = typename to_cute_type::type; + using B_type_cute = typename to_cute_type::type; + using A_type = conditional_t::value, + tfloat32_t, A_type_cute>; + using B_type = conditional_t::value, + tfloat32_t, A_type_cute>; using C_type = C_type_raw; static constexpr GMMA::Major GmmaMajorA = From dad541c1c904262709ddb369c16eb05917e62d30 Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Thu, 23 Oct 2025 15:15:16 +0800 Subject: [PATCH 12/17] new format --- src/tl_templates/cuda/gemm_mma.h | 12 +++++++++--- src/tl_templates/cuda/gemm_sm90.h | 12 +++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/tl_templates/cuda/gemm_mma.h b/src/tl_templates/cuda/gemm_mma.h index 4445688ad..17bcddd43 100644 --- a/src/tl_templates/cuda/gemm_mma.h +++ b/src/tl_templates/cuda/gemm_mma.h @@ -257,9 +257,15 @@ struct OperandTraits<64, N, K, false, num_warp_n, leading_dim, using Copy = DefaultCopy; }; -template struct to_cute_type {using type = T;}; -template<> struct to_cute_type {using type = cute::float_e4m3_t;}; -template<> struct to_cute_type {using type = cute::float_e5m2_t;}; +template struct to_cute_type { + using type = T; +}; +template <> struct to_cute_type { + using type = cute::float_e4m3_t; +}; +template <> struct to_cute_type { + using type = cute::float_e5m2_t; +}; template struct to_cute_type {using type = T;}; -template<> struct to_cute_type {using type = cute::float_e4m3_t;}; -template<> struct to_cute_type {using type = cute::float_e5m2_t;}; +template struct to_cute_type { + using type = T; +}; +template <> struct to_cute_type { + using type = cute::float_e4m3_t; +}; +template <> struct to_cute_type { + using type = cute::float_e5m2_t; +}; template Date: Thu, 23 Oct 2025 16:46:02 +0800 Subject: [PATCH 13/17] fix lint --- src/tl_templates/cuda/common.h | 10 ++++++++++ src/tl_templates/cuda/gemm_mma.h | 14 ++------------ src/tl_templates/cuda/gemm_sm100.h | 10 ++++++---- src/tl_templates/cuda/gemm_sm90.h | 13 ++----------- src/tl_templates/cuda/gemm_sp_sm90.h | 10 ++++++---- 5 files changed, 26 insertions(+), 31 deletions(-) diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index 9ed0b26f9..524149f97 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -342,6 +342,16 @@ struct float_e5m2_t : public cute::float_e5m2_t { : float_e5m2_t(static_cast(x)) {} }; +template struct to_cute_type { + using type = T; +}; +template <> struct to_cute_type { + using type = cute::float_e4m3_t; +}; +template <> struct to_cute_type { + using type = cute::float_e5m2_t; +}; + } // namespace tl namespace cutlass { diff --git a/src/tl_templates/cuda/gemm_mma.h b/src/tl_templates/cuda/gemm_mma.h index 17bcddd43..025d99662 100644 --- a/src/tl_templates/cuda/gemm_mma.h +++ b/src/tl_templates/cuda/gemm_mma.h @@ -257,24 +257,14 @@ struct OperandTraits<64, N, K, false, num_warp_n, leading_dim, using Copy = DefaultCopy; }; -template struct to_cute_type { - using type = T; -}; -template <> struct to_cute_type { - using type = cute::float_e4m3_t; -}; -template <> struct to_cute_type { - using type = cute::float_e5m2_t; -}; - template class GemmTensorOp { public: - using A_type_cute = typename to_cute_type::type; - using B_type_cute = typename to_cute_type::type; + using A_type_cute = typename tl::to_cute_type::type; + using B_type_cute = typename tl::to_cute_type::type; using A_type = typename std::conditional::value, tfloat32_t, A_type_raw>::type; diff --git a/src/tl_templates/cuda/gemm_sm100.h b/src/tl_templates/cuda/gemm_sm100.h index 5b50fe72a..856d37dd1 100644 --- a/src/tl_templates/cuda/gemm_sm100.h +++ b/src/tl_templates/cuda/gemm_sm100.h @@ -289,12 +289,14 @@ template class GemmTensorOp { public: + using A_type_cute = typename tl::to_cute_type::type; + using B_type_cute = typename tl::to_cute_type::type; using A_type = - typename std::conditional::value, - tfloat32_t, A_type_raw>::type; + typename std::conditional::value, + tfloat32_t, A_type_cute>::type; using B_type = - typename std::conditional::value, - tfloat32_t, B_type_raw>::type; + typename std::conditional::value, + tfloat32_t, B_type_cute>::type; using C_type = C_type_raw; static_assert(AtomM == 128 || AtomM == 64 || AtomM == 32); diff --git a/src/tl_templates/cuda/gemm_sm90.h b/src/tl_templates/cuda/gemm_sm90.h index eb74de8f0..543a29d09 100644 --- a/src/tl_templates/cuda/gemm_sm90.h +++ b/src/tl_templates/cuda/gemm_sm90.h @@ -15,23 +15,14 @@ using namespace SM90; namespace tl_wgmma { using namespace cutlass::gemm::collective::detail; // ss_smem_selector -template struct to_cute_type { - using type = T; -}; -template <> struct to_cute_type { - using type = cute::float_e4m3_t; -}; -template <> struct to_cute_type { - using type = cute::float_e5m2_t; -}; template class GemmTensorOp { public: - using A_type_cute = typename to_cute_type::type; - using B_type_cute = typename to_cute_type::type; + using A_type_cute = typename tl::to_cute_type::type; + using B_type_cute = typename tl::to_cute_type::type; using A_type = conditional_t::value, tfloat32_t, A_type_cute>; using B_type = conditional_t::value, diff --git a/src/tl_templates/cuda/gemm_sp_sm90.h b/src/tl_templates/cuda/gemm_sp_sm90.h index db55a21ec..6184f9be7 100644 --- a/src/tl_templates/cuda/gemm_sp_sm90.h +++ b/src/tl_templates/cuda/gemm_sp_sm90.h @@ -13,10 +13,12 @@ class GemmTensorOp { public: static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4"); - using A_type = conditional_t::value, - tfloat32_t, A_type_raw>; - using B_type = conditional_t::value, - tfloat32_t, B_type_raw>; + using A_type_cute = typename tl::to_cute_type::type; + using B_type_cute = typename tl::to_cute_type::type; + using A_type = conditional_t::value, + tfloat32_t, A_type_cute>; + using B_type = conditional_t::value, + tfloat32_t, B_type_cute>; using C_type = C_type_raw; static constexpr bool need_tfloat32_cast = From 6d885a4d952e45a55c7afc1edd5b10144063660d Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Fri, 24 Oct 2025 01:05:27 +0800 Subject: [PATCH 14/17] new format --- tilelang/language/allocate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 11fce9dfe..2c8fbb297 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -15,7 +15,6 @@ with the appropriate memory scope. """ -from __future__ import annotations from tilelang import tvm as tvm from tvm.script import tir as T from tvm.tir import PrimExpr From 7f1a507e0daa5b737a4ae2a70cf85be3dc8592b7 Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Sat, 25 Oct 2025 17:52:19 +0800 Subject: [PATCH 15/17] fix lint --- src/tl_templates/cuda/gemm_mma.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tl_templates/cuda/gemm_mma.h b/src/tl_templates/cuda/gemm_mma.h index 025d99662..c22854c0b 100644 --- a/src/tl_templates/cuda/gemm_mma.h +++ b/src/tl_templates/cuda/gemm_mma.h @@ -267,9 +267,9 @@ class GemmTensorOp { using B_type_cute = typename tl::to_cute_type::type; using A_type = typename std::conditional::value, - tfloat32_t, A_type_raw>::type; + tfloat32_t, A_type_cute>::type; using B_type = - typename std::conditional::value, + typename std::conditional::value, tfloat32_t, B_type_cute>::type; using C_type = C_type_raw; From 521c0bb8905d6c5d516ea53313021f3f20f90f4d Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Wed, 29 Oct 2025 12:30:59 +0800 Subject: [PATCH 16/17] format --- tilelang/language/allocate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 64d20ccb2..445e212ac 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -1,4 +1,3 @@ -from __future__ import annotations """Memory allocation utilities for Tile-AI programs. This module provides a set of functions for allocating different types of memory buffers From e9b95011f74782cbdc258498e1c3fde667001df3 Mon Sep 17 00:00:00 2001 From: nicunxiao Date: Wed, 29 Oct 2025 13:33:08 +0800 Subject: [PATCH 17/17] trigger ci