-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[AMD] Add initial support for scaled_dot(mxfp8, fp8) #4994
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
antiagainst
merged 5 commits into
triton-lang:main
from
antiagainst:amd-mxfp-scaled-dot
Oct 28, 2024
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
305e1ba
[AMD] Add initial support for scaled_dot(mxfp8, fp8)
antiagainst f1349db
Enable certain tests
antiagainst 27d403b
Drop packing logic given we process standalone elements
antiagainst 88d2739
Adjust UpcastMXFPOp op verification
antiagainst 06e1982
Test more configurations
antiagainst File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
146 changes: 146 additions & 0 deletions
146
third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| #include "PatternTritonGPUOpToLLVM.h" | ||
|
|
||
| #include "mlir/Conversion/LLVMCommon/Pattern.h" | ||
| #include "mlir/IR/BuiltinOps.h" | ||
| #include "mlir/IR/TypeUtilities.h" | ||
| #include "mlir/IR/ValueRange.h" | ||
| #include "mlir/Transforms/DialectConversion.h" | ||
| #include "triton/Conversion/TritonGPUToLLVM/Utility.h" | ||
| #include "triton/Dialect/Triton/IR/Dialect.h" | ||
| #include "triton/Dialect/TritonGPU/IR/Attributes.h" | ||
| #include "llvm/ADT/STLExtras.h" | ||
| #include "llvm/ADT/SmallVector.h" | ||
| #include "llvm/Support/Debug.h" | ||
| #include <array> | ||
|
|
||
| using namespace mlir; | ||
| using namespace mlir::triton; | ||
| using namespace mlir::triton::gpu; | ||
|
|
||
| namespace { | ||
|
|
||
| Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, | ||
| Value scale) { | ||
| Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); | ||
| Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); | ||
| Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty); | ||
| Value scaledBf16 = fmul(v, scaleBf16); | ||
| // Account for NaN in the scale as per the mxfp specification. | ||
| return select(scaleIsNan, nanBf16, scaledBf16); | ||
| }; | ||
|
|
||
| class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> { | ||
| private: | ||
| const TargetInfoBase &targetInfo; | ||
|
|
||
| public: | ||
| UpcastMXFPOpPattern(LLVMTypeConverter &typeConverter, | ||
| const TargetInfoBase &targetInfo, PatternBenefit benefit) | ||
| : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { | ||
| } | ||
|
|
||
| LogicalResult | ||
| matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| auto fpType = op.getFpType(); | ||
| if (!(fpType == F8F6F4Type::E4M3 || fpType == F8F6F4Type::E5M2)) | ||
| return rewriter.notifyMatchFailure(op, "NYI: non-mxfp8 cases"); | ||
|
|
||
| Location loc = op.getLoc(); | ||
| auto xVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); | ||
| auto scaleVals = unpackLLElements(loc, adaptor.getScale(), rewriter); | ||
| LDBG("x: " << xVals.size() << " x " << xVals.front().getType()); | ||
| LDBG("scale: " << scaleVals.size() << " x " << scaleVals.front().getType()); | ||
|
|
||
| // When we lower scaled dot op, we made sure to distribute K only on one | ||
| // warp. MXFP spec mandates 1 scale value for every 32 onsecutive values | ||
| // along the K dimension. So in total each thread should read 32x main | ||
| // element values. | ||
| if (xVals.size() != scaleVals.size() * 32) | ||
| return rewriter.notifyMatchFailure(op, "unsupported problem size"); | ||
|
|
||
| auto dotEncoding = | ||
| cast<DotOperandEncodingAttr>(op.getSrc().getType().getEncoding()); | ||
| if (dotEncoding.getOpIdx() == 1) | ||
| return rewriter.notifyMatchFailure(op, "NYI: dot RHS"); | ||
| auto mfmaEncoding = dyn_cast<AMDMfmaEncodingAttr>(dotEncoding.getParent()); | ||
| if (!mfmaEncoding) | ||
| return rewriter.notifyMatchFailure(op, "NYI: non-mfma dot operand"); | ||
| LDBG("mfma: " << mfmaEncoding); | ||
|
|
||
| int mDim = mfmaEncoding.getMDim(); | ||
| if (mDim != 32 && mDim != 16) | ||
| return rewriter.notifyMatchFailure(op, "NYI: non-mfma32/16 intrinsics"); | ||
|
|
||
| int numThreads = triton::gpu::TritonGPUDialect::getThreadsPerWarp( | ||
| op->getParentOfType<ModuleOp>()); | ||
| Value warpSize = i32_val(numThreads); | ||
| Value tid = tid_val(); | ||
| Value warpId = udiv(tid, warpSize); | ||
| Value laneId = urem(tid, warpSize); | ||
|
|
||
| // Given that MFMA layout for the A tensor arranges thread in a column-major | ||
| // manner, for the current tid, it's at row (tid % mDim). When we set up | ||
| // blocked layout for the A scale tensor, we made sure that it has a | ||
| // threadsPerWarp = [M=mDim, K=64/mDim]. So the threads holding scale values | ||
| // for the current thread starts at ((tid % mDim) * (64 / mDim)). | ||
| Value offset = mul(urem(laneId, i32_val(mDim)), i32_val(numThreads / mDim)); | ||
|
|
||
| if (mDim == 32) { | ||
| // One mfma32 intrinsic processes a 32x8 A tensor slice. Due to how we | ||
| // tile, the same warp owns the whole K dim. Inside a warp, each thread | ||
| // only holds 4 consecutive elements along K--a 1x4 vector. We need to | ||
| // tile the warp 4 times to cover 32 values along K. So for a thread, the | ||
| // first 4 1x4 vectors it holds shares the first scale value at row (tid % | ||
| // mDim). the second 4 1x4 vectors shares the second scale value at row | ||
| // (tid % mDim); and so forth. | ||
| std::array<Value, 2> scaleThreads = {offset, add(offset, i32_val(1))}; | ||
|
|
||
| for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { | ||
| std::array<Value, 2> si = { | ||
| targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]), | ||
| targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]), | ||
| }; | ||
|
|
||
| for (int j = 0; j < 32; ++j) { | ||
| int index = 32 * i + j; | ||
| xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 16]); | ||
| } | ||
| } | ||
| } else { | ||
| assert(mDim == 16); | ||
| // One mfma16 intrinsic processes a 16x16 A tensor slice. Similarly, we | ||
| // need to tile the warp 2 times to cover 32 valeus. So for a thread, the | ||
| // first 2 1x4 vectors shares the first scale value at row (tid % mDim). | ||
| std::array<Value, 4> scaleThreads = {offset, add(offset, i32_val(1)), | ||
| add(offset, i32_val(2)), | ||
| add(offset, i32_val(3))}; | ||
|
|
||
| for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { | ||
| auto si = std::array<Value, 4>{ | ||
| targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]), | ||
| targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]), | ||
| targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[2]), | ||
| targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[3]), | ||
| }; | ||
|
|
||
| for (int j = 0; j < 32; ++j) { | ||
| int index = 32 * i + j; | ||
| xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Value result = | ||
| packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType()); | ||
| rewriter.replaceOp(op, result); | ||
| return success(); | ||
| } | ||
| }; | ||
| } // anonymous namespace | ||
|
|
||
| void mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns( | ||
| LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, | ||
| const TargetInfo &targetInfo, PatternBenefit benefit) { | ||
| patterns.add<UpcastMXFPOpPattern>(typeConverter, targetInfo, benefit); | ||
| } |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.