From 6c9f09664881910525bfd8a40a34a41b99539acb Mon Sep 17 00:00:00 2001 From: gzhu Date: Thu, 22 Sep 2022 23:48:55 -0700 Subject: [PATCH] [Triton-MLIR][Backend] Revesmem allocation for non-scratch convert_layout --- lib/Analysis/Allocation.cpp | 14 +++----------- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 17 ++++++----------- 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 0ea29afdcf97..6afa7ea1aa6a 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -43,7 +43,6 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, return 0; } }; - // blocked -> blocked if (srcLayout.isa() && dstLayout.isa()) { auto srcBlockedLayout = srcLayout.cast(); @@ -66,14 +65,6 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, } paddedRepShape[outOrd[0]] += pad; } - // blocked -> shared - if (srcLayout.isa() && - dstLayout.isa()) { - auto sharedLayout = dstLayout.cast(); - for (int v : dstTy.getShape()) - paddedRepShape.push_back(v); - } - return paddedRepShape; } @@ -140,8 +131,9 @@ class AllocationAnalysis { auto dstTy = cvtLayout.result().getType().cast(); auto srcEncoding = srcTy.getEncoding(); auto dstEncoding = dstTy.getEncoding(); - if (srcEncoding.isa()) { - // only block->block and block->shared is supported now + if (srcEncoding.isa() || + dstEncoding.isa()) { + // Only blocked -> blocked conversion requires for scratch allocation return; } // ConvertLayoutOp with both input/output non-shared_layout diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index eec72293522c..7dbb01fd4c9c 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -585,12 +585,13 @@ class ConvertTritonGPUOpToLLVMPattern return multiDimIdx; } + template Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter, Value smem, const Allocation *allocation, - Operation *op) const { + T value) const { auto ptrTy = LLVM::LLVMPointerType::get( this->getTypeConverter()->convertType(rewriter.getIntegerType(8)), 3); - auto bufferId = allocation->getBufferId(op); + auto bufferId = allocation->getBufferId(value); assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); size_t offset = allocation->getOffset(bufferId); auto llvmIndexTy = this->getTypeConverter()->getIndexType(); @@ -1399,8 +1400,6 @@ struct ConvertLayoutOpConversion if ((!srcLayout.isa()) || (!dstLayout.isa())) { // TODO: not implemented - llvm::errs() - << "convert_layout except for blocked -> blocked is not implemented"; return failure(); } auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); @@ -1996,12 +1995,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { return failure(); } - Value getSmemAddr(Value value, Location loc, - ConversionPatternRewriter &rewriter) const { - return getSharedMemoryBase(loc, rewriter, smem, allocation, - value.getDefiningOp()); - } - const Allocation *allocation; Value smem; }; @@ -2340,7 +2333,9 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, SmallVector ptrs(numPtrs); Type smemPtrTy = helper.getShemPtrTy(); - auto smemBase = getSmemAddr(tensor, loc, rewriter); + auto smemBase = + getSharedMemoryBase(loc, rewriter, smem, allocation, tensor); + for (int i = 0; i < numPtrs; i++) { ptrs[i] = bit_cast( smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]})));