Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,16 @@ def TritonGPULoopScheduling: Pass<"tritongpu-loop-scheduling", "mlir::ModuleOp">
"number of pipeline stages">
];
}

def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::ModuleOp"> {
let summary = "Improve coalescing for async global to local copies";

let description = "For AsyncCopyGlobalToLocal ops where the shared encoding's vec is less than "
"the blocked encoding's sizePerThread, this pass improves coalescing by clipping the "
"sizePerThread value";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}

#endif
5 changes: 5 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ enum class MMALoadType {
// pipelining
};
MMALoadType getMMALoadType(Operation *loadOp);

// Returns composed LinearLayout for register to shared copy
std::optional<triton::LinearLayout>
getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
Attribute srcEnc, Attribute dstEnc, int elemBitWidth);
} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
46 changes: 12 additions & 34 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/STLExtras.h"

namespace mlir {
Expand Down Expand Up @@ -174,41 +175,17 @@ bool emitTransferBetweenRegistersAndShared(
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");

std::optional<LinearLayout> regLayout =
triton::gpu::toLinearLayout(shape, registerTy.getEncoding());
std::optional<LinearLayout> sharedLayout = triton::gpu::toLinearLayout(
shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth());
if (!regLayout.has_value() || !sharedLayout.has_value()) {
auto regToSharedLayout = getRegToSharedLayout(
ctx, shape, registerTy.getEncoding(), sharedTy.getEncoding(),
elemLlvmTy.getIntOrFloatBitWidth());
if (!regToSharedLayout.has_value())
return false;
}
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());

// sharedLayout's in-dims are currently (offset, block). Reshape to
// (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional
// shmem strides. (The offsetX's appear in minor-to-major order.)
auto sharedLegacy =
cast<triton::gpu::SharedEncodingAttr>(sharedTy.getEncoding());
SmallVector<std::pair<StringAttr, int32_t>> multiDimSharedSize;
for (int i = 0; i < rank; i++) {
int dim = sharedOrder[i];
int64_t size = std::max(
int64_t{1},
shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]);
multiDimSharedSize.push_back(
{str_attr("offset" + std::to_string(dim)), size});
}
multiDimSharedSize.push_back({kBlock, sharedLayout->getInDimSize(kBlock)});
sharedLayout = sharedLayout->reshapeIns(multiDimSharedSize);

// regToSharedLayout maps from (register, lane, warp, block) to (offsetX1,
// ..., offsetXN, block), where the offsetX's are in minor-to-major order.
LinearLayout regToSharedLayout = regLayout->invertAndCompose(*sharedLayout);

