Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions test/TritonGPU/amd/amd-instruction-sched.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=ck_v3' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_CKV3_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local-prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2

Expand Down Expand Up @@ -68,8 +68,8 @@ module {
// INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>>

// USE_CKV3_GLOBAL_LOAD: [lower-insert-instruction-sched-hints]
// USE_CKV3_GLOBAL_LOAD-SAME: Skipping instruction scheduling because `ck_v3` scheduling can be used only with `buffer_load` instructions.
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: [lower-insert-instruction-sched-hints]
// USE_LOCAL_PREFETCH_GLOBAL_LOAD-SAME: skipping `local-prefetch` scheduling given it needs `buffer_load` instructions

// LABELING_PS_1: scf.for
// LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>}
Expand Down
22 changes: 21 additions & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ class HIPOptions:
# instruction scheduling of the AMDGPU backend which aims at maximizing occupancy.
# The option is experimental and may change at any time regarding its semantics and/or may
# be gone entirely anytime.
#
# Current experimental scheduling variants:
#
# llvm-iglp-0: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `0` to the GEMM's
# k-loop; i.e., "interleave DS and MFMA instructions for small GEMM kernels".
# llvm-iglp-1: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `1` to the GEMM's
# k-loop; i.e., "interleave DS and MFMA instructions for single wave small
# GEMM kernels.".
# local-prefetch: implements instruction scheduling similar to the one from the ROCm Composable
# Kernel library. Note, this variant requires the use of buffer load/store ops
# and a special software pipelining style - i.e., 1x LDS and 1x register
# prefetch buffers for each GEMM tile.
instruction_sched_variant: str = 'none'

def __post_init__(self):
Expand Down Expand Up @@ -215,6 +227,7 @@ def make_ttgir(mod, metadata, options):
passes.ttgpuir.add_remove_layout_conversions(pm)
amd.passes.ttgpuir.add_optimize_epilogue(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)

if amd.has_matrix_core_feature(options.arch):
assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. "
"We used to trigger software pipelining with "
Expand All @@ -229,7 +242,14 @@ def make_ttgir(mod, metadata, options):
passes.ttgpuir.add_reduce_data_duplication(pm)
if amd.has_matrix_core_feature(options.arch):
amd.passes.ttgpuir.add_reorder_instructions(pm)
if os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1":

use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"

# The `local-prefetch` scheduling variant requires turning on buffer ops.
if options.instruction_sched_variant == "local-prefetch":
use_buffer_ops = True

if use_buffer_ops:
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
passes.common.add_canonicalizer(pm)
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm)
Expand Down
55 changes: 24 additions & 31 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "llvm/TargetParser/TargetParser.h"

namespace mlir::triton {
#define GEN_PASS_DEF_TRITONAMDGPUINSERTINSTRUCTIONSCHEDHINTS
Expand Down Expand Up @@ -221,12 +220,13 @@ struct InstructionSchedHintsRewriter
std::transform(variant.begin(), variant.end(), variant.begin(),
[](unsigned char c) { return std::tolower(c); });

this->schedulingType = llvm::StringSwitch<SchedulingType>(variant)
.Case("none", SchedulingType::NONE)
.Case("iglp0", SchedulingType::IGLP0)
.Case("iglp1", SchedulingType::IGLP1)
.Case("ck_v3", SchedulingType::CK_V3)
.Default(SchedulingType::UNKNOWN);
this->schedulingType =
llvm::StringSwitch<SchedulingType>(variant)
.Case("none", SchedulingType::NONE)
.Case("llvm-iglp-0", SchedulingType::LLVM_IGLP_0)
.Case("llvm-iglp-1", SchedulingType::LLVM_IGLP_1)
.Case("local-prefetch", SchedulingType::LOCAL_PREFETCH)
.Default(SchedulingType::UNKNOWN);

if (this->numStages < 2) {
this->schedulingType = SchedulingType::NONE;
Expand All @@ -237,26 +237,24 @@ struct InstructionSchedHintsRewriter

enum class SchedulingType : uint32_t {
NONE = 0,
IGLP0,
IGLP1,
CK_V3,
LLVM_IGLP_0,
LLVM_IGLP_1,
LOCAL_PREFETCH,
UNKNOWN
};

// This is the implementation of the CK's V3 pipelining (see
// see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp).
// The following is inspired by ROCm Composable Kernel library's V3 pipelining
// (see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp).
// This scheduling requires 1x register and 1x LDS buffers combined with the
// local (LDS to registers) and global (HBM to registers) data prefetching.
// see:
// include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.h
void
createCKV3Schedule(PatternRewriter &rewriter, Location loc,
triton::amdgpu::InstructionSchedHint schedHint) const {
void createLocalPrefetchSchedule(
PatternRewriter &rewriter, Location loc,
triton::amdgpu::InstructionSchedHint schedHint) const {

if (!(schedHint.getIsBufferLoadsAEnabled() &&
schedHint.getIsBufferLoadsBEnabled())) {
LDBG("Skipping instruction scheduling because `ck_v3` "
"scheduling can be used only with `buffer_load` instructions.");
LDBG("skipping `local-prefetch` scheduling given it needs `buffer_load` "
"instructions");
return;
}

Expand Down Expand Up @@ -435,8 +433,8 @@ struct InstructionSchedHintsRewriter
// backend documentation.
const bool limitSchedulingRange =
!(schedulingType == SchedulingType::NONE ||
schedulingType == SchedulingType::IGLP0 ||
schedulingType == SchedulingType::IGLP1);
schedulingType == SchedulingType::LLVM_IGLP_0 ||
schedulingType == SchedulingType::LLVM_IGLP_1);
Location loc = instructionSchedHint->getLoc();
Block *block = instructionSchedHint->getBlock();
if (limitSchedulingRange) {
Expand All @@ -448,22 +446,17 @@ struct InstructionSchedHintsRewriter
rewriter.setInsertionPoint(block, std::prev(block->end()));

switch (schedulingType) {
case SchedulingType::IGLP0:
[[fallthrough]];
case SchedulingType::IGLP1: {
case SchedulingType::LLVM_IGLP_0:
case SchedulingType::LLVM_IGLP_1:
createIglpOpt(rewriter, loc, static_cast<int>(schedulingType) - 1);
break;
}
case SchedulingType::CK_V3: {
createCKV3Schedule(rewriter, loc, instructionSchedHint);
case SchedulingType::LOCAL_PREFETCH:
createLocalPrefetchSchedule(rewriter, loc, instructionSchedHint);
break;
}
case SchedulingType::NONE:
[[fallthrough]];
default: {
default:
break;
}
}

if (limitSchedulingRange)
createSchedBarrier(rewriter, loc,
Expand Down
Loading