Lower linalg.copy to direct global load#20568
Conversation
|
Side note, I'm still poking at getting the buffer fat pointer to LDS intrinsic set up - it's caught up in bikeshedding on the compiler team |
e4ce145 to
335319d
Compare
krzysz00
left a comment
There was a problem hiding this comment.
High-level observation: at the point this is being called, shouldn't we know the subgroup size, so that we don't need the subgroup_id op?
Like, you can just look at the workgroup sizes to see which subgroup you're in
| //===----------------------------------------------------------------------===// | ||
|
|
||
| SmallVector<int64_t> | ||
| UseGlobalLoadDMAAttr::getStaticTilingLevelSizes(unsigned level, |
There was a problem hiding this comment.
Why's there an unsigned in here?
There was a problem hiding this comment.
This is how the LoweringConfigAttrInterface exposes tiling levels. It's up to the backend + lowering config to interpret the level consistently.
bbe9565 to
f6c7290
Compare
iree_gpu.global_load_dma oplinalg.copy to direct global load
275e72b to
97161a8
Compare
compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp
Outdated
Show resolved
Hide resolved
krzysz00
left a comment
There was a problem hiding this comment.
A few notes, some of which I apparently failed to submit on Friday x.x
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
ee462fd to
c2b31d4
Compare
|
fixed. |
8152e33 to
f7a8bda
Compare
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/LLVMGPU/test/direct_load.mlir
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
e48e8b2 to
0b6d30c
Compare
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
| return numElements; | ||
| } | ||
|
|
||
| static bool distributeLinalgCopyToThreads(RewriterBase &rewriter, |
There was a problem hiding this comment.
Also, this might want to be a LogicalResult?
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
Outdated
Show resolved
Hide resolved
c66f276 to
c5206f1
Compare
c5206f1 to
cb3bad7
Compare
There was a problem hiding this comment.
This test is failing on Windows:
- https://github.com/iree-org/iree/actions/runs/15320695495/job/43103617380#step:10:338
- https://github.com/iree-org/iree/actions/runs/15343607885/job/43174932290#step:10:342
Logs snippet:
108/1604 Test #116: iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_global_loads.mlir.test ........................................***Failed 1.59 sec
-- Testing: 1 tests, 1 workers --
FAIL: IREE :: src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_global_loads.mlir (1 of 1)
******************** TEST 'IREE :: src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_global_loads.mlir' FAILED ********************
Exit Code: 1
Command Output (stderr):
--
iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-lower-to-global-loads))" C:/home/runner/_work/iree/iree/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_global_loads.mlir | FileCheck C:/home/runner/_work/iree/iree/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_global_loads.mlir # RUN: at line 1
+ iree-opt --split-input-file '--pass-pipeline=builtin.module(func.func(iree-codegen-gpu-lower-to-global-loads))' C:/home/runner/_work/iree/iree/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_global_loads.mlir
+ FileCheck C:/home/runner/_work/iree/iree/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_global_loads.mlir
C:/home/runner/_work/iree/iree/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_global_loads.mlir:25:11: error: CHECK: expected string not found in input
// CHECK: %[[C4:.*]] = arith.constant 4 : index
^
<stdin>:6:32: note: scanning from here
%c1 = arith.constant 1 : index
^
<stdin>:8:2: note: possible intended match here
%1 = gpu.subgroup_id : index
^
C:/home/runner/_work/iree/iree/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_global_loads.mlir:55:11: error: CHECK: expected string not found in input
// CHECK: %[[C4:.*]] = arith.constant 4 : index
^
<stdin>:28:32: note: scanning from here
%c1 = arith.constant 1 : index
^
<stdin>:30:2: note: possible intended match here
%1 = gpu.subgroup_id : index
^Might need to use CHECK-DAG instead of CHECK or change the compiler code to be more deterministic across platforms.
There was a problem hiding this comment.
Will submit a fix for it soon.
There was a problem hiding this comment.
The usual issue here is evaluation order of function arguments when you're nesting build() calls if that helps
There was a problem hiding this comment.
you're probably right @krzysz00 - I suspect it is
scf::ForOp forOp = rewriter.create<scf::ForOp>(
loc, /*lb=*/rewriter.create<arith::ConstantIndexOp>(loc, 0),
/*ub=*/rewriter.create<arith::ConstantIndexOp>(loc, numCopiesPerThread),
/*steps=*/rewriter.create<arith::ConstantIndexOp>(loc, 1));
## Summary
This PR sets the foundation for using `global_load_lds` instruction to
load values from global to LDS memory. The pipeline is as follows:
* Only convert `linalg.copy` emitted in `PromoteGPUMatMulOperands`. When
it sees fit, insert a different attribute
(`#iree_gpu.use_global_load_dma`) to `linalg.copy` to tag it along the
pipeline.
* Tagged `linalg.copy` will not be decomposed/tiled until bufferization.
* after distributed to threads and bufferization, the tagged
`linalg.copy` will then be lowered to a sequence of code responsible for
subgroup-coalesced loading op `iree_gpu.global_load_dma`.
* `iree_gpu.global_load_dma` will be mapped to `amdgpu.gather_to_lds`
op, which will mapped to corresponding rocdl op.
* Disable padding to reduce bank conflict pass because the destination
workgroup memory has to be contiguous.
## Lowering `linalg.copy`
After bufferization and distribute to threads, tagged `linalg.copy`
still exists in the IR:
```
linalg.copy {lowering_config = #iree_gpu.use_global_load_dma}
ins(%subview_12 : memref<64x128xi8, strided<[256, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>)
outs(%alloc_4 : memref<64x128xi8, #gpu.address_space<workgroup>>)
```
Note that this `linalg.copy` is kept in the thread's code. The op itself
is then converted into a `for loop`, in which subgroup of threads loads
coalesced chunk of values. For example, assume there are N subgroups
loading from `tensor<a x b x c>`:
* then `i`-th subgruop will load a sub tensor of size `[a/N, b, c]`, so
each slice is consecutive.
* At this moment, assume row-major, and only tile the outermost dim.
* The reason right now we are only dealing with `linalg.copy` emitted by
`GPUPromoteMatmulOperands` is that we know the destination is allocated
contiguously.
* TODO: expand to any memref slices.
* given `gpu.subgroup_id` and `gpu.lane_id`, each thread calculates the
consecutive data chunk the subgroup the thread belongs to is responsible
to load:
* the chunk indices is the delinearized indices of the input tensor,
from:
* `affine.delinearize_index[gpu.subgroup_id * (num_elems_of(tensor) /
num_subgroups)]`, to
* `affine.delinearize_index[(gpu.subgroup_id + 1) *
(num_elems_of(tensor) / num_subgroups) - 1]`
* Assume each subgroup will load `n` values from linearized index `[N_f,
N_b]`, then thread with lane id `i` will try to load: `iter = 0 to n :
N_f + subgroup_size * iter + (i - 1)` .
Then it will be converted to something like the following (in the
example, assume `workgroup size = 256`, `subgroup_size = 64`, loading
`64x128xi8`):
```miler
scf.for %indvar = %c0 to %c32 step %c1 {
;; thread-specific gathering address from global address
%17 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 2048 + s2 * 64)>()[%lane_id, %subgroup_id, %indvar]
%18:2 = affine.delinearize_index %17 into (128, 64) : index, index
;; this iteration's base storing index
%19 = affine.apply affine_map<()[s0, s1] -> (s0 * 2048 + s1 * 64)>()[%subgroup_id, %indvar]
%20:2 = affine.delinearize_index %19 into (128, 64) : index, index
iree_gpu.global_load_dma %subview_13[%18#0, %18#1] -> %alloc_5[%20#0, %20#1] : memref<128x64xi8, strided<[256, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> -> memref<128x64xi8, #gpu.address_space<workgroup>>
}
;; if there are residual elements (subgroup_copy_region_size % subgroup_size != 0), copy residual elements here
gpu.barrier
```
## Dependent PRs:
* design doc: https://hackmd.io/N0RitxPzT9GPhM0jEPtOCg?view
* upstream changes required:
* llvm/llvm-project#133498
* llvm/llvm-project#136405
* llvm/llvm-project#137671
* llvm/llvm-project#137425
* #20800 (review)
---------
Signed-off-by: Alan Li <me@alanli.org>
Summary
This PR sets the foundation for using
global_load_ldsinstruction to load values from global to LDS memory. The pipeline is as follows:linalg.copyemitted inPromoteGPUMatMulOperands. When it sees fit, insert a different attribute (#iree_gpu.use_global_load_dma) tolinalg.copyto tag it along the pipeline.linalg.copywill not be decomposed/tiled until bufferization.linalg.copywill then be lowered to a sequence of code responsible for subgroup-coalesced loading opiree_gpu.global_load_dma.iree_gpu.global_load_dmawill be mapped toamdgpu.gather_to_ldsop, which will mapped to corresponding rocdl op.Lowering
linalg.copyAfter bufferization and distribute to threads, tagged
linalg.copystill exists in the IR:Note that this
linalg.copyis kept in the thread's code. The op itself is then converted into afor loop, in which subgroup of threads loads coalesced chunk of values. For example, assume there are N subgroups loading fromtensor<a x b x c>:i-th subgruop will load a sub tensor of size[a/N, b, c], so each slice is consecutive.linalg.copyemitted byGPUPromoteMatmulOperandsis that we know the destination is allocated contiguously.gpu.subgroup_idandgpu.lane_id, each thread calculates the consecutive data chunk the subgroup the thread belongs to is responsible to load:affine.delinearize_index[gpu.subgroup_id * (num_elems_of(tensor) / num_subgroups)], toaffine.delinearize_index[(gpu.subgroup_id + 1) * (num_elems_of(tensor) / num_subgroups) - 1]nvalues from linearized index[N_f, N_b], then thread with lane idiwill try to load:iter = 0 to n : N_f + subgroup_size * iter + (i - 1).Then it will be converted to something like the following (in the example, assume
workgroup size = 256,subgroup_size = 64, loading64x128xi8):Dependent PRs: