Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,27 @@ def TTG_MemDescType : TTG_TypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> {
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace,
"bool":$mutable_memory
"bool":$mutableMemory,
ArrayRefParameter<"int64_t">:$allocShape
);

let extraClassDeclaration = [{
MemDescType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory());
return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory(), getAllocShape());
}

bool hasRank() const { return true; }
}];

let builders = [
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace
), [{
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false);
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false, /*allocShape=*/shape);
}]>,
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
Expand All @@ -75,10 +78,23 @@ def TTG_MemDescType : TTG_TypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> {
"Attribute":$memorySpace,
"bool":$mutableMemory
), [{
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory);
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory, /*allocShape=*/shape);
}]>,
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace,
"bool":$mutableMemory,
"llvm::ArrayRef<int64_t>":$allocShape
), [{
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory, allocShape);
}]>

];

let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
}


Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,7 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;

AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
// Encoding attributes
if (auto mmaAttr = mlir::dyn_cast<MmaEncodingTrait>(attr)) {
os << "mma";
return AliasResult::FinalAlias;
Expand All @@ -2475,6 +2476,11 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
os << "slice";
return AliasResult::FinalAlias;
} */
// Memory space attributes
if (auto smem = mlir::dyn_cast<SharedMemorySpaceAttr>(attr)) {
os << "smem";
return AliasResult::FinalAlias;
}
return OpAsmDialectInterface::getAlias(attr, os);
}
};
Expand Down
72 changes: 48 additions & 24 deletions lib/Dialect/TritonGPU/IR/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,47 +30,54 @@ void TokenType::print(AsmPrinter &printer) const {
static constexpr llvm::StringRef kMutableMemory = "mutable";

Type MemDescType::parse(AsmParser &parser) {
if (parser.parseLess())
if (failed(parser.parseLess()))
return Type();

SmallVector<int64_t> dimensions;
if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false))
SmallVector<int64_t> dimensions; // required
if (failed(parser.parseDimensionList(dimensions, /*allowDynamic=*/false)))
return Type();

// Parse the element type.
Type elementType;
if (parser.parseType(elementType))
Type elementType; // required
if (failed(parser.parseType(elementType)))
return Type();

Attribute encoding;
if (succeeded(parser.parseOptionalComma())) {
if (parser.parseAttribute(encoding))
return Type();
}
bool mutableMemory = false;
Attribute memorySpace;
Attribute encoding; // required
if (failed(parser.parseComma()) || failed(parser.parseAttribute(encoding)))
return Type();

Attribute memorySpace; // required
if (failed(parser.parseComma()) || failed(parser.parseAttribute(memorySpace)))
return Type();

bool mutableMemory = false; // optional
SmallVector<int64_t> allocShape; // optional
if (succeeded(parser.parseOptionalComma())) {
if (failed(parser.parseOptionalKeyword(kMutableMemory))) {
if (parser.parseAttribute(memorySpace))
return Type();
} else {
if (succeeded(parser.parseOptionalKeyword(kMutableMemory))) {
mutableMemory = true;
}
}
if (mutableMemory == false && succeeded(parser.parseOptionalComma())) {
if (parser.parseOptionalKeyword(kMutableMemory))
if (succeeded(parser.parseOptionalComma())) {
if (failed(parser.parseDimensionList(allocShape, /*allowDynamic=*/false,
/*withTrailingX=*/false))) {
return Type();
}
}
} else if (failed(parser.parseDimensionList(allocShape,
/*allowDynamic=*/false,
/*withTrailingX=*/false))) {
return Type();
mutableMemory = true;
}
}

if (parser.parseGreater())
return Type();

return MemDescType::get(parser.getContext(), dimensions, elementType,
encoding, memorySpace, mutableMemory);
encoding, memorySpace, mutableMemory, dimensions);
}

void MemDescType::print(AsmPrinter &printer) const {
printer << "<";
for (auto dim : getShape())
auto shape = getShape();
for (auto dim : shape)
printer << dim << "x";
printer << getElementType();
if (getEncoding())
Expand All @@ -79,9 +86,26 @@ void MemDescType::print(AsmPrinter &printer) const {
printer << ", " << getMemorySpace();
if (getMutableMemory())
printer << ", " << kMutableMemory;
auto allocShape = getAllocShape();
if (allocShape != shape) {
printer << ", " << allocShape[0];
for (auto dim : allocShape.drop_front(1)) {
printer << "x" << dim;
}
}
printer << ">";
}

LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
Attribute encoding, Attribute memorySpace,
bool mutableMemory,
ArrayRef<int64_t> allocShape) {
if (allocShape.size() < shape.size())
emitError() << "alloc shape must have at least as many dimensions as shape";
return success();
}

