Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class AllocationAnalysis;
SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec);
SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op);

} // namespace triton

Expand Down
52 changes: 33 additions & 19 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getUniqueContigPerThread;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
Expand Down Expand Up @@ -50,9 +51,7 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
return {inOrd, outOrd};
}

SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec) {
SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) {
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Expand All @@ -76,37 +75,52 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
}
}

assert(srcLayout && dstLayout &&
"Unexpected layout in getScratchConfigForCvtLayout()");
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]];
unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]];
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
assert(srcLayout && dstLayout && "Unexpected layout in getRepShape()");

auto srcShapePerCTA = getShapePerCTA(srcTy);
auto dstShapePerCTA = getShapePerCTA(dstTy);
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape());

unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank);
unsigned pad = std::max(inVec, outVec);
SmallVector<unsigned> repShape(rank);
for (unsigned d = 0; d < rank; ++d) {
paddedRepShape[d] =
repShape[d] =
std::max(std::min<unsigned>(srcShapePerCTA[d], srcShapePerCTATile[d]),
std::min<unsigned>(dstShapePerCTA[d], dstShapePerCTATile[d]));
}
if (rank == 1)
return paddedRepShape;
return repShape;
}

SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec) {
auto repShape = getRepShapeForCvtLayout(op);

auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
unsigned srcContigPerThread =
getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
unsigned dstContigPerThread =
getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;

if (repShape.size() <= 1)
return repShape;
unsigned paddedDim = 1;
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
paddedDim = dstBlockedLayout.getOrder()[0];
}
paddedRepShape[paddedDim] += pad;
return paddedRepShape;
unsigned pad = std::max(inVec, outVec);
repShape[paddedDim] += pad;
return repShape;
}

SmallVector<unsigned>
Expand Down
31 changes: 27 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,30 @@ struct ConvertLayoutOpConversion
llvm_unreachable("unexpected layout in getMultiDimOffset");
}

SmallVector<Value>
getWrappedMultiDimOffset(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDimOffset,
ArrayRef<unsigned> shape,
SmallVector<unsigned> shapePerCTATile,
SmallVector<int64_t> shapePerCTA) const {
unsigned rank = shape.size();
SmallVector<Value> multiDimOffsetWrapped(rank);
for (unsigned d = 0; d < rank; ++d) {
if (shapePerCTATile[d] > shapePerCTA[d])
multiDimOffsetWrapped[d] = urem(multiDimOffset[d], i32_val(shape[d]));
else
multiDimOffsetWrapped[d] = multiDimOffset[d];
}
return multiDimOffsetWrapped;
}

