Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layout conversion bypass for blocked to dotOperand #4538

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);

bool atomicNeedsSharedMemory(Value result);

bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
Expand Down
65 changes: 62 additions & 3 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,64 @@ bool supportMMA(Value value, int version) {
(elemTy.isInteger(8) && version >= 2);
}

bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
return false;
auto parentLayout =
dyn_cast<BlockedEncodingAttr>(dotOperandLayout.getParent());
if (parentLayout == nullptr)
return false;
auto opShape = srcTy.getShape();
auto rank = opShape.size();

int kDim = dotOperandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
int nonKDim = dotOperandLayout.getOpIdx() == 0 ? rank - 2 : rank - 1;
auto ctaLayout = blockedLayout.getCTALayout();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One issue we have in the codebase is lots of mysterious layout/indexing--it's not easy for others reading the code to pick up the intent. The following might not be that tricky; but still can we add a comment to explain what the following checks are doing in a high level?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the comments! but the wording is quite confusing to me right now. What about something like

The following logic checks that a source blocked layout B matches a destination dot operand layout with blocked layout parent P. It's considered match if 1) each thread holds a whole copy of all elements along the K dimension for B, and 2) distribution along all other non-K dimensions match between S and B. This is to guarantee that each thread have all the data needed for reduction without exchange with other threads. (And/or whatever other reasons why we want this kind of match.)

// Layout of blocked dot operand matches with parent blocked layout except K
// dim: while vector of elements across k dim is stored by one thread.
// clang-format off
//
// i.e. tensor<64x32xf16, #dot_op<{opIdx=0, parent=#blocked}>> will have sizePerThread = [<depends on #blocked>, 32]
// and tensor<64x32xf16, #dot_op<{opIdx=1, parent=#blocked}>> will have sizePerThread = [64, <depends on #blocked>]
//
// For example tensor<64x32xf16, #dot_op<{opIdx=0, parent=#blocked<{sizePerThread = [2, 8], threadsPerWarp = [32, 1]}>>>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going from dot operand to blocked layout? Isn't it the reverse of what we are doing in this function? I'm also not sure the distribution is correct? Isn't this contradicting to the check at L571?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is misleading, let me change this comment.

I mean that these dot and blocked layouts are equal? I should not use "converted" here

// could be converted to tensor<128x64xf16, #blocked<{sizePerThread = [2, 32], threadsPerWarp = [32, 1]}>>
//
// clang-format on
// Following conditions verifies that src layout holds all elements across K
// per thread, and the rest of src and dst layout matches.
bool ctaLayoutCompatible =
ctaLayout.getCTASplitNum()[kDim] == 1 &&
blockedLayout.getCTALayout() == parentLayout.getCTALayout();
bool threadHoldsWholeKDim =
blockedLayout.getSizePerThread()[kDim] == opShape[kDim];
bool nonKDimCompatible =
blockedLayout.getOrder() == parentLayout.getOrder() &&
blockedLayout.getSizePerThread()[nonKDim] ==
parentLayout.getSizePerThread()[nonKDim] &&
blockedLayout.getThreadsPerWarp()[nonKDim] ==
parentLayout.getThreadsPerWarp()[nonKDim] &&
blockedLayout.getWarpsPerCTA()[nonKDim] ==
parentLayout.getWarpsPerCTA()[nonKDim];
bool matrixDimsCompatible =
ctaLayoutCompatible && threadHoldsWholeKDim && nonKDimCompatible;
if (rank == 2)
return matrixDimsCompatible;

// additional check for batch dimension if it is present
assert(rank == 3);
bool bDimCompatible =
blockedLayout.getSizePerThread()[0] ==
parentLayout.getSizePerThread()[0] &&
blockedLayout.getThreadsPerWarp()[0] ==
parentLayout.getThreadsPerWarp()[0] &&
blockedLayout.getWarpsPerCTA()[0] == parentLayout.getWarpsPerCTA()[0];
return matrixDimsCompatible && bDimCompatible;
}

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
Expand Down Expand Up @@ -624,12 +682,13 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
}

bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut` and
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
// checks.
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`,
// `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully
// subsumed by the linear-layout checks.
// TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
// supported yet in Triton's backend.
return !cvtReordersRegisters(srcTy, dstTy) &&
!isBlockedToDotShortcut(srcTy, dstTy) &&
!isMmaToDotShortcut(srcTy, dstTy) &&
!isMfmaToDotShortcut(srcTy, dstTy);
}
Expand Down
32 changes: 32 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,36 @@ struct ConvertLayoutOpConversion
const TargetInfoBase &targetInfo;
};

struct ConvertLayoutOpBlockedToDotOpShortcutConversion
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
const TargetInfoBase &targetInfo;
explicit ConvertLayoutOpBlockedToDotOpShortcutConversion(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}

LogicalResult
matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = op.getContext();

const auto &shape = op.getType().getShape();
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
auto dstDotEncoding = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (!dstDotEncoding)
return failure();
if (!isa<BlockedEncodingAttr>(srcTy.getEncoding()) ||
!isa<BlockedEncodingAttr>(dstDotEncoding.getParent()))
return failure();
if (cvtNeedsSharedMemory(srcTy, dstTy))
return failure();
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}
};

