diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 530f9b8928..a59dcf805b 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2786,15 +2786,15 @@ struct CanonicalizeConvertFromAlloc mlir::LogicalResult matchAndRewrite(triton::gpu::LocalAllocOp op, - PatternRewriter &rewriter) const override { + PatternRewriter &baseRewriter) const override { if (!op.getSrc()) return failure(); auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); + PatternRewriterWithAsyncTaskIds rewriter(baseRewriter, op); auto newAlloc = rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), convert.getSrc()); - newAlloc->setAttrs(op->getAttrs()); return mlir::success(); } }; @@ -2806,13 +2806,13 @@ struct CanonicalizeConvertFromLocalStore mlir::LogicalResult matchAndRewrite(triton::gpu::LocalStoreOp op, - PatternRewriter &rewriter) const override { + PatternRewriter &baseRewriter) const override { auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); - auto store = rewriter.replaceOpWithNewOp(op, convert.getSrc(), - op.getDst()); - store->setAttrs(op->getAttrs()); + PatternRewriterWithAsyncTaskIds rewriter(baseRewriter, op); + auto store = rewriter.replaceOpWithNewOp( + op, convert.getSrc(), op.getDst()); return mlir::success(); } }; @@ -2930,10 +2930,10 @@ struct CanonicalizeConvertFromConvert // cvt(cvt(x, type1), type2) -> cvt(x, type2) if (auto cvt = dyn_cast(arg)) { auto srcType = op.getSrc().getType(); - auto origAttrs = op->getAttrs(); - auto newOp = rewriter.replaceOpWithNewOp( - op, op->getResultTypes().front(), cvt.getSrc()); - newOp->setAttrs(origAttrs); + PatternRewriterWithAsyncTaskIds rewriterTask(rewriter, cvt); + auto newOp = + rewriterTask.replaceOpWithNewOp( + op, op->getResultTypes().front(), cvt.getSrc()); return success(); }