Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUStreamPipelineV2();
mlir::registerTritonAMDGPUCanonicalizePointers();
mlir::registerTritonAMDGPUConvertToBufferOps();
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,33 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
constexpr int patternBenefitClampOptimizedPattern = 20;
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;

struct BackendCallbacks {
/**
* A backend-specific callback for appending auxiliary data during
* `LocalStoreOp` conversion.
*
* @param[in] op The reference to the re-written `LocalStoreOp`.
* @param[in] count The number of issued LLVM instructions.
* @param[in] type The input type of issued LLVM instructions.
*/
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
Type llvmOpType)>
localStoreOpConversion = nullptr;
};

void populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
RewritePatternSet &patterns,
PatternBenefit benefit);
// The given callback is invoked at the end of a successful rewrite. The
// callback receives 1) the current source op, 2) the number of issued LLVM
// instructions and 3) their input types. Each MLIR backend can provide a
// callback and, thus, handle backend-specific behaviors.
void populateMemoryOpToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit,
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);

void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
Expand Down
10 changes: 5 additions & 5 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1366,11 +1366,11 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);

void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
Type elemLlvmTy, ArrayRef<Value> srcVals,
Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);
void storeDistributedToShared(
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount = nullptr);

inline Value getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
Expand Down
36 changes: 25 additions & 11 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ using namespace mlir::triton::gpu;
// blocked -> shared.
// Swizzling in shared memory to avoid bank conflict. Normally used for
// A/B operands of dots.
void lowerDistributedToShared(Location loc, Value src, Value dst,
Value adaptorSrc,
const SharedMemoryObject &smemObj,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
const TargetInfoBase &targetInfo) {
void lowerDistributedToShared(
Location loc, Value src, Value dst, Value adaptorSrc,
const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo,
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<MemDescType>(dst.getType());
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
Expand All @@ -33,7 +32,7 @@ void lowerDistributedToShared(Location loc, Value src, Value dst,
auto dstStrides = smemObj.getStrides();
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides,
loc, rewriter, targetInfo);
loc, rewriter, targetInfo, llvmOpCount);
}

struct LocalAllocOpConversion
Expand Down Expand Up @@ -200,12 +199,15 @@ struct LocalStoreOpConversion
public:
using ConvertOpToLLVMPattern<
triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern;
using BackendCallbackType =
decltype(BackendCallbacks::localStoreOpConversion);

LocalStoreOpConversion(const LLVMTypeConverter &converter,
const TargetInfoBase &targetInfo,
BackendCallbackType backendCallback,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp>(converter, benefit),
targetInfo(targetInfo) {}
targetInfo(targetInfo), backendCallback(backendCallback) {}

LogicalResult
matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
Expand All @@ -215,24 +217,36 @@ struct LocalStoreOpConversion
getTypeConverter()->convertType(op.getDst().getType().getElementType());
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter);

std::pair<size_t, Type> llvmOpCount;
lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(),
adaptor.getSrc(), smemObj, getTypeConverter(),
rewriter, targetInfo);
rewriter, targetInfo, &llvmOpCount);

if (backendCallback)
(backendCallback)(op, llvmOpCount.first, llvmOpCount.second);

rewriter.eraseOp(op);
return success();
}

private:
const TargetInfoBase &targetInfo;
BackendCallbackType backendCallback;
};

} // namespace

void mlir::triton::populateMemoryOpToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit) {
RewritePatternSet &patterns, PatternBenefit benefit,
std::optional<BackendCallbacks> backendCallbacks) {
patterns.add<LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, benefit);

auto backendCall =
backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr;
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, backendCall,
benefit);
}
8 changes: 7 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
Type elemLlvmTy, ArrayRef<Value> srcVals,
Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target) {
const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount) {
bool success = emitTransferBetweenRegistersAndShared(
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase,
dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) {
Expand All @@ -418,7 +419,12 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
store(vec, vecAddr)
.setAlignment(vecTy.getNumElements() *
elemLlvmTy.getIntOrFloatBitWidth() / 8);
if (llvmOpCount) {
Comment thread
antiagainst marked this conversation as resolved.
++(llvmOpCount->first);
llvmOpCount->second = vecTy;
}
});

