Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
6b5374d
Add GatherOp with lit tests
Mogball Nov 23, 2024
f9bfec3
implement gather op through to LLVM
Mogball Nov 25, 2024
c091f31
expose through frontend and add unit tests
Mogball Nov 26, 2024
44dabb4
rename dim to axis
Mogball Nov 26, 2024
d4a32c8
add LLVMIR test
Mogball Nov 26, 2024
02a2f2b
Merge remote-tracking branch 'origin/main' into mogball/gather
Mogball Nov 26, 2024
56d1279
newlines
Mogball Nov 26, 2024
6a2f788
more newlines
Mogball Nov 26, 2024
8f1358e
more newlines
Mogball Nov 26, 2024
63d0de2
format code
Mogball Nov 26, 2024
39a35ce
reduce test_gather smem usage for AMD
Mogball Nov 26, 2024
508b981
Merge remote-tracking branch 'origin/main' into mogball/gather
Mogball Nov 27, 2024
741d0a8
assert_close
Mogball Nov 27, 2024
4189c09
clarify gather impl comment
Mogball Nov 27, 2024
9fee613
add gather lowering test with dot layout
Mogball Nov 27, 2024
a4d9a2e
warp local
Mogball Nov 27, 2024
5b47ed6
remove membar todo
Mogball Nov 27, 2024
b670068
require other dims to match source dims
Mogball Nov 27, 2024
dad8ca5
tol=0
Mogball Nov 27, 2024
bccf50a
Merge branch 'mogball/gather' into mogball/faster_gather
Mogball Nov 27, 2024
e023fec
sublayoutIsZero
Mogball Nov 27, 2024
c0f3702
sublayout check
Mogball Nov 27, 2024
4f53fd0
struggle
Mogball Nov 28, 2024
309a380
Merge remote-tracking branch 'origin/main' into mogball/gather
Mogball Nov 28, 2024
64630e6
merge
Mogball Nov 28, 2024
4f930c0
Merge branch 'mogball/gather' into mogball/faster_gather
Mogball Nov 28, 2024
45d2c1e
Merge remote-tracking branch 'origin/main' into mogball/faster_gather
Mogball Nov 28, 2024
c840525
restore layout helper and remove unused code
Mogball Dec 3, 2024
6759cc7
emitHardwareTuple
Mogball Dec 4, 2024
0628835
almost there
Mogball Dec 4, 2024
a52cd4f
Merge remote-tracking branch 'origin/main' into mogball/faster_gather
Mogball Dec 4, 2024
97a5b5d
emitHardwareTuple
Mogball Dec 4, 2024
50cb9b2
unused code
Mogball Dec 4, 2024
899f525
Merge branch 'mogball/cleanup' into mogball/faster_gather
Mogball Dec 4, 2024
5c4ea15
maybe it's here
Mogball Dec 4, 2024
3d8401f
add some unit tests
Mogball Dec 4, 2024
db051eb
don't redundantly shuffle index
Mogball Dec 5, 2024
4c240d2
test for trivial 2d
Mogball Dec 5, 2024
19f2d77
more complex 2d test
Mogball Dec 5, 2024
fbadf01
trying to add integration test
Mogball Dec 5, 2024
eff4d85
Merge remote-tracking branch 'origin/main' into mogball/faster_gather
Mogball Dec 5, 2024
a66c724
add integration tests
Mogball Dec 5, 2024
3c6cc8f
Merge remote-tracking branch 'origin/main' into mogball/faster_gather
Mogball Dec 5, 2024
e2d4c1b
add algo description
Mogball Dec 5, 2024
89a1970
fix test on AMD
Mogball Dec 5, 2024
e7ba411
skip AMD
Mogball Dec 5, 2024
da62621
fix subpermutation vs permutation terms
Mogball Dec 5, 2024
19fa7fc
Merge remote-tracking branch 'origin/main' into mogball/faster_gather
Mogball Dec 10, 2024
20bae72
review comments
Mogball Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class GatherLoweringHelper {

// Get the shared memory scratch size required by this op.
unsigned getScratchSizeInBytes();
// Determine if the gather can be performed completely within a warp.
bool isWarpLocal();

private:
triton::GatherOp gatherOp;
Expand Down
5 changes: 5 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,12 @@ LinearLayout ensureLayoutNotSmallerThan(
const LinearLayout &layout,
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);

// Return a vector of the standard out dimension names for tensor layouts. These
// are "dim0", "dim1", etc.
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank);
// Return an identity mapping from `inDimName` to the standard out dimensions,
// with the dimensions sized according to the shape. The bases are sorted
// according to `order`, with the most minor dimension first.
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);

