Skip to content

Commit 7a962cd

Browse files
committed
fixed warp_coeff
1 parent a0afb56 commit 7a962cd

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

src/tir/transforms/lower_warp_memory.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ class WarpStoreCoeffFinder : private StmtExprVisitor {
115115
/// Visitor implementation
116116
void VisitExpr_(const CallNode* op) final {
117117
if (op->op.same_as(builtin::ptx_ldmatrix()) && op->args[3].as<VarNode>() == buffer_) {
118-
int num_matrix = op->args[1].as<IntImmNode>()->value;
119-
warp_coeff_ = num_matrix * 2;
118+
UpdatePattern(op->args[4]);
120119
} else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as<VarNode>() == buffer_) {
121120
auto* ptr = op->args[0].as<IntImmNode>();
122121
CHECK(ptr);
@@ -499,7 +498,7 @@ Pass LowerWarpMemory() {
499498
WarpMemoryRewriter warp_memory_rewriter(warp_size);
500499
auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body));
501500
n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt);
502-
// LOG(INFO) << f;
501+
LOG(INFO) << f;
503502
return f;
504503
};
505504
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});

0 commit comments

Comments
 (0)