if (!success)
llvm::report_fatal_error("Failed to emit transfer from register to shared");
}
Expand Down
103 changes: 103 additions & 0 deletions test/TritonGPU/amd/amd-instruction-sched.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=ck_v3' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_CKV3_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2

module {
// INSERT_IGLP0-LABEL: @test_dot_op
// INSERT_IGLP1-LABEL: @test_dot_op
// INSTR_COUNT_NS1-LABEL: @test_dot_op
// INSTR_COUNT_NS2-LABEL: @test_dot_op
// LABELING_PS_1-LABEL: @test_dot_op
// LABELING_PS_2-LABEL: @test_dot_op
tt.func @test_dot_op(%lb : index, %ub : index, %step : index,
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32},
%C : !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// A ptrs
%a_ptr_splat = tt.splat %A : !tt.ptr<f16> -> tensor<128x32x!tt.ptr<f16>>
%a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32>
%a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
%a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32> -> tensor<128x32xi32>
%a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr<f16>>, tensor<128x32xi32>
// B ptrs
%b_ptr_splat = tt.splat %B : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>>
%b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32>
%b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32>
%b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32> -> tensor<32x128xi32>
%b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr<f16>>, tensor<32x128xi32>

%a_mask = arith.constant dense<true> : tensor<128x32xi1>
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16>
%b_mask = arith.constant dense<true> : tensor<32x128xi1>
%b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16>
%c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32>

%a_off = arith.constant dense<4> : tensor<128x32xi32>
%b_off = arith.constant dense<4> : tensor<32x128xi32>

%loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>>, tensor<32x128x!tt.ptr<f16>>, tensor<128x128xf32>) {
%a = tt.load %a_ptr : tensor<128x32x!tt.ptr<f16>>
%b = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr<f16>>

// INSERT_IGLP0: rocdl.iglp.opt 0
// INSERT_IGLP1: rocdl.iglp.opt 1

// INSTR_COUNT_NS1: amdgpu.instruction_sched_hint
// INSTR_COUNT_NS1-SAME: isBufferLoadsAEnabled = false
// INSTR_COUNT_NS1-SAME: isBufferLoadsBEnabled = false
// INSTR_COUNT_NS1-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>>
// INSTR_COUNT_NS1-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>>
// INSTR_COUNT_NS1-SAME: numDsWritesA = #amdgpu.InstCounter<0, none>
// INSTR_COUNT_NS1-SAME: numDsWritesB = #amdgpu.InstCounter<0, none>
// INSTR_COUNT_NS1-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS1-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS1-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>>

// INSTR_COUNT_NS2: amdgpu.instruction_sched_hint
// INSTR_COUNT_NS2-SAME: isBufferLoadsAEnabled = false
// INSTR_COUNT_NS2-SAME: isBufferLoadsBEnabled = false
// INSTR_COUNT_NS2-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>>
// INSTR_COUNT_NS2-SAME: numDsWritesA = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numDsWritesB = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>>

// USE_CKV3_GLOBAL_LOAD: [lower-insert-instruction-sched-hints]
// USE_CKV3_GLOBAL_LOAD-SAME: Skipping instruction scheduling because `ck_v3` scheduling can be used only with `buffer_load` instructions.

// LABELING_PS_1: scf.for
// LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>}
// LABELING_PS_1: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>}
// LABELING_PS_1: %[[REG1_OP0:.+]] = triton_gpu.convert_layout %[[REG0_OP0]]
// LABELING_PS_1: %[[REG1_OP1:.+]] = triton_gpu.convert_layout %[[REG0_OP1]]
// LABELING_PS_1: tt.dot %[[REG1_OP0]], %[[REG1_OP1]], {{.*}}

// LABELING_PS_2: scf.for
// LABELING_PS_2: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>}
// LABELING_PS_2: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>}
// LABELING_PS_2: triton_gpu.local_store %[[REG0_OP0]], %{{.*}} {OpIdx = #amdgpu.OpIdx<0>}
// LABELING_PS_2: triton_gpu.local_store %[[REG0_OP1]], %{{.*}} {OpIdx = #amdgpu.OpIdx<1>}

%c = tt.dot %a, %b, %prev_c : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>>, tensor<128x32xi32>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>>, tensor<32x128xi32>
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>>, tensor<32x128x!tt.ptr<f16>>, tensor<128x128xf32>
}

