From 6b5374d0ace25ad14757ead3cfcac701eef4f3cc Mon Sep 17 00:00:00 2001 From: Mogball Date: Sat, 23 Nov 2024 08:05:29 -0500 Subject: [PATCH 01/38] Add GatherOp with lit tests --- include/triton/Dialect/Triton/IR/TritonOps.td | 25 +++++++++++ lib/Dialect/Triton/IR/Ops.cpp | 31 +++++++++++++ test/Triton/invalid.mlir | 44 +++++++++++++++++++ test/Triton/ops.mlir | 7 +++ 4 files changed, 107 insertions(+) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 197b9df7cf78..e31d6b03e62b 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -869,6 +869,31 @@ def TT_HistogramOp : TT_Op<"histogram", [Pure]> { }]; } +// +// Gather Op +// +def TT_GatherOp : TT_Op<"gather", [Pure]> { + let summary = "local gather operation"; + let description = [{ + Gather elements from the input tensor using the indices tensor along a + single specified dimension. The output tensor has the same shape as the + indices tensor. The input and indices tensors must have the same number of + dimension, and each dimension of the indices tensor that is not the gather + dimension cannot be greater than the corresponding dimension in the input + tensor. + }]; + + let arguments = (ins TT_Tensor:$src, TT_IntTensor:$indices, I32Attr:$dim); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $src `[` $indices `]` attr-dict `:` + functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + // // Print Op // diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 269c32553eff..d4bbb9fa819c 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1068,6 +1068,37 @@ Speculation::Speculatability ExternElementwiseOp::getSpeculatability() { return Speculation::NotSpeculatable; } +// -- GatherOp -- +LogicalResult GatherOp::verify() { + RankedTensorType indicesTy = getIndices().getType(); + RankedTensorType srcTy = getSrc().getType(); + RankedTensorType resTy = getResult().getType(); + + if (indicesTy.getShape() != resTy.getShape()) { + return emitOpError("indices and output shapes must match"); + } + if (indicesTy.getEncoding() != resTy.getEncoding()) { + return emitOpError("indices and output encodings must match"); + } + if (srcTy.getElementType() != resTy.getElementType()) { + return emitOpError("input and output element types must match"); + } + if (srcTy.getRank() != indicesTy.getRank()) { + return emitOpError("input and indices ranks must match"); + } + for (int dim = 0; dim < indicesTy.getRank(); ++dim) { + if (dim == getDim()) + continue; + if (indicesTy.getShape()[dim] > srcTy.getShape()[dim]) { + return emitOpError("indices dimension ") + << dim + << " cannot be greater than the corresponding input dimension"; + } + } + + return success(); +} + // -- ExperimentalTensormapCreateOp -- LogicalResult ExperimentalTensormapCreateOp::verify() { auto rank = getBoxDim().size(); diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index c7fb41707e1c..58896a752872 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -352,3 +352,47 @@ tt.func public @fn(%arg0: tensor<16x32x64xf32, #shared>) { tt.return } } // end module + +// ----- + +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) { + // expected-error @below {{indices and output shapes must match}} + %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512xf32> + tt.return +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32, #blocked>) { + // expected-error @below {{indices and output encodings must match}} + %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32, #blocked>) -> tensor<512x4xf32, #blocked1> + tt.return +} +} + +// ----- + +tt.func @gather_op(%arg0: tensor<128x16xf16>, %arg1: tensor<512x4xi32>) { + // expected-error @below {{input and output element types must match}} + %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128x16xf16>, tensor<512x4xi32>) -> tensor<512x4xf32> + tt.return +} + +// ----- + +tt.func @gather_op(%arg0: tensor<128xf32>, %arg1: tensor<512x4xi32>) { + // expected-error @below {{input and indices ranks must match}} + %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> + tt.return +} + +// ----- + +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x32xi32>) { + // expected-error @below {{indices dimension 1 cannot be greater than the corresponding input dimension}} + %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128x16xf32>, tensor<512x32xi32>) -> tensor<512x32xf32> + tt.return +} \ No newline at end of file diff --git a/test/Triton/ops.mlir b/test/Triton/ops.mlir index 9dec1e9c481e..5c402f78ca63 100644 --- a/test/Triton/ops.mlir +++ b/test/Triton/ops.mlir @@ -250,3 +250,10 @@ tt.func @experimental_descriptor_load(%0: !tt.tensordesc>) { %1 = tt.experimental_descriptor_load %0[%c0_i32] : !tt.tensordesc> -> tensor<128xf32> tt.return } + +// CHECK-LABEL: @gather_op +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) -> tensor<512x4xf32> { + // CHECK-NEXT: %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> + %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> + tt.return %0 : tensor<512x4xf32> +} \ No newline at end of file From f9bfec329e9d280f622902339ed0d7ec0949d379 Mon Sep 17 00:00:00 2001 From: Mogball Date: Mon, 25 Nov 2024 13:45:33 -0800 Subject: [PATCH 02/38] implement gather op through to LLVM --- include/triton/Analysis/Utility.h | 13 +++ .../PatternTritonGPUOpToLLVM.h | 4 + .../Conversion/TritonGPUToLLVM/Utility.h | 5 + lib/Analysis/Allocation.cpp | 4 + lib/Analysis/Utility.cpp | 11 ++ lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 1 + .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 103 ++++++++++++++++++ .../TritonToTritonGPUPass.cpp | 1 + .../Transforms/RemoveLayoutConversions.cpp | 2 + lib/Dialect/TritonGPU/Transforms/Utility.cpp | 3 + test/Conversion/allocate_shared_memory.mlir | 15 +++ test/Conversion/triton_to_tritongpu.mlir | 11 ++ test/TritonGPU/combine.mlir | 22 ++++ .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 2 + .../TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp | 2 + 15 files changed, 199 insertions(+) create mode 100644 lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp create mode 100644 test/Conversion/allocate_shared_memory.mlir diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index df6029db0de2..e06db19c6d5a 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -153,6 +153,19 @@ class ScanLoweringHelper { SmallVector srcElementTypes; }; +// Helper class for lowering `tt.gather` operations. This class shares lowering +// logic between shared memory allocation and LLVM codegen. +class GatherLoweringHelper { +public: + GatherLoweringHelper(triton::GatherOp gatherOp); + + // Get the shared memory scratch size required by this op. + unsigned getScratchSizeInBytes(); + +private: + triton::GatherOp gatherOp; +}; + // Decomposes a reshape into simpler pieces. // // As an example, suppose we have a reshape from [4,4,4] to [2,2,8,2]. diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index b6d2fbeff94f..d6530b093346 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -92,6 +92,10 @@ void populateScanOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, const TargetInfoBase &targetInfo, PatternBenefit benefit); +void populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index ba24461a1f6d..77fb64189a3f 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1125,6 +1125,11 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, // Emit indices calculation within each ConversionPattern, and returns a // [elemsPerThread X rank] index matrix. +// +// For example, for a thread a owns `elemsPerThread` elements of a tensor with +// type `type` and layout `layout`, the result will contain `elemsPerThread` +// vectors. Each vector contains the SSA values of the indices required to +// access the corresponding element, starting from the inner dimension. SmallVector> emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, Attribute layout, RankedTensorType type, bool withCTAOffset); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 53897578aa4a..8dc03aa81e74 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -125,6 +125,10 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) { ScanLoweringHelper helper(scanOp); return helper.getScratchSizeInBytes(); } + if (auto gatherOp=dyn_cast(op)) { + GatherLoweringHelper helper(gatherOp); + return helper.getScratchSizeInBytes(); + } if (auto histogram = dyn_cast(op)) { auto dstTy = histogram.getType(); int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp( diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 6166e1019901..6a4353bde8d1 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -408,6 +408,17 @@ unsigned ScanLoweringHelper::getAxisBlockStride() { llvm_unreachable("Axis not found in order"); } +GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp) + : gatherOp(gatherOp) {} + +unsigned GatherLoweringHelper::getScratchSizeInBytes() { + // For now, lower the gather op by writing the source tensor to shared memory. + // TODO(jeff): Leverage locality to avoid using scratch space when possible. + RankedTensorType srcType = gatherOp.getSrc().getType(); + return product(srcType.getShape()) * + ceil(srcType.getElementTypeBitWidth(), 8); +} + unsigned getNumScratchElements(ArrayRef shape) { if (shape.empty()) return 0; diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 0a39d403101c..d6cc4387f79e 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -13,6 +13,7 @@ add_triton_library(TritonGPUToLLVM AllocateSharedMemory.cpp ReduceOpToLLVM.cpp ScanOpToLLVM.cpp + GatherOpToLLVM.cpp ConvertLayoutOpToLLVM.cpp ControlFlowOpToLLVM.cpp FuncOpToLLVM.cpp diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp new file mode 100644 index 000000000000..f29313b589c5 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -0,0 +1,103 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { +class GatherOpConversion : public ConvertOpToLLVMPattern { +public: + GatherOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + +private: + const TargetInfoBase &targetInfo; +}; + +LogicalResult +GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + RankedTensorType srcType = op.getSrc().getType(); + + // Compute the src subtensor shape owned by this CTA. + SmallVector srcShapePerCTA = + convertType(triton::gpu::getShapePerCTA(srcType)); + + // Grab the src values in this thread. + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + + // Emit the indices of the src values owned by this thread. + SmallVector> srcIndices = + emitIndices(loc, rewriter, targetInfo, srcType.getEncoding(), + op.getSrc().getType(), /*withCTAOffset=*/true); + + // Store the src values owned by the thread into their respective location in + // the scratch memory. + assert(srcValues.size() == srcIndices.size()); + + // Get the base pointer to the scratch memory. + Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); + + // For each src element owned by the thread, index into the scratch memory and + // then store it. + Type elemType = getTypeConverter()->convertType(srcType.getElementType()); + for (auto [value, indices] : llvm::zip(srcValues, srcIndices)) { + // Convert the index at each dim into a single offset given the shape of the + // tensor. + Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); + // Emit the offset into the shared memory and then store the value. + Value ptr = gep(smemBase.getType(), elemType, smemBase, offset); + store(value, ptr); + } + + // Synchronize the whole CTA. + // TODO(jeff): Should we teach Membar that gather synchronizes? + barrier(); + + // Grab the index values owned by this thread. + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + + // I = LL(pid) + // idx = indices[I] + // I_gather = [I[d] if d != axis else idx for d in range(len(I))] + // out[I] = src[I_gather] + RankedTensorType dstType = op.getType(); + SmallVector> dstIndices = + emitIndices(loc, rewriter, targetInfo, dstType.getEncoding(), dstType, + /*withCTAOffset=*/true); + + unsigned axis = op.getDim(); + SmallVector dstShapePerCTA = + convertType(triton::gpu::getShapePerCTA(dstType)); + + SmallVector results(dstIndices.size()); + for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) { + indices[axis] = idx; + Value offset = LLVM::linearize(rewriter, loc, indices, dstShapePerCTA); + Value ptr = gep(smemPtrType, elemType, smemBase, offset); + results[i] = load(elemType, ptr); + } + + Value packed = packLLElements(loc, getTypeConverter(), results, rewriter, + dstType); + rewriter.replaceOp(op, packed); + return success(); +} + +} // namespace + +void triton::populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.insert(typeConverter, targetInfo, benefit); +} \ No newline at end of file diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index ebc302c93cd7..12ebdf49a621 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -540,6 +540,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, TritonExpandDimsPattern, TritonTransPattern, TritonDotPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 70f5219111ab..9c05ae5bb46d 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -282,6 +282,7 @@ SmallVector LayoutPropagation::propagateToUsers(Value value, setEncoding(user->getResults(), info, changed, user); continue; } + // TODO(jeff): Propagate tt.gather indices layout to dst. } return changed; } @@ -709,6 +710,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) { } return newOp; } + // TODO(jeff): Handle tt.gather once it supports layout propagation. llvm::report_fatal_error("unexpected op in rewrite"); return nullptr; } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index b8f3abfcaca8..52ddf12b9f1e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -472,6 +472,8 @@ std::optional inferSrcEncoding(Operation *op, Attribute encoding) { return inferSrcEncoding(trans, encoding); if (auto reshape = dyn_cast(op)) return inferSrcEncoding(reshape, encoding); + // TODO(jeff): Handle progagating tt.gather indices -> dst layout. + // This requires updating the API to specify the exact operands and results. return std::nullopt; } @@ -499,6 +501,7 @@ std::optional inferDstEncoding(Operation *op, Attribute encoding) { return inferDstEncoding(trans, encoding); if (auto reshape = dyn_cast(op)) return inferDstEncoding(reshape, encoding); + // TODO(jeff): Handle progagating tt.gather indices -> dst layout. return std::nullopt; } diff --git a/test/Conversion/allocate_shared_memory.mlir b/test/Conversion/allocate_shared_memory.mlir new file mode 100644 index 000000000000..8a8151fcc9e6 --- /dev/null +++ b/test/Conversion/allocate_shared_memory.mlir @@ -0,0 +1,15 @@ +// RUN: triton-opt %s --allocate-shared-memory | FileCheck %s + +// CHECK-LABEL: module +// CHECK-SAME: triton_gpu.shared = 131072 : i32 +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + +// CHECK-LABEL: @gather_op +// TODO(jeff): Optimize the lowering to reduce shared memory usage. +tt.func @gather_op(%arg0: tensor<1024x4xi32>, %arg1: tensor<128x256xf32>) { + // CHECK-NEXT: allocation.offset = 0 : i32 + %0 = tt.gather %arg1[%arg0] {dim = 0 : i32} : (tensor<128x256xf32>, tensor<1024x4xi32>) -> tensor<1024x4xf32> + tt.return +} + +} \ No newline at end of file diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index 96482b2298e1..86fdcf3585c6 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -108,3 +108,14 @@ tt.func @arith_splat_bool(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { tt.return } } + +// ----- + +// CHECK-LABEL: gather_op +tt.func @gather_op() { + %cst = arith.constant dense<1.0> : tensor<128x4xf32> + %cst_0 = arith.constant dense<1> : tensor<256x4xi32> + // CHECK: tt.gather %{{.*}}[%{{.*}}] {dim = 0 : i32} : (tensor<128x4xf32, #blocked>, tensor<256x4xi32, #blocked>) -> tensor<256x4xf32, #blocked> + %0 = tt.gather %cst[%cst_0] {dim = 0 : i32} : (tensor<128x4xf32>, tensor<256x4xi32>) -> tensor<256x4xf32> + tt.return +} \ No newline at end of file diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 5e1cad52af90..ae2ff756aeaf 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2685,3 +2685,25 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war tt.return } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + +// TODO(jeff): Support indices -> dst layout propagation to remove both +// layout conversions here. +tt.func @propagate_layout_gather(%arg0: tensor<1024x4xi32, #blocked>, %arg1: tensor<128x256xf32, #blocked>) -> tensor<1024x4xf32, #blocked2> { + // CHECK-LABEL: propagate_layout_gather + + // XCHECK-NOT: convert_layout + %0 = triton_gpu.convert_layout %arg0 : tensor<1024x4xi32, #blocked> -> tensor<1024x4xi32, #blocked1> + %1 = tt.gather %arg1[%0] {dim = 0 : i32} : (tensor<128x256xf32, #blocked>, tensor<1024x4xi32, #blocked1>) -> tensor<1024x4xf32, #blocked1> + %2 = triton_gpu.convert_layout %1 : tensor<1024x4xf32, #blocked1> -> tensor<1024x4xf32, #blocked2> + tt.return %2 : tensor<1024x4xf32, #blocked2> +} + +} \ No newline at end of file diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index f99cd50b0d27..b27189aac635 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -196,6 +196,8 @@ struct ConvertTritonAMDGPUToLLVM commonBenefit); populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns, commonBenefit); + populatePatterns7(mlir::triton::populateGatherOpToLLVMPatterns, + commonBenefit); mlir::triton::BackendCallbacks callbacks; callbacks.localStoreOpConversion = storeOpConversionCallback; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 6674c9a81012..a5cb57e980e9 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -145,6 +145,8 @@ struct ConvertTritonGPUToLLVM targetInfo, benefit); mlir::triton::populateScanOpToLLVMPatterns(typeConverter, patterns, targetInfo, benefit); + mlir::triton::populateGatherOpToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); populateBarrierOpToLLVMPatterns(typeConverter, patterns, benefit); populateTensorPtrOpsToLLVMPatterns(typeConverter, patterns, benefit); populateClusterOpsToLLVMPatterns(typeConverter, patterns, benefit); From c091f3166980df312b85888bbc3df0dd4d75e811 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 26 Nov 2024 12:20:47 -0800 Subject: [PATCH 03/38] expose through frontend and add unit tests --- include/triton/Dialect/Triton/IR/TritonOps.td | 3 +- .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 21 ++++++--- lib/Dialect/Triton/IR/Ops.cpp | 18 ++++++++ python/src/ir.cc | 3 ++ python/test/unit/language/test_core.py | 43 +++++++++++++++++++ python/triton/language/__init__.py | 2 + python/triton/language/core.py | 20 +++++++++ python/triton/language/semantic.py | 15 +++++++ test/Triton/invalid.mlir | 7 +++ 9 files changed, 124 insertions(+), 8 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index e31d6b03e62b..ac2321e5e456 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -872,7 +872,8 @@ def TT_HistogramOp : TT_Op<"histogram", [Pure]> { // // Gather Op // -def TT_GatherOp : TT_Op<"gather", [Pure]> { +def TT_GatherOp : TT_Op<"gather", [Pure, + DeclareOpInterfaceMethods]> { let summary = "local gather operation"; let description = [{ Gather elements from the input tensor using the indices tensor along a diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index f29313b589c5..dd169bf61a35 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -75,20 +75,27 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, emitIndices(loc, rewriter, targetInfo, dstType.getEncoding(), dstType, /*withCTAOffset=*/true); - unsigned axis = op.getDim(); - SmallVector dstShapePerCTA = - convertType(triton::gpu::getShapePerCTA(dstType)); + unsigned idxWidth = op.getIndices().getType().getElementTypeBitWidth(); + unsigned axis = op.getDim(); SmallVector results(dstIndices.size()); for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) { + // The LL index computations are performed with 32 bit integers. If the + // indices are something else, cast them to i32. + if (idxWidth > 32) { + idx = trunc(i32_ty, idx); + } else if (idxWidth < 32) { + // Negative indices don't make sense, so zero-extend. + idx = zext(i32_ty, idx); + } indices[axis] = idx; - Value offset = LLVM::linearize(rewriter, loc, indices, dstShapePerCTA); - Value ptr = gep(smemPtrType, elemType, smemBase, offset); + Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); + Value ptr = gep(smemBase.getType(), elemType, smemBase, offset); results[i] = load(elemType, ptr); } - Value packed = packLLElements(loc, getTypeConverter(), results, rewriter, - dstType); + Value packed = + packLLElements(loc, getTypeConverter(), results, rewriter, dstType); rewriter.replaceOp(op, packed); return success(); } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index d4bbb9fa819c..8fa4b6c4f2d2 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1086,6 +1086,9 @@ LogicalResult GatherOp::verify() { if (srcTy.getRank() != indicesTy.getRank()) { return emitOpError("input and indices ranks must match"); } + if (getDim() >= srcTy.getRank()) { + return emitOpError("gather dimension must be less than the input rank"); + } for (int dim = 0; dim < indicesTy.getRank(); ++dim) { if (dim == getDim()) continue; @@ -1099,6 +1102,21 @@ LogicalResult GatherOp::verify() { return success(); } +LogicalResult GatherOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + GatherOpAdaptor adaptor(operands, attributes, properties, regions); + auto indicesType = cast(adaptor.getIndices().getType()); + auto srcType = cast(adaptor.getSrc().getType()); + + // Shape and encoding of the indices with the element type of the src. + inferredReturnTypes.push_back( + RankedTensorType::get(indicesType.getShape(), srcType.getElementType(), + indicesType.getEncoding())); + return success(); +} + // -- ExperimentalTensormapCreateOp -- LogicalResult ExperimentalTensormapCreateOp::verify() { auto rank = getBoxDim().size(); diff --git a/python/src/ir.cc b/python/src/ir.cc index bb09dae7171a..16e96b52dcf0 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1627,6 +1627,9 @@ void init_triton_ir(py::module &&m) { IntegerType::get(operand.getContext(), 32)), operand); }) + .def("create_gather", + [](TritonOpBuilder &self, Value src, Value indices, int dim) + -> Value { return self.create(src, indices, dim); }) // Force GPU barrier .def("create_barrier", [](TritonOpBuilder &self) { self.create(); }) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index b2e0fa59a8d4..df058a6ce4c5 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6078,3 +6078,46 @@ def kernel(In, Out, # perm[0], perm[1], perm[2], perm[3], perm[4], red_dims[0], red_dims[1], red_dims[2]) assert torch.all(ref == result) + + +@pytest.mark.parametrize("src_shape, indices_shape, dim", [ + ([4, 4], [8, 2], 0), + ([256, 128], [512, 64], 0), + ([256, 128], [256, 256], 1), +]) +def test_gather(src_shape, indices_shape, dim): + + @triton.jit + def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, + src_dim0: tl.constexpr, src_dim1: tl.constexpr, src_stride0: tl.constexpr, src_stride1: tl.constexpr, + idx_dim0: tl.constexpr, idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, + out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, out_stride1: tl.constexpr): + src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + + tl.arange(0, src_dim1)[None, :] * src_stride1) + src = tl.load(src_ptr + src_offs) + + idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + + tl.arange(0, idx_dim1)[None, :] * idx_stride1) + idx = tl.load(idx_ptr + idx_offs) + + out = tl.gather(src, idx, axis) + + out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + + tl.arange(0, out_dim1)[None, :] * out_stride1) + tl.store(out_ptr + out_offs, out) + + def triton_gather(src: torch.Tensor, dim: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + + gather_kernel[(1,)](src, indices, output, dim, + src.shape[0], src.shape[1], src.stride(0), src.stride(1), + indices.shape[0], indices.shape[1], indices.stride(0), indices.stride(1), + output.shape[0], output.shape[1], output.stride(0), output.stride(1)) + + return output + + src = torch.randn(src_shape, device='cuda') + indices = torch.randint(0, src.shape[dim], indices_shape, device='cuda') + ref = torch.gather(src, dim, indices) + result = triton_gather(src, dim, indices) + assert torch.all(ref == result) \ No newline at end of file diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 737ff06e6aed..0c8965fc520a 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -70,6 +70,7 @@ float8e5b16, full, function_type, + gather, histogram, inline_asm_elementwise, int1, @@ -188,6 +189,7 @@ "fma", "full", "function_type", + "gather", "histogram", "inline_asm_elementwise", "interleave", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 145c9648298d..0b64de0100eb 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1098,6 +1098,9 @@ def reduce(self, axis, combine_fn, keep_dims=False) -> tensor: def associative_scan(self, axis, combine_fn, reverse=False) -> tensor: ... + def gather(self, indices, axis) -> tensor: + ... + def histogram(self, num_bins) -> tensor: ... @@ -2347,6 +2350,23 @@ def histogram(input, num_bins, _builder=None, _generator=None): return semantic.histogram(input, num_bins, _builder) +@_tensor_member_fn +@builtin +def gather(src, index, dim, _builder=None): + """Gather from a tensor along a given dimension. + + :param src: the source tensor + :type src: Tensor + :param index: the index tensor + :type index: Tensor + :param dim: the dimension to gather along + :type dim: int + + """ + dim = _constexpr_to_value(dim) + return semantic.gather(src, index, dim, _builder) + + # ----------------------- # Compiler Hint Ops # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 4b27700b00c4..5ccde78cba78 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1677,6 +1677,21 @@ def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs))) +# ===----------------------------------------------------------------------=== +# Gather +# ===----------------------------------------------------------------------=== + +def gather(src: tl.tensor, index: tl.tensor, dim: int, builder: ir.builder) -> tl.tensor: + assert index.dtype.is_int(), "index must be an integer tensor" + rank = len(src.type.shape) + assert len(index.type.shape) == rank, "src and index must have the same rank" + assert 0 <= dim < rank, "dim must be a valid axis" + for d in range(rank): + assert d == dim or index.type.shape[d] <= src.type.shape[d], f"index dim {dim} cannot be greater than the corresponding src dim" + gather = builder.create_gather(src.handle, index.handle, dim) + return wrap_tensor(gather, src.type.scalar, index.type.shape) + + # ===----------------------------------------------------------------------=== # Histogram # ===----------------------------------------------------------------------=== diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index 58896a752872..0c5db7dac995 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -395,4 +395,11 @@ tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x32xi32>) { // expected-error @below {{indices dimension 1 cannot be greater than the corresponding input dimension}} %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128x16xf32>, tensor<512x32xi32>) -> tensor<512x32xf32> tt.return +} +// ----- + +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) { + // expected-error @below {{gather dimension must be less than the input rank}} + %0 = tt.gather %arg0[%arg1] {dim = 3 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> + tt.return } \ No newline at end of file From 44dabb4d778b1a592d3dca9ac80f12e984cec215 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 26 Nov 2024 12:46:08 -0800 Subject: [PATCH 04/38] rename dim to axis --- include/triton/Dialect/Triton/IR/TritonOps.td | 6 +++--- .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 2 +- lib/Dialect/Triton/IR/Ops.cpp | 4 ++-- python/src/ir.cc | 4 ++-- python/test/unit/language/test_core.py | 14 +++++++------- python/triton/language/core.py | 10 +++++----- python/triton/language/semantic.py | 18 +++++++++++++----- test/Conversion/allocate_shared_memory.mlir | 2 +- test/Conversion/triton_to_tritongpu.mlir | 4 ++-- test/Triton/invalid.mlir | 12 ++++++------ test/Triton/ops.mlir | 4 ++-- test/TritonGPU/combine.mlir | 2 +- 12 files changed, 45 insertions(+), 37 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index ac2321e5e456..7ac54764066e 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -877,14 +877,14 @@ def TT_GatherOp : TT_Op<"gather", [Pure, let summary = "local gather operation"; let description = [{ Gather elements from the input tensor using the indices tensor along a - single specified dimension. The output tensor has the same shape as the - indices tensor. The input and indices tensors must have the same number of + single specified axis. The output tensor has the same shape as the indices + tensor. The input and indices tensors must have the same number of dimension, and each dimension of the indices tensor that is not the gather dimension cannot be greater than the corresponding dimension in the input tensor. }]; - let arguments = (ins TT_Tensor:$src, TT_IntTensor:$indices, I32Attr:$dim); + let arguments = (ins TT_Tensor:$src, TT_IntTensor:$indices, I32Attr:$axis); let results = (outs TT_Tensor:$result); let assemblyFormat = [{ diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index dd169bf61a35..960655394e50 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -77,7 +77,7 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, unsigned idxWidth = op.getIndices().getType().getElementTypeBitWidth(); - unsigned axis = op.getDim(); + unsigned axis = op.getAxis(); SmallVector results(dstIndices.size()); for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) { // The LL index computations are performed with 32 bit integers. If the diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 8fa4b6c4f2d2..f8ff387a7795 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1086,11 +1086,11 @@ LogicalResult GatherOp::verify() { if (srcTy.getRank() != indicesTy.getRank()) { return emitOpError("input and indices ranks must match"); } - if (getDim() >= srcTy.getRank()) { + if (getAxis() >= srcTy.getRank()) { return emitOpError("gather dimension must be less than the input rank"); } for (int dim = 0; dim < indicesTy.getRank(); ++dim) { - if (dim == getDim()) + if (dim == getAxis()) continue; if (indicesTy.getShape()[dim] > srcTy.getShape()[dim]) { return emitOpError("indices dimension ") diff --git a/python/src/ir.cc b/python/src/ir.cc index 16e96b52dcf0..4daf0250e992 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1628,8 +1628,8 @@ void init_triton_ir(py::module &&m) { operand); }) .def("create_gather", - [](TritonOpBuilder &self, Value src, Value indices, int dim) - -> Value { return self.create(src, indices, dim); }) + [](TritonOpBuilder &self, Value src, Value indices, int axis) + -> Value { return self.create(src, indices, axis); }) // Force GPU barrier .def("create_barrier", [](TritonOpBuilder &self) { self.create(); }) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index df058a6ce4c5..6ec2fcecc610 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6080,12 +6080,12 @@ def kernel(In, Out, # assert torch.all(ref == result) -@pytest.mark.parametrize("src_shape, indices_shape, dim", [ +@pytest.mark.parametrize("src_shape, indices_shape, axis", [ ([4, 4], [8, 2], 0), ([256, 128], [512, 64], 0), ([256, 128], [256, 256], 1), ]) -def test_gather(src_shape, indices_shape, dim): +def test_gather(src_shape, indices_shape, axis): @triton.jit def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, @@ -6106,10 +6106,10 @@ def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, tl.arange(0, out_dim1)[None, :] * out_stride1) tl.store(out_ptr + out_offs, out) - def triton_gather(src: torch.Tensor, dim: int, indices: torch.Tensor): + def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) - gather_kernel[(1,)](src, indices, output, dim, + gather_kernel[(1,)](src, indices, output, axis, src.shape[0], src.shape[1], src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), indices.stride(1), output.shape[0], output.shape[1], output.stride(0), output.stride(1)) @@ -6117,7 +6117,7 @@ def triton_gather(src: torch.Tensor, dim: int, indices: torch.Tensor): return output src = torch.randn(src_shape, device='cuda') - indices = torch.randint(0, src.shape[dim], indices_shape, device='cuda') - ref = torch.gather(src, dim, indices) - result = triton_gather(src, dim, indices) + indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda') + ref = torch.gather(src, axis, indices) + result = triton_gather(src, axis, indices) assert torch.all(ref == result) \ No newline at end of file diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 0b64de0100eb..c822ef812cd6 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -2352,19 +2352,19 @@ def histogram(input, num_bins, _builder=None, _generator=None): @_tensor_member_fn @builtin -def gather(src, index, dim, _builder=None): +def gather(src, index, axis, _builder=None): """Gather from a tensor along a given dimension. :param src: the source tensor :type src: Tensor :param index: the index tensor :type index: Tensor - :param dim: the dimension to gather along - :type dim: int + :param axis: the dimension to gather along + :type axis: int """ - dim = _constexpr_to_value(dim) - return semantic.gather(src, index, dim, _builder) + axis = _constexpr_to_value(axis) + return semantic.gather(src, index, axis, _builder) # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 5ccde78cba78..f2586f5347c9 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1681,14 +1681,22 @@ def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, # Gather # ===----------------------------------------------------------------------=== -def gather(src: tl.tensor, index: tl.tensor, dim: int, builder: ir.builder) -> tl.tensor: +def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: assert index.dtype.is_int(), "index must be an integer tensor" + rank = len(src.type.shape) - assert len(index.type.shape) == rank, "src and index must have the same rank" - assert 0 <= dim < rank, "dim must be a valid axis" + assert len(index.type.shape) == rank, "source and index tensors must have the same rank" + + assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})" + if axis < 0: + axis += rank + for d in range(rank): - assert d == dim or index.type.shape[d] <= src.type.shape[d], f"index dim {dim} cannot be greater than the corresponding src dim" - gather = builder.create_gather(src.handle, index.handle, dim) + if d == axis: + continue + assert index.type.shape[d] <= src.type.shape[d], f"index dim {axis} cannot be greater than the corresponding source dim" + + gather = builder.create_gather(src.handle, index.handle, axis) return wrap_tensor(gather, src.type.scalar, index.type.shape) diff --git a/test/Conversion/allocate_shared_memory.mlir b/test/Conversion/allocate_shared_memory.mlir index 8a8151fcc9e6..6335275d6c91 100644 --- a/test/Conversion/allocate_shared_memory.mlir +++ b/test/Conversion/allocate_shared_memory.mlir @@ -8,7 +8,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // TODO(jeff): Optimize the lowering to reduce shared memory usage. tt.func @gather_op(%arg0: tensor<1024x4xi32>, %arg1: tensor<128x256xf32>) { // CHECK-NEXT: allocation.offset = 0 : i32 - %0 = tt.gather %arg1[%arg0] {dim = 0 : i32} : (tensor<128x256xf32>, tensor<1024x4xi32>) -> tensor<1024x4xf32> + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<128x256xf32>, tensor<1024x4xi32>) -> tensor<1024x4xf32> tt.return } diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index 86fdcf3585c6..5c80a4d11cf6 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -115,7 +115,7 @@ tt.func @arith_splat_bool(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { tt.func @gather_op() { %cst = arith.constant dense<1.0> : tensor<128x4xf32> %cst_0 = arith.constant dense<1> : tensor<256x4xi32> - // CHECK: tt.gather %{{.*}}[%{{.*}}] {dim = 0 : i32} : (tensor<128x4xf32, #blocked>, tensor<256x4xi32, #blocked>) -> tensor<256x4xf32, #blocked> - %0 = tt.gather %cst[%cst_0] {dim = 0 : i32} : (tensor<128x4xf32>, tensor<256x4xi32>) -> tensor<256x4xf32> + // CHECK: tt.gather %{{.*}}[%{{.*}}] {axis = 0 : i32} : (tensor<128x4xf32, #blocked>, tensor<256x4xi32, #blocked>) -> tensor<256x4xf32, #blocked> + %0 = tt.gather %cst[%cst_0] {axis = 0 : i32} : (tensor<128x4xf32>, tensor<256x4xi32>) -> tensor<256x4xf32> tt.return } \ No newline at end of file diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index 0c5db7dac995..c09a8d530ae3 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -357,7 +357,7 @@ tt.func public @fn(%arg0: tensor<16x32x64xf32, #shared>) { tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) { // expected-error @below {{indices and output shapes must match}} - %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512xf32> + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512xf32> tt.return } @@ -368,7 +368,7 @@ tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) { module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32, #blocked>) { // expected-error @below {{indices and output encodings must match}} - %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32, #blocked>) -> tensor<512x4xf32, #blocked1> + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32, #blocked>) -> tensor<512x4xf32, #blocked1> tt.return } } @@ -377,7 +377,7 @@ tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32, #blocked> tt.func @gather_op(%arg0: tensor<128x16xf16>, %arg1: tensor<512x4xi32>) { // expected-error @below {{input and output element types must match}} - %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128x16xf16>, tensor<512x4xi32>) -> tensor<512x4xf32> + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf16>, tensor<512x4xi32>) -> tensor<512x4xf32> tt.return } @@ -385,7 +385,7 @@ tt.func @gather_op(%arg0: tensor<128x16xf16>, %arg1: tensor<512x4xi32>) { tt.func @gather_op(%arg0: tensor<128xf32>, %arg1: tensor<512x4xi32>) { // expected-error @below {{input and indices ranks must match}} - %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> tt.return } @@ -393,13 +393,13 @@ tt.func @gather_op(%arg0: tensor<128xf32>, %arg1: tensor<512x4xi32>) { tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x32xi32>) { // expected-error @below {{indices dimension 1 cannot be greater than the corresponding input dimension}} - %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128x16xf32>, tensor<512x32xi32>) -> tensor<512x32xf32> + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x32xi32>) -> tensor<512x32xf32> tt.return } // ----- tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) { // expected-error @below {{gather dimension must be less than the input rank}} - %0 = tt.gather %arg0[%arg1] {dim = 3 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> + %0 = tt.gather %arg0[%arg1] {axis = 3 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> tt.return } \ No newline at end of file diff --git a/test/Triton/ops.mlir b/test/Triton/ops.mlir index 5c402f78ca63..f8aef5586992 100644 --- a/test/Triton/ops.mlir +++ b/test/Triton/ops.mlir @@ -253,7 +253,7 @@ tt.func @experimental_descriptor_load(%0: !tt.tensordesc>) { // CHECK-LABEL: @gather_op tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) -> tensor<512x4xf32> { - // CHECK-NEXT: %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> - %0 = tt.gather %arg0[%arg1] {dim = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> + // CHECK-NEXT: %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> tt.return %0 : tensor<512x4xf32> } \ No newline at end of file diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index ae2ff756aeaf..7f2d6c5ddcde 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2701,7 +2701,7 @@ tt.func @propagate_layout_gather(%arg0: tensor<1024x4xi32, #blocked>, %arg1: ten // XCHECK-NOT: convert_layout %0 = triton_gpu.convert_layout %arg0 : tensor<1024x4xi32, #blocked> -> tensor<1024x4xi32, #blocked1> - %1 = tt.gather %arg1[%0] {dim = 0 : i32} : (tensor<128x256xf32, #blocked>, tensor<1024x4xi32, #blocked1>) -> tensor<1024x4xf32, #blocked1> + %1 = tt.gather %arg1[%0] {axis = 0 : i32} : (tensor<128x256xf32, #blocked>, tensor<1024x4xi32, #blocked1>) -> tensor<1024x4xf32, #blocked1> %2 = triton_gpu.convert_layout %1 : tensor<1024x4xf32, #blocked1> -> tensor<1024x4xf32, #blocked2> tt.return %2 : tensor<1024x4xf32, #blocked2> } From d4a32c8f02569ba2f72a83b9e37b3d91f3174cc0 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 26 Nov 2024 13:22:52 -0800 Subject: [PATCH 05/38] add LLVMIR test --- test/Conversion/tritongpu_to_llvm.mlir | 31 ++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 3f2fd578da82..f81963099e9a 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1897,3 +1897,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + +tt.func @gather_in_shared(%arg0: tensor<16x2xi32, #blocked1>, %arg1: tensor<8x4xf32, #blocked>) { + // CHECK-LABEL: gather_in_shared + + // CHECK: [[S0:%.*]] = llvm.extractvalue %arg1[0] + + // CHECK: [[SMEM_BASE:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM:%.*]] = llvm.getelementptr [[SMEM_BASE]] + // CHECK: store [[S0]] + // CHECK-NEXT: nvvm.barrier0 + + // CHECK: [[I0:%.*]] = llvm.extractvalue %arg0[0] + + // CHECK: [[IDX:%.*]] = llvm.add {{.*}}, [[I0]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM]][[[IDX]]] + // CHECK-NEXT: [[OUT0:%.*]] = llvm.load [[PTR]] + + // CHECK: insertvalue [[OUT0]], {{.*}}[0] + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<8x4xf32, #blocked>, tensor<16x2xi32, #blocked1>) -> tensor<16x2xf32, #blocked1> + tt.return +} + +} \ No newline at end of file From 56d12799d85abb1d9dc2e5da6f85bd91c5628fef Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 26 Nov 2024 13:35:00 -0800 Subject: [PATCH 06/38] newlines --- test/Conversion/allocate_shared_memory.mlir | 2 +- test/Conversion/tritongpu_to_llvm.mlir | 2 +- test/Triton/invalid.mlir | 2 +- test/Triton/ops.mlir | 2 +- test/TritonGPU/combine.mlir | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/Conversion/allocate_shared_memory.mlir b/test/Conversion/allocate_shared_memory.mlir index 6335275d6c91..f2d122283b38 100644 --- a/test/Conversion/allocate_shared_memory.mlir +++ b/test/Conversion/allocate_shared_memory.mlir @@ -12,4 +12,4 @@ tt.func @gather_op(%arg0: tensor<1024x4xi32>, %arg1: tensor<128x256xf32>) { tt.return } -} \ No newline at end of file +} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index f81963099e9a..5ac689df315b 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1927,4 +1927,4 @@ tt.func @gather_in_shared(%arg0: tensor<16x2xi32, #blocked1>, %arg1: tensor<8x4x tt.return } -} \ No newline at end of file +} diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index c09a8d530ae3..d58f95441a3b 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -402,4 +402,4 @@ tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) { // expected-error @below {{gather dimension must be less than the input rank}} %0 = tt.gather %arg0[%arg1] {axis = 3 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> tt.return -} \ No newline at end of file +} diff --git a/test/Triton/ops.mlir b/test/Triton/ops.mlir index f8aef5586992..77847805bcc1 100644 --- a/test/Triton/ops.mlir +++ b/test/Triton/ops.mlir @@ -256,4 +256,4 @@ tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) -> tenso // CHECK-NEXT: %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> tt.return %0 : tensor<512x4xf32> -} \ No newline at end of file +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index c19e07f9a7ad..f01ef871ee71 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2724,4 +2724,4 @@ tt.func @propagate_layout_gather(%arg0: tensor<1024x4xi32, #blocked>, %arg1: ten tt.return %2 : tensor<1024x4xf32, #blocked2> } -} \ No newline at end of file +} From 6a2f7886ae613e97ede2d38418d0f7a243fc8d18 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 26 Nov 2024 13:35:19 -0800 Subject: [PATCH 07/38] more newlines --- lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 960655394e50..71adcdfdfb0c 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -107,4 +107,4 @@ void triton::populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, PatternBenefit benefit) { patterns.insert(typeConverter, targetInfo, benefit); -} \ No newline at end of file +} From 8f1358e30ad626e6adb307c70018875b9346a6b4 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 26 Nov 2024 13:35:50 -0800 Subject: [PATCH 08/38] more newlines --- python/test/unit/language/test_core.py | 2 +- test/Conversion/triton_to_tritongpu.mlir | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 31ccd6f74901..394d1b58cb37 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6129,4 +6129,4 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda') ref = torch.gather(src, axis, indices) result = triton_gather(src, axis, indices) - assert torch.all(ref == result) \ No newline at end of file + assert torch.all(ref == result) diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index 5c80a4d11cf6..0b615d44b14c 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -118,4 +118,4 @@ tt.func @gather_op() { // CHECK: tt.gather %{{.*}}[%{{.*}}] {axis = 0 : i32} : (tensor<128x4xf32, #blocked>, tensor<256x4xi32, #blocked>) -> tensor<256x4xf32, #blocked> %0 = tt.gather %cst[%cst_0] {axis = 0 : i32} : (tensor<128x4xf32>, tensor<256x4xi32>) -> tensor<256x4xf32> tt.return -} \ No newline at end of file +} From 63d0de2f8186b8b7c96b4c4cc47e1d1595ed7b63 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 26 Nov 2024 13:55:32 -0800 Subject: [PATCH 09/38] format code --- lib/Analysis/Allocation.cpp | 2 +- .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 1 - python/test/unit/language/test_core.py | 25 ++++++++----------- python/triton/language/semantic.py | 4 ++- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 8dc03aa81e74..c79e81e65ca6 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -125,7 +125,7 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) { ScanLoweringHelper helper(scanOp); return helper.getScratchSizeInBytes(); } - if (auto gatherOp=dyn_cast(op)) { + if (auto gatherOp = dyn_cast(op)) { GatherLoweringHelper helper(gatherOp); return helper.getScratchSizeInBytes(); } diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 71adcdfdfb0c..b315e0d646ae 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -75,7 +75,6 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, emitIndices(loc, rewriter, targetInfo, dstType.getEncoding(), dstType, /*withCTAOffset=*/true); - unsigned idxWidth = op.getIndices().getType().getElementTypeBitWidth(); unsigned axis = op.getAxis(); SmallVector results(dstIndices.size()); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 394d1b58cb37..29eec914658e 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6097,31 +6097,28 @@ def kernel(In, Out, # def test_gather(src_shape, indices_shape, axis): @triton.jit - def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, - src_dim0: tl.constexpr, src_dim1: tl.constexpr, src_stride0: tl.constexpr, src_stride1: tl.constexpr, - idx_dim0: tl.constexpr, idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, - out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, out_stride1: tl.constexpr): - src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + - tl.arange(0, src_dim1)[None, :] * src_stride1) + def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, + src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, + idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, + out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, + out_stride1: tl.constexpr): + src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) src = tl.load(src_ptr + src_offs) - idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + - tl.arange(0, idx_dim1)[None, :] * idx_stride1) + idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) idx = tl.load(idx_ptr + idx_offs) out = tl.gather(src, idx, axis) - out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + - tl.arange(0, out_dim1)[None, :] * out_stride1) + out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) tl.store(out_ptr + out_offs, out) def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) - gather_kernel[(1,)](src, indices, output, axis, - src.shape[0], src.shape[1], src.stride(0), src.stride(1), - indices.shape[0], indices.shape[1], indices.stride(0), indices.stride(1), - output.shape[0], output.shape[1], output.stride(0), output.stride(1)) + gather_kernel[(1, )](src, indices, output, axis, src.shape[0], src.shape[1], + src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), + indices.stride(1), output.shape[0], output.shape[1], output.stride(0), output.stride(1)) return output diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index f2586f5347c9..390e2b5d4247 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1681,6 +1681,7 @@ def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, # Gather # ===----------------------------------------------------------------------=== + def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: assert index.dtype.is_int(), "index must be an integer tensor" @@ -1694,7 +1695,8 @@ def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> for d in range(rank): if d == axis: continue - assert index.type.shape[d] <= src.type.shape[d], f"index dim {axis} cannot be greater than the corresponding source dim" + assert index.type.shape[d] <= src.type.shape[ + d], f"index dim {axis} cannot be greater than the corresponding source dim" gather = builder.create_gather(src.handle, index.handle, axis) return wrap_tensor(gather, src.type.scalar, index.type.shape) From 39a35cea891050055ec59685e5d74624d1d8fde0 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 26 Nov 2024 13:57:13 -0800 Subject: [PATCH 10/38] reduce test_gather smem usage for AMD --- python/test/unit/language/test_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 29eec914658e..d78371e1729c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6091,8 +6091,8 @@ def kernel(In, Out, # @pytest.mark.parametrize("src_shape, indices_shape, axis", [ ([4, 4], [8, 2], 0), - ([256, 128], [512, 64], 0), - ([256, 128], [256, 256], 1), + ([128, 64], [256, 32], 0), + ([128, 64], [128, 128], 1), ]) def test_gather(src_shape, indices_shape, axis): From 741d0a841afd3eeda0f67c7ae875f82635fe3af6 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 26 Nov 2024 22:14:14 -0800 Subject: [PATCH 11/38] assert_close --- python/test/unit/language/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d78371e1729c..e270faeaa74b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6126,4 +6126,4 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda') ref = torch.gather(src, axis, indices) result = triton_gather(src, axis, indices) - assert torch.all(ref == result) + torch.testing.assert_close(result, ref) From 4189c09a832a36e62d9d61a85ba00a90d056c84f Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 26 Nov 2024 22:18:44 -0800 Subject: [PATCH 12/38] clarify gather impl comment --- lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index b315e0d646ae..0efe8f764d16 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -66,6 +66,10 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, SmallVector idxValues = unpackLLElements(loc, adaptor.getIndices(), rewriter); + // Apply the layout of the destination tensor to obtain the indices of the + // column to gather along, then for each column, replace the index along the + // gather axis with the appropriate index value. + // // I = LL(pid) // idx = indices[I] // I_gather = [I[d] if d != axis else idx for d in range(len(I))] From 9fee613041013e023755c51173ceda4d7a13297c Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 26 Nov 2024 22:26:52 -0800 Subject: [PATCH 13/38] add gather lowering test with dot layout --- test/Conversion/tritongpu_to_llvm.mlir | 38 ++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 5ac689df315b..253590828d1d 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1928,3 +1928,41 @@ tt.func @gather_in_shared(%arg0: tensor<16x2xi32, #blocked1>, %arg1: tensor<8x4x } } + +// ----- + +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [1, 1]}> +#dot = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + +tt.func @gather_in_shared_dot_input(%arg0: tensor<16x2xi32, #blocked>, %arg1: tensor<8x4xf32, #dot>) { + // CHECK-LABEL: gather_in_shared_dot_input + + // CHECK: [[S0:%.*]] = llvm.extractvalue %arg1[0] + // CHECK: [[S1:%.*]] = llvm.extractvalue %arg1[1] + // CHECK: [[S2:%.*]] = llvm.extractvalue %arg1[2] + // CHECK: [[S3:%.*]] = llvm.extractvalue %arg1[3] + + // CHECK: [[SMEM_BASE:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM:%.*]] = llvm.getelementptr [[SMEM_BASE]] + // CHECK: store [[S0]] + // CHECK: store [[S1]] + // CHECK: store [[S2]] + // CHECK: store [[S3]] + // CHECK-NEXT: nvvm.barrier0 + + // CHECK: [[I0:%.*]] = llvm.extractvalue %arg0[0] + + // CHECK: [[IDX:%.*]] = llvm.add {{.*}}, [[I0]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM]][[[IDX]]] + // CHECK-NEXT: [[OUT0:%.*]] = llvm.load [[PTR]] + + // CHECK: insertvalue [[OUT0]], {{.*}}[0] + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<8x4xf32, #dot>, tensor<16x2xi32, #blocked>) -> tensor<16x2xf32, #blocked> + tt.return +} + +} From a4d9a2e33379efd9f023195999f5147123e76204 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 27 Nov 2024 09:04:33 -0800 Subject: [PATCH 14/38] warp local --- include/triton/Analysis/Utility.h | 2 ++ lib/Analysis/Utility.cpp | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index e06db19c6d5a..ec6d57cf9112 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -161,6 +161,8 @@ class GatherLoweringHelper { // Get the shared memory scratch size required by this op. unsigned getScratchSizeInBytes(); + // Determine if the gather can be performed completely within a warp. + bool isWarpLocal(); private: triton::GatherOp gatherOp; diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index e6d921464bc9..37c3e4e09468 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -419,6 +419,22 @@ unsigned GatherLoweringHelper::getScratchSizeInBytes() { ceil(srcType.getElementTypeBitWidth(), 8); } +bool GatherLoweringHelper::isWarpLocal() { + // The gather is warp-local if for each column along the gather axis in the + // source tensor, all the elements are owned by the same warp. + RankedTensorType srcType = gatherOp.getSrc().getType(); + std::optional maybeLayout = + toLinearLayout(srcType.getShape(), srcType.getEncoding()); + // FIXME: If an unsupported layout was encountered, assume the gather is not + // warp-local. + if (!maybeLayout) + return false; + LinearLayout layout = std::move(*maybeLayout); + + + return false; +} + unsigned getNumScratchElements(ArrayRef shape) { if (shape.empty()) return 0; From 5b47ed685bee8fb2e3f24e185e43d0fa2afd8865 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 27 Nov 2024 09:08:09 -0800 Subject: [PATCH 15/38] remove membar todo --- lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 0efe8f764d16..5ab81eff819c 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -59,7 +59,6 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, } // Synchronize the whole CTA. - // TODO(jeff): Should we teach Membar that gather synchronizes? barrier(); // Grab the index values owned by this thread. From b6700689245a982fe51d3602f37d31fa0ecd0ff1 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 27 Nov 2024 09:17:46 -0800 Subject: [PATCH 16/38] require other dims to match source dims --- lib/Dialect/Triton/IR/Ops.cpp | 5 ++--- python/test/unit/language/test_core.py | 4 ++-- python/triton/language/semantic.py | 3 +-- test/Conversion/allocate_shared_memory.mlir | 4 ++-- test/Conversion/tritongpu_to_llvm.mlir | 8 ++++---- test/Triton/invalid.mlir | 2 +- test/Triton/ops.mlir | 8 ++++---- test/TritonGPU/combine.mlir | 10 +++++----- 8 files changed, 21 insertions(+), 23 deletions(-) diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 7a77d57065b8..a5d8dc3646ee 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1097,10 +1097,9 @@ LogicalResult GatherOp::verify() { for (int dim = 0; dim < indicesTy.getRank(); ++dim) { if (dim == getAxis()) continue; - if (indicesTy.getShape()[dim] > srcTy.getShape()[dim]) { + if (indicesTy.getShape()[dim] != srcTy.getShape()[dim]) { return emitOpError("indices dimension ") - << dim - << " cannot be greater than the corresponding input dimension"; + << dim << " must match the corresponding input dimension"; } } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index e270faeaa74b..e98b6d69fc28 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6090,8 +6090,8 @@ def kernel(In, Out, # @pytest.mark.parametrize("src_shape, indices_shape, axis", [ - ([4, 4], [8, 2], 0), - ([128, 64], [256, 32], 0), + ([4, 4], [8, 4], 0), + ([128, 64], [256, 64], 0), ([128, 64], [128, 128], 1), ]) def test_gather(src_shape, indices_shape, axis): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 390e2b5d4247..60890ac596eb 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1695,8 +1695,7 @@ def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> for d in range(rank): if d == axis: continue - assert index.type.shape[d] <= src.type.shape[ - d], f"index dim {axis} cannot be greater than the corresponding source dim" + assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim" gather = builder.create_gather(src.handle, index.handle, axis) return wrap_tensor(gather, src.type.scalar, index.type.shape) diff --git a/test/Conversion/allocate_shared_memory.mlir b/test/Conversion/allocate_shared_memory.mlir index f2d122283b38..6b378295a536 100644 --- a/test/Conversion/allocate_shared_memory.mlir +++ b/test/Conversion/allocate_shared_memory.mlir @@ -6,9 +6,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: @gather_op // TODO(jeff): Optimize the lowering to reduce shared memory usage. -tt.func @gather_op(%arg0: tensor<1024x4xi32>, %arg1: tensor<128x256xf32>) { +tt.func @gather_op(%arg0: tensor<1024x256xi32>, %arg1: tensor<128x256xf32>) { // CHECK-NEXT: allocation.offset = 0 : i32 - %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<128x256xf32>, tensor<1024x4xi32>) -> tensor<1024x4xf32> + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<128x256xf32>, tensor<1024x256xi32>) -> tensor<1024x256xf32> tt.return } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 253590828d1d..a6b9fb78c679 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1905,7 +1905,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { -tt.func @gather_in_shared(%arg0: tensor<16x2xi32, #blocked1>, %arg1: tensor<8x4xf32, #blocked>) { +tt.func @gather_in_shared(%arg0: tensor<16x4xi32, #blocked1>, %arg1: tensor<8x4xf32, #blocked>) { // CHECK-LABEL: gather_in_shared // CHECK: [[S0:%.*]] = llvm.extractvalue %arg1[0] @@ -1923,7 +1923,7 @@ tt.func @gather_in_shared(%arg0: tensor<16x2xi32, #blocked1>, %arg1: tensor<8x4x // CHECK: insertvalue [[OUT0]], {{.*}}[0] - %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<8x4xf32, #blocked>, tensor<16x2xi32, #blocked1>) -> tensor<16x2xf32, #blocked1> + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<8x4xf32, #blocked>, tensor<16x4xi32, #blocked1>) -> tensor<16x4xf32, #blocked1> tt.return } @@ -1937,7 +1937,7 @@ tt.func @gather_in_shared(%arg0: tensor<16x2xi32, #blocked1>, %arg1: tensor<8x4x module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { -tt.func @gather_in_shared_dot_input(%arg0: tensor<16x2xi32, #blocked>, %arg1: tensor<8x4xf32, #dot>) { +tt.func @gather_in_shared_dot_input(%arg0: tensor<16x4xi32, #blocked>, %arg1: tensor<8x4xf32, #dot>) { // CHECK-LABEL: gather_in_shared_dot_input // CHECK: [[S0:%.*]] = llvm.extractvalue %arg1[0] @@ -1961,7 +1961,7 @@ tt.func @gather_in_shared_dot_input(%arg0: tensor<16x2xi32, #blocked>, %arg1: te // CHECK: insertvalue [[OUT0]], {{.*}}[0] - %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<8x4xf32, #dot>, tensor<16x2xi32, #blocked>) -> tensor<16x2xf32, #blocked> + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<8x4xf32, #dot>, tensor<16x4xi32, #blocked>) -> tensor<16x4xf32, #blocked> tt.return } diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index d58f95441a3b..d99c42d1ba72 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -392,7 +392,7 @@ tt.func @gather_op(%arg0: tensor<128xf32>, %arg1: tensor<512x4xi32>) { // ----- tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x32xi32>) { - // expected-error @below {{indices dimension 1 cannot be greater than the corresponding input dimension}} + // expected-error @below {{indices dimension 1 must match the corresponding input dimension}} %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x32xi32>) -> tensor<512x32xf32> tt.return } diff --git a/test/Triton/ops.mlir b/test/Triton/ops.mlir index 77847805bcc1..eb7a63c340a7 100644 --- a/test/Triton/ops.mlir +++ b/test/Triton/ops.mlir @@ -252,8 +252,8 @@ tt.func @experimental_descriptor_load(%0: !tt.tensordesc>) { } // CHECK-LABEL: @gather_op -tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) -> tensor<512x4xf32> { - // CHECK-NEXT: %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> - %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> - tt.return %0 : tensor<512x4xf32> +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x16xi32>) -> tensor<512x16xf32> { + // CHECK-NEXT: %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x16xi32>) -> tensor<512x16xf32> + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x16xi32>) -> tensor<512x16xf32> + tt.return %0 : tensor<512x16xf32> } diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index f01ef871ee71..ad6011faa9ab 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2714,14 +2714,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // TODO(jeff): Support indices -> dst layout propagation to remove both // layout conversions here. -tt.func @propagate_layout_gather(%arg0: tensor<1024x4xi32, #blocked>, %arg1: tensor<128x256xf32, #blocked>) -> tensor<1024x4xf32, #blocked2> { +tt.func @propagate_layout_gather(%arg0: tensor<1024x256xi32, #blocked>, %arg1: tensor<128x256xf32, #blocked>) -> tensor<1024x256xf32, #blocked2> { // CHECK-LABEL: propagate_layout_gather // XCHECK-NOT: convert_layout - %0 = triton_gpu.convert_layout %arg0 : tensor<1024x4xi32, #blocked> -> tensor<1024x4xi32, #blocked1> - %1 = tt.gather %arg1[%0] {axis = 0 : i32} : (tensor<128x256xf32, #blocked>, tensor<1024x4xi32, #blocked1>) -> tensor<1024x4xf32, #blocked1> - %2 = triton_gpu.convert_layout %1 : tensor<1024x4xf32, #blocked1> -> tensor<1024x4xf32, #blocked2> - tt.return %2 : tensor<1024x4xf32, #blocked2> + %0 = triton_gpu.convert_layout %arg0 : tensor<1024x256xi32, #blocked> -> tensor<1024x256xi32, #blocked1> + %1 = tt.gather %arg1[%0] {axis = 0 : i32} : (tensor<128x256xf32, #blocked>, tensor<1024x256xi32, #blocked1>) -> tensor<1024x256xf32, #blocked1> + %2 = triton_gpu.convert_layout %1 : tensor<1024x256xf32, #blocked1> -> tensor<1024x256xf32, #blocked2> + tt.return %2 : tensor<1024x256xf32, #blocked2> } } From dad8ca5bfe037dafa7e64aeaff6716cfe3f33f49 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 27 Nov 2024 09:20:08 -0800 Subject: [PATCH 17/38] tol=0 --- python/test/unit/language/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index e98b6d69fc28..82d2b6b36465 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6126,4 +6126,4 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda') ref = torch.gather(src, axis, indices) result = triton_gather(src, axis, indices) - torch.testing.assert_close(result, ref) + torch.testing.assert_close(result, ref, rtol=0, atol=0) From e023fecfa69257303046c42e9ffb549e7a9147d6 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 27 Nov 2024 09:39:11 -0800 Subject: [PATCH 18/38] sublayoutIsZero --- lib/Analysis/Utility.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 37c3e4e09468..02f472909db1 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -412,8 +412,13 @@ GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp) : gatherOp(gatherOp) {} unsigned GatherLoweringHelper::getScratchSizeInBytes() { - // For now, lower the gather op by writing the source tensor to shared memory. - // TODO(jeff): Leverage locality to avoid using scratch space when possible. + // If the gather is warp-local, no scratch space is needed. + if (isWarpLocal()) + return 0; + + // Otherwise, performing the gather will require scratch space to communicate + // the source tensor across threads. For now, assume the whole source tensor + // is written back to shared memory. RankedTensorType srcType = gatherOp.getSrc().getType(); return product(srcType.getShape()) * ceil(srcType.getElementTypeBitWidth(), 8); @@ -431,8 +436,14 @@ bool GatherLoweringHelper::isWarpLocal() { return false; LinearLayout layout = std::move(*maybeLayout); - - return false; + // If the sublayout `(block, warp) -> dimN` is zero, then changing the warp or + // block does not alter how elements are mapped to `dimN`. + Builder b(gatherOp.getContext()); + StringAttr block = b.getStringAttr("block"); + StringAttr warp = b.getStringAttr("warp"); + StringAttr gatherDim = + b.getStringAttr("dim" + std::to_string(gatherOp.getAxis())); + return layout.sublayoutIsZero({block, warp}, gatherDim); } unsigned getNumScratchElements(ArrayRef shape) { From c0f3702aa6f567f7c4c155ded63d920592aa156b Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 27 Nov 2024 10:14:58 -0800 Subject: [PATCH 19/38] sublayout check --- lib/Analysis/Utility.cpp | 33 +++++++++++++++---- .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 27 ++++++++++++++- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 02f472909db1..3694698ffa8f 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -426,24 +426,43 @@ unsigned GatherLoweringHelper::getScratchSizeInBytes() { bool GatherLoweringHelper::isWarpLocal() { // The gather is warp-local if for each column along the gather axis in the - // source tensor, all the elements are owned by the same warp. + // source and index tensors, all the elements are owned by the same warp. RankedTensorType srcType = gatherOp.getSrc().getType(); - std::optional maybeLayout = + RankedTensorType idxType = gatherOp.getIndices().getType(); + std::optional srcLayout = toLinearLayout(srcType.getShape(), srcType.getEncoding()); + std::optional idxLayout = + toLinearLayout(idxType.getShape(), idxType.getEncoding()); + // FIXME: If an unsupported layout was encountered, assume the gather is not // warp-local. - if (!maybeLayout) + if (!srcLayout || !idxLayout) return false; - LinearLayout layout = std::move(*maybeLayout); - // If the sublayout `(block, warp) -> dimN` is zero, then changing the warp or - // block does not alter how elements are mapped to `dimN`. Builder b(gatherOp.getContext()); StringAttr block = b.getStringAttr("block"); StringAttr warp = b.getStringAttr("warp"); StringAttr gatherDim = b.getStringAttr("dim" + std::to_string(gatherOp.getAxis())); - return layout.sublayoutIsZero({block, warp}, gatherDim); + + // If the sublayout `(block, warp) -> dimN` is zero, then changing the warp or + // block does not alter how elements are mapped to `dimN`. + if (!srcLayout->sublayoutIsZero({block, warp}, gatherDim) || + !idxLayout->sublayoutIsZero({block, warp}, gatherDim)) + return false; + + // `dimN` is invariant to the warp, but the `(block, warp)` mapping to all + // other dimensions must be the same for both layouts. If so, then the warp + // that owns a particular index element also owns all the source elements it + // could index into. + SmallVector otherDims; + for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + if (dim != gatherOp.getAxis()){ + otherDims.push_back(b.getStringAttr("dim" + Twine(dim))); + } + } + return srcLayout->sublayout({block, warp}, otherDims) == + idxLayout->sublayout({block, warp}, otherDims); } unsigned getNumScratchElements(ArrayRef shape) { diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 5ab81eff819c..13bb1ed659e9 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -17,12 +17,33 @@ class GatherOpConversion : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override; private: + // Codegen the gather by storing the source tensor into shared memory and then + // gathering directly from shared memory. + void emitGatherInShared(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + // Codegen a warp-local gather by shuffling elements across the warp and + // selecting from them. + void emitWarpLocalGather(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + const TargetInfoBase &targetInfo; }; LogicalResult GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + GatherLoweringHelper helper(op); + // Specialize the lowering based on the source layout. + if (helper.isWarpLocal()) { + emitWarpLocalGather(op, adaptor, rewriter); + } else { + emitGatherInShared(op, adaptor, rewriter); + } + return success(); +} + +void GatherOpConversion::emitGatherInShared( + GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); RankedTensorType srcType = op.getSrc().getType(); @@ -99,7 +120,11 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, Value packed = packLLElements(loc, getTypeConverter(), results, rewriter, dstType); rewriter.replaceOp(op, packed); - return success(); +} + +void GatherOpConversion::emitWarpLocalGather( + GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); } } // namespace From 4f53fd0295e4056cd5925c4043e82d8783b56929 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 27 Nov 2024 16:05:53 -0800 Subject: [PATCH 20/38] struggle --- bin/triton-tensor-layout.cpp | 25 ++++++++ .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 59 +++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/bin/triton-tensor-layout.cpp b/bin/triton-tensor-layout.cpp index 4087ac135022..9045824b2dfd 100644 --- a/bin/triton-tensor-layout.cpp +++ b/bin/triton-tensor-layout.cpp @@ -85,6 +85,31 @@ LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { // Dispatch to the corresponding dialect helper function to print the layout. if (dialectName == "triton_gpu") { os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); + + auto ll = + *cast(tensorType.getEncoding()) + .toLinearLayout(tensorType.getShape()); + + llvm::errs() << ll.toString() << "\n"; + + MLIRContext *ctx = tensorType.getContext(); + auto kWarp = StringAttr::get(ctx, "warp"); + auto kBlock = StringAttr::get(ctx, "block"); + auto kLane = StringAttr::get(ctx, "lane"); + auto kRegister = StringAttr::get(ctx, "register"); + auto kGatherDim = StringAttr::get(ctx, "dim0"); + + auto threadLayout = ll.sublayout({kRegister}, kGatherDim); + + llvm::errs() << threadLayout.toString() << "\n"; + llvm::errs() << threadLayout.getInDimSize(kRegister) << "\n"; + for (unsigned i = 0; i < threadLayout.getInDimSize(kRegister); ++i) { + auto k = threadLayout.apply({ {kRegister, i}}); + for (auto [j, e] : llvm::enumerate(k)) { + llvm::errs() << i << " : " << j << " : " << e.first << " : " << e.second << "\n"; + } + } + return success(); } diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 13bb1ed659e9..da6c0870c6e2 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -1,8 +1,10 @@ #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" using namespace mlir; using namespace mlir::triton; +using namespace mlir::triton::gpu; namespace { class GatherOpConversion : public ConvertOpToLLVMPattern { @@ -124,7 +126,64 @@ void GatherOpConversion::emitGatherInShared( void GatherOpConversion::emitWarpLocalGather( GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); Location loc = op.getLoc(); + RankedTensorType srcType = op.getSrc().getType(); + RankedTensorType idxType = op.getIndices().getType(); + + llvm::errs() << getLayoutStr(srcType, false) << "\n"; + llvm::errs() << getLayoutStr(idxType, false) << "\n"; + + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); + StringAttr kGatherDim = rewriter.getStringAttr("dim" + Twine(op.getAxis())); + + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + + // For a warp-local gather, a couple things are true: + // - Each warp owns 2^N columns of the source tensor along the gather axis + // and the same columns of the index tensor, which may be shorter or longer + // than those in the source tensor. + // - Columns may be owned by multiple warps if the layout is oversubscribed. + // - In a particular column, each thread owns at least one element of the + // source tensor and at least one element of the index tensor. + + // Organize the source and index values into columns. + SmallVector otherDims; + for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + if (dim != op.getAxis()) { + otherDims.push_back(str_attr("dim" + Twine(dim))); + } + } + + LinearLayout srcLayout = + *toLinearLayout(srcType.getShape(), srcType.getEncoding()); + LinearLayout idxLayout = + *toLinearLayout(idxType.getShape(), idxType.getEncoding()); + + LinearLayout srcColLayout = srcLayout.sublayout({kRegister}, otherDims); + LinearLayout idxColLayout = idxLayout.sublayout({kRegister}, otherDims); + LinearLayout srcThreadLayout = srcLayout.sublayout({kRegister}, kGatherDim); + LinearLayout idxThreadLayout = idxLayout.sublayout({kRegister}, kGatherDim); + + // Sanity check the layouts. + assert(srcColLayout.getInDimSize(kRegister) == srcValues.size()); + assert(idxColLayout.getInDimSize(kRegister) == idxValues.size()); + assert(srcThreadLayout.getInDimSize(kRegister) == srcValues.size()); + assert(idxThreadLayout.getInDimSize(kRegister) == idxValues.size()); + + SmallVector> srcValuesCol, idxValuesCol; + for (auto [i, srcVal] : llvm::enumerate(srcValues)) { + SmallVector> colIdx = + srcColLayout.apply({{kRegister, i}}); + } + + SmallVector tmpResults(idxValues.size(), f32_val(0.0)); + rewriter.replaceOp(op, packLLElements(loc, getTypeConverter(), tmpResults, + rewriter, op.getType())); } } // namespace From 64630e647d901f5c839cfddc3c1fdb587f689927 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 27 Nov 2024 17:11:54 -0800 Subject: [PATCH 21/38] merge --- test/Conversion/allocate_shared_memory.mlir | 4 +-- test/Conversion/tritongpu_to_llvm.mlir | 18 +++++------ test/Triton/invalid.mlir | 6 ++-- test/TritonGPU/coalesce-async-copy.mlir | 36 ++++++++++----------- test/TritonGPU/combine.mlir | 12 +++---- 5 files changed, 38 insertions(+), 38 deletions(-) diff --git a/test/Conversion/allocate_shared_memory.mlir b/test/Conversion/allocate_shared_memory.mlir index 6b378295a536..345714f5b2b3 100644 --- a/test/Conversion/allocate_shared_memory.mlir +++ b/test/Conversion/allocate_shared_memory.mlir @@ -1,8 +1,8 @@ // RUN: triton-opt %s --allocate-shared-memory | FileCheck %s // CHECK-LABEL: module -// CHECK-SAME: triton_gpu.shared = 131072 : i32 -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +// CHECK-SAME: ttg.shared = 131072 : i32 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @gather_op // TODO(jeff): Optimize the lowering to reduce shared memory usage. diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 0a4f60cac5e4..17c17a0bee14 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1901,8 +1901,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- // CHECK: inline_asm_pack -#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // check specifically for the case where asm has two results, pack > 1, and the result bitwidth is < 32 tt.func public @inline_asm_pack(%80: tensor<64x64xi8, #blocked>) attributes {noinline = false} { // CHECK: llvm.inline_asm asm_dialect {{.*}} (vector<4xi8>) -> !llvm.struct<(vector<2xbf16>, vector<2xbf16>, vector<2xbf16>, vector<2xbf16>)> @@ -1913,10 +1913,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @gather_in_shared(%arg0: tensor<16x4xi32, #blocked1>, %arg1: tensor<8x4xf32, #blocked>) { // CHECK-LABEL: gather_in_shared @@ -1944,11 +1944,11 @@ tt.func @gather_in_shared(%arg0: tensor<16x4xi32, #blocked1>, %arg1: tensor<8x4x // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [1, 1]}> -#dot = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [1, 1]}> +#dot = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @gather_in_shared_dot_input(%arg0: tensor<16x4xi32, #blocked>, %arg1: tensor<8x4xf32, #dot>) { // CHECK-LABEL: gather_in_shared_dot_input diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index 84036a0271c3..ce660d4228a7 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -363,9 +363,9 @@ tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32, #blocked>) { // expected-error @below {{indices and output encodings must match}} %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32, #blocked>) -> tensor<512x4xf32, #blocked1> diff --git a/test/TritonGPU/coalesce-async-copy.mlir b/test/TritonGPU/coalesce-async-copy.mlir index 4707ddaca9cb..0190238da135 100644 --- a/test/TritonGPU/coalesce-async-copy.mlir +++ b/test/TritonGPU/coalesce-async-copy.mlir @@ -1,35 +1,35 @@ // RUN: triton-opt %s -split-input-file -tritongpu-coalesce-async-copy | FileCheck %s -// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> -// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi1, #[[NEW_BLOCKED]]> -// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi8, #[[NEW_BLOCKED]]> -// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi1, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi8, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @async_copy_i8(%input: tensor<64x16x!tt.ptr, #blocked>, - %view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>, + %view: !ttg.memdesc<64x16xi8, #shared, #ttg.shared_memory, mutable>, %mask: tensor<64x16xi1, #blocked>, %other: tensor<64x16xi8, #blocked>) { - %token = triton_gpu.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable> + %token = ttg.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #ttg.shared_memory, mutable> tt.return } } // ----- -// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> -// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @async_copy_i8_no_mask_or_other(%input: tensor<64x16x!tt.ptr, #blocked>, - %view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>) { - %token = triton_gpu.async_copy_global_to_local %input, %view : tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable> + %view: !ttg.memdesc<64x16xi8, #shared, #ttg.shared_memory, mutable>) { + %token = ttg.async_copy_global_to_local %input, %view : tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #ttg.shared_memory, mutable> tt.return } } diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 46c56b54361f..a980e19efe62 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2706,11 +2706,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // TODO(jeff): Support indices -> dst layout propagation to remove both // layout conversions here. @@ -2718,9 +2718,9 @@ tt.func @propagate_layout_gather(%arg0: tensor<1024x256xi32, #blocked>, %arg1: t // CHECK-LABEL: propagate_layout_gather // XCHECK-NOT: convert_layout - %0 = triton_gpu.convert_layout %arg0 : tensor<1024x256xi32, #blocked> -> tensor<1024x256xi32, #blocked1> + %0 = ttg.convert_layout %arg0 : tensor<1024x256xi32, #blocked> -> tensor<1024x256xi32, #blocked1> %1 = tt.gather %arg1[%0] {axis = 0 : i32} : (tensor<128x256xf32, #blocked>, tensor<1024x256xi32, #blocked1>) -> tensor<1024x256xf32, #blocked1> - %2 = triton_gpu.convert_layout %1 : tensor<1024x256xf32, #blocked1> -> tensor<1024x256xf32, #blocked2> + %2 = ttg.convert_layout %1 : tensor<1024x256xf32, #blocked1> -> tensor<1024x256xf32, #blocked2> tt.return %2 : tensor<1024x256xf32, #blocked2> } From c8405252af618c72e3fd501812e90433b09c9b11 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 3 Dec 2024 14:30:14 -0800 Subject: [PATCH 22/38] restore layout helper and remove unused code --- bin/triton-tensor-layout.cpp | 25 ------------------- .../TritonGPU/IR/LinearLayoutConversions.cpp | 2 -- 2 files changed, 27 deletions(-) diff --git a/bin/triton-tensor-layout.cpp b/bin/triton-tensor-layout.cpp index ccbaa2fc9e43..7c635dafaa3d 100644 --- a/bin/triton-tensor-layout.cpp +++ b/bin/triton-tensor-layout.cpp @@ -85,31 +85,6 @@ LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { // Dispatch to the corresponding dialect helper function to print the layout. if (dialectName == "ttg") { os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); - - auto ll = - *cast(tensorType.getEncoding()) - .toLinearLayout(tensorType.getShape()); - - llvm::errs() << ll.toString() << "\n"; - - MLIRContext *ctx = tensorType.getContext(); - auto kWarp = StringAttr::get(ctx, "warp"); - auto kBlock = StringAttr::get(ctx, "block"); - auto kLane = StringAttr::get(ctx, "lane"); - auto kRegister = StringAttr::get(ctx, "register"); - auto kGatherDim = StringAttr::get(ctx, "dim0"); - - auto threadLayout = ll.sublayout({kRegister}, kGatherDim); - - llvm::errs() << threadLayout.toString() << "\n"; - llvm::errs() << threadLayout.getInDimSize(kRegister) << "\n"; - for (unsigned i = 0; i < threadLayout.getInDimSize(kRegister); ++i) { - auto k = threadLayout.apply({ {kRegister, i}}); - for (auto [j, e] : llvm::enumerate(k)) { - llvm::errs() << i << " : " << j << " : " << e.first << " : " << e.second << "\n"; - } - } - return success(); } diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 32152190b6e6..e03fa9f7932b 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -525,9 +525,7 @@ std::optional BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { assert(shape.size() == getOrder().size()); - int rank = shape.size(); MLIRContext *ctx = getContext(); - SmallVector outDimNames = standardOutDimNames(ctx, rank); const auto &order = getOrder(); LinearLayout ctaLayout = From 6759cc7bef1111599c503a4b139586e01419fa57 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 4 Dec 2024 13:46:46 -0800 Subject: [PATCH 23/38] emitHardwareTuple --- .../Conversion/TritonGPUToLLVM/Utility.h | 6 ++++ lib/Conversion/TritonGPUToLLVM/Utility.cpp | 35 +++++++++++-------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 244bc6181aa1..a1c37efb52f1 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1123,6 +1123,12 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, return idx; } +// Emit code to compute the (blockId, warpId, laneId) for the current thread. +std::tuple +emitHardwareTuple(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, bool withCTAOffset, + unsigned threadsPerWarp); + // Emit indices calculation within each ConversionPattern, and returns a // [elemsPerThread X rank] index matrix. // diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 49f05a758e42..a310cdba5f4f 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -99,6 +99,20 @@ applyLinearLayout(Location loc, RewriterBase &rewriter, return outIndices; } +std::tuple emitHardwareTuple(Location loc, + RewriterBase &rewriter, + const TargetInfoBase &target, + bool withCTAOffset, + unsigned threadsPerWarpCst) { + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(threadsPerWarpCst); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + Value blockId = + withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); + return {blockId, warpId, laneId}; +} + SmallVector> emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, Attribute layout, RankedTensorType type, bool withCTAOffset) { @@ -116,12 +130,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, StringAttr kWarp = str_attr("warp"); StringAttr kBlock = str_attr("block"); - Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(ll->getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); - Value blockId = - withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); + auto [blockId, warpId, laneId] = emitHardwareTuple( + loc, rewriter, target, withCTAOffset, ll->getInDimSize(kLane)); unsigned rank = shape.size(); SmallVector> ret; // Linear layout function is split in two parts below: @@ -214,10 +224,9 @@ bool emitTransferBetweenRegistersAndShared( std::min(regToSharedLayout->getNumConsecutiveInOut(), maxVecElems.value_or(std::numeric_limits::max())); - Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(regToSharedLayout->getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); + auto [blockId, warpId, laneId] = + emitHardwareTuple(loc, rewriter, target, /*withCTAOffset=*/false, + regToSharedLayout->getInDimSize(kLane)); int numElems = regToSharedLayout->getInDimSize(kRegister); auto vecTy = vec_ty(elemLlvmTy, vecElems); @@ -625,10 +634,8 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, auto instrShape = mmaLayout.getInstrShape(); SmallVector mmaColIdx(2); SmallVector mmaRowIdx(2); - Value threadId = getThreadId(rewriter, loc); - Value warpSize = i32_val(32); - Value laneId = urem(threadId, warpSize); - Value warpId = udiv(threadId, warpSize); + auto [blockId, warpId, laneId] = emitHardwareTuple( + loc, rewriter, targetInfo, /*withCTAOffset=*/false, 32); // TODO: fix the bug in MMAEncodingAttr document SmallVector multiDimWarpId(2); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); From 0628835177a9f797579e3e25ad2d9fe7025dc9c6 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 4 Dec 2024 13:46:56 -0800 Subject: [PATCH 24/38] almost there --- lib/Analysis/Utility.cpp | 68 +++++++++--- .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 104 ++++++++++++------ test/Conversion/allocate_shared_memory.mlir | 6 +- 3 files changed, 131 insertions(+), 47 deletions(-) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 3694698ffa8f..81ee197f7304 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -440,29 +440,71 @@ bool GatherLoweringHelper::isWarpLocal() { return false; Builder b(gatherOp.getContext()); - StringAttr block = b.getStringAttr("block"); - StringAttr warp = b.getStringAttr("warp"); - StringAttr gatherDim = + StringAttr kBlock = b.getStringAttr("block"); + StringAttr kWarp = b.getStringAttr("warp"); + StringAttr kLane = b.getStringAttr("lane"); + StringAttr kGatherDim = b.getStringAttr("dim" + std::to_string(gatherOp.getAxis())); - // If the sublayout `(block, warp) -> dimN` is zero, then changing the warp or - // block does not alter how elements are mapped to `dimN`. - if (!srcLayout->sublayoutIsZero({block, warp}, gatherDim) || - !idxLayout->sublayoutIsZero({block, warp}, gatherDim)) + // The tensor layouts must be distributed layouts, where the basis matrix is a + // subpermutation matrix plus some zero rows for broadcasting. + // FIXME(jeff): Check this invariant somehow. + // + // We want to know if all elements of a column along the gather axis are + // mapped to the same set of warps, which means the gather can be performed + // entirely within the warp. We need to query + // + // srcLayout.inverse().sublayoutIsZero({kGatherDim}, {kBlock, kWarp}) + // + // But due to broadcasting, the matrix might not be invertible. But since the + // matrix is a subpermutation matrix, we can instead query + // + // srcLayout.sublayoutIsZero({kBlock, kWarp}, {kGatherDim}) + // + // Which implies that changing the warp will not change the gather dimension. + // And since there is no swizzling, this applies to all warps. + if (!srcLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim) || + !idxLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim)) return false; + SmallVector otherDims; + for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + if (dim != gatherOp.getAxis()) { + otherDims.push_back(b.getStringAttr("dim" + Twine(dim))); + } + } + // `dimN` is invariant to the warp, but the `(block, warp)` mapping to all // other dimensions must be the same for both layouts. If so, then the warp // that owns a particular index element also owns all the source elements it // could index into. - SmallVector otherDims; - for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { - if (dim != gatherOp.getAxis()){ - otherDims.push_back(b.getStringAttr("dim" + Twine(dim))); + if (srcLayout->sublayout({kBlock, kWarp}, otherDims) != + idxLayout->sublayout({kBlock, kWarp}, otherDims)) + return false; + + // The two constraints above ensure that data-movement to perform the gather + // operation are contained within a warp. The subsequent constraints simplify + // codegen. + + // Require that for any given gather column, the threads mapped to the column + // in the index and source tensors are the same. This means we don't need to + // xor shuffle across threads before emitting index shuffles; we push warp + // shuffling to layout conversions. + if (srcLayout->sublayout(kLane, otherDims) != + idxLayout->sublayout(kLane, otherDims)) + return false; + + // Broadcasted source layouts are not supported at the moment, because we + // rely on the source layout being invertible. + for (auto &bases : srcLayout->getBases()) { + auto isZero = [](ArrayRef base) { + return llvm::all_of(base, [](int32_t b) { return b == 0; }); + }; + if (llvm::any_of(bases.second, isZero)) { + return false; } } - return srcLayout->sublayout({block, warp}, otherDims) == - idxLayout->sublayout({block, warp}, otherDims); + return true; } unsigned getNumScratchElements(ArrayRef shape) { diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index da6c0870c6e2..14302ad5ad48 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -124,6 +124,15 @@ void GatherOpConversion::emitGatherInShared( rewriter.replaceOp(op, packed); } +static LinearLayout +identityND(ArrayRef> dimsAndSizes) { + auto ret = LinearLayout::empty(); + for (auto [dim, size] : dimsAndSizes) { + ret *= LinearLayout::identity1D(size, dim, dim); + } + return ret; +} + void GatherOpConversion::emitWarpLocalGather( GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); @@ -134,51 +143,82 @@ void GatherOpConversion::emitWarpLocalGather( llvm::errs() << getLayoutStr(srcType, false) << "\n"; llvm::errs() << getLayoutStr(idxType, false) << "\n"; + // Layout dimension names. + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); StringAttr kLane = str_attr("lane"); StringAttr kRegister = str_attr("register"); StringAttr kGatherDim = rewriter.getStringAttr("dim" + Twine(op.getAxis())); - - SmallVector srcValues = - unpackLLElements(loc, adaptor.getSrc(), rewriter); - SmallVector idxValues = - unpackLLElements(loc, adaptor.getIndices(), rewriter); - - // For a warp-local gather, a couple things are true: - // - Each warp owns 2^N columns of the source tensor along the gather axis - // and the same columns of the index tensor, which may be shorter or longer - // than those in the source tensor. - // - Columns may be owned by multiple warps if the layout is oversubscribed. - // - In a particular column, each thread owns at least one element of the - // source tensor and at least one element of the index tensor. - - // Organize the source and index values into columns. - SmallVector otherDims; + SmallVector allDims, otherDims; for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + allDims.push_back(str_attr("dim" + Twine(dim))); if (dim != op.getAxis()) { - otherDims.push_back(str_attr("dim" + Twine(dim))); + otherDims.push_back(allDims.back()); } } + // Compute the src and idx layouts. LinearLayout srcLayout = *toLinearLayout(srcType.getShape(), srcType.getEncoding()); LinearLayout idxLayout = *toLinearLayout(idxType.getShape(), idxType.getEncoding()); - LinearLayout srcColLayout = srcLayout.sublayout({kRegister}, otherDims); - LinearLayout idxColLayout = idxLayout.sublayout({kRegister}, otherDims); - LinearLayout srcThreadLayout = srcLayout.sublayout({kRegister}, kGatherDim); - LinearLayout idxThreadLayout = idxLayout.sublayout({kRegister}, kGatherDim); - - // Sanity check the layouts. - assert(srcColLayout.getInDimSize(kRegister) == srcValues.size()); - assert(idxColLayout.getInDimSize(kRegister) == idxValues.size()); - assert(srcThreadLayout.getInDimSize(kRegister) == srcValues.size()); - assert(idxThreadLayout.getInDimSize(kRegister) == idxValues.size()); - - SmallVector> srcValuesCol, idxValuesCol; - for (auto [i, srcVal] : llvm::enumerate(srcValues)) { - SmallVector> colIdx = - srcColLayout.apply({{kRegister, i}}); + // Let `ll_src` be the source layout and `ll_idx` be the index layout. + // Let `src_col` be a tuple of dimensions except the gather dimension, + // representing a specific column in the source tensor. Likewise for + // `idx_col`. Let `src_idx` be the index into gather dimension in the source + // tensor. + // + // `(src_lane, src_reg) = ll_src^-1(src_col, src_idx)`, where `src_lane` is + // the thread that contains the required element and `src_reg` is the register + // within that thread. + // + // Because `ll_src(block=0, warp=0, lane=0)[otherDims] == + // idx_src(0, 0, 0)[otherDims]`, we know given any `idx_reg` (element in the + // index tensor) the thread will need to read from the same column in the + // source tensor. + // + // Thus, we can obtain + // + // (src_lane, src_reg) = (ll_src^-1)( + // ll_idx(black, warp, lane, idx_reg)[otherDims], + // idxValues[idx_reg] + // ) + // + // And the mapping will be the correct for each thread. + // + // Given `src_reg \in [0, N)`, we just need to emit N index shuffles for each + // `idx_reg` (the number of index shuffles is quadratic!) and `arith.select` + // using `src_reg` to get the right one. + + // Fully invert the source layout. We know it is invertible because + // `isWarpLocal` checked this. + SmallVector> srcDimsAndSizes; + for (auto [i, size] : llvm::enumerate(srcType.getShape())) { + srcDimsAndSizes.push_back({str_attr("dim" + Twine(i)), size}); + } + auto srcShapeId = identityND(srcDimsAndSizes); + LinearLayout invSrcLayout = srcShapeId.invertAndCompose(srcLayout); + + // Sanity check: the warp must be invariant to the index because otherwise the + // gather would need to read across warps! + assert(invSrcLayout.sublayoutIsZero(kGatherDim, {kBlock, kWarp}) && + "expected a warp-local gather"); + + LinearLayout idxColLayout = + idxLayout.sublayout({kBlock, kWarp, kLane, kRegister}, otherDims); + + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(srcLayout.getInDimSize(kLane)); + assert(srcLayout.getInDimSize(kLane) == idxLayout.getInDimSize(kLane)); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + for (auto [idxReg, idxVal] : llvm::enumerate(idxValues)) { } SmallVector tmpResults(idxValues.size(), f32_val(0.0)); diff --git a/test/Conversion/allocate_shared_memory.mlir b/test/Conversion/allocate_shared_memory.mlir index 345714f5b2b3..f3c2ed703386 100644 --- a/test/Conversion/allocate_shared_memory.mlir +++ b/test/Conversion/allocate_shared_memory.mlir @@ -1,14 +1,16 @@ // RUN: triton-opt %s --allocate-shared-memory | FileCheck %s +#blocked = #ttg.blocked<{sizePerThread = [32, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> + // CHECK-LABEL: module // CHECK-SAME: ttg.shared = 131072 : i32 module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @gather_op // TODO(jeff): Optimize the lowering to reduce shared memory usage. -tt.func @gather_op(%arg0: tensor<1024x256xi32>, %arg1: tensor<128x256xf32>) { +tt.func @gather_op(%arg0: tensor<1024x256xi32, #blocked>, %arg1: tensor<128x256xf32, #blocked>) { // CHECK-NEXT: allocation.offset = 0 : i32 - %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<128x256xf32>, tensor<1024x256xi32>) -> tensor<1024x256xf32> + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<128x256xf32, #blocked>, tensor<1024x256xi32, #blocked>) -> tensor<1024x256xf32, #blocked> tt.return } From 97a5b5d3a02b802e59f16fa5f6337b4763f94214 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 4 Dec 2024 13:46:46 -0800 Subject: [PATCH 25/38] emitHardwareTuple --- .../Conversion/TritonGPUToLLVM/Utility.h | 6 ++++ lib/Conversion/TritonGPUToLLVM/Utility.cpp | 35 +++++++++++-------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 244bc6181aa1..a1c37efb52f1 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1123,6 +1123,12 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, return idx; } +// Emit code to compute the (blockId, warpId, laneId) for the current thread. +std::tuple +emitHardwareTuple(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, bool withCTAOffset, + unsigned threadsPerWarp); + // Emit indices calculation within each ConversionPattern, and returns a // [elemsPerThread X rank] index matrix. // diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 49f05a758e42..a310cdba5f4f 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -99,6 +99,20 @@ applyLinearLayout(Location loc, RewriterBase &rewriter, return outIndices; } +std::tuple emitHardwareTuple(Location loc, + RewriterBase &rewriter, + const TargetInfoBase &target, + bool withCTAOffset, + unsigned threadsPerWarpCst) { + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(threadsPerWarpCst); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + Value blockId = + withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); + return {blockId, warpId, laneId}; +} + SmallVector> emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, Attribute layout, RankedTensorType type, bool withCTAOffset) { @@ -116,12 +130,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, StringAttr kWarp = str_attr("warp"); StringAttr kBlock = str_attr("block"); - Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(ll->getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); - Value blockId = - withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); + auto [blockId, warpId, laneId] = emitHardwareTuple( + loc, rewriter, target, withCTAOffset, ll->getInDimSize(kLane)); unsigned rank = shape.size(); SmallVector> ret; // Linear layout function is split in two parts below: @@ -214,10 +224,9 @@ bool emitTransferBetweenRegistersAndShared( std::min(regToSharedLayout->getNumConsecutiveInOut(), maxVecElems.value_or(std::numeric_limits::max())); - Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(regToSharedLayout->getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); + auto [blockId, warpId, laneId] = + emitHardwareTuple(loc, rewriter, target, /*withCTAOffset=*/false, + regToSharedLayout->getInDimSize(kLane)); int numElems = regToSharedLayout->getInDimSize(kRegister); auto vecTy = vec_ty(elemLlvmTy, vecElems); @@ -625,10 +634,8 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, auto instrShape = mmaLayout.getInstrShape(); SmallVector mmaColIdx(2); SmallVector mmaRowIdx(2); - Value threadId = getThreadId(rewriter, loc); - Value warpSize = i32_val(32); - Value laneId = urem(threadId, warpSize); - Value warpId = udiv(threadId, warpSize); + auto [blockId, warpId, laneId] = emitHardwareTuple( + loc, rewriter, targetInfo, /*withCTAOffset=*/false, 32); // TODO: fix the bug in MMAEncodingAttr document SmallVector multiDimWarpId(2); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); From 50cb9b26a01fa1f5c63cd561dd8361c750710294 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 4 Dec 2024 14:09:50 -0800 Subject: [PATCH 26/38] unused code --- lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 32152190b6e6..9f6bc4d61f51 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -524,10 +524,7 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { std::optional BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { assert(shape.size() == getOrder().size()); - - int rank = shape.size(); MLIRContext *ctx = getContext(); - SmallVector outDimNames = standardOutDimNames(ctx, rank); const auto &order = getOrder(); LinearLayout ctaLayout = From 5c4ea15885845c2fb27fd58f93c59f7a981c68f6 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 4 Dec 2024 15:00:19 -0800 Subject: [PATCH 27/38] maybe it's here --- include/triton/Tools/LinearLayout.h | 9 +++ .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 64 +++++++++++-------- lib/Tools/LinearLayout.cpp | 16 +++++ 3 files changed, 61 insertions(+), 28 deletions(-) diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index cfc4c0d13bbe..a779e416235c 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -342,6 +342,11 @@ class LinearLayout { static LinearLayout identity1D(int32_t size, StringAttr inDim, StringAttr outDim); + // Creates an ND -> ND layout that's the identity function, i.e. + // L(x0, x1, ..., x(N-1)) = (x0, x1, ..., x(N-1)). + static LinearLayout + identityND(ArrayRef> dimsAndSizes); + // Creates a 1D -> 1D layout that maps every input value to 0, i.e. L(x) = 0 // for x in [0, size). static LinearLayout zeros1D(int32_t size, StringAttr inDim, @@ -673,6 +678,10 @@ class LinearLayout { // don't place any guarantees on the choices made by this function. [[nodiscard]] LinearLayout invertAndCompose(const LinearLayout &outer) const; + // Inverts or pseudo-inverts this layout. This computes a layout `L^-1` such + // that `L.compose(L^-1)` is the identity layout. + [[nodiscard]] LinearLayout invert() const; + // For each in-dim, returns a bitmask of the "free variables" in the layout // function. // diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 14302ad5ad48..18e9c67363e6 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -124,15 +124,6 @@ void GatherOpConversion::emitGatherInShared( rewriter.replaceOp(op, packed); } -static LinearLayout -identityND(ArrayRef> dimsAndSizes) { - auto ret = LinearLayout::empty(); - for (auto [dim, size] : dimsAndSizes) { - ret *= LinearLayout::identity1D(size, dim, dim); - } - return ret; -} - void GatherOpConversion::emitWarpLocalGather( GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); @@ -140,9 +131,6 @@ void GatherOpConversion::emitWarpLocalGather( RankedTensorType srcType = op.getSrc().getType(); RankedTensorType idxType = op.getIndices().getType(); - llvm::errs() << getLayoutStr(srcType, false) << "\n"; - llvm::errs() << getLayoutStr(idxType, false) << "\n"; - // Layout dimension names. StringAttr kBlock = str_attr("block"); StringAttr kWarp = str_attr("warp"); @@ -183,27 +171,23 @@ void GatherOpConversion::emitWarpLocalGather( // (src_lane, src_reg) = (ll_src^-1)( // ll_idx(black, warp, lane, idx_reg)[otherDims], // idxValues[idx_reg] - // ) + // )[{"lane", "register"}] // // And the mapping will be the correct for each thread. // // Given `src_reg \in [0, N)`, we just need to emit N index shuffles for each - // `idx_reg` (the number of index shuffles is quadratic!) and `arith.select` + // `idx_reg` (the number of index shuffles is quadratic!) and `llvm.select` // using `src_reg` to get the right one. // Fully invert the source layout. We know it is invertible because - // `isWarpLocal` checked this. - SmallVector> srcDimsAndSizes; - for (auto [i, size] : llvm::enumerate(srcType.getShape())) { - srcDimsAndSizes.push_back({str_attr("dim" + Twine(i)), size}); - } - auto srcShapeId = identityND(srcDimsAndSizes); - LinearLayout invSrcLayout = srcShapeId.invertAndCompose(srcLayout); + // `isWarpLocal` checked this (subpermutation matrix, no broadcasting). + LinearLayout invSrcLayout = srcLayout.invert(); // Sanity check: the warp must be invariant to the index because otherwise the // gather would need to read across warps! assert(invSrcLayout.sublayoutIsZero(kGatherDim, {kBlock, kWarp}) && "expected a warp-local gather"); + invSrcLayout = invSrcLayout.sublayout(allDims, {kLane, kRegister}); LinearLayout idxColLayout = idxLayout.sublayout({kBlock, kWarp, kLane, kRegister}, otherDims); @@ -213,16 +197,40 @@ void GatherOpConversion::emitWarpLocalGather( SmallVector idxValues = unpackLLElements(loc, adaptor.getIndices(), rewriter); - Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(srcLayout.getInDimSize(kLane)); - assert(srcLayout.getInDimSize(kLane) == idxLayout.getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); + auto [blockId, warpId, laneId] = + emitHardwareTuple(loc, rewriter, targetInfo, /*withCTAOffset=*/true, + srcLayout.getInDimSize(kLane)); + + unsigned /*N=*/srcRegsPerThread = srcLayout.getInDimSize(kRegister); + assert(srcRegsPerThread == srcValues.size()); + SmallVector results; for (auto [idxReg, idxVal] : llvm::enumerate(idxValues)) { + SmallVector> column = + applyLinearLayout(loc, rewriter, idxColLayout, + {{kBlock, blockId}, + {kWarp, warpId}, + {kLane, laneId}, + {kRegister, i32_val(idxReg)}}); + assert(column.size() == otherDims.size()); + + column.emplace_back(kGatherDim, idxVal); + SmallVector> srcLaneAndReg = + applyLinearLayout(loc, rewriter, invSrcLayout, column); + + auto [srcLaneName, srcLane] = srcLaneAndReg.back(); + auto [srcRegName, srcReg] = srcLaneAndReg.front(); + assert(srcLaneName == kLane && srcRegName == kRegister); + + assert(!srcValues.empty() && "can't gather from an empty tensor"); + Value result = undef(srcValues.front().getType()); + for (unsigned i = 0; i != srcRegsPerThread; ++i) { + Value value = targetInfo.shuffleIdx(rewriter, loc, srcValues[i], srcLane); + result = select(icmp_eq(i32_val(i), srcReg), value, result); + } + results.push_back(result); } - SmallVector tmpResults(idxValues.size(), f32_val(0.0)); - rewriter.replaceOp(op, packLLElements(loc, getTypeConverter(), tmpResults, + rewriter.replaceOp(op, packLLElements(loc, getTypeConverter(), results, rewriter, op.getType())); } diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 3a81231ac863..c7b1df3abc45 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -364,6 +364,15 @@ LinearLayout::LinearLayout( return LinearLayout({{inDimName, std::move(powersOf2)}}, {outDimName}); } +/*static*/ LinearLayout LinearLayout::identityND( + ArrayRef> dimsAndSizes) { + auto ret = LinearLayout::empty(); + for (auto [dim, size] : dimsAndSizes) { + ret *= LinearLayout::identity1D(size, dim, dim); + } + return ret; +} + /*static*/ LinearLayout LinearLayout::zeros1D(int32_t size, StringAttr inDimName, StringAttr outDimName) { @@ -918,6 +927,13 @@ LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const { return flatComposed.reshapeIns(retInDims).reshapeOuts(retOutDims); } +LinearLayout LinearLayout::invert() const { + SmallVector> dimsAndSizes; + llvm::append_range(dimsAndSizes, outDims); + auto id = identityND(dimsAndSizes); + return id.invertAndCompose(*this); +} + llvm::MapVector LinearLayout::getFreeVariableMasks() const { std::unique_ptr mat = getMatrix(*this); From 3d8401f8a13c6e2474afd90a1abb513d995c81b4 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 4 Dec 2024 15:49:14 -0800 Subject: [PATCH 28/38] add some unit tests --- .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 1 + test/Conversion/gather_to_llvm.mlir | 133 ++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 test/Conversion/gather_to_llvm.mlir diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 18e9c67363e6..3bbd47582edc 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -213,6 +213,7 @@ void GatherOpConversion::emitWarpLocalGather( {kRegister, i32_val(idxReg)}}); assert(column.size() == otherDims.size()); + // Combine the computed column with the data-dependent gather index. column.emplace_back(kGatherDim, idxVal); SmallVector> srcLaneAndReg = applyLinearLayout(loc, rewriter, invSrcLayout, column); diff --git a/test/Conversion/gather_to_llvm.mlir b/test/Conversion/gather_to_llvm.mlir new file mode 100644 index 000000000000..76232bd3a18d --- /dev/null +++ b/test/Conversion/gather_to_llvm.mlir @@ -0,0 +1,133 @@ +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O3 | FileCheck %s + +#trivial_layout = #ttg.linear<{register = [], lane = [[1], [2], [4], [8], [16]], warp = [], block = []}> + +#trivial_layout_wider = #ttg.linear<{register = [[32]], lane = [[1], [2], [4], [8], [16]], warp = [], block = []}> + +#trivial_layout_wider_reg_stride_1 = #ttg.linear<{register = [[1]], lane = [[2], [4], [8], [16], [32]], warp = [], block = []}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + +// Each source element is mapped to a single thread, so we expect one index shuffle. +// CHECK-LABEL: @gather_warp_local_trivial +tt.func private @gather_warp_local_trivial(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32 + // CHECK-NEXT: [[RES_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32xf32, #trivial_layout>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Same as above, but there are two index elements per thread. Expect 2 index shuffles +// with the results packed together. +// CHECK-LABEL: @gather_warp_local_larger_output +tt.func private @gather_warp_local_larger_output(%arg0: tensor<64xi32, #trivial_layout_wider>, %arg1: tensor<32xf32, #trivial_layout>) -> tensor<64xf32, #trivial_layout_wider> { + // CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[IDX0]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32 + // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID0]], i32 31) + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[IDX1]], 31 + // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID1]], i32 31) + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32xf32, #trivial_layout>, tensor<64xi32, #trivial_layout_wider>) -> tensor<64xf32, #trivial_layout_wider> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] + tt.return %0 : tensor<64xf32, #trivial_layout_wider> +} + +// Each thread has 2 elements of the source tensor, strided 32 apart, so we +// expect two index shuffles, using the MSB to select between the two. +// CHECK-LABEL: @gather_warp_local_larger_input +tt.func private @gather_warp_local_larger_input(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<64xf32, #trivial_layout_wider>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 31 + // CHECK-NEXT: [[REGID:%.*]] = and i32 [[IDX]], 32 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64xf32, #trivial_layout_wider>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: [[PICK:%.*]] = icmp eq i32 [[REGID]], 0 + // CHECK-NEXT: [[RES_i32:%.*]] = select i1 [[PICK]], i32 [[RES0]], i32 [[RES1]] + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Same as above, except the RegID comes from the LSB. +// CHECK-LABEL: @gather_warp_local_larger_input +tt.func private @gather_warp_local_larger_input_stride_1(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<64xf32, #trivial_layout_wider_reg_stride_1>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[REGID:%.*]] = and i32 [[IDX]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX]], 1 + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[TMP]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64xf32, #trivial_layout_wider_reg_stride_1>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: [[PICK:%.*]] = icmp eq i32 [[REGID]], 0 + // CHECK-NEXT: [[RES_i32:%.*]] = select i1 [[PICK]], i32 [[RES0]], i32 [[RES1]] + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Keep LLVM from DCE'ing the above functions. Use volatile stores to stop LLVM +// from removing unused function results. +tt.func @anchor(%ptr: !llvm.ptr, + %arg0: tensor<32xi32, #trivial_layout>, + %arg1: tensor<32xf32, #trivial_layout>, + %arg2: tensor<64xi32, #trivial_layout_wider>, + %arg3: tensor<64xf32, #trivial_layout_wider>, + %arg4: tensor<64xf32, #trivial_layout_wider_reg_stride_1>) { + %0 = tt.call @gather_warp_local_trivial(%arg0, %arg1) : (tensor<32xi32, #trivial_layout>, tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + %1 = builtin.unrealized_conversion_cast %0 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %1, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %2 = tt.call @gather_warp_local_larger_output(%arg2, %arg1) : (tensor<64xi32, #trivial_layout_wider>, tensor<32xf32, #trivial_layout>) -> tensor<64xf32, #trivial_layout_wider> + %3 = builtin.unrealized_conversion_cast %2 : tensor<64xf32, #trivial_layout_wider> to !llvm.struct<(f32, f32)> + llvm.store volatile %3, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + + %4 = tt.call @gather_warp_local_larger_input(%arg0, %arg3) : (tensor<32xi32, #trivial_layout>, tensor<64xf32, #trivial_layout_wider>) -> tensor<32xf32, #trivial_layout> + %5 = builtin.unrealized_conversion_cast %4 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %5, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %6 = tt.call @gather_warp_local_larger_input_stride_1(%arg0, %arg4) : (tensor<32xi32, #trivial_layout>, tensor<64xf32, #trivial_layout_wider_reg_stride_1>) -> tensor<32xf32, #trivial_layout> + %7 = builtin.unrealized_conversion_cast %6 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %7, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + tt.return +} + +} From db051eb1b5835bb0e2bea34130f67ae0d78d1688 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 4 Dec 2024 16:58:29 -0800 Subject: [PATCH 29/38] don't redundantly shuffle index --- .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 68 +++++++++++++++++-- test/Conversion/gather_to_llvm.mlir | 17 ++++- 2 files changed, 77 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 3bbd47582edc..684d0f1b39bd 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -175,9 +175,10 @@ void GatherOpConversion::emitWarpLocalGather( // // And the mapping will be the correct for each thread. // - // Given `src_reg \in [0, N)`, we just need to emit N index shuffles for each - // `idx_reg` (the number of index shuffles is quadratic!) and `llvm.select` - // using `src_reg` to get the right one. + // Given `src_reg \in [0, K*N)`, we just need to emit N index shuffles for + // each `idx_reg` (the number of index shuffles is quadratic!) and + // `llvm.select` using `src_reg` to get the right one. `K` is the number of + // elements per column owned by a thread. // Fully invert the source layout. We know it is invertible because // `isWarpLocal` checked this (subpermutation matrix, no broadcasting). @@ -203,6 +204,47 @@ void GatherOpConversion::emitWarpLocalGather( unsigned /*N=*/srcRegsPerThread = srcLayout.getInDimSize(kRegister); assert(srcRegsPerThread == srcValues.size()); + + // Given a index value, we need to know which sources register values it could + // index into. This is invariant to anything other than the register, which we + // checked already. Compute the full reverse map from + // + // idx_reg -> gather_column -> (src_reg0, src_reg1, ...) + // + LinearLayout invertSrcRegMap = invSrcLayout.sublayout(allDims, {kRegister}); + // Remove zero bases in the gather dimension to make the function injective + // (for a given column) over the same codomain. + LinearLayout::BasesT newInvertRegMapBases; + for (auto &[inDim, inDimBases] : invertSrcRegMap.getBases()) { + auto &newInDimBases = newInvertRegMapBases[inDim]; + if (inDim != kGatherDim) { + newInDimBases = inDimBases; + continue; + } + for (auto &basis : inDimBases) { + if (llvm::any_of(basis, [](int32_t val) { return val != 0; })) { + newInDimBases.push_back(basis); + } + } + } + invertSrcRegMap = LinearLayout( + newInvertRegMapBases, llvm::to_vector(invertSrcRegMap.getOutDimNames())); + // We are left with only non-zero bases in the gather dimension, which means + // the number of registers per column is the size of the "gather dimension". + unsigned numRegsPerColumn = invertSrcRegMap.getInDimSize(kGatherDim); + // Get a map from idx_reg to the column it indexes into. + LinearLayout idxRegToCol = idxLayout.sublayout({kRegister}, otherDims); + // Now given `idx_reg`, we can compute the column it belongs to in both src + // and index tensors, then partially apply `invertSrcRegMap` with this to + // obtain a function that outputs the corresponding registers in the src + // tensor in the same column. + + // L(column, i) = L(column, 0) xor L(0, i) + LinearLayout invertSrcRegMapColPart = + invertSrcRegMap.sublayout(otherDims, {kRegister}); + LinearLayout invertSrcRegMapRest = + invertSrcRegMap.sublayout({kGatherDim}, {kRegister}); + SmallVector results; for (auto [idxReg, idxVal] : llvm::enumerate(idxValues)) { SmallVector> column = @@ -223,11 +265,25 @@ void GatherOpConversion::emitWarpLocalGather( assert(srcLaneName == kLane && srcRegName == kRegister); assert(!srcValues.empty() && "can't gather from an empty tensor"); + + // Figure out which src registers we need to index shuffle from. This is + // invariant to anything else. + SmallVector> normalizedColumn = + idxRegToCol.apply({{kRegister, idxReg}}); + int32_t srcBase = + invertSrcRegMapColPart.apply(normalizedColumn).front().second; + Value result = undef(srcValues.front().getType()); - for (unsigned i = 0; i != srcRegsPerThread; ++i) { - Value value = targetInfo.shuffleIdx(rewriter, loc, srcValues[i], srcLane); - result = select(icmp_eq(i32_val(i), srcReg), value, result); + for (unsigned i = 0; i != numRegsPerColumn; ++i) { + int32_t rest = + invertSrcRegMapRest.apply({{kGatherDim, i}}).front().second; + int32_t srcRegIdx = srcBase ^ rest; + + Value value = + targetInfo.shuffleIdx(rewriter, loc, srcValues[srcRegIdx], srcLane); + result = select(icmp_eq(i32_val(srcRegIdx), srcReg), value, result); } + results.push_back(result); } diff --git a/test/Conversion/gather_to_llvm.mlir b/test/Conversion/gather_to_llvm.mlir index 76232bd3a18d..0c78f7150b76 100644 --- a/test/Conversion/gather_to_llvm.mlir +++ b/test/Conversion/gather_to_llvm.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O3 | FileCheck %s +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s #trivial_layout = #ttg.linear<{register = [], lane = [[1], [2], [4], [8], [16]], warp = [], block = []}> @@ -6,6 +6,8 @@ #trivial_layout_wider_reg_stride_1 = #ttg.linear<{register = [[1]], lane = [[2], [4], [8], [16], [32]], warp = [], block = []}> +#trivial_2d_one_col = #ttg.linear<{register = [[0, 1]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [], block = []}> + module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // Each source element is mapped to a single thread, so we expect one index shuffle. @@ -103,6 +105,11 @@ tt.func private @gather_warp_local_larger_input_stride_1(%arg0: tensor<32xi32, # tt.return %0 : tensor<32xf32, #trivial_layout> } +tt.func private @gather_2d_trivial(%arg0: tensor<32x2xi32, #trivial_2d_one_col>, %arg1: tensor<32x2xf32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> { + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x2xf32, #trivial_2d_one_col>, tensor<32x2xi32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> + tt.return %0 : tensor<32x2xf32, #trivial_2d_one_col> +} + // Keep LLVM from DCE'ing the above functions. Use volatile stores to stop LLVM // from removing unused function results. tt.func @anchor(%ptr: !llvm.ptr, @@ -110,7 +117,9 @@ tt.func @anchor(%ptr: !llvm.ptr, %arg1: tensor<32xf32, #trivial_layout>, %arg2: tensor<64xi32, #trivial_layout_wider>, %arg3: tensor<64xf32, #trivial_layout_wider>, - %arg4: tensor<64xf32, #trivial_layout_wider_reg_stride_1>) { + %arg4: tensor<64xf32, #trivial_layout_wider_reg_stride_1>, + %arg5: tensor<32x2xi32, #trivial_2d_one_col>, + %arg6: tensor<32x2xf32, #trivial_2d_one_col>) { %0 = tt.call @gather_warp_local_trivial(%arg0, %arg1) : (tensor<32xi32, #trivial_layout>, tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> %1 = builtin.unrealized_conversion_cast %0 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> llvm.store volatile %1, %ptr : !llvm.struct<(f32)>, !llvm.ptr @@ -127,6 +136,10 @@ tt.func @anchor(%ptr: !llvm.ptr, %7 = builtin.unrealized_conversion_cast %6 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> llvm.store volatile %7, %ptr : !llvm.struct<(f32)>, !llvm.ptr + %8 = tt.call @gather_2d_trivial(%arg5, %arg6) : (tensor<32x2xi32, #trivial_2d_one_col>, tensor<32x2xf32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> + %9 = builtin.unrealized_conversion_cast %8 : tensor<32x2xf32, #trivial_2d_one_col> to !llvm.struct<(f32, f32)> + llvm.store volatile %9, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + tt.return } From 4c240d2ed72f4b00341ebdd18e528b3abb84c5cc Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 4 Dec 2024 17:01:15 -0800 Subject: [PATCH 30/38] test for trivial 2d --- test/Conversion/gather_to_llvm.mlir | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/Conversion/gather_to_llvm.mlir b/test/Conversion/gather_to_llvm.mlir index 0c78f7150b76..d0c874597be6 100644 --- a/test/Conversion/gather_to_llvm.mlir +++ b/test/Conversion/gather_to_llvm.mlir @@ -105,8 +105,30 @@ tt.func private @gather_warp_local_larger_input_stride_1(%arg0: tensor<32xi32, # tt.return %0 : tensor<32xf32, #trivial_layout> } +// Each thread has 1 element in 2 gather columns, so this is the same as the +// trivial case except now it's 2D. We expect 2 independent index shuffles. +// CHECK-LABEL: @gather_2d_trivial tt.func private @gather_2d_trivial(%arg0: tensor<32x2xi32, #trivial_2d_one_col>, %arg1: tensor<32x2xf32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[IDX0]], 31 + // CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID0]], i32 31) + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[IDX1]], 31 + // CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID1]], i32 31) + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x2xf32, #trivial_2d_one_col>, tensor<32x2xi32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] tt.return %0 : tensor<32x2xf32, #trivial_2d_one_col> } From 19f2d77ef47e37b44826e5f148341577470f9a5b Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 4 Dec 2024 21:38:46 -0800 Subject: [PATCH 31/38] more complex 2d test --- test/Conversion/gather_to_llvm.mlir | 64 ++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/test/Conversion/gather_to_llvm.mlir b/test/Conversion/gather_to_llvm.mlir index d0c874597be6..14133230c170 100644 --- a/test/Conversion/gather_to_llvm.mlir +++ b/test/Conversion/gather_to_llvm.mlir @@ -8,6 +8,8 @@ #trivial_2d_one_col = #ttg.linear<{register = [[0, 1]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [], block = []}> +#span_2d_cols = #ttg.linear<{register = [[1, 0]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [], block = []}> + module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // Each source element is mapped to a single thread, so we expect one index shuffle. @@ -132,6 +134,59 @@ tt.func private @gather_2d_trivial(%arg0: tensor<32x2xi32, #trivial_2d_one_col>, tt.return %0 : tensor<32x2xf32, #trivial_2d_one_col> } +// The single warp is split into two columns. Each column has half contiguous +// threads, each with 2 contiguous elements. Expect 4 index shuffles: two per +// column. Thus, the index should be dependent on the thread id, since the +// register alone is not enough to determine the column. +// CHECK-LABEL: @gather_2d_span_2 +tt.func private @gather_2d_span_2(%arg0: tensor<32x2xi32, #span_2d_cols>, %arg1: tensor<32x2xf32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // This uses tid to select between the two columns: + // CHECK-NEXT: [[TID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + // CHECK-NEXT: [[COL:%.*]] = and i32 [[TID]], 16 + + // Break the index into reg and thread (within column) components: + // CHECK-NEXT: [[REGID0:%.*]] = and i32 [[IDX0]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX0]], 1 + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[TMP]], 15 + + // CHECK-NEXT: [[SHUFFLE_IDX:%.*]] = or disjoint i32 [[LANEID0]], [[COL]] + + // CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[SRES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[SHUFFLE_IDX]], i32 31) + // CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[SRES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[SHUFFLE_IDX]], i32 31) + + // Use the reg id to select between the two results: + // CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID0]], 0 + // CHECK-NEXT: [[RES0_i32:%.*]] = select i1 [[PICK0]], i32 [[SRES0]], i32 [[SRES1]] + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[REGID1:%.*]] = and i32 [[IDX1]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX1]], 1 + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[TMP]], 15 + + // CHECK-NEXT: [[SHUFFLE_IDX:%.*]] = or disjoint i32 [[LANEID1]], [[COL]] + + // CHECK-NEXT: [[SRES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[SHUFFLE_IDX]], i32 31) + // CHECK-NEXT: [[SRES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[SHUFFLE_IDX]], i32 31) + + // CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID1]], 0 + // CHECK-NEXT: [[RES1_i32:%.*]] = select i1 [[PICK0]], i32 [[SRES0]], i32 [[SRES1]] + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x2xf32, #span_2d_cols>, tensor<32x2xi32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] + tt.return %0 : tensor<32x2xf32, #span_2d_cols> +} + // Keep LLVM from DCE'ing the above functions. Use volatile stores to stop LLVM // from removing unused function results. tt.func @anchor(%ptr: !llvm.ptr, @@ -141,7 +196,10 @@ tt.func @anchor(%ptr: !llvm.ptr, %arg3: tensor<64xf32, #trivial_layout_wider>, %arg4: tensor<64xf32, #trivial_layout_wider_reg_stride_1>, %arg5: tensor<32x2xi32, #trivial_2d_one_col>, - %arg6: tensor<32x2xf32, #trivial_2d_one_col>) { + %arg6: tensor<32x2xf32, #trivial_2d_one_col>, + %arg7: tensor<32x2xi32, #span_2d_cols>, + %arg8: tensor<32x2xf32, #span_2d_cols>) { + %0 = tt.call @gather_warp_local_trivial(%arg0, %arg1) : (tensor<32xi32, #trivial_layout>, tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> %1 = builtin.unrealized_conversion_cast %0 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> llvm.store volatile %1, %ptr : !llvm.struct<(f32)>, !llvm.ptr @@ -162,6 +220,10 @@ tt.func @anchor(%ptr: !llvm.ptr, %9 = builtin.unrealized_conversion_cast %8 : tensor<32x2xf32, #trivial_2d_one_col> to !llvm.struct<(f32, f32)> llvm.store volatile %9, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + %10 = tt.call @gather_2d_span_2(%arg7, %arg8) : (tensor<32x2xi32, #span_2d_cols>, tensor<32x2xf32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> + %11 = builtin.unrealized_conversion_cast %10 : tensor<32x2xf32, #span_2d_cols> to !llvm.struct<(f32, f32)> + llvm.store volatile %11, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + tt.return } From fbadf01b5dc557bec83485847a3675a5d26a5785 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 4 Dec 2024 23:43:55 -0800 Subject: [PATCH 32/38] trying to add integration test --- python/test/unit/language/test_core.py | 3 ++ test/Conversion/gather_to_llvm.mlir | 43 +++++++++++++++++++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 2daa8aaf07d6..ed36382e8d6e 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6218,3 +6218,6 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): ref = torch.gather(src, axis, indices) result = triton_gather(src, axis, indices) torch.testing.assert_close(result, ref, rtol=0, atol=0) + +def test_gather_complex_layouts(): + pass diff --git a/test/Conversion/gather_to_llvm.mlir b/test/Conversion/gather_to_llvm.mlir index 14133230c170..28a8a7e6b246 100644 --- a/test/Conversion/gather_to_llvm.mlir +++ b/test/Conversion/gather_to_llvm.mlir @@ -1,5 +1,8 @@ // RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s +// Check the optimized LLVMIR, since InstCombine makes the linear layout +// logic understandable enough (in simple cases) to check correctness by eye. + #trivial_layout = #ttg.linear<{register = [], lane = [[1], [2], [4], [8], [16]], warp = [], block = []}> #trivial_layout_wider = #ttg.linear<{register = [[32]], lane = [[1], [2], [4], [8], [16]], warp = [], block = []}> @@ -10,6 +13,9 @@ #span_2d_cols = #ttg.linear<{register = [[1, 0]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [], block = []}> +#crazy_2d_src = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}> +#crazy_2d_idx = #ttg.linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}> + module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // Each source element is mapped to a single thread, so we expect one index shuffle. @@ -187,6 +193,35 @@ tt.func private @gather_2d_span_2(%arg0: tensor<32x2xi32, #span_2d_cols>, %arg1: tt.return %0 : tensor<32x2xf32, #span_2d_cols> } +// CHECK-LABEL: @gather_2d_crazy +tt.func private @gather_2d_crazy(%arg0: tensor<32x16xi32, #crazy_2d_idx>, %arg1: tensor<32x16xf32, #crazy_2d_src>) -> tensor<32x16xf32, #crazy_2d_idx> { + // The specific logic becomes hard to grasp here. Just check the shuffles. + + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float, float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float, float, float } %1, 1 + // CHECK-NEXT: [[SRC2:%.*]] = extractvalue { float, float, float, float } %1, 2 + // CHECK-NEXT: [[SRC3:%.*]] = extractvalue { float, float, float, float } %1, 3 + + // CHECK: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], + // CHECK-NEXT: [[VALUE2:%.*]] = bitcast float [[SRC2]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE2]], + + // CHECK: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE2]], + + // CHECK: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], + // CHECK-NEXT: [[VALUE3:%.*]] = bitcast float [[SRC3]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE3]], + + // CHECK: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE3]], + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x16xf32, #crazy_2d_src>, tensor<32x16xi32, #crazy_2d_idx>) -> tensor<32x16xf32, #crazy_2d_idx> + tt.return %0 : tensor<32x16xf32, #crazy_2d_idx> +} + // Keep LLVM from DCE'ing the above functions. Use volatile stores to stop LLVM // from removing unused function results. tt.func @anchor(%ptr: !llvm.ptr, @@ -198,7 +233,9 @@ tt.func @anchor(%ptr: !llvm.ptr, %arg5: tensor<32x2xi32, #trivial_2d_one_col>, %arg6: tensor<32x2xf32, #trivial_2d_one_col>, %arg7: tensor<32x2xi32, #span_2d_cols>, - %arg8: tensor<32x2xf32, #span_2d_cols>) { + %arg8: tensor<32x2xf32, #span_2d_cols>, + %arg9: tensor<32x16xi32, #crazy_2d_idx>, + %arg10: tensor<32x16xf32, #crazy_2d_src>) { %0 = tt.call @gather_warp_local_trivial(%arg0, %arg1) : (tensor<32xi32, #trivial_layout>, tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> %1 = builtin.unrealized_conversion_cast %0 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> @@ -224,6 +261,10 @@ tt.func @anchor(%ptr: !llvm.ptr, %11 = builtin.unrealized_conversion_cast %10 : tensor<32x2xf32, #span_2d_cols> to !llvm.struct<(f32, f32)> llvm.store volatile %11, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + %12 = tt.call @gather_2d_crazy(%arg9, %arg10) : (tensor<32x16xi32, #crazy_2d_idx>, tensor<32x16xf32, #crazy_2d_src>) -> tensor<32x16xf32, #crazy_2d_idx> + %13 = builtin.unrealized_conversion_cast %12 : tensor<32x16xf32, #crazy_2d_idx> to !llvm.struct<(f32, f32, f32, f32)> + llvm.store volatile %13, %ptr : !llvm.struct<(f32, f32, f32, f32)>, !llvm.ptr + tt.return } From a66c7240ae8a19b437778fc4dc781e925589f263 Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 5 Dec 2024 13:09:29 -0800 Subject: [PATCH 33/38] add integration tests --- .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 27 +++-- python/test/unit/language/test_core.py | 113 ++++++++++++++---- 2 files changed, 107 insertions(+), 33 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 684d0f1b39bd..678f888addc9 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -44,6 +44,20 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, return success(); } +static Value convertIndexToI32(Location loc, Value index, + ConversionPatternRewriter &rewriter) { + unsigned idxWidth = index.getType().getIntOrFloatBitWidth(); + // The LL index computations are performed with 32 bit integers. If the + // indices are something else, cast them to i32. + if (idxWidth > 32) { + index = trunc(i32_ty, index); + } else if (idxWidth < 32) { + // Negative indices don't make sense, so zero-extend. + index = zext(i32_ty, index); + } + return index; +} + void GatherOpConversion::emitGatherInShared( GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); @@ -101,19 +115,10 @@ void GatherOpConversion::emitGatherInShared( emitIndices(loc, rewriter, targetInfo, dstType.getEncoding(), dstType, /*withCTAOffset=*/true); - unsigned idxWidth = op.getIndices().getType().getElementTypeBitWidth(); unsigned axis = op.getAxis(); SmallVector results(dstIndices.size()); for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) { - // The LL index computations are performed with 32 bit integers. If the - // indices are something else, cast them to i32. - if (idxWidth > 32) { - idx = trunc(i32_ty, idx); - } else if (idxWidth < 32) { - // Negative indices don't make sense, so zero-extend. - idx = zext(i32_ty, idx); - } - indices[axis] = idx; + indices[axis] = convertIndexToI32(loc, idx, rewriter); Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); Value ptr = gep(smemBase.getType(), elemType, smemBase, offset); results[i] = load(elemType, ptr); @@ -256,7 +261,7 @@ void GatherOpConversion::emitWarpLocalGather( assert(column.size() == otherDims.size()); // Combine the computed column with the data-dependent gather index. - column.emplace_back(kGatherDim, idxVal); + column.emplace_back(kGatherDim, convertIndexToI32(loc, idxVal, rewriter)); SmallVector> srcLaneAndReg = applyLinearLayout(loc, rewriter, invSrcLayout, column); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index ed36382e8d6e..4b5d3294ef41 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6180,6 +6180,24 @@ def kernel(In, Out, # assert torch.all(ref == result) +@triton.jit +def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, + src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, + idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, + out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, + out_stride1: tl.constexpr): + src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) + src = tl.load(src_ptr + src_offs) + + idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) + idx = tl.load(idx_ptr + idx_offs) + + out = tl.gather(src, idx, axis) + + out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) + tl.store(out_ptr + out_offs, out) + + @pytest.mark.parametrize("src_shape, indices_shape, axis", [ ([4, 4], [8, 4], 0), ([128, 64], [256, 64], 0), @@ -6187,29 +6205,13 @@ def kernel(In, Out, # ]) def test_gather(src_shape, indices_shape, axis): - @triton.jit - def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, - src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, - idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, - out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, - out_stride1: tl.constexpr): - src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) - src = tl.load(src_ptr + src_offs) - - idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) - idx = tl.load(idx_ptr + idx_offs) - - out = tl.gather(src, idx, axis) - - out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) - tl.store(out_ptr + out_offs, out) - def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) - gather_kernel[(1, )](src, indices, output, axis, src.shape[0], src.shape[1], - src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), - indices.stride(1), output.shape[0], output.shape[1], output.stride(0), output.stride(1)) + gather_test_kernel[(1, )](src, indices, output, axis, src.shape[0], + src.shape[1], src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], + indices.stride(0), indices.stride(1), output.shape[0], output.shape[1], + output.stride(0), output.stride(1)) return output @@ -6219,5 +6221,72 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): result = triton_gather(src, axis, indices) torch.testing.assert_close(result, ref, rtol=0, atol=0) -def test_gather_complex_layouts(): - pass + +@pytest.mark.parametrize("src_shape, indices_shape, axis, src_layout, indices_layout", [ + ([32, 16], [32, 16], 0, + "linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>", + "linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}>" + ), + ([128, 64], [256, 64], 0, + "linear<{register = [[0, 2], [32, 0], [2, 0], [0, 16], [0, 32], [64, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>", + "linear<{register = [[0, 2], [32, 0], [0, 32], [2, 0], [0, 16], [64, 0], [128, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>" + ), +]) +def test_gather_complex_layouts(src_shape, indices_shape, axis, src_layout, indices_layout, tmp_path: pathlib.Path): + + def prepare_kernel(src: torch.Tensor, axis: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + compiled = gather_test_kernel.warmup(src, indices, output, axis, src.shape[0], src.shape[1], src.stride(0), + src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), + indices.stride(1), output.shape[0], output.shape[1], output.stride(0), + output.stride(1), grid=(1, )) + return output, compiled + + def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout, idx_layout): + ir = f""" +#src_layout = #ttg.{src_layout} +#idx_layout = #ttg.{idx_layout} +{ir}""" + + dtypes = {torch.int32: "i32", torch.float32: "f32", torch.int64: "i64", torch.float64: "f64"} + + src_spec = f"{src.shape[0]}x{src.shape[1]}x{dtypes[src.dtype]}" + indices_spec = f"{indices.shape[0]}x{indices.shape[1]}x{dtypes[indices.dtype]}" + output_spec = f"{indices.shape[0]}x{indices.shape[1]}x{dtypes[src.dtype]}" + + pat = r"(%[0-9]+) = tt.gather (%[0-9]+)\[(%[0-9]+)\] {axis = " + pat += str(axis) + pat += r" : i32} : \(tensor\<" + pat += src_spec + pat += r", (#[a-z]+[0-9]+)\>, tensor\<" + pat += indices_spec + pat += r", (#[a-z]+[0-9]+)\>\) -> tensor\<" + pat += output_spec + pat += r", (#[a-z]+[0-9]+)\>" + + repl = r""" + %src = ttg.convert_layout \2 : tensor<""" + src_spec + r""", \4> -> tensor<""" + src_spec + r""", #src_layout> + %idx = ttg.convert_layout \3 : tensor<""" + indices_spec + r""", \5> -> tensor<""" + indices_spec + r""", #idx_layout> + %out = tt.gather %src[%idx] {axis = """ + str( + axis + ) + r""" : i32} : (tensor<""" + src_spec + r""", #src_layout>, tensor<""" + indices_spec + r""", #idx_layout>) -> tensor<""" + output_spec + r""", #idx_layout> + \1 = ttg.convert_layout %out : tensor<""" + output_spec + r""", #idx_layout> -> tensor<""" + output_spec + r""", \6>""" + return re.sub(pat, repl, ir) + + src = torch.randn(src_shape, device='cuda') + indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda') + ref = torch.gather(src, axis, indices) + + output, compiled = prepare_kernel(src, axis, indices) + ir = compiled.asm["ttgir"] + ir = inject_layout(ir, src, axis, indices, src_layout, indices_layout) + + temp_file = tmp_path / "test_warp_gather.ttgir" + temp_file.write_text(ir) + + kernel = triton.compile(str(temp_file)) + assert "llvm.nvvm.shfl.sync.idx" in kernel.asm["llir"] + + kernel[(1, 1, 1)](src, indices, output) + + torch.testing.assert_close(output, ref, rtol=0, atol=0) From e2d4c1b37e7c6971926f754aabc9fb12e2518662 Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 5 Dec 2024 13:53:50 -0800 Subject: [PATCH 34/38] add algo description --- .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 678f888addc9..c78337038a5f 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -129,6 +129,57 @@ void GatherOpConversion::emitGatherInShared( rewriter.replaceOp(op, packed); } +// High-level description of the algorithm: +// +// `isWarpLocal` checks that it is possible to compute each output element +// without data movement across warps. +// +// If the gather dim is `dimN`, then this means +// +// ll^-1(dimN)[(block, warp)] == 0 +// +// for both source and index tensors: moving along the gather axis does not +// change the warp. Broadcasted layouts are not supported, so we know the +// layouts are subpermutation matrices. +// +// We can check this with `ll((block, warp))[dimN] == 0`. +// +// Let `gatherCol` be a tuple of all dimensions except the gather dimension. +// We also check that the gather columns line up the same way with respect to +// the warp between the source and index tensors with +// +// ll_src((block, warp))[gatherCol] == ll_idx((block, warp))[gatherCol] +// +// This means that for all index columns, the corresponding column in the source +// tensor is owned by the same warp. +// +// We also check +// +// ll_src(lane)[gatherCol] == ll_idx(lane)[gatherCol] +// +// This boils down to the fact that the algorithm essentially emits a series of +// index shuffles for each index value owned by each thread, and then a pile of +// selects to pick the right value. We need to figure out given an index value +// in a particular column, what are the source register values it could read +// from and who owns them. +// +// If this relationship did not hold, then the possible source registers for +// each index value varies with the thread, meaning the value operand provided +// to each shuffle index instruction would depend on the thread ID. This isn't a +// big deal. It just means would have to emit a pile of selects before each +// shuffle as well, to pick the right source register value. But we choose not +// to handle this. +// +// The codegen algorithm emits code: +// - Given the thread ID and a particular index tensor register, figure out +// which gather column it belongs to using a layout. +// - Using the index value itself as the value for `dimN`, use another layout to +// figure out which lane in the warp owns the desired value and which register +// in that lane it is. +// - For the gather column, figure out the source registers in that column, and +// for each of them, emit an index shuffle with the same computed lane ID. +// - Use the register component to select the right value from the shuffle +// results. void GatherOpConversion::emitWarpLocalGather( GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); From 89a1970830566369c24f5d8ff677e3a31cc41110 Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 5 Dec 2024 13:56:53 -0800 Subject: [PATCH 35/38] fix test on AMD --- python/test/unit/language/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 4b5d3294ef41..ac75e24ab713 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6285,7 +6285,7 @@ def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout temp_file.write_text(ir) kernel = triton.compile(str(temp_file)) - assert "llvm.nvvm.shfl.sync.idx" in kernel.asm["llir"] + assert ("nvvm.shfl.sync.idx" in kernel.asm["llir"]) or ("llvm.amdgcn.ds.bpermute" in kernel.asm["llir"]) kernel[(1, 1, 1)](src, indices, output) From e7ba411b48d61a8b89f785280582e74a1f1d40d5 Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 5 Dec 2024 14:22:28 -0800 Subject: [PATCH 36/38] skip AMD --- python/test/unit/language/test_core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index ac75e24ab713..ff70de157a29 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6233,6 +6233,8 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): ), ]) def test_gather_complex_layouts(src_shape, indices_shape, axis, src_layout, indices_layout, tmp_path: pathlib.Path): + if is_hip(): + pytest.skip("warp-local gather has issues on HIP") def prepare_kernel(src: torch.Tensor, axis: int, indices: torch.Tensor): output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) From da62621311e2fc01b0f411737c8946be4151f118 Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 5 Dec 2024 14:59:24 -0800 Subject: [PATCH 37/38] fix subpermutation vs permutation terms --- lib/Analysis/Utility.cpp | 4 ++-- lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 72f4b41e937c..261c8756137b 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -447,7 +447,7 @@ bool GatherLoweringHelper::isWarpLocal() { b.getStringAttr("dim" + std::to_string(gatherOp.getAxis())); // The tensor layouts must be distributed layouts, where the basis matrix is a - // subpermutation matrix plus some zero rows for broadcasting. + // subpermutation matrix (permutation matrix plus zeros for broadcasting). // FIXME(jeff): Check this invariant somehow. // // We want to know if all elements of a column along the gather axis are @@ -457,7 +457,7 @@ bool GatherLoweringHelper::isWarpLocal() { // srcLayout.inverse().sublayoutIsZero({kGatherDim}, {kBlock, kWarp}) // // But due to broadcasting, the matrix might not be invertible. But since the - // matrix is a subpermutation matrix, we can instead query + // matrix is a permutation matrix (checked below), we can instead query // // srcLayout.sublayoutIsZero({kBlock, kWarp}, {kGatherDim}) // diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index c78337038a5f..3a453ff5218c 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -140,7 +140,7 @@ void GatherOpConversion::emitGatherInShared( // // for both source and index tensors: moving along the gather axis does not // change the warp. Broadcasted layouts are not supported, so we know the -// layouts are subpermutation matrices. +// layouts are permutation matrices. // // We can check this with `ll((block, warp))[dimN] == 0`. // @@ -237,7 +237,7 @@ void GatherOpConversion::emitWarpLocalGather( // elements per column owned by a thread. // Fully invert the source layout. We know it is invertible because - // `isWarpLocal` checked this (subpermutation matrix, no broadcasting). + // `isWarpLocal` checked this. LinearLayout invSrcLayout = srcLayout.invert(); // Sanity check: the warp must be invariant to the index because otherwise the From 20bae727d7aa1875bb7b9ba3ad853eb616a04f09 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 10 Dec 2024 09:14:17 -0800 Subject: [PATCH 38/38] review comments --- lib/Analysis/Utility.cpp | 24 +++++++------------ .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 8 +++++-- python/test/unit/language/test_core.py | 3 ++- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index d9f97fe8483f..69eb196a95a0 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -455,7 +455,7 @@ bool GatherLoweringHelper::isWarpLocal() { // mapped to the same set of warps, which means the gather can be performed // entirely within the warp. We need to query // - // srcLayout.inverse().sublayoutIsZero({kGatherDim}, {kBlock, kWarp}) + // srcLayout.invert().sublayoutIsZero({kGatherDim}, {kBlock, kWarp}) // // But due to broadcasting, the matrix might not be invertible. But since the // matrix is a permutation matrix (checked below), we can instead query @@ -475,10 +475,10 @@ bool GatherLoweringHelper::isWarpLocal() { } } - // `dimN` is invariant to the warp, but the `(block, warp)` mapping to all - // other dimensions must be the same for both layouts. If so, then the warp - // that owns a particular index element also owns all the source elements it - // could index into. + // If the gather axis `dimN` is invariant to the warp, but the `(block, warp)` + // mapping to all other dimensions must be the same for both layouts. If so, + // then the warp that owns a particular index element also owns all the source + // elements it could index into. if (srcLayout->sublayout({kBlock, kWarp}, otherDims) != idxLayout->sublayout({kBlock, kWarp}, otherDims)) return false; @@ -495,17 +495,9 @@ bool GatherLoweringHelper::isWarpLocal() { idxLayout->sublayout(kLane, otherDims)) return false; - // Broadcasted source layouts are not supported at the moment, because we - // rely on the source layout being invertible. - for (auto &bases : srcLayout->getBases()) { - auto isZero = [](ArrayRef base) { - return llvm::all_of(base, [](int32_t b) { return b == 0; }); - }; - if (llvm::any_of(bases.second, isZero)) { - return false; - } - } - return true; + // Otherwise, the source layout has to be invertible. This primarily means + // the codegen path doesn't support broadcasted source layouts. + return srcLayout->isInvertible(); } unsigned getNumScratchElements(ArrayRef shape) { diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 3a453ff5218c..faf781369e0a 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -35,7 +35,11 @@ LogicalResult GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { GatherLoweringHelper helper(op); - // Specialize the lowering based on the source layout. + // Specialize the lowering based on the source layout. Given that the cost of + // a warp shuffle is approximately half the cost of a roundtrip to shared + // memory with zero bank conflicts, we will need a more precise heuristic to + // choose between the two codegen paths and rely on the middle end to pick the + // right layout. if (helper.isWarpLocal()) { emitWarpLocalGather(op, adaptor, rewriter); } else { @@ -218,7 +222,7 @@ void GatherOpConversion::emitWarpLocalGather( // within that thread. // // Because `ll_src(block=0, warp=0, lane=0)[otherDims] == - // idx_src(0, 0, 0)[otherDims]`, we know given any `idx_reg` (element in the + // ll_idx(0, 0, 0)[otherDims]`, we know given any `idx_reg` (element in the // index tensor) the thread will need to read from the same column in the // source tensor. // diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 7fa856d87dc2..3ca05bad506d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6243,6 +6243,7 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): torch.testing.assert_close(result, ref, rtol=0, atol=0) +# These layouts are specially chosen to trigger the warp shuffle codegen. @pytest.mark.parametrize("src_shape, indices_shape, axis, src_layout, indices_layout", [ ([32, 16], [32, 16], 0, "linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>", @@ -6253,7 +6254,7 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): "linear<{register = [[0, 2], [32, 0], [0, 32], [2, 0], [0, 16], [64, 0], [128, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>" ), ]) -def test_gather_complex_layouts(src_shape, indices_shape, axis, src_layout, indices_layout, tmp_path: pathlib.Path): +def test_gather_warp_shuffle(src_shape, indices_shape, axis, src_layout, indices_layout, tmp_path: pathlib.Path): if is_hip(): pytest.skip("warp-local gather has issues on HIP")