-
Notifications
You must be signed in to change notification settings - Fork 10
LHS Registers Part 1 - DotOp Hoisting and SMEM-RF Copy Lowering #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Addressed all comments in the original PR that are relevant to part 1 in this PR instead. |
| import torch | ||
| import os | ||
| os.environ['TRITON_ALWAYS_COMPILE'] = '1' | ||
| os.environ['MLIR_ENABLE_DUMP'] = '1' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like these were leftover from debugging
04ed621 to
b7e2df0
Compare
3596dc5 to
10d3305
Compare
|
Ops, seems that updates we have to maintain the Triton integration will cause PR diffs to break because of force-updates. We might need to figure out a better way to handle this as we didn't intend for this repo to accept incoming PRs. Apologies for this, but you will need to rebase again for the diff to include the proper changes. |
1b95c9a to
942dad4
Compare
|
@Moerafaat np - I've reapplied my changes on the new main |
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Outdated
Show resolved
Hide resolved
lezcano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nb. I have not reviewed OptimizeDotOperands.cpp, @ThomasRaoux should probably review that part once you put up a PR against the triton repo.
First of all, hats off if you've managed to fix SharedToDotOperand for kWidth != 2! It was an absolute madness of indices.
Now, @Jokeren is working on completely removing this devilish path in favour of the cleaner and more correct conversion via linear layouts. I think a more long-term solution here would be to implement linear layout support for DotOperand for Hopper, but we can do this in a follow-up PR.
| // To unify the ordering conventions would potentially require touching | ||
| // `ConvertLayoutOpToLLVM.cpp`, `ElementwiseOpToLLVM.cpp`, `MMAv2.cpp`, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to be fixed (at last!) in triton-lang#4951. Hopefully it'll get merged today, tomorrow at latest.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It just landed: triton-lang#4979
| if (isHopperWidthChange) { | ||
| vecWidth = 4 / mmaElemBytes; | ||
| } else { | ||
| vecWidth = 4 / elemBytes; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean vecWidth = kWidth here? This way, it would work for kWidth > 4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just left this comment below (#18 (comment)) which I think also applies to this question.
In what case would we have kWidth > 4 though? For Ampere and Hopper we should have 4 as the maximum right (with int8/fp8 dtype)
| // width-changing casting is done later in DotOp Layout... then, in the case of | ||
| // Hopper, the number of bytes held by each thread after loading will no longer | ||
| // be 32B. Hence this flag is required to stipulate different logic. | ||
| bool isHopperWidthChange = isHopper && (mmaElemBytes != elemBytes); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also applies to Ampere, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, and it makes me wonder now why this was working for Ampere before without this flag.
I'll inspect this code a bit more and come up with an explanation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The answer is that it just didn't work before. That's one of the reasons why we started migrating everything to Linear Layouts, and it's that code written manually has very subtle bugs like the ones you fixed in this PR :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update - I was testing int8 -> f16 for Ampere and was surprised at first to find that it was working correctly.
I've now come up with an explanation. You probably already know most of the following, but I'd just like to confirm we're on the same page:
If I have AxB where A is int8 and B is f16, with A cast to f16 before MMA, AccelerateMatmul.cpp will compute kWidth for both operands as 4. This is due to the logic in computeOrigBitWidth, which calculates the kWidth using the bitwidth of the smallest element along the dot-op chain. In my case this smallest element is int8, so kWidth = 32 / 8 = 4.
But the eventual layout right before MMA is f16, so the kWidth should be 2. So, it seems at first glance that with kWidth = 4, the results should be wrong. For example, for operand A, f16 layout expects thread 0 to hold elements
(0, {0, 1}), (8, {0, 1}), (0, {8, 9}), (8, {8, 9}) # (m_index, {k_indices...})
...but in reality, we load it with kWidth = 4, meaning
(0, {0, 1, 2, 3}), (8, {0, 1, 2, 3}), (0, {8, 9, 10, 11}), (8, {8, 9, 10, 11})
(I've attached Ampere dotOp layouts for int8 and f16 at the end)
So each "rep" of A here doesn't actually mean a "rep" of the actual f16 MMA instruction, but instead corresponds to 2 reps.
OTOH, B is also loaded with kWidth = 4, but with vecWidth = 2, so the first rep of B is loaded in this order:
(n_offset=0, k_offset={0, 1}), (n_offset=0, k_offset={8, 9})
and the second rep...
(n_offset=0, k_offset={2, 3}), (n_offset=0, k_offset={10, 11})
To match the ordering of elements in A and B, the lowering of int8 -> f16 will reorder the values of A, so that every 16 values of A can be split into 2x 8 values, i.e. 2 reps:
(m_offset=0, k_offset={0, 1}), (8, {0, 1}), (0, {8, 9}), (8, {8, 9})
(m_offset=0, k_offset={2, 3}), (8, {2, 3}), (0, {10, 11}), (8, {10, 11})
I think the key observation here was that the element ordering along K doesn't have to match what the PTX doc prescribes, as long as it's consistent between A and B.
To conclude, the logic here works for Ampere thanks to reorderValues (which has logic for 8b <> 16b and 16b <> 32b). OTOH, for WGMMA, I'd have to previously set kWidth = 2 and have special logic here to load the correct number of values; this is because operand B is always in shmem and doesn't have the kWidth/vecWidth logic above for Ampere's B.
My thought is that the reordering trick might allow for more vectorization and so might actually be worth implementing for Hopper. For this PR though, things look to be functionally correct for both Hopper and Ampere, so should I just leave the logic here as-is for now? (or unless there are other things, which I'm not aware of, that didn't work for Ampere that my Hopper fix here could extend to?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a pretty neat analysis!
The issue here is that I pretty much hit these bugs when doing fp4 x bf16, so, following the logic for fp8 x bf16, I chose kWidth=8, but then everything breaks as these computations break with kWidth > 4. Using kWidth=1 and kWidth=4 I was hitting a similar bug as well.
In general, the issue with most of these computations is that they are done in terms of element sizes (which is a tensor-level concept) instead of kWidth, which is a layout-level concept. For example, the same vectorisation optimisations could be done by looking at the kWidth themselves. cc @Jokeren who is trying to clean-up all this.
That being said, we have what we have, so our current attack route to clean all this mess is using LinearLayouts. The idea moving forward is to implement LinearLayout conversions for all the layouts we have. LinearLayouts are easy to prove to be correct, and all these optimisations can be implemented rather cleanly at a layout level.
With all this I want to say that it's probably fine for this code not to support Hopper with kWidth > 4, as we are aiming to delete it anyway in favour of our LinearLayout path.
| // matK (k-index of the "warp matrix") | ||
| // quadK (k-index of the "quad" in the core matrix) | ||
| // quadM (m-index of the "quad" in the core matrix) | ||
| // vecIdx (index of the element in the quad; this is always along the k-dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you need to support kWidth in this file? Or are you expecting that kWidth = 4 / elemSize at this stage? If so, could you add an assert?
Note that this needn't be the case, and we could have a larger kWidth, and have this code emit kWidth / (4 / elemSize) wgmma ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, I guess that it's just the lhs that supports funny kWidths, so at this stage the kWidth should agree. It would be good to assert that though.
| int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / mmaBitwidth; | ||
| int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / mmaBitwidth; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably write 2 * kWidth rather than 64 / mmaBitwidth for both Hopper and Ampere, as this is in both cases the number of elements along K.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point I believe kWidth may not equal 32 / mmaBitwidth for Ampere.
In the example in my comment above, operand B will have dtype = f16 but kWidth = 4. I think that in this case, matShapeK should be calculated based on the dtype bitwidth, so that the element ordering is correct.
| assert(isAmpere() || isHopper()); | ||
| auto rank = shape.size(); | ||
| auto warpsPerCTA = getWarpsPerCTA(); | ||
| SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about the n value here in Hopper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, after having a thought about this, I guess that in all these places that depend on instrN, it's fine to leave n = 8, as the n pattern repeats with period 8 for the different tile sizes.
It would be good to leave a note somewhere in the code and refer to that note in all the places where we use this "trick" (here, in the SharedEncodingAttr, in the shared to dot operand mma, etc).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Hopper actually opIdx should always be 0 since WGMMA doesn't support operand B in registers, and so we shouldn't ever use the n value.
I can add an assert here for opIdx == 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right! Adding an assert would be great.
|
Also, could you |
|
@ggengnv do you know roughly when will you be able to rebase this on top of main and address the review? I think we would like to incorporate this work upstream sooner than later. If you are busy, I can help with the rebasing and addressing the review. |
|
@lezcano Hey, sorry for taking a bit on this. I was busy last week but I'm rebasing this now. Aiming to get this done soon today; will keep you updated. |
|
PR transferred to triton-lang: triton-lang#5003 |
da8895b to
c8f89a6
Compare
|
merged into upstream triton |
…leaveTMem.cpp (triton-lang#7924) `TritonNvidiaGPU/interleave_tmem.mlir` fails under address sanitizer. The `ConstantIntOp` operations were created without attachment to any block in https://github.com/triton-lang/triton/pull/7622, which caused a memory leak. This change addresses the problem by adding an insertion point. <details open> <summary>Full log</summary> ================================================================= ==3831==ERROR: LeakSanitizer: detected memory leaks Direct leak of 576 byte(s) in 6 object(s) allocated from: #0 0x55c3eca39164 in malloc [third_party/llvm/llvm-project/compiler-rt/lib/asan/asan_malloc_linux.cpp:67](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/compiler-rt/lib/asan/asan_malloc_linux.cpp?l=67&ws=tap-presubmit-server/421956858&snapshot=2):3 #1 0x55c3f176afb3 in mlir::Operation::create(mlir::Location, mlir::OperationName, mlir::TypeRange, mlir::ValueRange, mlir::DictionaryAttr, mlir::OpaqueProperties, mlir::BlockRange, unsigned int) [third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp:113](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp?l=113&ws=tap-presubmit-server/421956858&snapshot=2):46 #2 0x55c3f176a90c in create [third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp:74](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp?l=74&ws=tap-presubmit-server/421956858&snapshot=2):10 #3 0x55c3f176a90c in mlir::Operation::create(mlir::Location, mlir::OperationName, mlir::TypeRange, mlir::ValueRange, mlir::NamedAttrList&&, mlir::OpaqueProperties, mlir::BlockRange, mlir::RegionRange) [third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp:57](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp?l=57&ws=tap-presubmit-server/421956858&snapshot=2):7 #4 0x55c3f176a61b in mlir::Operation::create(mlir::OperationState const&) [third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp:35](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp?l=35&ws=tap-presubmit-server/421956858&snapshot=2):7 #5 0x55c3f1678a78 in mlir::OpBuilder::create(mlir::OperationState const&) [third_party/llvm/llvm-project/mlir/lib/IR/Builders.cpp:453](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Builders.cpp?l=453&ws=tap-presubmit-server/421956858&snapshot=2):17 #6 0x55c3ecf3668f in mlir::arith::ConstantIntOp mlir::OpBuilder::create<mlir::arith::ConstantIntOp, int, int>(mlir::Location, int&&, int&&) [third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h:507](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h?l=507&ws=tap-presubmit-server/421956858&snapshot=2):16 #7 0x55c3eefa690a in findBufferAccessMemdescSubview [third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:75](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=75&ws=tap-presubmit-server/421956858&snapshot=2):33 #8 0x55c3eefa690a in mlir::triton::nvidia_gpu::(anonymous namespace)::findBufferAccess(mlir::Value) [third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:151](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=151&ws=tap-presubmit-server/421956858&snapshot=2):12 #9 0x55c3eefa70e7 in mlir::triton::nvidia_gpu::(anonymous namespace)::findBufferAccess(mlir::Value) [third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:156](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=156&ws=tap-presubmit-server/421956858&snapshot=2):34 #10 0x55c3eefa4c0c in tmemMayAlias [third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:173](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=173&ws=tap-presubmit-server/421956858&snapshot=2):28 #11 0x55c3eefa4c0c in sinkOps [third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:227](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=227&ws=tap-presubmit-server/421956858&snapshot=2):36 #12 0x55c3eefa4c0c in trySinkOp [third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:253](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=253&ws=tap-presubmit-server/421956858&snapshot=2):10 #13 0x55c3eefa4c0c in mlir::triton::nvidia_gpu::TritonNvidiaGPUInterleaveTMemPass::runOnOperation() [third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:275](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=275&ws=tap-presubmit-server/421956858&snapshot=2):14 #14 0x55c3f1560ad1 in operator() [third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:553](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=553&ws=tap-presubmit-server/421956858&snapshot=2):17 #15 0x55c3f1560ad1 in void llvm::function_ref<void ()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1>(long) [third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h?l=46&ws=tap-presubmit-server/421956858&snapshot=2):12 #16 0x55c3f1559920 in operator() [third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:69](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h?l=69&ws=tap-presubmit-server/421956858&snapshot=2):12 #17 0x55c3f1559920 in executeAction<mlir::PassExecutionAction, mlir::Pass &> [third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h:280](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h?l=280&ws=tap-presubmit-server/421956858&snapshot=2):7 #18 0x55c3f1559920 in mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) [third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:547](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=547&ws=tap-presubmit-server/421956858&snapshot=2):21 #19 0x55c3f155d46f in runPipeline [third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:619](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=619&ws=tap-presubmit-server/421956858&snapshot=2):16 #20 0x55c3f155d46f in mlir::PassManager::runPasses(mlir::Operation*, mlir::AnalysisManager) [third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:933](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=933&ws=tap-presubmit-server/421956858&snapshot=2):10 #21 0x55c3f155d15b in mlir::PassManager::run(mlir::Operation*) [third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:913](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=913&ws=tap-presubmit-server/421956858&snapshot=2):60 #22 0x55c3ed0a8b20 in performActions(llvm::raw_ostream&, std::__u::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) [third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:477](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=477&ws=tap-presubmit-server/421956858&snapshot=2):17 #23 0x55c3ed0a8363 in processBuffer [third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:553](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=553&ws=tap-presubmit-server/421956858&snapshot=2):12 #24 0x55c3ed0a8363 in operator() [third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:642](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=642&ws=tap-presubmit-server/421956858&snapshot=2):12 #25 0x55c3ed0a8363 in llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::MemoryBufferRef const&, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0>(long, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::MemoryBufferRef const&, llvm::raw_ostream&) [third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h?l=46&ws=tap-presubmit-server/421956858&snapshot=2):12 triton-lang#26 0x55c3f17bd34f in operator() [third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:69](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h?l=69&ws=tap-presubmit-server/421956858&snapshot=2):12 triton-lang#27 0x55c3f17bd34f in mlir::splitAndProcessBuffer(std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::MemoryBufferRef const&, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) [third_party/llvm/llvm-project/mlir/lib/Support/ToolUtilities.cpp:30](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Support/ToolUtilities.cpp?l=30&ws=tap-presubmit-server/421956858&snapshot=2):12 triton-lang#28 0x55c3ed09d0c6 in mlir::MlirOptMain(llvm::raw_ostream&, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) [third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:647](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=647&ws=tap-presubmit-server/421956858&snapshot=2):26 triton-lang#29 0x55c3ed09d67f in mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) [third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:693](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=693&ws=tap-presubmit-server/421956858&snapshot=2):14 triton-lang#30 0x55c3ed09dc59 in mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) [third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:709](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=709&ws=tap-presubmit-server/421956858&snapshot=2):10 triton-lang#31 0x55c3eca74a70 in main [third_party/triton/bin/triton-opt.cpp:14](https://cs.corp.google.com/piper///depot/google3/third_party/triton/bin/triton-opt.cpp?l=14&ws=tap-presubmit-server/421956858&snapshot=2):33 triton-lang#32 0x7f1fd58613d3 in __libc_start_main (/usr/grte/v5/lib64/libc.so.6+0x613d3) (BuildId: 9a996398ce14a94560b0c642eb4f6e94) triton-lang#33 0x55c3ec995aa9 in _start /usr/grte/v5/debug-src/src/csu/../sysdeps/x86_64/start.S:120 </details> --------- Co-authored-by: Thomas Raoux <[email protected]>


(Part 2: #19)
Part 1 of "WGMMA with LHS operand in registers" feature.
Hopper has two kinds of WGMMAs, "SS" (both operands in shmem) and "RS" (LHS operand A in registers).
In cases where we apply elementwise operations on A before WGMMA, Triton previously will copy A from global memory (GMEM) into registers (RF), perform the elementwise ops, and then copy to shared memory (SMEM) to perform SS WGMMA.
This PR adds an optimization for the case above to use RS GEMM. This requires the following changes:
Being without pipelining, this PR is not expected to see perf gains. Pipelining for MMAv3 operand in registers is added in Part 2.