Skip to content

Commit 18e8d73

Browse files
committed
fixed mma store codegen for 16x8x16
1 parent ec81250 commit 18e8d73

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/target/source/codegen_cuda.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -836,10 +836,12 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
836836
<< " = " << src << "[" << src_offset << " + i];\n";
837837
os << "}\n";
838838
} else if (m == 16 && n == 16) {
839-
os << "for (int i = 0; i < 8; ++i) {\n";
840-
os << dst << "[(i / 4 * 8 + threadIdx.x / 4) * " << stride
841-
<< " + (threadIdx.x % 4) * 4 + i % 4]"
842-
<< " = " << src << "[" << src_offset << " + i];\n";
839+
os << "for (int outer = 0; outer < 2; ++outer) {\n";
840+
os << "for (int i = 0; i < 4; ++i) {\n";
841+
os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride
842+
<< " + outer * 8 + (threadIdx.x % 4) * 2 + i % 2]"
843+
<< " = " << src << "[" << src_offset << " + i * outer * 4];\n";
844+
os << "}\n";
843845
os << "}\n";
844846
}
845847
} else if (op->op.same_as(builtin::mma_fill())) {
@@ -848,7 +850,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
848850
std::string dst_offset = this->PrintExpr(op->args[2]);
849851

850852
os << "for (int i = 0; i < " << num_elem << "; ++i) {\n";
851-
os << dst << "[" << dst_offset << " + i] = 0.0;" ;
853+
os << dst << "[" << dst_offset << " + i] = 0.0;";
852854
os << "}\n";
853855
} else {
854856
CodeGenC::VisitExpr_(op, os);

0 commit comments

Comments
 (0)