// TODO(jlebar): We don't currently support loading from shared memory in a
// different CTA. We'd need to emit `mapa.shared::cluster` instructions.
for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock);
for (int inBlock = 1; inBlock < regToSharedLayout->getInDimSize(kBlock);
inBlock *= 2) {
auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout.apply(
auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout->apply(
{{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}})));
// offsetX1, ..., offsetXN must all be 0.
if (!llvm::all_of(ArrayRef(idx).drop_back(1),
Expand All @@ -234,15 +211,15 @@ bool emitTransferBetweenRegistersAndShared(
// which have known strides. This would allow us to vectorize across multiple
// shmem out dimensions where possible.
const int vecElems =
std::min(regToSharedLayout.getNumConsecutiveInOut(),
std::min(regToSharedLayout->getNumConsecutiveInOut(),
maxVecElems.value_or(std::numeric_limits<int>::max()));

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(regToSharedLayout.getInDimSize(kLane));
Value threadsPerWarp = i32_val(regToSharedLayout->getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);

int numElems = regToSharedLayout.getInDimSize(kRegister);
int numElems = regToSharedLayout->getInDimSize(kRegister);
auto vecTy = vec_ty(elemLlvmTy, vecElems);
auto ptrTy = shmemBase.getType();
Value zero = i32_val(0);
Expand All @@ -253,14 +230,15 @@ bool emitTransferBetweenRegistersAndShared(
// we drop_end to drop block, which we know from above will be 0.
auto multiDimShmemOffset =
llvm::to_vector(llvm::drop_end(llvm::make_second_range(
applyLinearLayout(loc, rewriter, regToSharedLayout,
applyLinearLayout(loc, rewriter, *regToSharedLayout,
{{kRegister, i32_val(i * vecElems)},
{kLane, laneId},
{kWarp, warpId},
{kBlock, zero}}))));

// Reorder strides according to `order`. This way they match the
// multi-dimensional offsets in regToSharedLayout.
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());
Value shmemOffset = dot(rewriter, loc, multiDimShmemOffset,
applyPermutation(shmemStrides, sharedOrder));
auto vecAddr = gep(ptrTy, elemLlvmTy, shmemBase, shmemOffset);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_triton_library(TritonGPUTransforms
Prefetch.cpp
RemoveLayoutConversions.cpp
ReorderInstructions.cpp
CoalesceAsyncCopy.cpp
Utility.cpp

DEPENDS
Expand Down
124 changes: 124 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

namespace mlir {
namespace triton {
namespace gpu {

#define GEN_PASS_DEF_TRITONGPUCOALESCEASYNCCOPY
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

// This pass currently only applies if the following are all true...
// 1) Operand A for WGMMA is to be loaded in registers
// 2) We upcast operand A in registers before the WGMMA
// (downcasting is not yet supported)
// 3) Pipelining is enabled for loading A
//
// ...then for the AsyncCopyGlobalToLocal op, the SharedEncoding
// vec will be less than BlockedEncoding's sizePerThread for k-dim. E.g. if
// we're upcasting from int8 to bf16, then shared vec is 8 and sizePerThread
// for k is 16. In this case, AsyncCopyGlobalToLocal will generate two
// 8-byte-cp.async's for each contiguous 16B global data owned by each
// thread. This breaks coalescing (i.e. results 2x the minimum required
// transactions).
//
// This issue occurs for cp.async because it combines load and store into one
// instruction. The fix is to clip each dim of sizePerThread by shared vec, so
// that the vectorization of load and store are equal along the contiguous
// dimension. In the above example, each thread will then only own 8B contiguous
// global data.
struct ClipAsyncCopySizePerThread
: public OpRewritePattern<AsyncCopyGlobalToLocalOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(AsyncCopyGlobalToLocalOp copyOp,
PatternRewriter &rewriter) const override {
Value src = copyOp.getSrc();
Value mask = copyOp.getMask();
Value other = copyOp.getOther();
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<MemDescType>(copyOp.getResult().getType());
auto blockEnc = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
if (!blockEnc)
return rewriter.notifyMatchFailure(copyOp,
"src must be of blocked encoding");
auto sharedEnc = cast<SharedEncodingAttr>(dstTy.getEncoding());
auto sharedVec = sharedEnc.getVec();

// obtain max contiguous copy size
// Note this can be further optimized, as copyContigSize can be even
// smaller when lowering, depending on contiguity and mask alignment
// (see AsyncCopyGlobalToLocalOpConversion)
auto elemBitWidth = dstTy.getElementTypeBitWidth();
auto regToSharedLayout =
getRegToSharedLayout(rewriter.getContext(), srcTy.getShape(), blockEnc,
sharedEnc, elemBitWidth);
auto copyContigSize = regToSharedLayout->getNumConsecutiveInOut();

// obtain block sizePerThread along contig dim
auto sizePerThread = blockEnc.getSizePerThread();
auto blockContigSize = sizePerThread[blockEnc.getOrder()[0]];

if (blockContigSize <= copyContigSize)
return rewriter.notifyMatchFailure(
copyOp,
"blocked sizePerThread along contiguous dim must be greater than the "
"max contiguous copy size ");

sizePerThread[blockEnc.getOrder()[0]] = copyContigSize;

// obtain new blockedEnc based on clipped sizePerThread
auto mod = copyOp->getParentOfType<ModuleOp>();
int numWarps = TritonGPUDialect::getNumWarps(mod);
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
auto newBlockEnc = BlockedEncodingAttr::get(
copyOp.getContext(), srcTy.getShape(), sizePerThread,
blockEnc.getOrder(), numWarps, threadsPerWarp, blockEnc.getCTALayout());

// insert cvt's after src, mask, and other
auto convertBlockLayout = [&](Value src, BlockedEncodingAttr enc) {
auto ty = cast<TensorType>(src.getType());
auto newTy =
RankedTensorType::get(ty.getShape(), ty.getElementType(), enc);
auto cvt = rewriter.create<ConvertLayoutOp>(copyOp->getLoc(), newTy, src);
return cvt.getResult();
};
src = convertBlockLayout(src, newBlockEnc);
if (mask)
mask = convertBlockLayout(mask, newBlockEnc);
if (other)
other = convertBlockLayout(other, newBlockEnc);

rewriter.modifyOpInPlace(copyOp, [&]() {
copyOp.getSrcMutable().assign(src);
if (mask)
copyOp.getMaskMutable().assign(mask);
if (other)
copyOp.getOtherMutable().assign(other);
});

return success();
}
};

class CoalesceAsyncCopyPass
: public impl::TritonGPUCoalesceAsyncCopyBase<CoalesceAsyncCopyPass> {
public:
void runOnOperation() override {
ModuleOp m = getOperation();
MLIRContext *context = &getContext();

mlir::RewritePatternSet patterns(context);
patterns.add<ClipAsyncCopySizePerThread>(context);

if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns))))
signalPassFailure();
}
};

} // namespace gpu
} // namespace triton
} // namespace mlir
36 changes: 36 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1153,4 +1153,40 @@ void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) {
patterns.add<ForOpDeadArgElimination>(patterns.getContext());
}

