Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions test/TritonGPU/loop-pipeline-hip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,46 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}

// -----

// Check that both tensors are pipelined, but only 1 (B) goes to shared memory. This is restricted
// due to small contiguos range (< 32bits).
// CHECK-LABEL: matmul_reg_buffer
// CHECK: triton_gpu.local_alloc
// CHECK-NOT: triton_gpu.local_alloc
// CHECK-COUNT-2: tt.load
// CHECK: scf.for
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
tt.func @matmul_reg_buffer(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> {
%cst = arith.constant dense<4.000000e+00> : tensor<1x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%cst_0 = arith.constant dense<4> : tensor<1x128xi32, #blocked>
%cst_1 = arith.constant dense<4> : tensor<128x1xi32, #blocked1>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
%0 = tt.splat %arg3 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked1>
%1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<1xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x1xi32, #blocked1>
%3 = tt.broadcast %2 : tensor<1x1xi32, #blocked1> -> tensor<128x1xi32, #blocked1>
%4 = tt.addptr %0, %3 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<128x1xi32, #blocked1>
%5 = tt.splat %arg4 : !tt.ptr<f16> -> tensor<1x128x!tt.ptr<f16>, #blocked>
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked>
%8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<1x128xi32, #blocked>
%9 = tt.addptr %5, %8 : tensor<1x128x!tt.ptr<f16>, #blocked>, tensor<1x128xi32, #blocked>
%10:3 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2) -> (tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<1x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>) {
%11 = tt.load %arg6 : tensor<128x1x!tt.ptr<f16>, #blocked1>
%12 = triton_gpu.convert_layout %11 : tensor<128x1xf16, #blocked1> -> tensor<128x1xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%13 = tt.load %arg7 : tensor<1x128x!tt.ptr<f16>, #blocked>
%14 = triton_gpu.convert_layout %13 : tensor<1x128xf16, #blocked> -> tensor<1x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%15 = arith.mulf %14, %cst : tensor<1x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%16 = tt.dot %12, %15, %arg8 : tensor<128x1xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<1x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
%17 = tt.addptr %arg6, %cst_1 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<128x1xi32, #blocked1>
%18 = tt.addptr %arg7, %cst_0 : tensor<1x128x!tt.ptr<f16>, #blocked>, tensor<1x128xi32, #blocked>
scf.yield %17, %18, %16 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<1x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>
}
tt.return %10#2 : tensor<128x128xf32, #mma>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -448,18 +448,14 @@ void StreamPipeliner::assignMemoryLayouts() {
cast<tt::PointerType>(tensorTy.getElementType()).getPointeeType();
unsigned width = vec * pointeeTy.getIntOrFloatBitWidth();

// Limit shared memory sharing to width >= 32 elements.
LDBG("Load " << *loadOp << " has width " << width);
if (width < 32) {
LDBG("Skip width<32 load " << *loadOp);
continue;
}

if (use->hasTrait<OpTrait::DotLike>()) {
// Only use shared memory when feeding into a dot op.
loadInfo.usedByDot = true;
loadInfo.sharedEncoding =
getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr);
// If the max continugous bits we can read is < 32, buffer in registers.
if (width >= 32) {
loadInfo.sharedEncoding =
getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr);
}
} else if (auto useOp = dyn_cast<tt::LoadOp>(use)) {
// The use of this loadOp is another loadOp. If the use is not in the
// loadToInfo already, it means that the use is not valid for pipelining
Expand Down