struct ConvertLayoutOpUsingLinearLayoutsConversion
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
const TargetInfoBase &targetInfo;
Expand Down Expand Up @@ -666,5 +696,7 @@ void mlir::triton::populateConvertLayoutOpToLLVMPatterns(
// one left.
mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern(
typeConverter, targetInfo, patterns, benefit.getBenefit() + 1);
patterns.add<ConvertLayoutOpBlockedToDotOpShortcutConversion>(
typeConverter, targetInfo, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, targetInfo, benefit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
OpBuilder builder(cvtOp);
auto srcType = cast<RankedTensorType>(cvtOp.getSrc().getType());
auto dstType = cast<RankedTensorType>(cvtOp.getType());
if (!cvtNeedsSharedMemory(srcType, dstType))
return;
auto srcBlocked =
dyn_cast<triton::gpu::BlockedEncodingAttr>(srcType.getEncoding());
auto dstDotOp =
Expand Down
18 changes: 2 additions & 16 deletions lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,8 @@ class TritonGPUReduceDataDuplicationPass
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
if (!dstDotOp)
return;
if (auto srcMmaEncoding =
dyn_cast<triton::gpu::NvidiaMmaEncodingAttr>(srcEncoding)) {

if (srcMmaEncoding.getVersionMajor() != 2 ||
(srcMmaEncoding.getWarpsPerCTA()[1] == 1 &&
dstDotOp.getParent() == srcMmaEncoding))
return;
}
if (auto srcMfmaEncoding =
dyn_cast<triton::gpu::AMDMfmaEncodingAttr>(srcEncoding)) {

if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 &&
srcMfmaEncoding.getIsTransposed() &&
dstDotOp.getParent() == srcMfmaEncoding)
return;
}
if (!cvtNeedsSharedMemory(srcType, dstType))
antiagainst marked this conversation as resolved.
Show resolved Hide resolved
return;
auto srcOrder = triton::gpu::getOrder(srcEncoding);
auto rank = srcOrder.size();
SmallVector<unsigned> sharedOrder;
Expand Down
104 changes: 88 additions & 16 deletions test/Conversion/amd/decompose-unsupported-conversions.mlir
Original file line number Diff line number Diff line change
@@ -1,33 +1,105 @@
// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx1130 | FileCheck %s
// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions | FileCheck %s

// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}>
// CHECK: wmma_to_wmma_dot_op
// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}>
// CHECK-LABEL: wmma_to_wmma_dot_op
#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) {
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[WMMA]]> -> tensor<16x16xf16, #[[BLOCKED]]>
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[SHARED]], #triton_gpu.shared_memory>
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>>
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]>
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory>
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>>
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
tt.return
}
}

// -----

// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}>
// CHECK: wmma_to_wmma_dot3d_op
// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}>
// CHECK-LABEL: wmma_to_wmma_dot3d_op
#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) {
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[WMMA]]> -> tensor<2x16x16xf16, #[[BLOCKED]]>
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[SHARED]], #triton_gpu.shared_memory>
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>>
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]>
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory>
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>>
%0 = triton_gpu.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
tt.return
}
}

// -----

// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx1130
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func @blocked_to_dot_op_shortcut_gfx1130(%arg0: tensor<32x32xf16, #blocked>) {
// CHECK-NOT: triton_gpu.local_alloc
// CHECK: triton_gpu.convert_layout
// CHECK-NOT: triton_gpu.local_alloc
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
tt.return
}
}

antiagainst marked this conversation as resolved.
Show resolved Hide resolved
// -----

// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx940
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func @blocked_to_dot_op_shortcut_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
// CHECK-NOT: triton_gpu.local_alloc
// CHECK: triton_gpu.convert_layout
// CHECK-NOT: triton_gpu.local_alloc
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
tt.return
}
}

// -----

// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_elems_gfx940
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func @neg_blocked_to_dot_op_incompatible_elems_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: triton_gpu.local_alloc
// CHECK: triton_gpu.local_load
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
tt.return
}
}

// -----

// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_threads_gfx940
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func @neg_blocked_to_dot_op_incompatible_threads_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: triton_gpu.local_alloc
// CHECK: triton_gpu.local_load
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
tt.return
}
}

// -----

// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_warp_gfx940
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: triton_gpu.local_alloc
// CHECK: triton_gpu.local_load
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
tt.return
}
}
47 changes: 47 additions & 0 deletions test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s

// CHECK-LABEL: blocked_to_dot_op_shortcut_warp32
#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [0, 1]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func @blocked_to_dot_op_shortcut_warp32(%arg0: tensor<32x32xf16, #blocked>, %arg1: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) {
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
// CHECK-NOT: load
tt.return
}
}

// -----

// CHECK-LABEL: blocked_to_dot_op_shortcut_warp64
#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [2, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func @blocked_to_dot_op_shortcut_warp64(%arg0: tensor<32x32xf16, #blocked>) {
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
// CHECK-NOT: load
tt.return
}
}

// -----

// CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp32
#blocked = #triton_gpu.blocked<{sizePerThread = [2, 32, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func @blocked_to_dot3d_op_shortcut_warp32(%arg0: tensor<8x32x32xf16, #blocked>) {
%0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
// CHECK-NOT: load
tt.return
}
}

// -----

// CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp64
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32, 1], threadsPerWarp = [1, 2, 32], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func @blocked_to_dot3d_op_shortcut_warp64(%arg0: tensor<8x32x32xf16, #blocked>) {
%0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
// CHECK-NOT: load
tt.return
}
}
Loading
Loading