Skip to content

Commit a148d62

Browse files
authored
[Feature] Enhance vectorized conversion support in CUDA codegen (#1095)
* [Feature] Add vectorized float16 and float32 conversion support in CUDA codegen * Implemented handling for conversions between float16 and float32 types, specifically for vectorized operations using __half22float2 and __float22half2_rn. * Enhanced the existing code to support both directions of conversion based on the lane count. * Improved overall type handling in the VisitExpr_ method for better compatibility with TileLang. * [Feature] Add float32 to float8 conversion support in CUDA codegen * Implemented handling for conversion from float32 to float8 (E4M3/E5M2) in the VisitExpr_ method. * Added vectorized conversion support using __nv_cvt_float2_to_fp8x2 for float2 to fp8x2 transformations. * Enhanced type handling for better compatibility with TileLang, particularly for float8 types. * lint * fix a bug * [Enhancement] Support lanes=4 cases and add unit test for vectorized cast * lint * [Feature] Refactor bf16 convertion operations and remove legacy compile flags * lint
1 parent 86c8bb4 commit a148d62

File tree

7 files changed

+221
-98
lines changed

7 files changed

+221
-98
lines changed

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@ def get_bwd_configs():
2020

2121

2222
@tilelang.jit(
23-
out_idx=[3, 4],
24-
pass_configs={
23+
out_idx=[3, 4], pass_configs={
2524
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
26-
},
27-
compile_flags=["-O3", "-DENABLE_BF16"])
25+
})
2826
def flashattn_fwd(
2927
batch,
3028
heads,
@@ -140,11 +138,9 @@ def flash_fwd(
140138

141139

142140
@tilelang.jit(
143-
out_idx=[2],
144-
pass_configs={
141+
out_idx=[2], pass_configs={
145142
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
146-
},
147-
compile_flags=["-O3", "-DENABLE_BF16"])
143+
})
148144
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
149145
accum_dtype = "float"
150146
shape = [batch, heads, seq_len, dim]
@@ -180,11 +176,9 @@ def make_dq_layout(dQ):
180176

181177

182178
@tilelang.jit(
183-
out_idx=[1],
184-
pass_configs={
179+
out_idx=[1], pass_configs={
185180
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
186-
},
187-
compile_flags=["-O3", "-DENABLE_BF16"])
181+
})
188182
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
189183
accum_dtype = "float"
190184
shape = [batch, heads, seq_len, dim]
@@ -205,11 +199,9 @@ def flash_bwd_post(
205199
return flash_bwd_post
206200

207201

208-
@tilelang.jit(
209-
pass_configs={
210-
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
211-
},
212-
compile_flags=["-O3", "-DENABLE_BF16"])
202+
@tilelang.jit(pass_configs={
203+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
204+
})
213205
def flashattn_bwd(batch,
214206
heads,
215207
seq_len,

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,9 @@ def get_configs():
2323
rep=100,
2424
)
2525
@tilelang.jit(
26-
out_idx=[3],
27-
pass_configs={
26+
out_idx=[3], pass_configs={
2827
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
29-
},
30-
compile_flags=["-O3", "-DENABLE_BF16"])
28+
})
3129
def flashattn(
3230
batch,
3331
heads,

examples/attention_sink/example_mha_sink_bwd_bhsd.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@ def get_bwd_configs():
2020

2121

2222
@tilelang.jit(
23-
out_idx=[3, 4],
24-
pass_configs={
23+
out_idx=[3, 4], pass_configs={
2524
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
26-
},
27-
compile_flags=["-O3", "-DENABLE_BF16"])
25+
})
2826
def flashattn_fwd(
2927
batch,
3028
heads,
@@ -137,11 +135,9 @@ def flash_fwd(
137135

138136

139137
@tilelang.jit(
140-
out_idx=[2],
141-
pass_configs={
138+
out_idx=[2], pass_configs={
142139
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
143-
},
144-
compile_flags=["-O3", "-DENABLE_BF16"])
140+
})
145141
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
146142
accum_dtype = "float"
147143
shape = [batch, heads, seq_len, dim]
@@ -177,11 +173,9 @@ def make_dq_layout(dQ):
177173

178174

179175
@tilelang.jit(
180-
out_idx=[1],
181-
pass_configs={
176+
out_idx=[1], pass_configs={
182177
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
183-
},
184-
compile_flags=["-O3", "-DENABLE_BF16"])
178+
})
185179
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
186180
accum_dtype = "float"
187181
shape = [batch, heads, seq_len, dim]
@@ -202,11 +196,9 @@ def flash_bwd_post(
202196
return flash_bwd_post
203197

204198

205-
@tilelang.jit(
206-
pass_configs={
207-
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
208-
},
209-
compile_flags=["-O3", "-DENABLE_BF16"])
199+
@tilelang.jit(pass_configs={
200+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
201+
})
210202
def flashattn_bwd(
211203
batch,
212204
heads,

examples/attention_sink/example_mha_sink_fwd_bhsd.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,9 @@ def get_configs():
1818

1919
@autotune(configs=get_configs(), warmup=500, rep=100)
2020
@tilelang.jit(
21-
out_idx=[3],
22-
pass_configs={
21+
out_idx=[3], pass_configs={
2322
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
24-
},
25-
compile_flags=["-O3", "-DENABLE_BF16"])
23+
})
2624
def flashattn(
2725
batch,
2826
heads,

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@ def get_configs():
1919

2020
@autotune(configs=get_configs(), warmup=500, rep=100)
2121
@tilelang.jit(
22-
out_idx=[3],
23-
pass_configs={
22+
out_idx=[3], pass_configs={
2423
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
25-
},
26-
compile_flags=["-O3", "-DENABLE_BF16"])
24+
})
2725
def flashattn(
2826
batch,
2927
heads,

src/target/codegen_cuda.cc

Lines changed: 116 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -900,56 +900,123 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
900900
stream << ' ' << sret << ";\n";
901901
std::string src = SSAGetID(PrintExpr(op->value), from_ty);
902902

903-
// Handle bfloat16 special cases with supported ops
904-
bool used_bf16_op = false;
905-
if (from_ty.is_bfloat16() || target_ty.is_bfloat16()) {
906-
std::ostringstream func_name;
907-
if (from_ty.is_bfloat16()) {
908-
func_name << "bf16";
909-
} else if (from_ty.is_float()) {
910-
func_name << "float";
911-
}
912-
if (from_ty.lanes() > 1) {
913-
func_name << from_ty.lanes();
914-
}
915-
func_name << "2";
916-
if (target_ty.is_bfloat16()) {
917-
func_name << "bf16";
918-
} else if (target_ty.is_float()) {
919-
func_name << "float";
920-
} else if (target_ty == DataType::Int(16)) {
921-
func_name << "int16";
922-
}
923-
if (target_ty.lanes() > 1) {
924-
func_name << target_ty.lanes();
925-
}
926-
927-
auto fname = func_name.str();
928-
if (bf16_supported_ops_.count(fname)) {
929-
used_bf16_op = true;
930-
stream << "#ifdef ENABLE_BF16\n";
903+
// Handle conversion between float16 and float32
904+
if (from_ty.is_float16() && target_ty.is_float()) {
905+
// Use __half22float2 for vectorized conversion (half2 -> float2)
906+
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
907+
// half2 -> float2
931908
PrintIndent();
932-
stream << "reinterpret_cast<";
933-
if (target_ty.is_bfloat16()) {
934-
stream << "__nv_bfloat16";
935-
} else {
936-
PrintType(target_ty.element_of(), stream);
937-
}
938-
if (target_ty.lanes() > 1) {
939-
stream << target_ty.lanes();
940-
}
941-
stream << " &>(" << sret << ") = fastertransformer::" << fname
942-
<< "(reinterpret_cast<";
943-
if (from_ty.is_bfloat16()) {
944-
stream << "__nv_bfloat16";
945-
} else {
946-
PrintType(from_ty.element_of(), stream);
947-
}
948-
if (from_ty.lanes() > 1) {
949-
stream << from_ty.lanes();
950-
}
951-
stream << " const &>(" << src << "));\n";
952-
stream << "#else\n";
909+
stream << sret << " = __half22float2(*(half2*)(&(" << src << ")));\n";
910+
os << sret;
911+
return;
912+
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
913+
// half4 -> float4
914+
PrintIndent();
915+
stream << "((float2*)(&" << sret << "))[0] = "
916+
<< "__half22float2(*(half2*)(&(" << src << ")));\n";
917+
PrintIndent();
918+
stream << "((float2*)(&" << sret << "))[1] = "
919+
<< "__half22float2(*((half2*)(&(" << src << "))+1));\n";
920+
os << sret;
921+
return;
922+
}
923+
} else if (from_ty.is_float() && target_ty.is_float16()) {
924+
// Use __float22half2_rn for vectorized conversion (float2 -> half2)
925+
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
926+
// float2 -> half2
927+
PrintIndent();
928+
stream << "*(half2*)(&(" << sret << ")) = __float22half2_rn(*(float2*)(&("
929+
<< src << ")));\n";
930+
os << sret;
931+
return;
932+
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
933+
// float4 -> half4
934+
PrintIndent();
935+
stream << "((half2*)(&" << sret << "))[0] = "
936+
<< "__float22half2_rn(*(float2*)(&(" << src << ")));\n";
937+
PrintIndent();
938+
stream << "((half2*)(&" << sret << "))[1] = "
939+
<< "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n";
940+
os << sret;
941+
return;
942+
}
943+
}
944+
945+
// Handle conversion between bfloat16 and float32
946+
if (from_ty.is_bfloat16() && target_ty.is_float()) {
947+
// Use __bfloat1622float2 for vectorized conversion (bfloat162 -> float2)
948+
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
949+
// bfloat162 -> float2
950+
PrintIndent();
951+
stream << sret
952+
<< " = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
953+
<< src << ")));\n";
954+
os << sret;
955+
return;
956+
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
957+
// bfloat162x2 -> float4
958+
PrintIndent();
959+
stream << "((float2*)(&" << sret << "))[0] = "
960+
<< "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
961+
<< src << ")));\n";
962+
PrintIndent();
963+
stream << "((float2*)(&" << sret << "))[1] = "
964+
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
965+
<< src << "))+1));\n";
966+
os << sret;
967+
return;
968+
}
969+
} else if (from_ty.is_float() && target_ty.is_bfloat16()) {
970+
// Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162)
971+
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
972+
// float2 -> bfloat162
973+
PrintIndent();
974+
stream << "*reinterpret_cast<__nv_bfloat162*>(&(" << sret
975+
<< ")) = __float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
976+
os << sret;
977+
return;
978+
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
979+
// float4 -> bfloat162x2
980+
PrintIndent();
981+
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = "
982+
<< "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
983+
PrintIndent();
984+
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = "
985+
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n";
986+
os << sret;
987+
return;
988+
}
989+
}
990+
991+
// Handle conversion from float32 to float8 (E4M3/E5M2)
992+
if (from_ty.is_float() &&
993+
(target_ty.is_float8_e4m3() || target_ty.is_float8_e5m2())) {
994+
// FP32 -> FP8: Use __nv_cvt_float2_to_fp8x2 for vectorized conversion
995+
// (float2 -> fp8x2)
996+
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
997+
// float2 -> fp8x2
998+
PrintIndent();
999+
stream << "*reinterpret_cast<__nv_fp8x2_storage_t*>(&(" << sret
1000+
<< ")) = __nv_cvt_float2_to_fp8x2(*reinterpret_cast<float2*>(&("
1001+
<< src << ")), __NV_SATFINITE, "
1002+
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
1003+
<< ");\n";
1004+
os << sret;
1005+
return;
1006+
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
1007+
// float4 -> fp8x4
1008+
PrintIndent();
1009+
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = "
1010+
<< "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
1011+
<< ")), __NV_SATFINITE, "
1012+
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
1013+
<< ");\n";
1014+
PrintIndent();
1015+
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = "
1016+
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
1017+
<< "))+1), __NV_SATFINITE, "
1018+
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
1019+
<< ");\n";
9531020
}
9541021
}
9551022

@@ -964,9 +1031,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
9641031
PrintVecElemStore(sret, target_ty, i, val.str());
9651032
}
9661033

967-
if (used_bf16_op) {
968-
stream << "#endif\n";
969-
}
9701034
os << sret;
9711035
}
9721036

0 commit comments

Comments
 (0)