Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2786,15 +2786,15 @@ struct CanonicalizeConvertFromAlloc

mlir::LogicalResult
matchAndRewrite(triton::gpu::LocalAllocOp op,
PatternRewriter &rewriter) const override {
PatternRewriter &baseRewriter) const override {
if (!op.getSrc())
return failure();
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
PatternRewriterWithAsyncTaskIds rewriter(baseRewriter, op);
auto newAlloc = rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
op, op->getResult(0).getType(), convert.getSrc());
newAlloc->setAttrs(op->getAttrs());
return mlir::success();
}
};
Expand All @@ -2806,13 +2806,13 @@ struct CanonicalizeConvertFromLocalStore

mlir::LogicalResult
matchAndRewrite(triton::gpu::LocalStoreOp op,
PatternRewriter &rewriter) const override {
PatternRewriter &baseRewriter) const override {
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
auto store = rewriter.replaceOpWithNewOp<triton::gpu::LocalStoreOp>(op, convert.getSrc(),
op.getDst());
store->setAttrs(op->getAttrs());
PatternRewriterWithAsyncTaskIds rewriter(baseRewriter, op);
auto store = rewriter.replaceOpWithNewOp<triton::gpu::LocalStoreOp>(
op, convert.getSrc(), op.getDst());
return mlir::success();
}
};
Expand Down Expand Up @@ -2930,10 +2930,10 @@ struct CanonicalizeConvertFromConvert
// cvt(cvt(x, type1), type2) -> cvt(x, type2)
if (auto cvt = dyn_cast<ConvertLayoutOp>(arg)) {
auto srcType = op.getSrc().getType();
auto origAttrs = op->getAttrs();
auto newOp = rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, op->getResultTypes().front(), cvt.getSrc());
newOp->setAttrs(origAttrs);
PatternRewriterWithAsyncTaskIds rewriterTask(rewriter, cvt);
auto newOp =
rewriterTask.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, op->getResultTypes().front(), cvt.getSrc());
return success();
}

Expand Down