Skip to content

Commit 6debbb9

Browse files
committed
[Feature] Add vectorized float16 and float32 conversion support in CUDA codegen
* Implemented handling for conversions between float16 and float32 types, specifically for vectorized operations using __half22float2 and __float22half2_rn. * Enhanced the existing code to support both directions of conversion based on the lane count. * Improved overall type handling in the VisitExpr_ method for better compatibility with TileLang.
1 parent bddb125 commit 6debbb9

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

src/target/codegen_cuda.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,26 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
900900
stream << ' ' << sret << ";\n";
901901
std::string src = SSAGetID(PrintExpr(op->value), from_ty);
902902

903+
// Handle conversion between float16 and float32
904+
if (from_ty.is_float16() && target_ty.is_float()) {
905+
// Use __half22float2 for vectorized conversion (half2 -> float2)
906+
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
907+
PrintIndent();
908+
stream << sret << " = __half22float2(*(half2*)(&(" << src << ")));\n";
909+
os << sret;
910+
return;
911+
}
912+
} else if (from_ty.is_float() && target_ty.is_float16()) {
913+
// Use __float22half2_rn for vectorized conversion (float2 -> half2)
914+
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
915+
PrintIndent();
916+
stream << "*(half2*)(&(" << sret << ")) = __float22half2_rn(*(float2*)(&("
917+
<< src << ")));\n";
918+
os << sret;
919+
return;
920+
}
921+
}
922+
903923
// Handle bfloat16 special cases with supported ops
904924
bool used_bf16_op = false;
905925
if (from_ty.is_bfloat16() || target_ty.is_bfloat16()) {

0 commit comments

Comments
 (0)