//===----------------------------------------------------------------------===//
// Triton Dialect
//===----------------------------------------------------------------------===//
Expand Down
35 changes: 19 additions & 16 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext());
ttg::MemDescType subviewTy = ttg::MemDescType::get(
allocTy.getShape().drop_front(), allocTy.getElementType(),
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true);
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true,
/*allocShape=*/allocTy.getAllocShape());
auto view = builder.createWithStage<ttg::MemDescSubviewOp>(
loc, stage, clusterId, subviewTy, alloc, copyOffsets);
Operation *copy = builder.createWithStage<ttg::AsyncCopyGlobalToLocalOp>(
Expand Down Expand Up @@ -232,7 +233,8 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp,
copyOffsets[0] = insertIdx;
ttg::MemDescType subviewTy = ttg::MemDescType::get(
allocTy.getShape().drop_front(), allocTy.getElementType(),
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true);
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true,
/*allocShape=*/allocTy.getAllocShape());
auto view = builder.createWithStage<ttg::MemDescSubviewOp>(
loc, stage, clusterId, subviewTy, alloc, copyOffsets);

Expand Down Expand Up @@ -526,7 +528,7 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp,
bufferShape.insert(bufferShape.begin(), distance);
Type memdescType = ttg::MemDescType::get(bufferShape, ty.getElementType(),
sharedEnc, sharedMemorySpace,
/*mutableMemory*/ true);
/*mutableMemory=*/true);
Value alloc =
builder.create<ttg::LocalAllocOp>(loadOp->getLoc(), memdescType, Value());
return alloc;
Expand All @@ -544,12 +546,13 @@ static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) {
/*CTASplitNum=*/{1}, /*CTAOrder=*/{0});
auto barrierEncoding =
ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout);
Type barrierMemDescType = ttg::MemDescType::get(
auto barrierMemDescType = ttg::MemDescType::get(
{distance}, builder.getI64Type(), barrierEncoding, sharedMemorySpace,
/*mutableMemory=*/true);
Type singleBarrierMemDescType =
ttg::MemDescType::get({1}, builder.getI64Type(), barrierEncoding,
sharedMemorySpace, /*mutableMemory=*/true);
Type singleBarrierMemDescType = ttg::MemDescType::get(
{1}, builder.getI64Type(), barrierEncoding, sharedMemorySpace,
/*mutableMemory=*/true,
/*allocShape=*/barrierMemDescType.getAllocShape());
Value barrierAlloc =
builder.create<ttg::LocalAllocOp>(loc, barrierMemDescType, Value());
for (unsigned i = 0; i < distance; i++) {
Expand Down Expand Up @@ -650,11 +653,11 @@ static void createTMABarrierAndWait(
OpBuilderWithStage builder(forOp);
Attribute sharedMemorySpace =
ttg::SharedMemorySpaceAttr::get(builder.getContext());
auto allocTy = cast<ttg::MemDescType>(barrierAlloc.getType());
ttg::MemDescType barrierTy = ttg::MemDescType::get(
{1}, builder.getI64Type(),
cast<ttg::MemDescType>(barrierAlloc.getType()).getEncoding(),
sharedMemorySpace,
/*mutableMemory=*/true);
{1}, builder.getI64Type(), allocTy.getEncoding(), sharedMemorySpace,
/*mutableMemory=*/true,
/*allocShape=*/allocTy.getAllocShape());
builder.setInsertionPoint(group[0]->loadOp);
Value barrier = builder.createWithStage<ttg::MemDescSubviewOp>(
loc, stage, cluster, barrierTy, barrierAlloc,
Expand Down Expand Up @@ -835,14 +838,14 @@ static void invalidateBarriers(OpBuilder &builder,
Attribute sharedMemorySpace =
ttg::SharedMemorySpaceAttr::get(builder.getContext());
for (Value barrier : barriers) {
int numBarriers = cast<ttg::MemDescType>(barrier.getType()).getShape()[0];
auto allocTy = cast<ttg::MemDescType>(barrier.getType());
int numBarriers = allocTy.getShape()[0];
for (int i = 0; i < numBarriers; i++) {
Value idx = builder.create<arith::ConstantIntOp>(barrier.getLoc(), i, 32);
ttg::MemDescType barrierTy = ttg::MemDescType::get(
{1}, builder.getI64Type(),
cast<ttg::MemDescType>(barrier.getType()).getEncoding(),
sharedMemorySpace,
/*mutableMemory=*/true);
{1}, builder.getI64Type(), allocTy.getEncoding(), sharedMemorySpace,
/*mutableMemory=*/true,
/*allocShape=*/allocTy.getShape());
Value barrierView = builder.create<ttg::MemDescSubviewOp>(
barrier.getLoc(), barrierTy, barrier, idx);
builder.create<ttng::InvalBarrierOp>(barrier.getLoc(), barrierView);
Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
builder.create<arith::ConstantIntOp>(v.getLoc(), off, 32));
Value newSmem = builder.create<triton::gpu::MemDescSubviewOp>(
v.getLoc(),
triton::gpu::MemDescType::get(shape, elementType, type.getEncoding(),
type.getMemorySpace()),
triton::gpu::MemDescType::get(
shape, elementType, type.getEncoding(), type.getMemorySpace(),
type.getMutableMemory(), type.getAllocShape()),
v, offsetsVal);

auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
Expand Down
15 changes: 9 additions & 6 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5328,20 +5328,22 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
layouts = f"""
#src = {src_layout}
#dst = {dst_layout}
#smem = #ttg.shared_memory
""" if interm_layout is None else f"""
#src = {src_layout}
#interm = {interm_layout}
#dst = {dst_layout}
#smem = #ttg.shared_memory
"""

conversion = f"""
%12 = ttg.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst>
%13 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst>
""" if interm_layout is None else f"""
%15 = ttg.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !ttg.memdesc<{M}x{N}xi32, #interm, #ttg.shared_memory>
%16 = ttg.local_load %15 : !ttg.memdesc<{M}x{N}xi32, #interm, #ttg.shared_memory> -> tensor<{M}x{N}xi32, #src>
%17 = ttg.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !ttg.memdesc<{M}x{N}xf16, #interm, #ttg.shared_memory>
%18 = ttg.local_load %17 : !ttg.memdesc<{M}x{N}xf16, #interm, #ttg.shared_memory> -> tensor<{M}x{N}xf16, #src>
%15 = ttg.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !ttg.memdesc<{M}x{N}xi32, #interm, #smem>
%16 = ttg.local_load %15 : !ttg.memdesc<{M}x{N}xi32, #interm, #smem> -> tensor<{M}x{N}xi32, #src>
%17 = ttg.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !ttg.memdesc<{M}x{N}xf16, #interm, #smem>
%18 = ttg.local_load %17 : !ttg.memdesc<{M}x{N}xf16, #interm, #smem> -> tensor<{M}x{N}xf16, #src>

%12 = ttg.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst>
%13 = ttg.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst>
Expand Down Expand Up @@ -5405,6 +5407,7 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
layouts = f"""
#dist = {dist_layout}
#shared = {shared_layout}
#smem = #ttg.shared_memory
"""
ir = layouts + f"""
module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
Expand Down Expand Up @@ -5433,8 +5436,8 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
%17 = tt.broadcast %15 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist>
%18 = tt.addptr %16, %17 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>, tensor<{M}x{N}x{K}xi32, #dist>
%19 = tt.load %18 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>
%20 = ttg.local_alloc %19 : (tensor<{M}x{N}x{K}xi32, #dist>) -> !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory>
%21 = ttg.local_load %20 : !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory> -> tensor<{M}x{N}x{K}xi32, #dist>
%20 = ttg.local_alloc %19 : (tensor<{M}x{N}x{K}xi32, #dist>) -> !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #smem>
%21 = ttg.local_load %20 : !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #smem> -> tensor<{M}x{N}x{K}xi32, #dist>
%22 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>>
%23 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
%24 = tt.expand_dims %23 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist>
Expand Down
5 changes: 3 additions & 2 deletions test/Conversion/amd/compute-base-ptr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}>
#shared = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @local_load_offset
tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) {
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1)
%1 = ttg.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> loc(#loc2)
%1 = ttg.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> loc(#loc2)
// This catches base ptr calculation in the computeBasePtr, checks if the gep has correct element type.
// CHECK: llvm.getelementptr {{.*}} (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 local_load:3:0
%2 = ttg.local_load %1 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> loc(#loc3)
%2 = ttg.local_load %1 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> loc(#loc3)
tt.return
}
}
Expand Down
6 changes: 4 additions & 2 deletions test/Conversion/amd/decompose-unsupported-conversions.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
// CHECK: #[[$SHARED:.+]] = #ttg.shared<{{.*}}>
// CHECK-LABEL: wmma_to_wmma_dot_op
#mma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1130", "ttg.threads-per-warp" = 32 : i32} {
tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) {
// CHECK: %[[SRC_BLOCKED:.+]] = ttg.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]>
// CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<16x16xf16, #[[$SHARED]], #ttg.shared_memory>
// CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<16x16xf16, #[[$SHARED]], #smem>
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = ttg.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>>
%0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
tt.return
Expand All @@ -22,10 +23,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
// CHECK: #[[$SHARED:.+]] = #ttg.shared<{{.*}}>
// CHECK-LABEL: wmma_to_wmma_dot3d_op
#mma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) {
// CHECK: %[[SRC_BLOCKED:.+]] = ttg.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]>
// CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<2x16x16xf16, #[[$SHARED]], #ttg.shared_memory>
// CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<2x16x16xf16, #[[$SHARED]], #smem>
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = ttg.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>>
%0 = ttg.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
tt.return
Expand Down
Loading
Loading