-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[AMD] Enable ds_read_tr* lowering for PartitionedSharedEncodingAttr #10062
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,10 +31,6 @@ class TransLocalLoadOpConversion | |
| MemDescType srcTy = op.getSrc().getType(); | ||
| RankedTensorType dstTy = op.getType(); | ||
|
|
||
| // Partitioned tensors have multiple bases; fall back to generic lowering. | ||
| if (isa<triton::gpu::PartitionedSharedEncodingAttr>(srcTy.getEncoding())) { | ||
| return failure(); | ||
| } | ||
| auto typeConverter = this->getTypeConverter(); | ||
| auto llvmElemTy = typeConverter->convertType(dstTy.getElementType()); | ||
| unsigned bitWidth = llvmElemTy.getIntOrFloatBitWidth(); | ||
|
|
@@ -48,18 +44,17 @@ class TransLocalLoadOpConversion | |
| if (!ldsParams) | ||
| return failure(); | ||
|
|
||
| auto paddedEnc = | ||
| dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(srcTy.getEncoding()); | ||
| LinearLayout cvtDstLL = LinearLayout::empty(); | ||
| if (paddedEnc) { | ||
| const auto &sharedLL = paddedEnc.getLinearComponent(); | ||
| cvtDstLL = triton::gpu::toLinearLayout(dstTy).invertAndCompose(sharedLL); | ||
| if (paddedEnc.getMinInterval() < ldsParams->tileSize) | ||
| LinearLayout sharedLL; | ||
| if (triton::gpu::isPaddedEncoding(srcTy.getEncoding())) { | ||
| sharedLL = triton::gpu::paddedLinearLayout(srcTy); | ||
| if (triton::gpu::getMinInterval(srcTy.getEncoding()) < | ||
| ldsParams->tileSize) | ||
| return failure(); | ||
| } else { | ||
| auto sharedLL = triton::gpu::toLinearLayout(srcTy); | ||
| cvtDstLL = triton::gpu::toLinearLayout(dstTy).invertAndCompose(sharedLL); | ||
| sharedLL = triton::gpu::toLinearLayout(srcTy); | ||
| } | ||
| LinearLayout cvtDstLL = | ||
| triton::gpu::toLinearLayout(dstTy).invertAndCompose(sharedLL); | ||
| auto kBlock = StringAttr::get(ctx, "block"); | ||
| auto maybeSublayout = cvtDstLL.quotient({kBlock}); | ||
| if (!maybeSublayout) { | ||
|
|
@@ -68,7 +63,7 @@ class TransLocalLoadOpConversion | |
| cvtDstLL = maybeSublayout.value(); | ||
| auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), | ||
| llvmElemTy, rewriter); | ||
| auto smemBase = smemObj.getBase(); | ||
| SmallVector<Value> smemBases = llvm::to_vector(smemObj.getBases()); | ||
| auto affineOffset = smemObj.getShmemOffset(loc, rewriter, srcTy); | ||
| auto maskSpanAffineOffset = smemObj.getMaskSpanOffsets(srcTy); | ||
| auto paddingShifts = getPaddedSharedShifts(srcTy.getEncoding(), | ||
|
|
@@ -77,7 +72,7 @@ class TransLocalLoadOpConversion | |
|
|
||
| llvm::SmallVector<Value> values; | ||
| auto result = lowerDsReadTr( | ||
| op, ldsParams.value(), loc, cvtDstLL, values, smemBase, affineOffset, | ||
| op, ldsParams.value(), loc, cvtDstLL, values, smemBases, affineOffset, | ||
| maskSpanAffineOffset, paddingShifts, llvmElemTy, rewriter, targetInfo); | ||
| if (failed(result)) { | ||
| return failure(); | ||
|
|
@@ -92,15 +87,15 @@ class TransLocalLoadOpConversion | |
| } | ||
|
|
||
| private: | ||
| LogicalResult lowerDsReadTr( | ||
| triton::gpu::LocalLoadOp op, | ||
| ::triton::AMD::TargetInfo::LDSTransLoadParams ldsParams, Location loc, | ||
| LinearLayout cvt, | ||
| SmallVector<Value> &vals, // Input for stmatrix, output for ldmatrix | ||
| Value smemBase, Value affineOffset, uint64_t maskSpanAffineOffset, | ||
| ArrayRef<std::pair<unsigned, unsigned>> paddingShifts, Type llvmElemTy, | ||
| ConversionPatternRewriter &rewriter, | ||
| const ::triton::AMD::TargetInfo &targetInfo) const { | ||
| LogicalResult | ||
| lowerDsReadTr(triton::gpu::LocalLoadOp op, | ||
| ::triton::AMD::TargetInfo::LDSTransLoadParams ldsParams, | ||
| Location loc, LinearLayout cvt, SmallVector<Value> &vals, | ||
| ArrayRef<Value> smemBases, Value affineOffset, | ||
| uint64_t maskSpanAffineOffset, | ||
| ArrayRef<std::pair<unsigned, unsigned>> paddingShifts, | ||
| Type llvmElemTy, ConversionPatternRewriter &rewriter, | ||
| const ::triton::AMD::TargetInfo &targetInfo) const { | ||
|
|
||
| auto b = TritonLLVMOpBuilder(loc, rewriter); | ||
| auto *ctx = rewriter.getContext(); | ||
|
|
@@ -111,9 +106,30 @@ class TransLocalLoadOpConversion | |
| auto kWarp = S("warp"); | ||
| auto kOffset = S("offset"); | ||
| auto kAddr = S("addr"); | ||
| auto kPartition = S("partition"); | ||
| auto smemPtrTy = ptr_ty(ctx, 3); | ||
| auto bitWidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); | ||
|
|
||
| assert(!smemBases.empty() && "expected at least one smem base"); | ||
| LinearLayout cvtLayout = cvt; | ||
| LinearLayout partitionLayout; | ||
| Value basesVec; | ||
| const bool isPartitioned = smemBases.size() > 1; | ||
|
|
||
| if (isPartitioned) { | ||
| assert(cvtLayout.hasOutDim(kPartition) && | ||
| cvtLayout.getOutDimSize(kPartition) == | ||
| static_cast<int32_t>(smemBases.size()) && | ||
| "smemBases size must match partition dimension size"); | ||
| auto inDimNames = llvm::to_vector(cvtLayout.getInDimNames()); | ||
| partitionLayout = cvtLayout.sublayout(inDimNames, {kPartition}); | ||
| SmallVector<StringAttr> outDims = | ||
| llvm::to_vector(cvtLayout.getOutDimNames()); | ||
| llvm::erase(outDims, kPartition); | ||
| cvtLayout = cvtLayout.sublayout(inDimNames, outDims); | ||
| basesVec = LLVM::buildBasePtrVector(loc, rewriter, smemBases); | ||
| } | ||
|
|
||
| // Map onto offsets (contiguous part) and addr (non-contiguous part) | ||
| LinearLayout fullTile; | ||
| // Contiguous tile | ||
|
|
@@ -170,11 +186,11 @@ class TransLocalLoadOpConversion | |
| // Add warp dimension so we can invert and compose with reps later | ||
| fullTile *= LinearLayout::identity1D(1, kWarp, kAddr); | ||
|
|
||
| if (cvt.getInDimSize(kReg) < fullTile.getInDimSize(kReg)) { | ||
| if (cvtLayout.getInDimSize(kReg) < fullTile.getInDimSize(kReg)) { | ||
| return failure(); | ||
| } | ||
|
|
||
| auto maybeQuot = divideLeft(cvt, tile); | ||
| auto maybeQuot = divideLeft(cvtLayout, tile); | ||
| if (!maybeQuot.has_value()) { | ||
| return failure(); | ||
| } | ||
|
|
@@ -209,6 +225,22 @@ class TransLocalLoadOpConversion | |
| auto [nAdditive, permStrides] = | ||
| actionAdditiveStrides(reps, addrLayout, maskSpanAffineOffset); | ||
| reps = permStrides.apply(reps); | ||
| if (isPartitioned) { | ||
| partitionLayout = permStrides.apply(partitionLayout); | ||
|
|
||
| // One ds_read_tr* instruction produces `fullTile.getInDimSize(kReg)` | ||
| // consecutive register values from a single LDS base pointer. We only | ||
| // select a partition once per instruction, so all of those register | ||
| // positions must map to the same partition. For a LinearLayout that holds | ||
| // iff the low log2(elemsPerInstr) register bases contribute 0 to | ||
| // kPartition. Bail out if not, so a generic lowering can take over. | ||
| const unsigned numInstrRegBits = | ||
| llvm::Log2_32(fullTile.getInDimSize(kReg)); | ||
| for (unsigned pos = 0; pos < numInstrRegBits; ++pos) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just checking
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, because this would check whole reg bases of partition layout, which would include repetitions. The point is that we want to check just first numInstrRegBits, which are number of register from fullTile layout, which is one instruction. It's fine for different repetitions (instructions) to be in different partitions, but we want to check if registers from single instruction are in different partition.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For reference, if want it without looking at the bases, you can do that by reshaping
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah makes sense. Thanks for the explanation. |
||
| if (partitionLayout.getBasis(kReg, pos, kPartition) != 0) | ||
| return failure(); | ||
| } | ||
|
Comment on lines
+231
to
+242
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I should probably add same check in regular lowering path as well. |
||
| } | ||
|
|
||
| // Perform computation in bytes, LLVM optimises this better | ||
| assert(bitWidth >= 8); | ||
|
|
@@ -282,7 +314,7 @@ class TransLocalLoadOpConversion | |
| auto elemsPerInstr = fullTile.getInDimSize(kReg); | ||
| auto elemsPerVec = ldsParams.instBitWidth / bitWidth; | ||
| auto vecTy = vec_ty(llvmElemTy, elemsPerVec); | ||
| for (int i = 0; i < cvt.getInDimSize(kReg); i += nAdditive) { | ||
| for (int i = 0; i < cvtLayout.getInDimSize(kReg); i += nAdditive) { | ||
| auto regIdx = reps.apply({{kReg, i}, {kLane, 0}, {kWarp, 0}})[0].second; | ||
| auto regIdxI8 = regIdx * (bitWidth / 8); | ||
| Value offset = b.xor_(regBase, b.i32_val(regIdxI8)); | ||
|
|
@@ -297,14 +329,21 @@ class TransLocalLoadOpConversion | |
| // separately. | ||
| regIdxAddI8 = applyPadding(regIdxAddI8, paddingShifts); | ||
| Value innerOffset = b.add(offset, b.i32_val(regIdxAddI8)); | ||
| auto vecAddr = b.gep(smemPtrTy, i8_ty, smemBase, innerOffset, | ||
| Value smemBaseVal = smemBases[0]; | ||
| if (isPartitioned) { | ||
| auto partOut = applyLinearLayout( | ||
| loc, rewriter, partitionLayout, | ||
| {{kReg, b.i32_val(i + i2)}, {kLane, laneId}, {kWarp, warpId}}); | ||
| smemBaseVal = b.extract_element(basesVec, partOut[0].second); | ||
| } | ||
| auto vecAddr = b.gep(smemPtrTy, i8_ty, smemBaseVal, innerOffset, | ||
| LLVM::GEPNoWrapFlags::inbounds); | ||
| llvm::append_range(vals, | ||
| lowerInst(rewriter, loc, vecAddr, i + i2, vecTy)); | ||
| } | ||
| } | ||
| // apply all the inverse permutations in the reverse order | ||
| assert(vals.size() == cvt.getInDimSize(kReg)); | ||
| assert(vals.size() == cvtLayout.getInDimSize(kReg)); | ||
| vals = permStrides.inverse().apply(vals); | ||
|
|
||
| return success(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed this comment on purpose, as it seems as an accidental copy/paste from nv path