@@ -822,16 +822,26 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
822822 this ->stream << PrintLoadMatrixAssembly (trans, num, type, local_ptr, local_elem_offset,
823823 smem_ptr, smem_elem_offset);
824824 } else if (op->op .same_as (builtin::mma_store ())) {
825+ int m = Downcast<Integer>(op->args [1 ])->value ;
826+ int n = Downcast<Integer>(op->args [1 ])->value ;
825827 std::string dst = this ->PrintExpr (op->args [2 ]);
826828 std::string src = this ->PrintExpr (op->args [3 ]);
827829 std::string src_offset = this ->PrintExpr (op->args [4 ]);
828830 std::string stride = this ->PrintExpr (op->args [5 ]);
829831
830- os << " for (int i = 0; i < 4; ++i) {\n " ;
831- os << dst << " [(i / 2 * 8 + threadIdx.x / 4) * " << stride
832- << " + (threadIdx.x % 4) * 2 + i % 2]"
833- << " = " << src << " [" << src_offset << " + i];\n " ;
834- os << " }\n " ;
832+ if (m == 16 && n == 8 ) {
833+ os << " for (int i = 0; i < 4; ++i) {\n " ;
834+ os << dst << " [(i / 2 * 8 + threadIdx.x / 4) * " << stride
835+ << " + (threadIdx.x % 4) * 2 + i % 2]"
836+ << " = " << src << " [" << src_offset << " + i];\n " ;
837+ os << " }\n " ;
838+ } else if (m == 16 && n == 16 ) {
839+ os << " for (int i = 0; i < 8; ++i) {\n " ;
840+ os << dst << " [(i / 4 * 8 + threadIdx.x / 4) * " << stride
841+ << " + (threadIdx.x % 4) * 4 + i % 4]"
842+ << " = " << src << " [" << src_offset << " + i];\n " ;
843+ os << " }\n " ;
844+ }
835845 } else if (op->op .same_as (builtin::mma_fill ())) {
836846 std::string num_elem = this ->PrintExpr (op->args [0 ]);
837847 std::string dst = this ->PrintExpr (op->args [1 ]);
0 commit comments