diff --git a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py index e465d946c..f8f970ea4 100644 --- a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py @@ -20,11 +20,9 @@ def get_bwd_configs(): @tilelang.jit( - out_idx=[3, 4], - pass_configs={ + out_idx=[3, 4], pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) + }) def flashattn_fwd( batch, heads, @@ -140,11 +138,9 @@ def flash_fwd( @tilelang.jit( - out_idx=[2], - pass_configs={ + out_idx=[2], pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) + }) def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): accum_dtype = "float" shape = [batch, heads, seq_len, dim] @@ -180,11 +176,9 @@ def make_dq_layout(dQ): @tilelang.jit( - out_idx=[1], - pass_configs={ + out_idx=[1], pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) + }) def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): accum_dtype = "float" shape = [batch, heads, seq_len, dim] @@ -205,11 +199,9 @@ def flash_bwd_post( return flash_bwd_post -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) +@tilelang.jit(pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, +}) def flashattn_bwd(batch, heads, seq_len, diff --git a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py index c33d5829b..49a3ecbd8 100644 --- a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py @@ -23,11 +23,9 @@ def get_configs(): rep=100, ) @tilelang.jit( - out_idx=[3], - pass_configs={ + out_idx=[3], pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) + }) def flashattn( batch, heads, diff --git a/examples/attention_sink/example_mha_sink_bwd_bhsd.py b/examples/attention_sink/example_mha_sink_bwd_bhsd.py index 3c99a89ea..ee1c35ece 100644 --- a/examples/attention_sink/example_mha_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_bwd_bhsd.py @@ -20,11 +20,9 @@ def get_bwd_configs(): @tilelang.jit( - out_idx=[3, 4], - pass_configs={ + out_idx=[3, 4], pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) + }) def flashattn_fwd( batch, heads, @@ -137,11 +135,9 @@ def flash_fwd( @tilelang.jit( - out_idx=[2], - pass_configs={ + out_idx=[2], pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) + }) def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): accum_dtype = "float" shape = [batch, heads, seq_len, dim] @@ -177,11 +173,9 @@ def make_dq_layout(dQ): @tilelang.jit( - out_idx=[1], - pass_configs={ + out_idx=[1], pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) + }) def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): accum_dtype = "float" shape = [batch, heads, seq_len, dim] @@ -202,11 +196,9 @@ def flash_bwd_post( return flash_bwd_post -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) +@tilelang.jit(pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, +}) def flashattn_bwd( batch, heads, diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd.py b/examples/attention_sink/example_mha_sink_fwd_bhsd.py index dec823102..7e59e277e 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd.py @@ -18,11 +18,9 @@ def get_configs(): @autotune(configs=get_configs(), warmup=500, rep=100) @tilelang.jit( - out_idx=[3], - pass_configs={ + out_idx=[3], pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) + }) def flashattn( batch, heads, diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py index 2936a9acd..eee2f3ac5 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py @@ -19,11 +19,9 @@ def get_configs(): @autotune(configs=get_configs(), warmup=500, rep=100) @tilelang.jit( - out_idx=[3], - pass_configs={ + out_idx=[3], pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) + }) def flashattn( batch, heads, diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index fdca036d2..e621276e9 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -900,56 +900,123 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { stream << ' ' << sret << ";\n"; std::string src = SSAGetID(PrintExpr(op->value), from_ty); - // Handle bfloat16 special cases with supported ops - bool used_bf16_op = false; - if (from_ty.is_bfloat16() || target_ty.is_bfloat16()) { - std::ostringstream func_name; - if (from_ty.is_bfloat16()) { - func_name << "bf16"; - } else if (from_ty.is_float()) { - func_name << "float"; - } - if (from_ty.lanes() > 1) { - func_name << from_ty.lanes(); - } - func_name << "2"; - if (target_ty.is_bfloat16()) { - func_name << "bf16"; - } else if (target_ty.is_float()) { - func_name << "float"; - } else if (target_ty == DataType::Int(16)) { - func_name << "int16"; - } - if (target_ty.lanes() > 1) { - func_name << target_ty.lanes(); - } - - auto fname = func_name.str(); - if (bf16_supported_ops_.count(fname)) { - used_bf16_op = true; - stream << "#ifdef ENABLE_BF16\n"; + // Handle conversion between float16 and float32 + if (from_ty.is_float16() && target_ty.is_float()) { + // Use __half22float2 for vectorized conversion (half2 -> float2) + if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { + // half2 -> float2 PrintIndent(); - stream << "reinterpret_cast<"; - if (target_ty.is_bfloat16()) { - stream << "__nv_bfloat16"; - } else { - PrintType(target_ty.element_of(), stream); - } - if (target_ty.lanes() > 1) { - stream << target_ty.lanes(); - } - stream << " &>(" << sret << ") = fastertransformer::" << fname - << "(reinterpret_cast<"; - if (from_ty.is_bfloat16()) { - stream << "__nv_bfloat16"; - } else { - PrintType(from_ty.element_of(), stream); - } - if (from_ty.lanes() > 1) { - stream << from_ty.lanes(); - } - stream << " const &>(" << src << "));\n"; - stream << "#else\n"; + stream << sret << " = __half22float2(*(half2*)(&(" << src << ")));\n"; + os << sret; + return; + } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { + // half4 -> float4 + PrintIndent(); + stream << "((float2*)(&" << sret << "))[0] = " + << "__half22float2(*(half2*)(&(" << src << ")));\n"; + PrintIndent(); + stream << "((float2*)(&" << sret << "))[1] = " + << "__half22float2(*((half2*)(&(" << src << "))+1));\n"; + os << sret; + return; + } + } else if (from_ty.is_float() && target_ty.is_float16()) { + // Use __float22half2_rn for vectorized conversion (float2 -> half2) + if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { + // float2 -> half2 + PrintIndent(); + stream << "*(half2*)(&(" << sret << ")) = __float22half2_rn(*(float2*)(&(" + << src << ")));\n"; + os << sret; + return; + } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { + // float4 -> half4 + PrintIndent(); + stream << "((half2*)(&" << sret << "))[0] = " + << "__float22half2_rn(*(float2*)(&(" << src << ")));\n"; + PrintIndent(); + stream << "((half2*)(&" << sret << "))[1] = " + << "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n"; + os << sret; + return; + } + } + + // Handle conversion between bfloat16 and float32 + if (from_ty.is_bfloat16() && target_ty.is_float()) { + // Use __bfloat1622float2 for vectorized conversion (bfloat162 -> float2) + if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { + // bfloat162 -> float2 + PrintIndent(); + stream << sret + << " = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&(" + << src << ")));\n"; + os << sret; + return; + } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { + // bfloat162x2 -> float4 + PrintIndent(); + stream << "((float2*)(&" << sret << "))[0] = " + << "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&(" + << src << ")));\n"; + PrintIndent(); + stream << "((float2*)(&" << sret << "))[1] = " + << "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&(" + << src << "))+1));\n"; + os << sret; + return; + } + } else if (from_ty.is_float() && target_ty.is_bfloat16()) { + // Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162) + if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { + // float2 -> bfloat162 + PrintIndent(); + stream << "*reinterpret_cast<__nv_bfloat162*>(&(" << sret + << ")) = __float22bfloat162_rn(*(float2*)(&(" << src << ")));\n"; + os << sret; + return; + } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { + // float4 -> bfloat162x2 + PrintIndent(); + stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = " + << "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n"; + PrintIndent(); + stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = " + << "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n"; + os << sret; + return; + } + } + + // Handle conversion from float32 to float8 (E4M3/E5M2) + if (from_ty.is_float() && + (target_ty.is_float8_e4m3() || target_ty.is_float8_e5m2())) { + // FP32 -> FP8: Use __nv_cvt_float2_to_fp8x2 for vectorized conversion + // (float2 -> fp8x2) + if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { + // float2 -> fp8x2 + PrintIndent(); + stream << "*reinterpret_cast<__nv_fp8x2_storage_t*>(&(" << sret + << ")) = __nv_cvt_float2_to_fp8x2(*reinterpret_cast(&(" + << src << ")), __NV_SATFINITE, " + << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + os << sret; + return; + } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { + // float4 -> fp8x4 + PrintIndent(); + stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = " + << "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src + << ")), __NV_SATFINITE, " + << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; + PrintIndent(); + stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = " + << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src + << "))+1), __NV_SATFINITE, " + << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << ");\n"; } } @@ -964,9 +1031,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { PrintVecElemStore(sret, target_ty, i, val.str()); } - if (used_bf16_op) { - stream << "#endif\n"; - } os << sret; } diff --git a/testing/python/language/test_tilelang_language_vectorized_cast.py b/testing/python/language/test_tilelang_language_vectorized_cast.py new file mode 100644 index 000000000..a1777c79f --- /dev/null +++ b/testing/python/language/test_tilelang_language_vectorized_cast.py @@ -0,0 +1,81 @@ +import torch +import tilelang.testing +import tilelang.language as T + +str2dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float8_e4m3": torch.float8_e4m3fn, + "float8_e5m2": torch.float8_e5m2, +} + + +@tilelang.jit +def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str): + assert M % 256 == 0 + + @T.prim_func + def main( + A: T.Tensor[(M), dtype_A], # noqa: F821 + B: T.Tensor[(M), dtype_B], # noqa: F821 + ): + with T.Kernel(1, threads=128): + T.copy(A, B) + + return main + + +def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, lanes: int = 2): + """Run the vectorized cast kernel and check the correctness. + Args: + src_dtype_str: The source data type string. + dst_dtype_str: The destination data type string. + check_str: Used to ensure vectorized cast is used. + lanes: The number of lanes of the source and destination data types. + """ + + M = 128 * lanes + kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str) + + A = torch.randn(M, dtype=str2dtype[src_dtype_str]).cuda() + B = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda() + + kernel(A, B) + + torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B) + + code = kernel.get_kernel_source() + + assert check_str in code, \ + f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!" + + +def test_vectorized_cast(): + # fp32 -> fp16 + run_vectorized_cast("float32", "float16", "__float22half2_rn", 2) + run_vectorized_cast("float32", "float16", "__float22half2_rn", 4) + + # fp16 -> fp32 + run_vectorized_cast("float16", "float32", "__half22float2", 2) + run_vectorized_cast("float16", "float32", "__half22float2", 4) + + # fp32 -> fp8_e4m3 + run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2) + run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4) + + # fp32 -> fp8_e5m2 + run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2) + run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4) + + # fp32 -> bf16 + run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 2) + run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 4) + + # bf16 -> fp32 + run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 2) + run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 4) + + +if __name__ == "__main__": + tilelang.testing.main()