[Backend] Implement optimized gather codegen within a warp#5345
[Backend] Implement optimized gather codegen within a warp#5345Mogball merged 49 commits intotriton-lang:mainfrom
Conversation
|
This is now ready for review! Would love eyes on the codegen algorithm |
| return LinearLayout({{inDimName, std::move(powersOf2)}}, {outDimName}); | ||
| } | ||
|
|
||
| /*static*/ LinearLayout LinearLayout::identityND( |
There was a problem hiding this comment.
If we incorporate identityND this way as a function of LinearLayout, I think zeros1D and identity1D also have to be refactored such that only one dimName should be used in the arguments.
There was a problem hiding this comment.
I went ahead and removed this since you're right that it is overly specific.
| ConversionPatternRewriter &rewriter) const { | ||
| GatherLoweringHelper helper(op); | ||
| // Specialize the lowering based on the source layout. | ||
| if (helper.isWarpLocal()) { |
There was a problem hiding this comment.
You may also want to add helper.isThreadLocal?
There was a problem hiding this comment.
Yeah that's a good idea. I'll do it as a follow up since I haven't implemented this case
There was a problem hiding this comment.
Even more, I really think this is something that could be in a Utility.cpp for LinearLayouts, as this is the sort of stuff we might want to use in other lowerings. Not for this PR tho.
lezcano
left a comment
There was a problem hiding this comment.
I have to go now and I didn't have time to review the whole thing, so I'm leaving some minor comments for now. Will review on Monday
| ConversionPatternRewriter &rewriter) const { | ||
| GatherLoweringHelper helper(op); | ||
| // Specialize the lowering based on the source layout. | ||
| if (helper.isWarpLocal()) { |
There was a problem hiding this comment.
Even more, I really think this is something that could be in a Utility.cpp for LinearLayouts, as this is the sort of stuff we might want to use in other lowerings. Not for this PR tho.
lezcano
left a comment
There was a problem hiding this comment.
On the algorithmic side, I share similar concerns as Keren. I am not sure if this is going to be performant vs. a return trip to shmem. A back of the envelope computation taking "T(round trip to shmem w/o bank conflicts) = 2 * T(warp shuffle)`, this would be just more efficient when we have very few registers per column. Note that this wouldn't be terribly bad, as we can compensate for this with a rough heuristic, but regardless It'd be nice to see some perf numbers.
A way to make this logic higher level is by implementing a couple helper functions that work on dim0..n-1 -> dim0..n-1. For example, here we could have a layout dim0..n-1 -> dim0..n-1 that takes (x_0, ..., x_n-1) -> (0, ..., i, ..., 0) where i is the kGatherDim dimension. Same for the one that zeros-out that dimension and is the identity everywhere else. Using these two linear layouts, you can select kGatherDim in the "logical matrix space", and the rest. This should simplify the layout handling quite a bit, I think
| // Remove zero bases in the gather dimension to make the function injective | ||
| // (for a given column) over the same codomain. | ||
| LinearLayout::BasesT newInvertRegMapBases; | ||
| for (auto &[inDim, inDimBases] : invertSrcRegMap.getBases()) { | ||
| auto &newInDimBases = newInvertRegMapBases[inDim]; | ||
| if (inDim != kGatherDim) { | ||
| newInDimBases = inDimBases; | ||
| continue; | ||
| } | ||
| for (auto &basis : inDimBases) { | ||
| if (llvm::any_of(basis, [](int32_t val) { return val != 0; })) { | ||
| newInDimBases.push_back(basis); | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
we are checking that the srcLayout is injective, so this is a noop. Let's remove it for now. The better way of doing is via pseudoinverse
There was a problem hiding this comment.
This isn't a noop. Even if srcLayout is injective, the sublayout computed here might not be. The sublayout is slicing the output to just the register component.
There was a problem hiding this comment.
What this code is essentially doing is creating a layout that, given the coords of a column, it computes the unique register IDs in that column
There was a problem hiding this comment.
I have some ideas for how to improve this, but we can discuss them offline, and they can be implemented in a follow-up PR.
Can you elaborate a bit on what you mean here? Would I construct these layouts and compose them with the other layouts to the same effect? |
|
@Jokeren @lezcano I have dug into the performance of warp shuffles vs. shared memory more. The bank-conflict degree on the write of the source tensor to shared memory is equal to the number of threads that own sections of each gather column, and same for the index tensor and the degree of the reads. For a gather with no bank conflicts, the layout has to be such that along the gather axis, there is only one assigned thread, but this can be different between the index and source columns. The gather can then be divided into This is fortunate, because it means the shared memory and warp shuffle implementations excel for different layouts. Since the shared codegen is naive, I had to pick shapes that just happen to land on ideal smem layouts: So @lezcano your intuition that the cost of smem is about twice as a shuffle is pretty close! However, this does demonstrate that warp shuffles are faster than smem in certain cases and sometimes much faster depending on the layout (the middle-end is responsible for picking the right layout). Importantly, there doesn't seem to be anything weird going on. I'll go ahead and clean up this PR. I have some pending middle-end changes for the layout picking that I will put up afterwards. |
|
Thanks. Can you describe briefly how you got the perf numbers shown above? More specifically, is it possible for you to attach the script you used for benchmarking? |
|
This is the script I used, along with some hacks in the compiler to make sure the right layouts and lowering paths are picked. For convenience, here are the layouts: import triton
import triton.language as tl
import torch
@triton.jit
def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr,
src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr,
idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr,
out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr,
out_stride1: tl.constexpr):
src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1)
src = tl.load(src_ptr + src_offs)
idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1)
idx = tl.load(idx_ptr + idx_offs)
for i in range(200):
src = tl.gather(src, idx, axis)
out = src
out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1)
tl.store(out_ptr + out_offs, out)
src_shape = [32, 32]
axis = 1
indices_shape = [32, 32]
src = torch.randn(src_shape, device='cuda')
indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda')
def prepare_kernel(src: torch.Tensor, axis: int, indices: torch.Tensor):
output = torch.empty(indices.shape, dtype=src.dtype, device=src.device)
compiled = gather_test_kernel.warmup(src, indices, output, axis, src.shape[0], src.shape[1], src.stride(0),
src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0),
indices.stride(1), output.shape[0], output.shape[1], output.stride(0),
output.stride(1), grid=(1, ))
return output, compiled
output, compiled = prepare_kernel(src, axis, indices)
#print(compiled.asm["ttgir"])
#print(compiled.asm["ttgir"].count("convert_layout"))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['unused'], # Argument names to use as an x-axis for the plot.
x_vals=[50], # Different possible values for `x_name`.
x_log=False, # x axis is logarithmic.
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # Line styles.
ylabel='GB/s', # Label name for the y-axis.
line_arg='unused2', # Argument name whose value corresponds to a different line in the plot.
line_vals=[64], # Possible values for `line_arg`.
line_names=['64'], # Label name for the lines.
plot_name='tmp', # Name for the plot. Used also as a file name for saving the plot.
args={}, # Values for function arguments not in `x_names` and `y_name`.
))
def benchmark(unused, unused2):
def runner():
compiled[(1, 1, 1)](src, indices, output)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=[0.5, 0.2, 0.8])
#gbps = lambda ms: 3 * A.numel() * A.element_size() * 1e-9 / (ms * 1e-3)
#return gbps(ms), gbps(max_ms), gbps(min_ms)
return ms
print("num_conversions:", compiled.asm["ttgir"].count("convert_layout"))
print(compiled.asm["ttgir"])
print(compiled.asm["llir"])
print("axis:", axis)
benchmark.run(print_data=True, show_plots=True) |
|
Note that the smem implementation is actually cheating a bit because the stores get hoisted out of the loop, so in reality it should be more than 1.5x slower: 189: ; preds = %4, %189
%190 = phi <1 x i32> [ %85, %4 ], [ %214, %189 ]
// phis and extractelements
tail call void @llvm.nvvm.barrier0(), !dbg !18
%207 = load <1 x i32>, ptr addrspace(3) %160, align 4, !dbg !18
%208 = load <1 x i32>, ptr addrspace(3) %164, align 4, !dbg !18
%209 = load <1 x i32>, ptr addrspace(3) %168, align 4, !dbg !18
%210 = load <1 x i32>, ptr addrspace(3) %172, align 4, !dbg !18
%211 = load <1 x i32>, ptr addrspace(3) %176, align 4, !dbg !18
%212 = load <1 x i32>, ptr addrspace(3) %180, align 4, !dbg !18
%213 = load <1 x i32>, ptr addrspace(3) %184, align 4, !dbg !18
%214 = load <1 x i32>, ptr addrspace(3) %188, align 4, !dbg !18
%215 = add nuw nsw i32 %198, 1, !dbg !19
%exitcond.not = icmp eq i32 %215, 200, !dbg !19
br i1 %exitcond.not, label %216, label %189, !dbg !19I also observed similar 2x slowdowns for large shapes (128x128, 256x128, etc.) |
lezcano
left a comment
There was a problem hiding this comment.
The benchmarks sound reasonable to me, even more so after Thomas' point that we can always land this and continue optimising the lowerings by looking at actual examples.
Regarding all the LL stuff, we can discuss it offline and improve the code in a follow-up PR.
On my end this LGTM modulo the minor points in the open discussions, but let's wait for the others.
|
Discussed the linear layout stuff with @lezcano. I agree there is definitely more abstractions/helpers that can be shared between the lowerings, but it's not super clear what those are at the moment. I'll spend some time today trying to refactor this code against the linear layout code in the ConvertLayoutOp lowering. In any case, as more of these ops are ported to linear layouts, patterns will merge. |
This PR implements a specialized codegen for
tt.gatherwhen it satisfies the conditions of being "warp local": it is possible to compute the output tensor without data movement across warps.isWarpLocalis a new function that checks this condition, and places additional restrictions to simplify codegen / separate concerns fromttg.convert_layout.This enables
tt.gatherto generate better code when the layout is suitable. In a subsequent PR, a special pattern will be added to generate optimized layouts fortt.gatherwhen possible/profitable to enable the lowering.