std::optional<LinearLayout>
getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
Attribute srcEnc, Attribute dstEnc, int elemBitWidth) {
StringAttr kBlock = StringAttr::get(ctx, ("block"));
int rank = shape.size();

std::optional<LinearLayout> regLayout =
triton::gpu::toLinearLayout(shape, srcEnc);
std::optional<LinearLayout> sharedLayout =
triton::gpu::toLinearLayout(shape, dstEnc, elemBitWidth);
if (!regLayout.has_value() || !sharedLayout.has_value()) {
return std::nullopt;
}
auto sharedOrder = triton::gpu::getOrder(dstEnc);

// sharedLayout's in-dims are currently (offset, block). Reshape to
// (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional
// shmem strides. (The offsetX's appear in minor-to-major order.)
auto sharedLegacy = cast<triton::gpu::SharedEncodingAttr>(dstEnc);
SmallVector<std::pair<StringAttr, int32_t>> multiDimSharedSize;
for (int i = 0; i < rank; i++) {
int dim = sharedOrder[i];
int64_t size = std::max(
int64_t{1},
shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]);
multiDimSharedSize.push_back(
{StringAttr::get(ctx, ("offset" + std::to_string(dim))), size});
}
multiDimSharedSize.push_back({kBlock, sharedLayout->getInDimSize(kBlock)});
sharedLayout = sharedLayout->reshapeIns(multiDimSharedSize);

// regToSharedLayout maps from (register, lane, warp, block) to (offsetX1,
// ..., offsetXN, block), where the offsetX's are in minor-to-major order.
return regLayout->invertAndCompose(*sharedLayout);
}

} // namespace mlir
2 changes: 2 additions & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ void init_triton_passes_ttgpuir(py::module &&m) {
createTritonGPUOptimizeAccumulatorInit);
ADD_PASS_OPTION_WRAPPER_1("add_loop_scheduling",
createTritonGPULoopScheduling, int);
ADD_PASS_WRAPPER_0("add_coalesce_async_copy",
createTritonGPUCoalesceAsyncCopy);
}

void init_triton_passes_convert(py::module &&m) {
Expand Down
35 changes: 35 additions & 0 deletions test/TritonGPU/coalesce-async-copy.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce-async-copy | FileCheck %s

// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]>
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi1, #[[NEW_BLOCKED]]>
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi8, #[[NEW_BLOCKED]]>
// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
tt.func @async_copy_i8(%input: tensor<64x16x!tt.ptr<i8>, #blocked>,
%view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>,
%mask: tensor<64x16xi1, #blocked>,
%other: tensor<64x16xi8, #blocked>) {
%token = triton_gpu.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr<i8>, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable>
tt.return
}
}

// -----

// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]>
// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
tt.func @async_copy_i8_no_mask_or_other(%input: tensor<64x16x!tt.ptr<i8>, #blocked>,
%view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>) {
%token = triton_gpu.async_copy_global_to_local %input, %view : tensor<64x16x!tt.ptr<i8>, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable>
tt.return
}
}
1 change: 1 addition & 0 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def make_ttgir(mod, metadata, opt, capability):
passes.ttgpuir.add_pipeline(pm, opt.num_stages)
passes.ttgpuir.add_prefetch(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
passes.ttgpuir.add_coalesce_async_copy(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
passes.ttgpuir.add_reorder_instructions(pm)
Expand Down