Expand Down
84 changes: 82 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,13 +413,93 @@ GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp)
: gatherOp(gatherOp) {}

unsigned GatherLoweringHelper::getScratchSizeInBytes() {
// For now, lower the gather op by writing the source tensor to shared memory.
// TODO(jeff): Leverage locality to avoid using scratch space when possible.
// If the gather is warp-local, no scratch space is needed.
if (isWarpLocal())
return 0;

// Otherwise, performing the gather will require scratch space to communicate
// the source tensor across threads. For now, assume the whole source tensor
// is written back to shared memory.
RankedTensorType srcType = gatherOp.getSrc().getType();
return product(srcType.getShape()) *
ceil<unsigned>(srcType.getElementTypeBitWidth(), 8);
}

bool GatherLoweringHelper::isWarpLocal() {
// The gather is warp-local if for each column along the gather axis in the
// source and index tensors, all the elements are owned by the same warp.
RankedTensorType srcType = gatherOp.getSrc().getType();
RankedTensorType idxType = gatherOp.getIndices().getType();
std::optional<LinearLayout> srcLayout =
toLinearLayout(srcType.getShape(), srcType.getEncoding());
std::optional<LinearLayout> idxLayout =
toLinearLayout(idxType.getShape(), idxType.getEncoding());

// FIXME: If an unsupported layout was encountered, assume the gather is not
// warp-local.
if (!srcLayout || !idxLayout)
return false;

Builder b(gatherOp.getContext());
StringAttr kBlock = b.getStringAttr("block");
StringAttr kWarp = b.getStringAttr("warp");
StringAttr kLane = b.getStringAttr("lane");
StringAttr kGatherDim =
b.getStringAttr("dim" + std::to_string(gatherOp.getAxis()));

// The tensor layouts must be distributed layouts, where the basis matrix is a
// subpermutation matrix (permutation matrix plus zeros for broadcasting).
// FIXME(jeff): Check this invariant somehow.
//
// We want to know if all elements of a column along the gather axis are
// mapped to the same set of warps, which means the gather can be performed
// entirely within the warp. We need to query
//
// srcLayout.invert().sublayoutIsZero({kGatherDim}, {kBlock, kWarp})
//
// But due to broadcasting, the matrix might not be invertible. But since the
// matrix is a permutation matrix (checked below), we can instead query
//
// srcLayout.sublayoutIsZero({kBlock, kWarp}, {kGatherDim})
//
// Which implies that changing the warp will not change the gather dimension.
// And since there is no swizzling, this applies to all warps.
if (!srcLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim) ||
!idxLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim))
return false;

SmallVector<StringAttr> otherDims;
for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) {
if (dim != gatherOp.getAxis()) {
otherDims.push_back(b.getStringAttr("dim" + Twine(dim)));
}
}

// If the gather axis `dimN` is invariant to the warp, but the `(block, warp)`
// mapping to all other dimensions must be the same for both layouts. If so,
// then the warp that owns a particular index element also owns all the source
// elements it could index into.
if (srcLayout->sublayout({kBlock, kWarp}, otherDims) !=
idxLayout->sublayout({kBlock, kWarp}, otherDims))
return false;

// The two constraints above ensure that data-movement to perform the gather
// operation are contained within a warp. The subsequent constraints simplify
// codegen.

// Require that for any given gather column, the threads mapped to the column
// in the index and source tensors are the same. This means we don't need to
// xor shuffle across threads before emitting index shuffles; we push warp
// shuffling to layout conversions.
if (srcLayout->sublayout(kLane, otherDims) !=
idxLayout->sublayout(kLane, otherDims))
return false;

// Otherwise, the source layout has to be invertible. This primarily means
// the codegen path doesn't support broadcasted source layouts.
return srcLayout->isInvertible();
}

unsigned getNumScratchElements(ArrayRef<unsigned> shape) {
if (shape.empty())
return 0;
Expand Down
Loading