@@ -1749,10 +1749,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
17491749 " reinterpret_cast<const (ARegType)*>((A_ptr) + (A_offset)), "
17501750 " reinterpret_cast<const (BRegType)*>((B_ptr) + (B_offset)));\n " ;
17511751 tl::codegen::Replacer replacer;
1752+ std::string AType = tl::codegen::ptx::DTypeEnumToString (dtype_a_enum);
1753+ if (AType == " tl::DataType::kFloat32" ) {
1754+ AType = " tl::DataType::kTensorFloat32" ;
1755+ }
1756+ std::string BType = tl::codegen::ptx::DTypeEnumToString (dtype_b_enum);
1757+ if (BType == " tl::DataType::kFloat32" ) {
1758+ BType = " tl::DataType::kTensorFloat32" ;
1759+ }
1760+
17521761 replacer.register_rule (" (AType)" ,
1753- tl::codegen::ptx::DTypeEnumToString (dtype_a_enum ));
1762+ tl::codegen::ptx::DTypeEnumToString (AType ));
17541763 replacer.register_rule (" (BType)" ,
1755- tl::codegen::ptx::DTypeEnumToString (dtype_b_enum ));
1764+ tl::codegen::ptx::DTypeEnumToString (BType ));
17561765 replacer.register_rule (" (CType)" ,
17571766 tl::codegen::ptx::DTypeEnumToString (dtype_c_enum));
17581767 replacer.register_rule (" (M)" , std::to_string (m));
@@ -1838,16 +1847,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
18381847 std::string B_offset = this ->PrintExpr (op->args [9 ]);
18391848 std::string c_ref = this ->PrintExpr (op->args [10 ]);
18401849 std::string c_offset = this ->PrintExpr (op->args [11 ]);
1841- bool scale_out = Downcast<Bool> (op->args [12 ])-> value ;
1850+ std::string scale_out = this -> PrintExpr (op->args [12 ]);
18421851 bool scale_in_a = Downcast<Bool>(op->args [13 ])->value ;
18431852 bool scale_in_b = Downcast<Bool>(op->args [14 ])->value ;
18441853
18451854 const bool a_is_shared = true ;
18461855 this ->PrintIndent ();
1847- std::string asm_code = PrintWGMMAAssembly (
1848- shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc,
1849- A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a,
1850- scale_in_b, a_is_shared, " " , " " , " " , false );
18511856 auto [m, n, k] = tl::codegen::ptx::ParseMMAShape (shape);
18521857 need_wgmma_instruction_h_ = true ;
18531858 std::string wgmma_asm_code =
@@ -1856,10 +1861,18 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
18561861 " uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n " ;
18571862 // replace patterns
18581863 tl::codegen::Replacer replacer;
1859- replacer.register_rule (" (AType)" ,
1860- tl::codegen::ptx::DTypeEnumToString (A_dtype));
1861- replacer.register_rule (" (BType)" ,
1862- tl::codegen::ptx::DTypeEnumToString (B_dtype));
1864+
1865+ std::string AType = tl::codegen::ptx::DTypeEnumToString (A_dtype);
1866+ if (AType == " tl::DataType::kFloat32" ) {
1867+ AType = " tl::DataType::kTensorFloat32" ;
1868+ }
1869+ std::string BType = tl::codegen::ptx::DTypeEnumToString (B_dtype);
1870+ if (BType == " tl::DataType::kFloat32" ) {
1871+ BType = " tl::DataType::kTensorFloat32" ;
1872+ }
1873+
1874+ replacer.register_rule (" (AType)" , AType);
1875+ replacer.register_rule (" (BType)" , BType);
18631876 replacer.register_rule (" (CType)" ,
18641877 tl::codegen::ptx::DTypeEnumToString (C_dtype));
18651878 replacer.register_rule (" (M)" , std::to_string (m));
@@ -1874,7 +1887,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
18741887 replacer.register_rule (" (desc_b)" , b_desc);
18751888 replacer.register_rule (" (B_offset)" , B_offset);
18761889 replacer.register_rule (" (C)" , c_ref + " + " + c_offset);
1877- replacer.register_rule (" (scale_out)" , scale_out ? " true " : " false " );
1890+ replacer.register_rule (" (scale_out)" , scale_out);
18781891 wgmma_asm_code = replacer.rewrite (wgmma_asm_code);
18791892 this ->stream << wgmma_asm_code;
18801893 } else if (op->op .same_as (tl::ptx_wgmma_rs ())) {
@@ -1904,7 +1917,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
19041917 std::string B_offset = this ->PrintExpr (op->args [8 ]);
19051918 std::string c_ref = this ->PrintExpr (op->args [9 ]);
19061919 std::string c_offset = this ->PrintExpr (op->args [10 ]);
1907- bool scale_out = Downcast<Bool> (op->args [11 ])-> value ;
1920+ std::string scale_out = this -> PrintExpr (op->args [11 ]);
19081921 bool scale_in_a = Downcast<Bool>(op->args [12 ])->value ;
19091922 bool scale_in_b = Downcast<Bool>(op->args [13 ])->value ;
19101923
@@ -1924,10 +1937,17 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
19241937 " (scale_out));\n " ;
19251938
19261939 tl::codegen::Replacer replacer;
1927- replacer.register_rule (" (AType)" ,
1928- tl::codegen::ptx::DTypeEnumToString (dtype_a_enum));
1929- replacer.register_rule (" (BType)" ,
1930- tl::codegen::ptx::DTypeEnumToString (dtype_b_enum));
1940+ std::string AType = tl::codegen::ptx::DTypeEnumToString (A_dtype);
1941+ if (AType == " tl::DataType::kFloat32" ) {
1942+ AType = " tl::DataType::kTensorFloat32" ;
1943+ }
1944+ std::string BType = tl::codegen::ptx::DTypeEnumToString (B_dtype);
1945+ if (BType == " tl::DataType::kFloat32" ) {
1946+ BType = " tl::DataType::kTensorFloat32" ;
1947+ }
1948+
1949+ replacer.register_rule (" (AType)" , AType);
1950+ replacer.register_rule (" (BType)" , BType);
19311951 replacer.register_rule (" (CType)" ,
19321952 tl::codegen::ptx::DTypeEnumToString (dtype_c_enum));
19331953 replacer.register_rule (" (M)" , std::to_string (m));
@@ -1943,7 +1963,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
19431963 replacer.register_rule (" (B_offset)" , B_offset);
19441964 replacer.register_rule (" (C_ptr)" , c_ref);
19451965 replacer.register_rule (" (C_offset)" , c_offset);
1946- replacer.register_rule (" (scale_out)" , scale_out ? " true " : " false " );
1966+ replacer.register_rule (" (scale_out)" , scale_out);
19471967 wgmma_call = replacer.rewrite (wgmma_call);
19481968 this ->stream << wgmma_call;
19491969 } else if (op->op .same_as (tl::ptx_tcgen05_mma_ss ())) {
0 commit comments