Skip to content

[Backend] Implement optimized gather codegen within a warp#5345

Merged
Mogball merged 49 commits intotriton-lang:mainfrom
Mogball:mogball/faster_gather
Dec 10, 2024
Merged

[Backend] Implement optimized gather codegen within a warp#5345
Mogball merged 49 commits intotriton-lang:mainfrom
Mogball:mogball/faster_gather

Conversation

@Mogball
Copy link
Copy Markdown
Collaborator

@Mogball Mogball commented Dec 5, 2024

This PR implements a specialized codegen for tt.gather when it satisfies the conditions of being "warp local": it is possible to compute the output tensor without data movement across warps. isWarpLocal is a new function that checks this condition, and places additional restrictions to simplify codegen / separate concerns from ttg.convert_layout.

This enables tt.gather to generate better code when the layout is suitable. In a subsequent PR, a special pattern will be added to generate optimized layouts for tt.gather when possible/profitable to enable the lowering.

Comment thread lib/Tools/LinearLayout.cpp
@Mogball Mogball changed the title [DRAFT][Backend] Implement gather within a warp [Backend] Implement optimized gather codegen within a warp Dec 5, 2024
@Mogball Mogball marked this pull request as ready for review December 5, 2024 21:18
@Mogball
Copy link
Copy Markdown
Collaborator Author

Mogball commented Dec 5, 2024

This is now ready for review! Would love eyes on the codegen algorithm

Comment thread lib/Analysis/Utility.cpp Outdated
Comment thread lib/Analysis/Utility.cpp Outdated
Comment thread python/test/unit/language/test_core.py
Comment thread lib/Tools/LinearLayout.cpp Outdated
return LinearLayout({{inDimName, std::move(powersOf2)}}, {outDimName});
}

