Skip to content

Commit ec81250

Browse files
committed
add 16x8x16 mma store codegen
1 parent e08df2a commit ec81250

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

src/target/source/codegen_cuda.cc

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -822,16 +822,26 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
822822
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
823823
smem_ptr, smem_elem_offset);
824824
} else if (op->op.same_as(builtin::mma_store())) {
825+
int m = Downcast<Integer>(op->args[1])->value;
826+
int n = Downcast<Integer>(op->args[1])->value;
825827
std::string dst = this->PrintExpr(op->args[2]);
826828
std::string src = this->PrintExpr(op->args[3]);
827829
std::string src_offset = this->PrintExpr(op->args[4]);
828830
std::string stride = this->PrintExpr(op->args[5]);
829831

830-
os << "for (int i = 0; i < 4; ++i) {\n";
831-
os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride
832-
<< " + (threadIdx.x % 4) * 2 + i % 2]"
833-
<< " = " << src << "[" << src_offset << " + i];\n";
834-
os << "}\n";
832+
if (m == 16 && n == 8) {
833+
os << "for (int i = 0; i < 4; ++i) {\n";
834+
os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride
835+
<< " + (threadIdx.x % 4) * 2 + i % 2]"
836+
<< " = " << src << "[" << src_offset << " + i];\n";
837+
os << "}\n";
838+
} else if (m == 16 && n == 16) {
839+
os << "for (int i = 0; i < 8; ++i) {\n";
840+
os << dst << "[(i / 4 * 8 + threadIdx.x / 4) * " << stride
841+
<< " + (threadIdx.x % 4) * 4 + i % 4]"
842+
<< " = " << src << "[" << src_offset << " + i];\n";
843+
os << "}\n";
844+
}
835845
} else if (op->op.same_as(builtin::mma_fill())) {
836846
std::string num_elem = this->PrintExpr(op->args[0]);
837847
std::string dst = this->PrintExpr(op->args[1]);

0 commit comments

Comments
 (0)