@@ -744,7 +744,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
744744 // arg 10: C accumulator
745745 // arg 11: C accumulator index
746746 // arg 12: saturate
747- ICHECK_EQ (op->args .size (), 13U );
747+ // arg 13: (optional) 1-bit operator (xor or and)
748+ ICHECK (op->args .size () == 13U || op->args .size () == 14U );
748749 std::string shape = Downcast<StringImm>(op->args [0 ])->value ;
749750 std::string A_layout = Downcast<StringImm>(op->args [1 ])->value ;
750751 std::string B_layout = Downcast<StringImm>(op->args [2 ])->value ;
@@ -757,11 +758,51 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
757758 std::string b_bias = this ->PrintExpr (op->args [9 ]);
758759 std::string c_ref = this ->PrintExpr (op->args [10 ]);
759760 std::string c_bias = this ->PrintExpr (op->args [11 ]);
760- bool saturate = (Downcast<IntImm>(op->args [12 ])->value != 0 );
761- std::string asm_code = PrintMMAAssembly (shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype,
762- a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, saturate);
761+ bool saturate = Downcast<Bool>(op->args [12 ])->value ;
762+ std::string bit_op = op->args .size () > 13 ? Downcast<StringImm>(op->args [13 ])->value : " " ;
763+ std::string asm_code =
764+ PrintMMAAssembly (shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref,
765+ b_bias, c_ref, c_bias, " " , " " , " " , bit_op, false , saturate);
763766
764767 this ->stream << asm_code;
768+ } else if (op->op .same_as (builtin::ptx_mma_sp ())) {
769+ // arg 0: shape: mXnXkX
770+ // arg 1: A layout: row/col
771+ // arg 2: B layout: row/col
772+ // arg 3: A precision: fp16, fp32, ...
773+ // arg 4: B precision: fp16, fp32, ...
774+ // arg 5: C precision: fp16, fp32, ...
775+ // arg 6: A multiplicand
776+ // arg 7: A multiplicand index
777+ // arg 8: B multiplicand
778+ // arg 9: B multiplicand index
779+ // arg 10: C accumulator
780+ // arg 11: C accumulator index
781+ // arg 12: metadata
782+ // arg 13: metadata index
783+ // arg 14: sparse_selector
784+ // arg 15: saturate
785+ ICHECK_EQ (op->args .size (), 16U );
786+ std::string shape = Downcast<StringImm>(op->args [0 ])->value ;
787+ std::string A_layout = Downcast<StringImm>(op->args [1 ])->value ;
788+ std::string B_layout = Downcast<StringImm>(op->args [2 ])->value ;
789+ std::string A_dtype = Downcast<StringImm>(op->args [3 ])->value ;
790+ std::string B_dtype = Downcast<StringImm>(op->args [4 ])->value ;
791+ std::string C_dtype = Downcast<StringImm>(op->args [5 ])->value ;
792+ std::string a_ref = this ->PrintExpr (op->args [6 ]);
793+ std::string a_offset = this ->PrintExpr (op->args [7 ]);
794+ std::string b_ref = this ->PrintExpr (op->args [8 ]);
795+ std::string b_offset = this ->PrintExpr (op->args [9 ]);
796+ std::string c_ref = this ->PrintExpr (op->args [10 ]);
797+ std::string c_offset = this ->PrintExpr (op->args [11 ]);
798+ std::string metadata = this ->PrintExpr (op->args [12 ]);
799+ std::string metadata_offset = this ->PrintExpr (op->args [13 ]);
800+ std::string sparse_selector = this ->PrintExpr (op->args [14 ]);
801+ bool saturate = Downcast<Bool>(op->args [15 ])->value ;
802+ std::string asm_code = PrintMMAAssembly (
803+ shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset,
804+ c_ref, c_offset, metadata, metadata_offset, sparse_selector, " " , true , saturate);
805+ this ->stream << asm_code;
765806 } else {
766807 CodeGenC::VisitExpr_ (op, os);
767808 }
0 commit comments