// C ptrs
%c_ptr_splat = tt.splat %C : !tt.ptr<f32> -> tensor<128x128x!tt.ptr<f32>>
%c_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32>
%c_tmp1 = tt.expand_dims %c_tmp0 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32>
%c_offs = tt.broadcast %c_tmp1 : tensor<1x128xi32> -> tensor<128x128xi32>
%c_ptr = tt.addptr %c_ptr_splat, %c_offs : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>

tt.store %c_ptr, %loop#2 : tensor<128x128x!tt.ptr<f32>>
tt.return
}
}
2 changes: 1 addition & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def make_llir(src, metadata, options):
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.instruction_sched_variant)
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.num_stages, options.instruction_sched_variant)
Comment thread
antiagainst marked this conversation as resolved.
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,31 @@ class TritonAMDGPU_Attr<string name, list<Trait> traits = [],
: AttrDef<TritonAMDGPU_Dialect, name, traits, baseCppClass> {
}

def TritonAMDGPU_OpIdxAttr : TritonAMDGPU_Attr<"OpIdx"> {
let cppNamespace = "::mlir::triton::amdgpu";
let mnemonic = "OpIdx";
let summary = "An operand index attribute.";
let description = [{
The attribute is a way to describe which input argument of the target
operation (e.g., `tt.dot`) the result of a given operation belongs to.
}];

let parameters = (ins "uint32_t":$value);
let assemblyFormat = "`<` $value `>`";
}

def TritonAMDGPU_InstCounter : TritonAMDGPU_Attr<"InstCounter"> {
let cppNamespace = "::mlir::triton::amdgpu";
let mnemonic = "InstCounter";
let summary = "An instruction counter attribute.";
let description = [{
The attribute holds the number of issued LLVM instructions of a specific kind as well as
the data type.
}];

let parameters = (ins "uint32_t":$value, "Type":$type);
let assemblyFormat = "`<` params `>`";
}


#endif
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def TritonAMDGPU_Dialect : Dialect {
}];

let dependentDialects = [];

let useDefaultAttributePrinterParser = 1;
let usePropertiesForAttributes = 1;
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,29 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
interleave for better instruction level parallelism.
}];

let assemblyFormat = [{attr-dict}];
let arguments = (ins
TritonAMDGPU_InstCounter:$numDsReadsA,
TritonAMDGPU_InstCounter:$numDsReadsB,
TritonAMDGPU_InstCounter:$numDsWritesA,
TritonAMDGPU_InstCounter:$numDsWritesB,
TritonAMDGPU_InstCounter:$numGlobalLoadsA,
TritonAMDGPU_InstCounter:$numGlobalLoadsB,
BoolAttr:$isBufferLoadsAEnabled,
BoolAttr:$isBufferLoadsBEnabled,
TritonAMDGPU_InstCounter:$numMMAs
);

let builders = [
OpBuilder<(ins), [{
auto ctx = $_state.getContext();
auto noneType = NoneType::get(ctx);
auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, noneType);
build($_builder, $_state, emptyAttr, emptyAttr, emptyAttr, emptyAttr,
emptyAttr, emptyAttr, false, false, emptyAttr);
}]>
];

let assemblyFormat = [{ attr-dict }];
}

//
Expand Down
Loading