Skip to content

Support generic reduction and scan cases.#14

Merged
minjang merged 1 commit intotriton-lang:mainfrom
ienkovich:ienkovich/cpu/scan-reduce
Jun 10, 2024
Merged

Support generic reduction and scan cases.#14
minjang merged 1 commit intotriton-lang:mainfrom
ienkovich:ienkovich/cpu/scan-reduce

Conversation

@ienkovich
Copy link
Copy Markdown
Collaborator

No description provided.

@ienkovich ienkovich requested a review from minjang June 4, 2024 21:49
@ienkovich ienkovich requested a review from ptillet as a code owner June 4, 2024 21:49
Copy link
Copy Markdown
Collaborator

@minjang minjang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! I just have a naming convention suggestion (ScanReduce -> ReduceScan) and some minor code clarity. I will accept soon.

This is a general question for reduce/scan with vectorization. So, basically is the scan algorithm like std::inclusive_scan? And use vector::shuffle to move the data inside SIMD registers and then finally accumulate.

Comment thread third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp Outdated
kind == vector::CombiningKind::MAXIMUMF) {
if (elemTy.isF32())
initVal =
rewriter.getF32FloatAttr(std::numeric_limits<float>::quiet_NaN());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. It's like std::fmin and std::min.

else if (elemTy.isF64())
initVal =
rewriter.getF64FloatAttr(std::numeric_limits<double>::quiet_NaN());
else
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not urgent, maybe we can support F16/BF16 using its raw binary representations for quite_NaN?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this needs to be added. I couldn't yet find examples of how such constants are created, used, and lowered. Did you see any examples?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just saw one F16 testing cases in the test_core.py regarding reduction. But not urgent at all.

Comment thread third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp Outdated
Comment thread third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp Outdated
Comment thread third_party/cpu/lib/TritonToTritonCPU/ScanReduceCommon.h Outdated
Comment thread third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp Outdated
Comment thread third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp Outdated
Comment thread third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp Outdated
Comment thread third_party/cpu/lib/TritonToTritonCPU/ScanReduceCommon.h
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
@ienkovich ienkovich force-pushed the ienkovich/cpu/scan-reduce branch from bbed4e5 to 9869e84 Compare June 10, 2024 18:54
@ienkovich
Copy link
Copy Markdown
Collaborator Author

This is a general question for reduce/scan with vectorization. So, basically is the scan algorithm like std::inclusive_scan? And use vector::shuffle to move the data inside SIMD registers and then finally accumulate.

There are three general cases here. For scans on trailing dimension, we use shuffle to accumulate neighbors, then increase shuffle stride and use masks to keep already computed elements intact. For reductions on trailing dimensions we use shuffle to swap to vector parts and accumulate, then swap halves of halves, etc. For cases when a scan/reduction goes on a non-trailing dimension, we iterate through all sub-vectors and accumulate. Works like a fully unrolled loop.

@minjang minjang merged commit 5adb663 into triton-lang:main Jun 10, 2024
@ienkovich ienkovich deleted the ienkovich/cpu/scan-reduce branch June 10, 2024 22:19
minjang pushed a commit to minjang/triton-cpu that referenced this pull request Jun 22, 2024
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
minjang pushed a commit that referenced this pull request Jun 24, 2024
When running
[convert_blocked1d_to_slice0](https://github.com/triton-lang/triton/blob/0ba5f0c3cd029d5c3d1f01b9bf29dac32c27345e/test/Conversion/tritongpu_to_llvm.mlir#L924)
Triton ends up computing a rank of a matrix with 0 columns during linear
layout lowering, which trips up f2reduce, and causes undefined behavior,
detectable through
[UBSAN](https://clang.llvm.org/docs/UndefinedBehaviorSanitizer.html).

Fix this by returning the rank (0) early in these cases, without calling
f2reduce.

<details><summary>Stack trace</summary>
<p>

```
third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30: runtime error: shift exponent 18446744073709551615 is too large for 64-bit type 'unsigned long long'
    #0 0x556ee2fea3be in inplace_rref_small third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30
    #1 0x556ee2fea3be in f2reduce::inplace_rref_strided(unsigned long*, unsigned long, unsigned long, unsigned long) third_party/triton/third_party/f2reduce/f2reduce.cpp:470:9
    #2 0x556ee2ea70da in getMatrixRank third_party/triton/lib/Tools/LinearLayout.cpp:125:3
    #3 0x556ee2ea70da in mlir::triton::LinearLayout::checkInvariants(bool) third_party/triton/lib/Tools/LinearLayout.cpp:299:7
    #4 0x556ee2ea656d in mlir::triton::LinearLayout::tryCreate(llvm::MapVector<mlir::StringAttr, std::__u::vector<std::__u::vector<int, std::__u::allocator<int>>, std::__u::allocator<std::__u::vector<int, std::__u::allocator<int>>>>, llvm::DenseMap<mlir::StringAttr, unsigned int, llvm::DenseMapInfo<mlir::StringAttr, void>, llvm::detail::DenseMapPair<mlir::StringAttr, unsigned int>>, llvm::SmallVector<std::__u::pair<mlir::StringAttr, std::__u::vector<std::__u::vector<int, std::__u::allocator<int>>, std::__u::allocator<std::__u::vector<int, std::__u::allocator<int>>>>>, 0u>>, llvm::ArrayRef<std::__u::pair<mlir::StringAttr, int>>, bool) third_party/triton/lib/Tools/LinearLayout.cpp:190:41
    #5 0x556ee2eb2150 in mlir::triton::LinearLayout::divideRight(mlir::triton::LinearLayout const&) third_party/triton/lib/Tools/LinearLayout.cpp:654:51
    #6 0x556ee2ee1c39 in mlir::cvtNeedsSharedMemory(mlir::RankedTensorType, mlir::RankedTensorType) third_party/triton/lib/Analysis/Utility.cpp:652:14
    #7 0x556ee2cf38fd in mlir::triton::getRepShapeForCvtLayout(mlir::triton::gpu::ConvertLayoutOp) third_party/triton/lib/Analysis/Allocation.cpp:66:8
    #8 0x556ee2cf3efa in mlir::triton::getScratchConfigForCvtLayout(mlir::triton::gpu::ConvertLayoutOp, unsigned int&, unsigned int&) third_party/triton/lib/Analysis/Allocation.cpp:95:19
    #9 0x556ee2cf6057 in mlir::triton::AllocationAnalysis::getScratchValueSize(mlir::Operation*) third_party/triton/lib/Analysis/Allocation.cpp:272:24
    #10 0x556ee2cf5499 in operator() third_party/triton/lib/Analysis/Allocation.cpp:343:7
    #11 0x556ee2cf5499 in void llvm::function_ref<void (mlir::Operation*)>::callback_fn<mlir::triton::AllocationAnalysis::getValuesAndSizes()::'lambda'(mlir::Operation*)>(long, mlir::Operation*) third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12
    #12 0x556edeeee7a9 in operator() third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
    #13 0x556edeeee7a9 in void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:174:5
    #14 0x556edeeee87c in void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:182:9
    #15 0x556ee2cf49e7 in walk<(mlir::WalkOrder)0, mlir::ForwardIterator, (lambda at third_party/triton/lib/Analysis/Allocation.cpp:341:42), mlir::Operation *, void> third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:313:10
    #16 0x556ee2cf49e7 in walk<(mlir::WalkOrder)0, mlir::ForwardIterator, (lambda at third_party/triton/lib/Analysis/Allocation.cpp:341:42), void> third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h:794:12
    #17 0x556ee2cf49e7 in mlir::triton::AllocationAnalysis::getValuesAndSizes() third_party/triton/lib/Analysis/Allocation.cpp:341:16
    #18 0x556ee2cf4852 in run third_party/triton/lib/Analysis/Allocation.cpp:182:5
    #19 0x556ee2cf4852 in AllocationAnalysis third_party/triton/lib/Analysis/Allocation.cpp:169:5
    #20 0x556ee2cf4852 in mlir::Allocation::run(llvm::DenseMap<mlir::FunctionOpInterface, mlir::Allocation, llvm::DenseMapInfo<mlir::FunctionOpInterface, void>, llvm::detail::DenseMapPair<mlir::FunctionOpInterface, mlir::Allocation>>&) third_party/triton/lib/Analysis/Allocation.cpp:627:3
    #21 0x556ee1677402 in operator() third_party/triton/include/triton/Analysis/Allocation.h:227:26
    #22 0x556ee1677402 in void mlir::CallGraph<mlir::Allocation>::doWalk<(mlir::WalkOrder)0, (mlir::WalkOrder)1, mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::CallOpInterface, mlir::FunctionOpInterface), mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::FunctionOpInterface)>(mlir::FunctionOpInterface, llvm::DenseSet<mlir::FunctionOpInterface, llvm::DenseMapInfo<mlir::FunctionOpInterface, void>>&, mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::CallOpInterface, mlir::FunctionOpInterface), mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::FunctionOpInterface)) third_party/triton/include/triton/Analysis/Utility.h:350:7
    #23 0x556ee16756b3 in walk<(mlir::WalkOrder)0, (mlir::WalkOrder)1, (lambda at third_party/triton/include/triton/Analysis/Allocation.h:222:9), (lambda at third_party/triton/include/triton/Analysis/Allocation.h:224:9)> third_party/triton/include/triton/Analysis/Utility.h:242:7
    #24 0x556ee16756b3 in mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp) third_party/triton/include/triton/Analysis/Allocation.h:220:5
    #25 0x556ee2c2bf18 in (anonymous namespace)::AllocateSharedMemory::runOnOperation() third_party/triton/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp:26:22
...
UndefinedBehaviorSanitizer: invalid-shift-exponent third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30 
```
</p>
</details>
minjang pushed a commit that referenced this pull request Jun 24, 2024
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Aug 13, 2024
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
int3 pushed a commit that referenced this pull request Aug 29, 2024
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
minjang pushed a commit that referenced this pull request Sep 22, 2024
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
minjang pushed a commit that referenced this pull request Oct 22, 2024
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
minjang pushed a commit that referenced this pull request Oct 24, 2024
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Nov 13, 2024
Adds extra optional padding that can be use to ensure that input
matrices' strides are non-power-of-two to improve cache behavior.

Currently, it is most useful with DYNAMIC_K_BLOCK enabled.
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Nov 13, 2024
Adds extra optional padding that can be use to ensure that input
matrices' strides are non-power-of-two to improve cache behavior.

Currently, it is most useful with DYNAMIC_K_BLOCK enabled.
int3 pushed a commit that referenced this pull request Dec 6, 2024
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
ienkovich added a commit that referenced this pull request Dec 6, 2024
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Jan 20, 2025
Adds extra optional padding that can be use to ensure that input
matrices' strides are non-power-of-two to improve cache behavior.

Currently, it is most useful with DYNAMIC_K_BLOCK enabled.
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Feb 20, 2025
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Feb 24, 2025
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Feb 28, 2025
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Mar 3, 2025
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Apr 3, 2025
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
jopperm pushed a commit to jopperm/triton-cpu that referenced this pull request Feb 20, 2026
…lang#7796)

Getting a crash internally when running `09-persistent-matmul.py`
tutorial, and ASAN reports the following:

```
==7854==ERROR: AddressSanitizer: heap-use-after-free on address 0x7c884c02e800 at pc 0x557f344112d9 bp 0x7b35908a1840 sp 0x7b35908a1838
READ of size 8 at 0x7c884c02e800 thread T1128
    #0 0x557f344112d8 in getNextOperandUsingThisValue third_party/llvm/llvm-project/mlir/include/mlir/IR/UseDefLists.h:43:58
    triton-lang#1 0x557f344112d8 in operator++ third_party/llvm/llvm-project/mlir/include/mlir/IR/UseDefLists.h:322:39
    triton-lang#2 0x557f344112d8 in mlir::ResultRange::UseIterator::operator++() third_party/llvm/llvm-project/mlir/lib/IR/OperationSupport.cpp:613:5
    triton-lang#3 0x557f2ab70625 in mlir::lowerTokenOperations(mlir::Operation*, int, int) third_party/triton/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerToken.cpp:269:27
    triton-lang#4 0x557f2ab70de8 in mlir::doTokenLowering(mlir::triton::FuncOp&, unsigned int) third_party/triton/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerToken.cpp:321:3
    triton-lang#5 0x557f2ab2d018 in mlir::NVGPUWarpSpecializationPass::runOnFuncOp(mlir::triton::FuncOp) third_party/triton/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization.cpp:99:5
    triton-lang#6 0x557f2ab2c5d6 in operator() third_party/triton/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization.cpp:108:55
    triton-lang#7 0x557f2ab2c5d6 in operator() third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:304:7
    triton-lang#8 0x557f2ab2c5d6 in void llvm::function_ref<void (mlir::Operation*)>::callback_fn<std::__u::enable_if<!llvm::is_one_of<mlir::triton::FuncOp, mlir::Operation*, mlir::Region*, mlir::Block*>::value && std::is_same<void, void>::value, void>::type mlir::detail::walk<(mlir::WalkOrder)1, mlir::ForwardIterator, mlir::NVGPUWarpSpecializationPass::runOnOperation()::'lambda'(mlir::triton::FuncOp), mlir::triton::FuncOp, void>(mlir::Operation*, mlir::NVGPUWarpSpecializationPass::runOnOperation()::'lambda'(mlir::triton::FuncOp)&&)::'lambda'(mlir::Operation*)>(long, mlir::Operation*) third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46:12
    triton-lang#9 0x557f2820ce45 in operator() third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:69:12
    triton-lang#10 0x557f2820ce45 in void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:152:5
    triton-lang#11 0x557f2820ce2c in void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:147:9
    triton-lang#12 0x557f2ab2c0c9 in walk<(mlir::WalkOrder)1, mlir::ForwardIterator, (lambda at third_party/triton/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization.cpp:108:26), mlir::triton::FuncOp, void> third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:306:10
    triton-lang#13 0x557f2ab2c0c9 in walk<(mlir::WalkOrder)1, mlir::ForwardIterator, (lambda at third_party/triton/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization.cpp:108:26), void> third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h:798:12
    triton-lang#14 0x557f2ab2c0c9 in mlir::NVGPUWarpSpecializationPass::runOnOperation() third_party/triton/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization.cpp:108:21
...
```

The problem seems to be that we are iterating through uses, and then
removing some of them inside the loop, which invalidates the iterator.
jopperm pushed a commit to jopperm/triton-cpu that referenced this pull request Feb 20, 2026
…leaveTMem.cpp (triton-lang#7924)

`TritonNvidiaGPU/interleave_tmem.mlir` fails under address sanitizer. 

The `ConstantIntOp` operations were created without attachment to any
block in https://github.com/triton-lang/triton/pull/7622, which caused a
memory leak. This change addresses the problem by adding an insertion
point.

<details open>
  <summary>Full log</summary>

=================================================================
==3831==ERROR: LeakSanitizer: detected memory leaks

Direct leak of 576 byte(s) in 6 object(s) allocated from:
#0 0x55c3eca39164 in malloc
[third_party/llvm/llvm-project/compiler-rt/lib/asan/asan_malloc_linux.cpp:67](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/compiler-rt/lib/asan/asan_malloc_linux.cpp?l=67&ws=tap-presubmit-server/421956858&snapshot=2):3
triton-lang#1 0x55c3f176afb3 in mlir::Operation::create(mlir::Location,
mlir::OperationName, mlir::TypeRange, mlir::ValueRange,
mlir::DictionaryAttr, mlir::OpaqueProperties, mlir::BlockRange, unsigned
int)
[third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp:113](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp?l=113&ws=tap-presubmit-server/421956858&snapshot=2):46
triton-lang#2 0x55c3f176a90c in create
[third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp:74](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp?l=74&ws=tap-presubmit-server/421956858&snapshot=2):10
triton-lang#3 0x55c3f176a90c in mlir::Operation::create(mlir::Location,
mlir::OperationName, mlir::TypeRange, mlir::ValueRange,
mlir::NamedAttrList&&, mlir::OpaqueProperties, mlir::BlockRange,
mlir::RegionRange)
[third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp:57](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp?l=57&ws=tap-presubmit-server/421956858&snapshot=2):7
triton-lang#4 0x55c3f176a61b in mlir::Operation::create(mlir::OperationState
const&)
[third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp:35](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp?l=35&ws=tap-presubmit-server/421956858&snapshot=2):7
triton-lang#5 0x55c3f1678a78 in mlir::OpBuilder::create(mlir::OperationState
const&)
[third_party/llvm/llvm-project/mlir/lib/IR/Builders.cpp:453](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Builders.cpp?l=453&ws=tap-presubmit-server/421956858&snapshot=2):17
triton-lang#6 0x55c3ecf3668f in mlir::arith::ConstantIntOp
mlir::OpBuilder::create<mlir::arith::ConstantIntOp, int,
int>(mlir::Location, int&&, int&&)
[third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h:507](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h?l=507&ws=tap-presubmit-server/421956858&snapshot=2):16
triton-lang#7 0x55c3eefa690a in findBufferAccessMemdescSubview
[third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:75](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=75&ws=tap-presubmit-server/421956858&snapshot=2):33
triton-lang#8 0x55c3eefa690a in mlir::triton::nvidia_gpu::(anonymous
namespace)::findBufferAccess(mlir::Value)
[third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:151](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=151&ws=tap-presubmit-server/421956858&snapshot=2):12
triton-lang#9 0x55c3eefa70e7 in mlir::triton::nvidia_gpu::(anonymous
namespace)::findBufferAccess(mlir::Value)
[third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:156](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=156&ws=tap-presubmit-server/421956858&snapshot=2):34
triton-lang#10 0x55c3eefa4c0c in tmemMayAlias
[third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:173](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=173&ws=tap-presubmit-server/421956858&snapshot=2):28
triton-lang#11 0x55c3eefa4c0c in sinkOps
[third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:227](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=227&ws=tap-presubmit-server/421956858&snapshot=2):36
triton-lang#12 0x55c3eefa4c0c in trySinkOp
[third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:253](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=253&ws=tap-presubmit-server/421956858&snapshot=2):10
triton-lang#13 0x55c3eefa4c0c in
mlir::triton::nvidia_gpu::TritonNvidiaGPUInterleaveTMemPass::runOnOperation()
[third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:275](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=275&ws=tap-presubmit-server/421956858&snapshot=2):14
triton-lang#14 0x55c3f1560ad1 in operator()
[third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:553](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=553&ws=tap-presubmit-server/421956858&snapshot=2):17
triton-lang#15 0x55c3f1560ad1 in void llvm::function_ref<void
()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*,
mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1>(long)
[third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h?l=46&ws=tap-presubmit-server/421956858&snapshot=2):12
triton-lang#16 0x55c3f1559920 in operator()
[third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:69](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h?l=69&ws=tap-presubmit-server/421956858&snapshot=2):12
triton-lang#17 0x55c3f1559920 in executeAction<mlir::PassExecutionAction,
mlir::Pass &>
[third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h:280](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h?l=280&ws=tap-presubmit-server/421956858&snapshot=2):7
triton-lang#18 0x55c3f1559920 in mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*,
mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)
[third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:547](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=547&ws=tap-presubmit-server/421956858&snapshot=2):21
triton-lang#19 0x55c3f155d46f in runPipeline
[third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:619](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=619&ws=tap-presubmit-server/421956858&snapshot=2):16
triton-lang#20 0x55c3f155d46f in mlir::PassManager::runPasses(mlir::Operation*,
mlir::AnalysisManager)
[third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:933](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=933&ws=tap-presubmit-server/421956858&snapshot=2):10
triton-lang#21 0x55c3f155d15b in mlir::PassManager::run(mlir::Operation*)
[third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:913](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=913&ws=tap-presubmit-server/421956858&snapshot=2):60
triton-lang#22 0x55c3ed0a8b20 in performActions(llvm::raw_ostream&,
std::__u::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*,
mlir::MlirOptMainConfig const&)
[third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:477](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=477&ws=tap-presubmit-server/421956858&snapshot=2):17
triton-lang#23 0x55c3ed0a8363 in processBuffer
[third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:553](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=553&ws=tap-presubmit-server/421956858&snapshot=2):12
triton-lang#24 0x55c3ed0a8363 in operator()
[third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:642](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=642&ws=tap-presubmit-server/421956858&snapshot=2):12
triton-lang#25 0x55c3ed0a8363 in llvm::LogicalResult
llvm::function_ref<llvm::LogicalResult
(std::__u::unique_ptr<llvm::MemoryBuffer,
std::__u::default_delete<llvm::MemoryBuffer>>, llvm::MemoryBufferRef
const&,
llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&,
std::__u::unique_ptr<llvm::MemoryBuffer,
std::__u::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&,
mlir::MlirOptMainConfig const&)::$_0>(long,
std::__u::unique_ptr<llvm::MemoryBuffer,
std::__u::default_delete<llvm::MemoryBuffer>>, llvm::MemoryBufferRef
const&, llvm::raw_ostream&)
[third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h?l=46&ws=tap-presubmit-server/421956858&snapshot=2):12
triton-lang#26 0x55c3f17bd34f in operator()
[third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:69](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h?l=69&ws=tap-presubmit-server/421956858&snapshot=2):12
triton-lang#27 0x55c3f17bd34f in
mlir::splitAndProcessBuffer(std::__u::unique_ptr<llvm::MemoryBuffer,
std::__u::default_delete<llvm::MemoryBuffer>>,
llvm::function_ref<llvm::LogicalResult
(std::__u::unique_ptr<llvm::MemoryBuffer,
std::__u::default_delete<llvm::MemoryBuffer>>, llvm::MemoryBufferRef
const&, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef,
llvm::StringRef)
[third_party/llvm/llvm-project/mlir/lib/Support/ToolUtilities.cpp:30](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Support/ToolUtilities.cpp?l=30&ws=tap-presubmit-server/421956858&snapshot=2):12
triton-lang#28 0x55c3ed09d0c6 in mlir::MlirOptMain(llvm::raw_ostream&,
std::__u::unique_ptr<llvm::MemoryBuffer,
std::__u::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&,
mlir::MlirOptMainConfig const&)
[third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:647](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=647&ws=tap-presubmit-server/421956858&snapshot=2):26
triton-lang#29 0x55c3ed09d67f in mlir::MlirOptMain(int, char**, llvm::StringRef,
llvm::StringRef, mlir::DialectRegistry&)
[third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:693](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=693&ws=tap-presubmit-server/421956858&snapshot=2):14
triton-lang#30 0x55c3ed09dc59 in mlir::MlirOptMain(int, char**, llvm::StringRef,
mlir::DialectRegistry&)
[third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:709](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=709&ws=tap-presubmit-server/421956858&snapshot=2):10
triton-lang#31 0x55c3eca74a70 in main
[third_party/triton/bin/triton-opt.cpp:14](https://cs.corp.google.com/piper///depot/google3/third_party/triton/bin/triton-opt.cpp?l=14&ws=tap-presubmit-server/421956858&snapshot=2):33
triton-lang#32 0x7f1fd58613d3 in __libc_start_main
(/usr/grte/v5/lib64/libc.so.6+0x613d3) (BuildId:
9a996398ce14a94560b0c642eb4f6e94)
triton-lang#33 0x55c3ec995aa9 in _start
/usr/grte/v5/debug-src/src/csu/../sysdeps/x86_64/start.S:120

</details>

---------

Co-authored-by: Thomas Raoux <thomas.raoux@openai.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants