Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,37 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout);

SmallVector<unsigned> getSizePerThread(Attribute layout);

// Returns the number of contiguous elements that each thread
// has access to, on each dimension of the tensor. E.g.
// for a blocked layout with sizePerThread = [1, 4], returns [1, 4],
// regardless of the shape of the tensor.
SmallVector<unsigned> getContigPerThread(Attribute layout);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we've reached the point where we need comments describing each function with a small example. Things are no longer as self-explanatory as they were back when we only had a single ThreadsPerWarp for distributed layouts

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I agree, I'll add comments


// Returns the number of non-replicated contiguous elements that each thread
// has access to, on each dimension of the tensor. For a blocked layout
// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements
// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1,
// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be
// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4].
SmallVector<unsigned> getUniqueContigPerThread(Type type);

// Returns the number of threads per warp that have access to non-replicated
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
// 1], threadsPerWarp = [2, 16] and tensor shape = [2, 2], threads 0, 1, 16, 17
// have access to the full tensor, whereas the other threads have access to
// replicated elements, so this function returns [2, 2].
SmallVector<unsigned>
getThreadsPerWarpWithUniqueData(Attribute layout,
ArrayRef<int64_t> tensorShape);

// Returns the number of warps per CTA that have access to non-replicated
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2],
// returns [1, 1], since the first warp has access to the full tensor, whereas
// the other warps have access to replicated elements.
SmallVector<unsigned>
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);

SmallVector<unsigned> getThreadsPerCTA(Attribute layout);

SmallVector<unsigned>
Expand Down
15 changes: 11 additions & 4 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@ unsigned ReduceOpHelper::getInterWarpSize() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]);
triton::gpu::getWarpsPerCTAWithUniqueData(
getSrcLayout(), getSrcShape())[axis]);
}

unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
triton::gpu::getThreadsPerWarpWithUniqueData(
getSrcLayout(), getSrcShape())[axis]);
}

unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = getSrcLayout();
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
auto srcShape = getSrcShape();
return triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout,
srcShape)[axis] *
triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis];
}

SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
Expand Down Expand Up @@ -88,6 +92,9 @@ bool ReduceOpHelper::isSupportedLayout() {
return true;
}
}
if (auto sliceLayout = srcLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
return true;
}
return false;
}

Expand Down
19 changes: 14 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ struct ReduceOpConversion
Attribute layout, SmallVector<Value> &index,
SmallVector<Value> &writeIdx,
std::map<int, Value> &ints, unsigned axis) const {
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto dim = sliceLayout.getDim();
assert(dim != axis && "Reduction axis cannot be sliced");
auto parentLayout = sliceLayout.getParent();
getWriteIndexBasic(rewriter, loc, parentLayout, index, writeIdx, ints,
axis);
return;
}

writeIdx = index;
auto sizePerThread = triton::gpu::getSizePerThread(layout);
Value axisSizePerThread = ints[sizePerThread[axis]];
Expand All @@ -100,9 +109,10 @@ struct ReduceOpConversion
// to map every `axisSizePerThread` to 1 value in smem as:
// writeIdx[axis] = index[axis] / axisSizePerThread
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
}
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
if (mmaLayout && mmaLayout.isAmpere()) {
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (!mmaLayout.isAmpere()) {
llvm::report_fatal_error("Unsupported layout");
}
if (axis == 0) {
// Because warpTileSize = [16, 8] and threadsPerWarp = [8, 4], each 8
// rows in smem would correspond to a warp. The mapping
Expand All @@ -113,8 +123,7 @@ struct ReduceOpConversion
// Same as BlockedEncodingAttr case
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
}
}
if (mmaLayout && !mmaLayout.isAmpere()) {
} else {
llvm::report_fatal_error("Unsupported layout");
}
}
Expand Down
78 changes: 71 additions & 7 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,41 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
if (mmaLayout.isAmpere())
return {8, 4};
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parent = sliceLayout.getParent();
auto parentThreadsPerWarp = getThreadsPerWarp(parent);
SmallVector<unsigned> threadsPerWarp = parentThreadsPerWarp;
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
for (unsigned i = 0; i < threadsPerWarp.size(); i++)
threadsPerWarp[i] *= parentThreadsPerWarp[sliceLayout.getDim()];
return threadsPerWarp;
}
assert(0 && "getThreadsPerWarp not implemented");
return {};
}

SmallVector<unsigned>
getThreadsPerWarpWithUniqueData(Attribute layout,
ArrayRef<int64_t> tensorShape) {
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(tensorShape);
auto parentThreadsPerWarp =
getThreadsPerWarpWithUniqueData(parentLayout, parentShape);
SmallVector<unsigned> threadsPerWarp = parentThreadsPerWarp;
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
return threadsPerWarp;
}
auto threadsPerWarp = getThreadsPerWarp(layout);
assert(threadsPerWarp.size() == tensorShape.size() &&
"layout and tensor shape must have the same rank");
for (unsigned i = 0; i < threadsPerWarp.size(); i++) {
threadsPerWarp[i] = std::min<unsigned>(threadsPerWarp[i], tensorShape[i]);
}

return threadsPerWarp;
}

SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
Expand All @@ -94,10 +125,43 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
return SmallVector<unsigned>(mmaLayout.getWarpsPerCTA().begin(),
mmaLayout.getWarpsPerCTA().end());
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parent = sliceLayout.getParent();
auto parentWarpsPerCTA = getWarpsPerCTA(parent);
SmallVector<unsigned> warpsPerCTA = parentWarpsPerCTA;
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
for (unsigned i = 0; i < warpsPerCTA.size(); i++)
warpsPerCTA[i] *= parentWarpsPerCTA[sliceLayout.getDim()];
return warpsPerCTA;
}
assert(0 && "getWarpsPerCTA not implemented");
return {};
}

SmallVector<unsigned>
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape) {
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(tensorShape);
auto parentWarpsPerCTA =
getWarpsPerCTAWithUniqueData(parentLayout, parentShape);
SmallVector<unsigned> warpsPerCTA = parentWarpsPerCTA;
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
return warpsPerCTA;
}
auto warpsPerCTA = getWarpsPerCTA(layout);
assert(warpsPerCTA.size() == tensorShape.size() &&
"layout and tensor shape must have the same rank");
for (unsigned i = 0; i < warpsPerCTA.size(); i++) {
auto sizePerWarp =
getSizePerThread(layout)[i] * getThreadsPerWarp(layout)[i];
auto maxWarpsPerDim = ceil<unsigned>(tensorShape[i], sizePerWarp);
warpsPerCTA[i] = std::min<unsigned>(warpsPerCTA[i], maxWarpsPerDim);
}

return warpsPerCTA;
}

SmallVector<unsigned> getSizePerThread(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
Expand Down Expand Up @@ -189,7 +253,7 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
threads.push_back(blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d]);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersionMajor() == 2) {
if (mmaLayout.isAmpere()) {
threads = {8 * mmaLayout.getWarpsPerCTA()[0],
4 * mmaLayout.getWarpsPerCTA()[1]};
} else
Expand Down Expand Up @@ -1074,9 +1138,9 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
return mlir::failure();
}
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
// Ensure that the new insert_slice op is placed in the same place as the
// old insert_slice op. Otherwise, the new insert_slice op may be placed
// after the async_wait op, which is not allowed.
// Ensure that the new insert_slice op is placed in the same place as
// the old insert_slice op. Otherwise, the new insert_slice op may be
// placed after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(insert_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
Expand Down Expand Up @@ -1104,9 +1168,9 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
auto resType = RankedTensorType::get(
origResType.getShape(), origResType.getElementType(),
extract_slice.getType().cast<RankedTensorType>().getEncoding());
// Ensure that the new extract_slice op is placed in the same place as the
// old extract_slice op. Otherwise, the new extract_slice op may be placed
// after the async_wait op, which is not allowed.
// Ensure that the new extract_slice op is placed in the same place as
// the old extract_slice op. Otherwise, the new extract_slice op may be
// placed after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(extract_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
Expand Down
64 changes: 64 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,70 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
)


layouts = [
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
]


@pytest.mark.parametrize("M, N", [[32, 128], [128, 128], [128, 32]])
@pytest.mark.parametrize("src_layout", layouts)
def test_reduce_2d(M, N, src_layout, device='cuda'):
ir = f"""
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{M}> : tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
%3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
%5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
%8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
%9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
%11 = "tt.reduce"(%10) ({{
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%12 = "tt.reduce"(%11) ({{
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
tt.return
}}
}}
"""
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)

rs = RandomState(17)
x = rs.randint(0, 4, (M, N)).astype('int32')
x = (x.view('uint32') & np.uint32(0xffffe000)).view('int32')

z = np.zeros((1,)).astype('int32')

x_tri = torch.tensor(x, device=device)
z_tri = torch.tensor(z, device=device)

pgm = kernel[(1, 1, 1)](x_tri, z_tri)

z_ref = np.sum(x)

np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)


def test_generic_reduction(device='cuda'):

@triton.jit
Expand Down