/*static*/ LinearLayout LinearLayout::identityND(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we incorporate identityND this way as a function of LinearLayout, I think zeros1D and identity1D also have to be refactored such that only one dimName should be used in the arguments.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went ahead and removed this since you're right that it is overly specific.

Comment thread lib/Analysis/Utility.cpp Outdated
ConversionPatternRewriter &rewriter) const {
GatherLoweringHelper helper(op);
// Specialize the lowering based on the source layout.
if (helper.isWarpLocal()) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may also want to add helper.isThreadLocal?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a good idea. I'll do it as a follow up since I haven't implemented this case

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even more, I really think this is something that could be in a Utility.cpp for LinearLayouts, as this is the sort of stuff we might want to use in other lowerings. Not for this PR tho.

Comment thread lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp
Comment thread lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp
Comment thread lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp Outdated
Copy link
Copy Markdown
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to go now and I didn't have time to review the whole thing, so I'm leaving some minor comments for now. Will review on Monday

Comment thread include/triton/Tools/LinearLayout.h Outdated
Comment thread lib/Analysis/Utility.cpp Outdated
ConversionPatternRewriter &rewriter) const {
GatherLoweringHelper helper(op);
// Specialize the lowering based on the source layout.
if (helper.isWarpLocal()) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even more, I really think this is something that could be in a Utility.cpp for LinearLayouts, as this is the sort of stuff we might want to use in other lowerings. Not for this PR tho.

Copy link
Copy Markdown
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the algorithmic side, I share similar concerns as Keren. I am not sure if this is going to be performant vs. a return trip to shmem. A back of the envelope computation taking "T(round trip to shmem w/o bank conflicts) = 2 * T(warp shuffle)`, this would be just more efficient when we have very few registers per column. Note that this wouldn't be terribly bad, as we can compensate for this with a rough heuristic, but regardless It'd be nice to see some perf numbers.

A way to make this logic higher level is by implementing a couple helper functions that work on dim0..n-1 -> dim0..n-1. For example, here we could have a layout dim0..n-1 -> dim0..n-1 that takes (x_0, ..., x_n-1) -> (0, ..., i, ..., 0) where i is the kGatherDim dimension. Same for the one that zeros-out that dimension and is the identity everywhere else. Using these two linear layouts, you can select kGatherDim in the "logical matrix space", and the rest. This should simplify the layout handling quite a bit, I think

Comment on lines +271 to +285
// Remove zero bases in the gather dimension to make the function injective
// (for a given column) over the same codomain.
LinearLayout::BasesT newInvertRegMapBases;
for (auto &[inDim, inDimBases] : invertSrcRegMap.getBases()) {
auto &newInDimBases = newInvertRegMapBases[inDim];
if (inDim != kGatherDim) {
newInDimBases = inDimBases;
continue;
}
for (auto &basis : inDimBases) {
if (llvm::any_of(basis, [](int32_t val) { return val != 0; })) {
newInDimBases.push_back(basis);
}
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are checking that the srcLayout is injective, so this is a noop. Let's remove it for now. The better way of doing is via pseudoinverse

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't a noop. Even if srcLayout is injective, the sublayout computed here might not be. The sublayout is slicing the output to just the register component.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What this code is essentially doing is creating a layout that, given the coords of a column, it computes the unique register IDs in that column

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some ideas for how to improve this, but we can discuss them offline, and they can be implemented in a follow-up PR.

@Mogball
Copy link
Copy Markdown
Collaborator Author

Mogball commented Dec 9, 2024

A way to make this logic higher level is by implementing a couple helper functions that work on dim0..n-1 -> dim0..n-1. For example, here we could have a layout dim0..n-1 -> dim0..n-1 that takes (x_0, ..., x_n-1) -> (0, ..., i, ..., 0) where i is the kGatherDim dimension. Same for the one that zeros-out that dimension and is the identity everywhere else. Using these two linear layouts, you can select kGatherDim in the "logical matrix space", and the rest. This should simplify the layout handling quite a bit, I think

Can you elaborate a bit on what you mean here? Would I construct these layouts and compose them with the other layouts to the same effect?

@Mogball
Copy link
Copy Markdown
Collaborator Author

Mogball commented Dec 10, 2024

@Jokeren @lezcano I have dug into the performance of warp shuffles vs. shared memory more.

The bank-conflict degree on the write of the source tensor to shared memory is equal to the number of threads that own sections of each gather column, and same for the index tensor and the degree of the reads. For a gather with no bank conflicts, the layout has to be such that along the gather axis, there is only one assigned thread, but this can be different between the index and source columns. The gather can then be divided into totalNumCols / threadsPerWarp microkernels.

This is fortunate, because it means the shared memory and warp shuffle implementations excel for different layouts. Since the shared codegen is naive, I had to pick shapes that just happen to land on ideal smem layouts:

shape: 32x32

axis=0, warp shuffle, 0.00965 (ms)
axis=1, warp shuffle, 0.00976 (ms)
axis=0, smem, 0.016 (ms, no bank conflicts)
axis=1, smem, 0.179 (ms, maximum 32-way conflicts)

So @lezcano your intuition that the cost of smem is about twice as a shuffle is pretty close! However, this does demonstrate that warp shuffles are faster than smem in certain cases and sometimes much faster depending on the layout (the middle-end is responsible for picking the right layout). Importantly, there doesn't seem to be anything weird going on.

I'll go ahead and clean up this PR. I have some pending middle-end changes for the layout picking that I will put up afterwards.

@Jokeren
Copy link
Copy Markdown
Contributor

Jokeren commented Dec 10, 2024

Thanks. Can you describe briefly how you got the perf numbers shown above?

More specifically, is it possible for you to attach the script you used for benchmarking?

@Mogball
Copy link
Copy Markdown
Collaborator Author

Mogball commented Dec 10, 2024

This is the script I used, along with some hacks in the compiler to make sure the right layouts and lowering paths are picked.

For convenience, here are the layouts:

shape: 32x32

axis=0, warp shuffle, 0.00965 (ms)
#ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>

axis=1, warp shuffle, 0.00976 (ms)
#ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>


axis=0, smem, 0.016 (ms, no bank conflicts)
#ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>

axis=1, smem, 0.179 (ms, maximum 32-way conflicts)
#ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
import triton
import triton.language as tl
import torch


@triton.jit
def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr,
                       src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr,
                       idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr,
                       out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr,
                       out_stride1: tl.constexpr):
    src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1)
    src = tl.load(src_ptr + src_offs)

    idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1)
    idx = tl.load(idx_ptr + idx_offs)

    for i in range(200):
        src = tl.gather(src, idx, axis)

    out = src

    out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1)
    tl.store(out_ptr + out_offs, out)


