diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index a3e38e177d42..2e4cbbc65142 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -161,6 +161,8 @@ class GatherLoweringHelper { // Get the shared memory scratch size required by this op. unsigned getScratchSizeInBytes(); + // Determine if the gather can be performed completely within a warp. + bool isWarpLocal(); private: triton::GatherOp gatherOp; diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 85c789635a96..e592a9d6d1ee 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -214,7 +214,12 @@ LinearLayout ensureLayoutNotSmallerThan( const LinearLayout &layout, const llvm::SmallDenseMap &shape); +// Return a vector of the standard out dimension names for tensor layouts. These +// are "dim0", "dim1", etc. SmallVector standardOutDimNames(MLIRContext *ctx, int rank); +// Return an identity mapping from `inDimName` to the standard out dimensions, +// with the dimensions sized according to the shape. The bases are sorted +// according to `order`, with the most minor dimension first. LinearLayout identityStandardND(StringAttr inDimName, ArrayRef shape, ArrayRef order); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 3014245e61fb..69eb196a95a0 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -413,13 +413,93 @@ GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp) : gatherOp(gatherOp) {} unsigned GatherLoweringHelper::getScratchSizeInBytes() { - // For now, lower the gather op by writing the source tensor to shared memory. - // TODO(jeff): Leverage locality to avoid using scratch space when possible. + // If the gather is warp-local, no scratch space is needed. + if (isWarpLocal()) + return 0; + + // Otherwise, performing the gather will require scratch space to communicate + // the source tensor across threads. For now, assume the whole source tensor + // is written back to shared memory. RankedTensorType srcType = gatherOp.getSrc().getType(); return product(srcType.getShape()) * ceil(srcType.getElementTypeBitWidth(), 8); } +bool GatherLoweringHelper::isWarpLocal() { + // The gather is warp-local if for each column along the gather axis in the + // source and index tensors, all the elements are owned by the same warp. + RankedTensorType srcType = gatherOp.getSrc().getType(); + RankedTensorType idxType = gatherOp.getIndices().getType(); + std::optional srcLayout = + toLinearLayout(srcType.getShape(), srcType.getEncoding()); + std::optional idxLayout = + toLinearLayout(idxType.getShape(), idxType.getEncoding()); + + // FIXME: If an unsupported layout was encountered, assume the gather is not + // warp-local. + if (!srcLayout || !idxLayout) + return false; + + Builder b(gatherOp.getContext()); + StringAttr kBlock = b.getStringAttr("block"); + StringAttr kWarp = b.getStringAttr("warp"); + StringAttr kLane = b.getStringAttr("lane"); + StringAttr kGatherDim = + b.getStringAttr("dim" + std::to_string(gatherOp.getAxis())); + + // The tensor layouts must be distributed layouts, where the basis matrix is a + // subpermutation matrix (permutation matrix plus zeros for broadcasting). + // FIXME(jeff): Check this invariant somehow. + // + // We want to know if all elements of a column along the gather axis are + // mapped to the same set of warps, which means the gather can be performed + // entirely within the warp. We need to query + // + // srcLayout.invert().sublayoutIsZero({kGatherDim}, {kBlock, kWarp}) + // + // But due to broadcasting, the matrix might not be invertible. But since the + // matrix is a permutation matrix (checked below), we can instead query + // + // srcLayout.sublayoutIsZero({kBlock, kWarp}, {kGatherDim}) + // + // Which implies that changing the warp will not change the gather dimension. + // And since there is no swizzling, this applies to all warps. + if (!srcLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim) || + !idxLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim)) + return false; + + SmallVector otherDims; + for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + if (dim != gatherOp.getAxis()) { + otherDims.push_back(b.getStringAttr("dim" + Twine(dim))); + } + } + + // If the gather axis `dimN` is invariant to the warp, but the `(block, warp)` + // mapping to all other dimensions must be the same for both layouts. If so, + // then the warp that owns a particular index element also owns all the source + // elements it could index into. + if (srcLayout->sublayout({kBlock, kWarp}, otherDims) != + idxLayout->sublayout({kBlock, kWarp}, otherDims)) + return false; + + // The two constraints above ensure that data-movement to perform the gather + // operation are contained within a warp. The subsequent constraints simplify + // codegen. + + // Require that for any given gather column, the threads mapped to the column + // in the index and source tensors are the same. This means we don't need to + // xor shuffle across threads before emitting index shuffles; we push warp + // shuffling to layout conversions. + if (srcLayout->sublayout(kLane, otherDims) != + idxLayout->sublayout(kLane, otherDims)) + return false; + + // Otherwise, the source layout has to be invertible. This primarily means + // the codegen path doesn't support broadcasted source layouts. + return srcLayout->isInvertible(); +} + unsigned getNumScratchElements(ArrayRef shape) { if (shape.empty()) return 0; diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 5ab81eff819c..faf781369e0a 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -1,8 +1,10 @@ #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" using namespace mlir; using namespace mlir::triton; +using namespace mlir::triton::gpu; namespace { class GatherOpConversion : public ConvertOpToLLVMPattern { @@ -17,12 +19,51 @@ class GatherOpConversion : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override; private: + // Codegen the gather by storing the source tensor into shared memory and then + // gathering directly from shared memory. + void emitGatherInShared(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + // Codegen a warp-local gather by shuffling elements across the warp and + // selecting from them. + void emitWarpLocalGather(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + const TargetInfoBase &targetInfo; }; LogicalResult GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + GatherLoweringHelper helper(op); + // Specialize the lowering based on the source layout. Given that the cost of + // a warp shuffle is approximately half the cost of a roundtrip to shared + // memory with zero bank conflicts, we will need a more precise heuristic to + // choose between the two codegen paths and rely on the middle end to pick the + // right layout. + if (helper.isWarpLocal()) { + emitWarpLocalGather(op, adaptor, rewriter); + } else { + emitGatherInShared(op, adaptor, rewriter); + } + return success(); +} + +static Value convertIndexToI32(Location loc, Value index, + ConversionPatternRewriter &rewriter) { + unsigned idxWidth = index.getType().getIntOrFloatBitWidth(); + // The LL index computations are performed with 32 bit integers. If the + // indices are something else, cast them to i32. + if (idxWidth > 32) { + index = trunc(i32_ty, index); + } else if (idxWidth < 32) { + // Negative indices don't make sense, so zero-extend. + index = zext(i32_ty, index); + } + return index; +} + +void GatherOpConversion::emitGatherInShared( + GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); RankedTensorType srcType = op.getSrc().getType(); @@ -78,19 +119,10 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, emitIndices(loc, rewriter, targetInfo, dstType.getEncoding(), dstType, /*withCTAOffset=*/true); - unsigned idxWidth = op.getIndices().getType().getElementTypeBitWidth(); unsigned axis = op.getAxis(); SmallVector results(dstIndices.size()); for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) { - // The LL index computations are performed with 32 bit integers. If the - // indices are something else, cast them to i32. - if (idxWidth > 32) { - idx = trunc(i32_ty, idx); - } else if (idxWidth < 32) { - // Negative indices don't make sense, so zero-extend. - idx = zext(i32_ty, idx); - } - indices[axis] = idx; + indices[axis] = convertIndexToI32(loc, idx, rewriter); Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); Value ptr = gep(smemBase.getType(), elemType, smemBase, offset); results[i] = load(elemType, ptr); @@ -99,7 +131,224 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, Value packed = packLLElements(loc, getTypeConverter(), results, rewriter, dstType); rewriter.replaceOp(op, packed); - return success(); +} + +// High-level description of the algorithm: +// +// `isWarpLocal` checks that it is possible to compute each output element +// without data movement across warps. +// +// If the gather dim is `dimN`, then this means +// +// ll^-1(dimN)[(block, warp)] == 0 +// +// for both source and index tensors: moving along the gather axis does not +// change the warp. Broadcasted layouts are not supported, so we know the +// layouts are permutation matrices. +// +// We can check this with `ll((block, warp))[dimN] == 0`. +// +// Let `gatherCol` be a tuple of all dimensions except the gather dimension. +// We also check that the gather columns line up the same way with respect to +// the warp between the source and index tensors with +// +// ll_src((block, warp))[gatherCol] == ll_idx((block, warp))[gatherCol] +// +// This means that for all index columns, the corresponding column in the source +// tensor is owned by the same warp. +// +// We also check +// +// ll_src(lane)[gatherCol] == ll_idx(lane)[gatherCol] +// +// This boils down to the fact that the algorithm essentially emits a series of +// index shuffles for each index value owned by each thread, and then a pile of +// selects to pick the right value. We need to figure out given an index value +// in a particular column, what are the source register values it could read +// from and who owns them. +// +// If this relationship did not hold, then the possible source registers for +// each index value varies with the thread, meaning the value operand provided +// to each shuffle index instruction would depend on the thread ID. This isn't a +// big deal. It just means would have to emit a pile of selects before each +// shuffle as well, to pick the right source register value. But we choose not +// to handle this. +// +// The codegen algorithm emits code: +// - Given the thread ID and a particular index tensor register, figure out +// which gather column it belongs to using a layout. +// - Using the index value itself as the value for `dimN`, use another layout to +// figure out which lane in the warp owns the desired value and which register +// in that lane it is. +// - For the gather column, figure out the source registers in that column, and +// for each of them, emit an index shuffle with the same computed lane ID. +// - Use the register component to select the right value from the shuffle +// results. +void GatherOpConversion::emitWarpLocalGather( + GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + RankedTensorType srcType = op.getSrc().getType(); + RankedTensorType idxType = op.getIndices().getType(); + + // Layout dimension names. + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); + StringAttr kGatherDim = rewriter.getStringAttr("dim" + Twine(op.getAxis())); + SmallVector allDims, otherDims; + for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + allDims.push_back(str_attr("dim" + Twine(dim))); + if (dim != op.getAxis()) { + otherDims.push_back(allDims.back()); + } + } + + // Compute the src and idx layouts. + LinearLayout srcLayout = + *toLinearLayout(srcType.getShape(), srcType.getEncoding()); + LinearLayout idxLayout = + *toLinearLayout(idxType.getShape(), idxType.getEncoding()); + + // Let `ll_src` be the source layout and `ll_idx` be the index layout. + // Let `src_col` be a tuple of dimensions except the gather dimension, + // representing a specific column in the source tensor. Likewise for + // `idx_col`. Let `src_idx` be the index into gather dimension in the source + // tensor. + // + // `(src_lane, src_reg) = ll_src^-1(src_col, src_idx)`, where `src_lane` is + // the thread that contains the required element and `src_reg` is the register + // within that thread. + // + // Because `ll_src(block=0, warp=0, lane=0)[otherDims] == + // ll_idx(0, 0, 0)[otherDims]`, we know given any `idx_reg` (element in the + // index tensor) the thread will need to read from the same column in the + // source tensor. + // + // Thus, we can obtain + // + // (src_lane, src_reg) = (ll_src^-1)( + // ll_idx(black, warp, lane, idx_reg)[otherDims], + // idxValues[idx_reg] + // )[{"lane", "register"}] + // + // And the mapping will be the correct for each thread. + // + // Given `src_reg \in [0, K*N)`, we just need to emit N index shuffles for + // each `idx_reg` (the number of index shuffles is quadratic!) and + // `llvm.select` using `src_reg` to get the right one. `K` is the number of + // elements per column owned by a thread. + + // Fully invert the source layout. We know it is invertible because + // `isWarpLocal` checked this. + LinearLayout invSrcLayout = srcLayout.invert(); + + // Sanity check: the warp must be invariant to the index because otherwise the + // gather would need to read across warps! + assert(invSrcLayout.sublayoutIsZero(kGatherDim, {kBlock, kWarp}) && + "expected a warp-local gather"); + invSrcLayout = invSrcLayout.sublayout(allDims, {kLane, kRegister}); + + LinearLayout idxColLayout = + idxLayout.sublayout({kBlock, kWarp, kLane, kRegister}, otherDims); + + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + + auto [blockId, warpId, laneId] = + emitHardwareTuple(loc, rewriter, targetInfo, /*withCTAOffset=*/true, + srcLayout.getInDimSize(kLane)); + + unsigned /*N=*/srcRegsPerThread = srcLayout.getInDimSize(kRegister); + assert(srcRegsPerThread == srcValues.size()); + + // Given a index value, we need to know which sources register values it could + // index into. This is invariant to anything other than the register, which we + // checked already. Compute the full reverse map from + // + // idx_reg -> gather_column -> (src_reg0, src_reg1, ...) + // + LinearLayout invertSrcRegMap = invSrcLayout.sublayout(allDims, {kRegister}); + // Remove zero bases in the gather dimension to make the function injective + // (for a given column) over the same codomain. + LinearLayout::BasesT newInvertRegMapBases; + for (auto &[inDim, inDimBases] : invertSrcRegMap.getBases()) { + auto &newInDimBases = newInvertRegMapBases[inDim]; + if (inDim != kGatherDim) { + newInDimBases = inDimBases; + continue; + } + for (auto &basis : inDimBases) { + if (llvm::any_of(basis, [](int32_t val) { return val != 0; })) { + newInDimBases.push_back(basis); + } + } + } + invertSrcRegMap = LinearLayout( + newInvertRegMapBases, llvm::to_vector(invertSrcRegMap.getOutDimNames())); + // We are left with only non-zero bases in the gather dimension, which means + // the number of registers per column is the size of the "gather dimension". + unsigned numRegsPerColumn = invertSrcRegMap.getInDimSize(kGatherDim); + // Get a map from idx_reg to the column it indexes into. + LinearLayout idxRegToCol = idxLayout.sublayout({kRegister}, otherDims); + // Now given `idx_reg`, we can compute the column it belongs to in both src + // and index tensors, then partially apply `invertSrcRegMap` with this to + // obtain a function that outputs the corresponding registers in the src + // tensor in the same column. + + // L(column, i) = L(column, 0) xor L(0, i) + LinearLayout invertSrcRegMapColPart = + invertSrcRegMap.sublayout(otherDims, {kRegister}); + LinearLayout invertSrcRegMapRest = + invertSrcRegMap.sublayout({kGatherDim}, {kRegister}); + + SmallVector results; + for (auto [idxReg, idxVal] : llvm::enumerate(idxValues)) { + SmallVector> column = + applyLinearLayout(loc, rewriter, idxColLayout, + {{kBlock, blockId}, + {kWarp, warpId}, + {kLane, laneId}, + {kRegister, i32_val(idxReg)}}); + assert(column.size() == otherDims.size()); + + // Combine the computed column with the data-dependent gather index. + column.emplace_back(kGatherDim, convertIndexToI32(loc, idxVal, rewriter)); + SmallVector> srcLaneAndReg = + applyLinearLayout(loc, rewriter, invSrcLayout, column); + + auto [srcLaneName, srcLane] = srcLaneAndReg.back(); + auto [srcRegName, srcReg] = srcLaneAndReg.front(); + assert(srcLaneName == kLane && srcRegName == kRegister); + + assert(!srcValues.empty() && "can't gather from an empty tensor"); + + // Figure out which src registers we need to index shuffle from. This is + // invariant to anything else. + SmallVector> normalizedColumn = + idxRegToCol.apply({{kRegister, idxReg}}); + int32_t srcBase = + invertSrcRegMapColPart.apply(normalizedColumn).front().second; + + Value result = undef(srcValues.front().getType()); + for (unsigned i = 0; i != numRegsPerColumn; ++i) { + int32_t rest = + invertSrcRegMapRest.apply({{kGatherDim, i}}).front().second; + int32_t srcRegIdx = srcBase ^ rest; + + Value value = + targetInfo.shuffleIdx(rewriter, loc, srcValues[srcRegIdx], srcLane); + result = select(icmp_eq(i32_val(srcRegIdx), srcReg), value, result); + } + + results.push_back(result); + } + + rewriter.replaceOp(op, packLLElements(loc, getTypeConverter(), results, + rewriter, op.getType())); } } // namespace diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 5e60d9fd1448..3ca05bad506d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6201,6 +6201,24 @@ def kernel(In, Out, # assert torch.all(ref == result) +@triton.jit +def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, + src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, + idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, + out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, + out_stride1: tl.constexpr): + src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) + src = tl.load(src_ptr + src_offs) + + idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) + idx = tl.load(idx_ptr + idx_offs) + + out = tl.gather(src, idx, axis) + + out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) + tl.store(out_ptr + out_offs, out) + + @pytest.mark.parametrize("src_shape, indices_shape, axis", [ ([4, 4], [8, 4], 0), ([128, 64], [256, 64], 0), @@ -6208,29 +6226,13 @@ def kernel(In, Out, # ]) def test_gather(src_shape, indices_shape, axis): - @triton.jit - def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, - src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, - idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, - out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, - out_stride1: tl.constexpr): - src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) - src = tl.load(src_ptr + src_offs) - - idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) - idx = tl.load(idx_ptr + idx_offs) - - out = tl.gather(src, idx, axis) - - out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) - tl.store(out_ptr + out_offs, out) - def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) - gather_kernel[(1, )](src, indices, output, axis, src.shape[0], src.shape[1], - src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), - indices.stride(1), output.shape[0], output.shape[1], output.stride(0), output.stride(1)) + gather_test_kernel[(1, )](src, indices, output, axis, src.shape[0], + src.shape[1], src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], + indices.stride(0), indices.stride(1), output.shape[0], output.shape[1], + output.stride(0), output.stride(1)) return output @@ -6239,3 +6241,76 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): ref = torch.gather(src, axis, indices) result = triton_gather(src, axis, indices) torch.testing.assert_close(result, ref, rtol=0, atol=0) + + +# These layouts are specially chosen to trigger the warp shuffle codegen. +@pytest.mark.parametrize("src_shape, indices_shape, axis, src_layout, indices_layout", [ + ([32, 16], [32, 16], 0, + "linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>", + "linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}>" + ), + ([128, 64], [256, 64], 0, + "linear<{register = [[0, 2], [32, 0], [2, 0], [0, 16], [0, 32], [64, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>", + "linear<{register = [[0, 2], [32, 0], [0, 32], [2, 0], [0, 16], [64, 0], [128, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>" + ), +]) +def test_gather_warp_shuffle(src_shape, indices_shape, axis, src_layout, indices_layout, tmp_path: pathlib.Path): + if is_hip(): + pytest.skip("warp-local gather has issues on HIP") + + def prepare_kernel(src: torch.Tensor, axis: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + compiled = gather_test_kernel.warmup(src, indices, output, axis, src.shape[0], src.shape[1], src.stride(0), + src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), + indices.stride(1), output.shape[0], output.shape[1], output.stride(0), + output.stride(1), grid=(1, )) + return output, compiled + + def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout, idx_layout): + ir = f""" +#src_layout = #ttg.{src_layout} +#idx_layout = #ttg.{idx_layout} +{ir}""" + + dtypes = {torch.int32: "i32", torch.float32: "f32", torch.int64: "i64", torch.float64: "f64"} + + src_spec = f"{src.shape[0]}x{src.shape[1]}x{dtypes[src.dtype]}" + indices_spec = f"{indices.shape[0]}x{indices.shape[1]}x{dtypes[indices.dtype]}" + output_spec = f"{indices.shape[0]}x{indices.shape[1]}x{dtypes[src.dtype]}" + + pat = r"(%[0-9]+) = tt.gather (%[0-9]+)\[(%[0-9]+)\] {axis = " + pat += str(axis) + pat += r" : i32} : \(tensor\<" + pat += src_spec + pat += r", (#[a-z]+[0-9]+)\>, tensor\<" + pat += indices_spec + pat += r", (#[a-z]+[0-9]+)\>\) -> tensor\<" + pat += output_spec + pat += r", (#[a-z]+[0-9]+)\>" + + repl = r""" + %src = ttg.convert_layout \2 : tensor<""" + src_spec + r""", \4> -> tensor<""" + src_spec + r""", #src_layout> + %idx = ttg.convert_layout \3 : tensor<""" + indices_spec + r""", \5> -> tensor<""" + indices_spec + r""", #idx_layout> + %out = tt.gather %src[%idx] {axis = """ + str( + axis + ) + r""" : i32} : (tensor<""" + src_spec + r""", #src_layout>, tensor<""" + indices_spec + r""", #idx_layout>) -> tensor<""" + output_spec + r""", #idx_layout> + \1 = ttg.convert_layout %out : tensor<""" + output_spec + r""", #idx_layout> -> tensor<""" + output_spec + r""", \6>""" + return re.sub(pat, repl, ir) + + src = torch.randn(src_shape, device='cuda') + indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda') + ref = torch.gather(src, axis, indices) + + output, compiled = prepare_kernel(src, axis, indices) + ir = compiled.asm["ttgir"] + ir = inject_layout(ir, src, axis, indices, src_layout, indices_layout) + + temp_file = tmp_path / "test_warp_gather.ttgir" + temp_file.write_text(ir) + + kernel = triton.compile(str(temp_file)) + assert ("nvvm.shfl.sync.idx" in kernel.asm["llir"]) or ("llvm.amdgcn.ds.bpermute" in kernel.asm["llir"]) + + kernel[(1, 1, 1)](src, indices, output) + + torch.testing.assert_close(output, ref, rtol=0, atol=0) diff --git a/test/Conversion/allocate_shared_memory.mlir b/test/Conversion/allocate_shared_memory.mlir index 345714f5b2b3..f3c2ed703386 100644 --- a/test/Conversion/allocate_shared_memory.mlir +++ b/test/Conversion/allocate_shared_memory.mlir @@ -1,14 +1,16 @@ // RUN: triton-opt %s --allocate-shared-memory | FileCheck %s +#blocked = #ttg.blocked<{sizePerThread = [32, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> + // CHECK-LABEL: module // CHECK-SAME: ttg.shared = 131072 : i32 module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @gather_op // TODO(jeff): Optimize the lowering to reduce shared memory usage. -tt.func @gather_op(%arg0: tensor<1024x256xi32>, %arg1: tensor<128x256xf32>) { +tt.func @gather_op(%arg0: tensor<1024x256xi32, #blocked>, %arg1: tensor<128x256xf32, #blocked>) { // CHECK-NEXT: allocation.offset = 0 : i32 - %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<128x256xf32>, tensor<1024x256xi32>) -> tensor<1024x256xf32> + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<128x256xf32, #blocked>, tensor<1024x256xi32, #blocked>) -> tensor<1024x256xf32, #blocked> tt.return } diff --git a/test/Conversion/gather_to_llvm.mlir b/test/Conversion/gather_to_llvm.mlir new file mode 100644 index 000000000000..28a8a7e6b246 --- /dev/null +++ b/test/Conversion/gather_to_llvm.mlir @@ -0,0 +1,271 @@ +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s + +// Check the optimized LLVMIR, since InstCombine makes the linear layout +// logic understandable enough (in simple cases) to check correctness by eye. + +#trivial_layout = #ttg.linear<{register = [], lane = [[1], [2], [4], [8], [16]], warp = [], block = []}> + +#trivial_layout_wider = #ttg.linear<{register = [[32]], lane = [[1], [2], [4], [8], [16]], warp = [], block = []}> + +#trivial_layout_wider_reg_stride_1 = #ttg.linear<{register = [[1]], lane = [[2], [4], [8], [16], [32]], warp = [], block = []}> + +#trivial_2d_one_col = #ttg.linear<{register = [[0, 1]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [], block = []}> + +#span_2d_cols = #ttg.linear<{register = [[1, 0]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [], block = []}> + +#crazy_2d_src = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}> +#crazy_2d_idx = #ttg.linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + +// Each source element is mapped to a single thread, so we expect one index shuffle. +// CHECK-LABEL: @gather_warp_local_trivial +tt.func private @gather_warp_local_trivial(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32 + // CHECK-NEXT: [[RES_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32xf32, #trivial_layout>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Same as above, but there are two index elements per thread. Expect 2 index shuffles +// with the results packed together. +// CHECK-LABEL: @gather_warp_local_larger_output +tt.func private @gather_warp_local_larger_output(%arg0: tensor<64xi32, #trivial_layout_wider>, %arg1: tensor<32xf32, #trivial_layout>) -> tensor<64xf32, #trivial_layout_wider> { + // CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[IDX0]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32 + // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID0]], i32 31) + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[IDX1]], 31 + // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID1]], i32 31) + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32xf32, #trivial_layout>, tensor<64xi32, #trivial_layout_wider>) -> tensor<64xf32, #trivial_layout_wider> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] + tt.return %0 : tensor<64xf32, #trivial_layout_wider> +} + +// Each thread has 2 elements of the source tensor, strided 32 apart, so we +// expect two index shuffles, using the MSB to select between the two. +// CHECK-LABEL: @gather_warp_local_larger_input +tt.func private @gather_warp_local_larger_input(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<64xf32, #trivial_layout_wider>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 31 + // CHECK-NEXT: [[REGID:%.*]] = and i32 [[IDX]], 32 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64xf32, #trivial_layout_wider>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: [[PICK:%.*]] = icmp eq i32 [[REGID]], 0 + // CHECK-NEXT: [[RES_i32:%.*]] = select i1 [[PICK]], i32 [[RES0]], i32 [[RES1]] + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Same as above, except the RegID comes from the LSB. +// CHECK-LABEL: @gather_warp_local_larger_input +tt.func private @gather_warp_local_larger_input_stride_1(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<64xf32, #trivial_layout_wider_reg_stride_1>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[REGID:%.*]] = and i32 [[IDX]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX]], 1 + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[TMP]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64xf32, #trivial_layout_wider_reg_stride_1>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: [[PICK:%.*]] = icmp eq i32 [[REGID]], 0 + // CHECK-NEXT: [[RES_i32:%.*]] = select i1 [[PICK]], i32 [[RES0]], i32 [[RES1]] + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Each thread has 1 element in 2 gather columns, so this is the same as the +// trivial case except now it's 2D. We expect 2 independent index shuffles. +// CHECK-LABEL: @gather_2d_trivial +tt.func private @gather_2d_trivial(%arg0: tensor<32x2xi32, #trivial_2d_one_col>, %arg1: tensor<32x2xf32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[IDX0]], 31 + // CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID0]], i32 31) + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[IDX1]], 31 + // CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID1]], i32 31) + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x2xf32, #trivial_2d_one_col>, tensor<32x2xi32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] + tt.return %0 : tensor<32x2xf32, #trivial_2d_one_col> +} + +// The single warp is split into two columns. Each column has half contiguous +// threads, each with 2 contiguous elements. Expect 4 index shuffles: two per +// column. Thus, the index should be dependent on the thread id, since the +// register alone is not enough to determine the column. +// CHECK-LABEL: @gather_2d_span_2 +tt.func private @gather_2d_span_2(%arg0: tensor<32x2xi32, #span_2d_cols>, %arg1: tensor<32x2xf32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // This uses tid to select between the two columns: + // CHECK-NEXT: [[TID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + // CHECK-NEXT: [[COL:%.*]] = and i32 [[TID]], 16 + + // Break the index into reg and thread (within column) components: + // CHECK-NEXT: [[REGID0:%.*]] = and i32 [[IDX0]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX0]], 1 + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[TMP]], 15 + + // CHECK-NEXT: [[SHUFFLE_IDX:%.*]] = or disjoint i32 [[LANEID0]], [[COL]] + + // CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[SRES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[SHUFFLE_IDX]], i32 31) + // CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[SRES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[SHUFFLE_IDX]], i32 31) + + // Use the reg id to select between the two results: + // CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID0]], 0 + // CHECK-NEXT: [[RES0_i32:%.*]] = select i1 [[PICK0]], i32 [[SRES0]], i32 [[SRES1]] + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[REGID1:%.*]] = and i32 [[IDX1]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX1]], 1 + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[TMP]], 15 + + // CHECK-NEXT: [[SHUFFLE_IDX:%.*]] = or disjoint i32 [[LANEID1]], [[COL]] + + // CHECK-NEXT: [[SRES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[SHUFFLE_IDX]], i32 31) + // CHECK-NEXT: [[SRES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[SHUFFLE_IDX]], i32 31) + + // CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID1]], 0 + // CHECK-NEXT: [[RES1_i32:%.*]] = select i1 [[PICK0]], i32 [[SRES0]], i32 [[SRES1]] + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x2xf32, #span_2d_cols>, tensor<32x2xi32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] + tt.return %0 : tensor<32x2xf32, #span_2d_cols> +} + +// CHECK-LABEL: @gather_2d_crazy +tt.func private @gather_2d_crazy(%arg0: tensor<32x16xi32, #crazy_2d_idx>, %arg1: tensor<32x16xf32, #crazy_2d_src>) -> tensor<32x16xf32, #crazy_2d_idx> { + // The specific logic becomes hard to grasp here. Just check the shuffles. + + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float, float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float, float, float } %1, 1 + // CHECK-NEXT: [[SRC2:%.*]] = extractvalue { float, float, float, float } %1, 2 + // CHECK-NEXT: [[SRC3:%.*]] = extractvalue { float, float, float, float } %1, 3 + + // CHECK: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], + // CHECK-NEXT: [[VALUE2:%.*]] = bitcast float [[SRC2]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE2]], + + // CHECK: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE2]], + + // CHECK: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], + // CHECK-NEXT: [[VALUE3:%.*]] = bitcast float [[SRC3]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE3]], + + // CHECK: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE3]], + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x16xf32, #crazy_2d_src>, tensor<32x16xi32, #crazy_2d_idx>) -> tensor<32x16xf32, #crazy_2d_idx> + tt.return %0 : tensor<32x16xf32, #crazy_2d_idx> +} + +// Keep LLVM from DCE'ing the above functions. Use volatile stores to stop LLVM +// from removing unused function results. +tt.func @anchor(%ptr: !llvm.ptr, + %arg0: tensor<32xi32, #trivial_layout>, + %arg1: tensor<32xf32, #trivial_layout>, + %arg2: tensor<64xi32, #trivial_layout_wider>, + %arg3: tensor<64xf32, #trivial_layout_wider>, + %arg4: tensor<64xf32, #trivial_layout_wider_reg_stride_1>, + %arg5: tensor<32x2xi32, #trivial_2d_one_col>, + %arg6: tensor<32x2xf32, #trivial_2d_one_col>, + %arg7: tensor<32x2xi32, #span_2d_cols>, + %arg8: tensor<32x2xf32, #span_2d_cols>, + %arg9: tensor<32x16xi32, #crazy_2d_idx>, + %arg10: tensor<32x16xf32, #crazy_2d_src>) { + + %0 = tt.call @gather_warp_local_trivial(%arg0, %arg1) : (tensor<32xi32, #trivial_layout>, tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + %1 = builtin.unrealized_conversion_cast %0 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %1, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %2 = tt.call @gather_warp_local_larger_output(%arg2, %arg1) : (tensor<64xi32, #trivial_layout_wider>, tensor<32xf32, #trivial_layout>) -> tensor<64xf32, #trivial_layout_wider> + %3 = builtin.unrealized_conversion_cast %2 : tensor<64xf32, #trivial_layout_wider> to !llvm.struct<(f32, f32)> + llvm.store volatile %3, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + + %4 = tt.call @gather_warp_local_larger_input(%arg0, %arg3) : (tensor<32xi32, #trivial_layout>, tensor<64xf32, #trivial_layout_wider>) -> tensor<32xf32, #trivial_layout> + %5 = builtin.unrealized_conversion_cast %4 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %5, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %6 = tt.call @gather_warp_local_larger_input_stride_1(%arg0, %arg4) : (tensor<32xi32, #trivial_layout>, tensor<64xf32, #trivial_layout_wider_reg_stride_1>) -> tensor<32xf32, #trivial_layout> + %7 = builtin.unrealized_conversion_cast %6 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %7, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %8 = tt.call @gather_2d_trivial(%arg5, %arg6) : (tensor<32x2xi32, #trivial_2d_one_col>, tensor<32x2xf32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> + %9 = builtin.unrealized_conversion_cast %8 : tensor<32x2xf32, #trivial_2d_one_col> to !llvm.struct<(f32, f32)> + llvm.store volatile %9, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + + %10 = tt.call @gather_2d_span_2(%arg7, %arg8) : (tensor<32x2xi32, #span_2d_cols>, tensor<32x2xf32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> + %11 = builtin.unrealized_conversion_cast %10 : tensor<32x2xf32, #span_2d_cols> to !llvm.struct<(f32, f32)> + llvm.store volatile %11, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + + %12 = tt.call @gather_2d_crazy(%arg9, %arg10) : (tensor<32x16xi32, #crazy_2d_idx>, tensor<32x16xf32, #crazy_2d_src>) -> tensor<32x16xf32, #crazy_2d_idx> + %13 = builtin.unrealized_conversion_cast %12 : tensor<32x16xf32, #crazy_2d_idx> to !llvm.struct<(f32, f32, f32, f32)> + llvm.store volatile %13, %ptr : !llvm.struct<(f32, f32, f32, f32)>, !llvm.ptr + + tt.return +} + +}