// shared memory rd/st for blocked or mma layout with data padding
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
bool stNotRd, RankedTensorType type,
ArrayRef<unsigned> numCTAsEachRep,
ArrayRef<unsigned> multiDimRepId, unsigned vec,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> origRepShape,
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
Value smemBase) const {
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
Expand Down Expand Up @@ -286,8 +304,11 @@ struct ConvertLayoutOpConversion
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type,
multiDimCTAInRepId, shapePerCTATile);
Value offset =
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
SmallVector<Value> multiDimOffsetWrapped = getWrappedMultiDimOffset(
rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile,
shapePerCTA);
Value offset = linearize(rewriter, loc, multiDimOffsetWrapped,
paddedRepShape, outOrd);
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
Expand Down Expand Up @@ -575,6 +596,7 @@ struct ConvertLayoutOpConversion
rewriter, srcTy);
unsigned inVec = 0;
unsigned outVec = 0;
auto origRepShape = getRepShapeForCvtLayout(op);
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
if (getElementTypeOrSelf(op.getType())
.isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType>()) {
Expand Down Expand Up @@ -618,7 +640,7 @@ struct ConvertLayoutOpConversion
else
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy,
inNumCTAsEachRep, multiDimRepId, inVec, paddedRepShape,
outOrd, vals, smemBase);
origRepShape, outOrd, vals, smemBase);
} else {
assert(0 && "ConvertLayout with input layout not implemented");
return failure();
Expand Down Expand Up @@ -651,7 +673,8 @@ struct ConvertLayoutOpConversion
else
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
outNumCTAsEachRep, multiDimRepId, outVec,
paddedRepShape, outOrd, outVals, smemBase);
paddedRepShape, origRepShape, outOrd, outVals,
smemBase);
} else {
assert(0 && "ConvertLayout with output layout not implemented");
return failure();
Expand Down
18 changes: 10 additions & 8 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ class ConvertTritonGPUOpToLLVMPatternBase {
// Order
auto inOrder = triton::gpu::getOrder(srcEncoding);
auto outOrder = triton::gpu::getOrder(resSharedLayout);
assert(outVec * (maxPhase - 1) <= srcShape[outOrder[0]] &&
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How this formula is derived?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me outVec * maxPhase makes more sense? Just curious

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's only phases 0, 1, ... maxPhase - 1, if each one increments by vec, then the largest address is vec * (maxPhase - 1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The largest address would be vec * maxPhase - 1 ? I meant the largest starting addresses is vec * (maxPhase - 1) , but we actually have accessed addresses until vec * maxPhase

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, I was trying to get away from cases where we have outVec > 1, maxPhase = 1 and srcShape[outOrder[0]] = 1. In that case, we only use address range [0, 1], regardless of what outVec is, since we don't really swizzle. Maybe that case should also be illegal, even though in practice the code just works since we don't swizzle.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the condition should be maxPhase == 1 || vec * maxPhase <= srcShape[outOrder[0]], what do you think ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. LGTM

"Swizzling would generate out of bounds memory accesses");
// Tensor indices held by the current thread, as LLVM values
auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy, false);
// Swizzling with leading offsets (e.g. Hopper GMMA)
Expand Down Expand Up @@ -452,10 +454,10 @@ class ConvertTritonGPUOpToLLVMPatternBase {
auto dstElemTy = dstTy.getElementType();
auto inOrd = triton::gpu::getOrder(srcSharedLayout);
auto outOrd = triton::gpu::getOrder(dstDistributedLayout);
unsigned outVec =
inOrd == outOrd
? triton::gpu::getContigPerThread(dstDistributedLayout)[outOrd[0]]
: 1;
unsigned outVec = inOrd == outOrd
? triton::gpu::getUniqueContigPerThread(
dstDistributedLayout, dstShape)[outOrd[0]]
: 1;
unsigned inVec = srcSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy);
Expand Down Expand Up @@ -501,10 +503,10 @@ class ConvertTritonGPUOpToLLVMPatternBase {
auto dstElemTy = dstTy.getElementType();
auto inOrd = triton::gpu::getOrder(srcDistributedLayout);
auto outOrd = dstSharedLayout.getOrder();
unsigned inVec =
inOrd == outOrd
? triton::gpu::getContigPerThread(srcDistributedLayout)[inOrd[0]]
: 1;
unsigned inVec = inOrd == outOrd
? triton::gpu::getUniqueContigPerThread(
srcDistributedLayout, srcShape)[inOrd[0]]
: 1;
unsigned outVec = dstSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
Expand Down
68 changes: 35 additions & 33 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3607,6 +3607,7 @@ def kernel(Out):
# MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),
# MmaLayout(1, [4, 1], [1, 1], [0, 1]),
# MmaLayout((2, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]),
BlockedLayout([1, 16], [8, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
Expand All @@ -3624,15 +3625,16 @@ def kernel(Out):
]


@pytest.mark.parametrize("shape", [(128, 128)])
@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
@pytest.mark.parametrize("dst_layout", layouts)
def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device):
def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
if is_hip():
pytest.skip("test_convert2d is not supported in HIP")

if (M == 1 or N == 1) and interm_layout:
pytest.skip("Out of bound access when maxPhase > 1")
if str(src_layout) == str(dst_layout):
pytest.skip()
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
Expand All @@ -3648,43 +3650,43 @@ def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device):
"""

conversion = f"""
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
%12 = triton_gpu.convert_layout %9 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #dst>
""" if interm_layout is None else f"""
%15 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #interm>
%16 = triton_gpu.convert_layout %15 : (tensor<128x128xi32, #interm>) -> tensor<128x128xi32, #src>
%17 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #interm>
%18 = triton_gpu.convert_layout %17 : (tensor<128x128xf16, #interm>) -> tensor<128x128xf16, #src>
%15 = triton_gpu.convert_layout %9 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #interm>
%16 = triton_gpu.convert_layout %15 : (tensor<{M}x{N}xi32, #interm>) -> tensor<{M}x{N}xi32, #src>
%17 = triton_gpu.convert_layout %11 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #interm>
%18 = triton_gpu.convert_layout %17 : (tensor<{M}x{N}xf16, #interm>) -> tensor<{M}x{N}xf16, #src>

%12 = triton_gpu.convert_layout %16 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %18 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
%12 = triton_gpu.convert_layout %16 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #dst>
%13 = triton_gpu.convert_layout %18 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #dst>
"""

ir = layouts + """
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #src>
%4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src>
%5 = arith.muli %4, %cst : tensor<128x1xi32, #src>
%6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src>
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
""" + conversion + """
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
tt.store %14, %13 : tensor<128x128xf16, #dst>
ir = layouts + f"""
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<{M}x{N}x!tt.ptr<f16>, #src>
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src>
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
%7 = tt.broadcast %6 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%8 = tt.broadcast %5 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src>
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #src>, tensor<{M}x{N}xi32, #src>
%11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #src>
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<{M}x{N}x!tt.ptr<f16>, #dst>
""" + conversion + f"""
%14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr<f16>, #dst>, tensor<{M}x{N}xi32, #dst>
tt.store %14, %13 : tensor<{M}x{N}xf16, #dst>
tt.return
}
}
}}
}}
"""

x = to_triton(numpy_random(shape, dtype_str=dtype), device=device)
x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device)
z = torch.empty_like(x)

# write the IR to a temporary file using mkstemp
Expand Down