src_shape = [32, 32]
axis = 1
indices_shape = [32, 32]

src = torch.randn(src_shape, device='cuda')
indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda')


def prepare_kernel(src: torch.Tensor, axis: int, indices: torch.Tensor):
    output = torch.empty(indices.shape, dtype=src.dtype, device=src.device)
    compiled = gather_test_kernel.warmup(src, indices, output, axis, src.shape[0], src.shape[1], src.stride(0),
                                         src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0),
                                         indices.stride(1), output.shape[0], output.shape[1], output.stride(0),
                                         output.stride(1), grid=(1, ))
    return output, compiled


output, compiled = prepare_kernel(src, axis, indices)

#print(compiled.asm["ttgir"])
#print(compiled.asm["ttgir"].count("convert_layout"))


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['unused'],  # Argument names to use as an x-axis for the plot.
        x_vals=[50],  # Different possible values for `x_name`.
        x_log=False,  # x axis is logarithmic.
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],  # Line styles.
        ylabel='GB/s',  # Label name for the y-axis.
        line_arg='unused2',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=[64],  # Possible values for `line_arg`.
        line_names=['64'],  # Label name for the lines.
        plot_name='tmp',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(unused, unused2):

    def runner():
        compiled[(1, 1, 1)](src, indices, output)

    ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=[0.5, 0.2, 0.8])

    #gbps = lambda ms: 3 * A.numel() * A.element_size() * 1e-9 / (ms * 1e-3)
    #return gbps(ms), gbps(max_ms), gbps(min_ms)
    return ms


print("num_conversions:", compiled.asm["ttgir"].count("convert_layout"))
print(compiled.asm["ttgir"])
print(compiled.asm["llir"])
print("axis:", axis)
benchmark.run(print_data=True, show_plots=True)

@Mogball
Copy link
Copy Markdown
Collaborator Author

Mogball commented Dec 10, 2024

Note that the smem implementation is actually cheating a bit because the stores get hoisted out of the loop, so in reality it should be more than 1.5x slower:

189:                                              ; preds = %4, %189
  %190 = phi <1 x i32> [ %85, %4 ], [ %214, %189 ]
  // phis and extractelements
  tail call void @llvm.nvvm.barrier0(), !dbg !18
  %207 = load <1 x i32>, ptr addrspace(3) %160, align 4, !dbg !18
  %208 = load <1 x i32>, ptr addrspace(3) %164, align 4, !dbg !18
  %209 = load <1 x i32>, ptr addrspace(3) %168, align 4, !dbg !18
  %210 = load <1 x i32>, ptr addrspace(3) %172, align 4, !dbg !18
  %211 = load <1 x i32>, ptr addrspace(3) %176, align 4, !dbg !18
  %212 = load <1 x i32>, ptr addrspace(3) %180, align 4, !dbg !18
  %213 = load <1 x i32>, ptr addrspace(3) %184, align 4, !dbg !18
  %214 = load <1 x i32>, ptr addrspace(3) %188, align 4, !dbg !18
  %215 = add nuw nsw i32 %198, 1, !dbg !19
  %exitcond.not = icmp eq i32 %215, 200, !dbg !19
  br i1 %exitcond.not, label %216, label %189, !dbg !19

I also observed similar 2x slowdowns for large shapes (128x128, 256x128, etc.)

Copy link
Copy Markdown
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benchmarks sound reasonable to me, even more so after Thomas' point that we can always land this and continue optimising the lowerings by looking at actual examples.

Regarding all the LL stuff, we can discuss it offline and improve the code in a follow-up PR.

On my end this LGTM modulo the minor points in the open discussions, but let's wait for the others.

@Mogball
Copy link
Copy Markdown
Collaborator Author

Mogball commented Dec 10, 2024

Discussed the linear layout stuff with @lezcano. I agree there is definitely more abstractions/helpers that can be shared between the lowerings, but it's not super clear what those are at the moment. I'll spend some time today trying to refactor this code against the linear layout code in the ConvertLayoutOp lowering. In any case, as more of these ops are ported to linear layouts, patterns will merge.

@Mogball Mogball merged commit 2ec1d17 into triton-lang:main Dec 10, 2024
@Mogball Mogball deleted the mogball/faster_gather branch December 10, 2024 18:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants