diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index f7bcd24d5403..cfba6d7225b8 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -404,6 +404,7 @@ jobs: echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py + pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py cd python/test/unit pytest --capture=tee-sys -rfs -n 16 language runtime \ --ignore=language/test_line_info.py \ diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index c587cfc27ae8..7da4aa079327 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -402,6 +402,7 @@ jobs: echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py + pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py cd python/test/unit pytest --capture=tee-sys -rfs -n 16 language runtime \ --ignore=language/test_line_info.py \ diff --git a/test/Conversion/amd/invalid_extractslice_to_llvm.mlir b/test/Conversion/amd/invalid_extractslice_to_llvm.mlir new file mode 100644 index 000000000000..e561dfb26905 --- /dev/null +++ b/test/Conversion/amd/invalid_extractslice_to_llvm.mlir @@ -0,0 +1,111 @@ +// RUN: triton-opt -split-input-file %s --convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics + +// Invalid size +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{sizes [256, 2] must be a multiple of shapePerCTATile [256, 16]}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x2xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid zero source dimension +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_size_input(%arg0: tensor<256x0xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{source tensor dimension size zero at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x0xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid zero result dimension +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result tensor dimension size zero at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x0xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid offset, not multiple of shapePerTile +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTATile [256, 16]}} + %1 = amdgpu.extract_slice %arg0 [0,5] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid offset, out of bounds for dimension +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{invalid offset 128 at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,128] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid result layout +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_layout(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result layout must match source layout}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked2> + tt.return +} + +// ----- + +// Invalid result element type +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_element_type(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result element type must match source element type}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi64, #blocked1> + tt.return +} + +// ----- + +// Invalid result rank +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result rank must be equal to source rank}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid result shape +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result shape cannot be larger than input shape at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x256xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid rank +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_rank(%arg0: tensor<256x128x2xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{currently only 2D tensors are supported}} + %1 = amdgpu.extract_slice %arg0 [0,0,0] : tensor<256x128x2xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid non static offset +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_non_static_offset(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}, %arg1: i32) { + // expected-error @+2 {{expected ']'}} + // expected-error @+1 {{expected integer value}} + %2 = amdgpu.extract_slice %arg0 [%arg1, 0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} diff --git a/test/TritonGPU/amd/amd-extractslice-op.mlir b/test/TritonGPU/amd/amd-extractslice-op.mlir new file mode 100644 index 000000000000..ef47a9f9b434 --- /dev/null +++ b/test/TritonGPU/amd/amd-extractslice-op.mlir @@ -0,0 +1,14 @@ +// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s + +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @basic_insert_slice(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // CHECK: llvm.func @basic_insert_slice + // CHECK-COUNT-64: %{{[0-9]*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %64 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-8: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + %72 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return + } +} diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 68c50d48635b..0b865cd1b8af 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -31,10 +31,12 @@ include "mlir/IR/EnumAttr.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "TritonAMDGPUDialect.td" include "TritonAMDGPUAttrDefs.td" + class TT_AMDGPU_Op traits = []> : Op { } @@ -44,6 +46,74 @@ class TT_AMDGPU_Op traits = []> : // def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +//===----------------------------------------------------------------------===// +// ExtractSliceOp +//===----------------------------------------------------------------------===// + +def ExtractSliceOp + : TT_AMDGPU_Op<"extract_slice", [Pure]> { + let summary = "extract slice operation"; + let description = [{ + The "extract_slice" operation enables extracting a slice of a tensor in + registers. + + The "extract_slice" operation supports the following arguments: + + * source: the base tensor on which to create a view tensor + * offsets: offsets into the base tensor at which to create the view + + Example 1: + + ```mlir + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], + threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [0, 1]}> + #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], + threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}> + %1 = triton_gpu.convert_layout %0 : tensor<128x128xf16, #blocked> + -> tensor<128x128xf16, #blocked1> + // create a slice of base tensor %1 with static offsets + %2 = amdgpu.extract_slice %0 [0, 0] : + tensor<128x128xf16, #blocked1> to tensor<128x32xf16, #blocked1> + ``` + + Example 1 shows how "extract_slice" operation may be used. In this example a + new slice of 128x32 is created. "extract_slice" works on tensors with layout + where the desired slice has the same layout as the source tensor. + "%0" cannot be sliced directly as the resulting slice cannot have the same + layout as "%0". Therefore it needs to be converted to a layout suitable + for slicing. "#blocked1" layout is appropriate for this as it keeps the + sizePerThread the same thus keeping coalescing properties the same. + In order to utilize all threads in a warp, "threadsPerWarp" is set to + [16,4] for this new layout. This layout conversion carried out before + using "extract_slice" ensures slicing still uses all threads efficiently. The + size of the slice is determined by the result type. + }]; + + let arguments = (ins AnyRankedTensor:$source, + DenseI64ArrayAttr:$static_offsets); + let results = (outs AnyRankedTensor:$result); + + let builders = [ + // Build a ExtractSliceOp with static offsets and the same result type + OpBuilder<(ins "RankedTensorType":$resultType, + "Value":$source, + "ArrayRef": $static_offsets)>, + ]; + + let extraClassDeclaration = [{ + std::array getArrayAttrMaxRanks() { + unsigned rank = getSource().getType().getRank(); + return {rank, rank, rank}; + } + }]; + + let assemblyFormat = [{ + $source $static_offsets attr-dict `:` type($source) `to` type($result) + }]; + + let hasVerifier = 1; +} + def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { let summary = "A placeholder op for instruction scheduling hints within a basic block"; let description = [{ diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h new file mode 100644 index 000000000000..90922e802988 --- /dev/null +++ b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h @@ -0,0 +1,14 @@ +#ifndef TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H +#define TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" + +namespace mlir::triton::AMD { + +void populateExtractSliceOpToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns, + mlir::PatternBenefit benefit); + +} + +#endif diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 1e429fdc39a9..7c2473dbe56f 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -24,10 +24,10 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/OperationSupport.h" - #include "llvm/ADT/TypeSwitch.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + // clang-format off #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "Dialect/TritonAMDGPU/IR/Dialect.cpp.inc" @@ -53,3 +53,83 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { #define GET_OP_CLASSES #include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc" + +namespace mlir::triton::amdgpu { + +LogicalResult ExtractSliceOp::verify() { + auto srcTy = getSource().getType(); + auto srcLayout = srcTy.getEncoding(); + auto srcElementType = getElementTypeOrSelf(srcTy); + auto resultTy = getResult().getType(); + auto resultLayout = resultTy.getEncoding(); + auto resultElementType = getElementTypeOrSelf(resultTy); + + if (srcElementType != resultElementType) { + return emitError("result element type must match source element type"); + } + if (srcLayout != resultLayout) { + return emitError("result layout must match source layout"); + } + if (srcTy.getRank() != resultTy.getRank()) { + return emitError("result rank must be equal to source rank"); + } + if (srcTy.getRank() != 2) { + return emitError("currently only 2D tensors are supported"); + } + + auto srcShape = srcTy.getShape(); + auto shapePerCTATile = + mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape); + shapePerCTATile[0] = + std::min(static_cast(srcShape[0]), shapePerCTATile[0]); + shapePerCTATile[1] = + std::min(static_cast(srcShape[1]), shapePerCTATile[1]); + + // ExtractSlice only supports slicing where offsets and sizes are multiples of + // shapePerCTATile. This condition ensures that slice has the same layout as + // the original tensor. + + auto offsets = getStaticOffsets(); + if (offsets.size() != 2) { + return emitError("invalid offset shape ") << offsets; + } + + SmallVector sizes; + for (auto i = 0; i < 2; ++i) { + auto resultDimSize = resultTy.getDimSize(i); + auto srcDimSize = srcTy.getDimSize(i); + if (resultDimSize == 0) { + return emitError("result tensor dimension size zero at dimension ") << i; + } + if (srcDimSize == 0) { + return emitError("source tensor dimension size zero at dimension ") << i; + } + if (resultDimSize > srcDimSize) { + return emitError( + "result shape cannot be larger than input shape at dimension ") + << i; + } + if (offsets[i] + resultDimSize > srcDimSize) { + return emitError("invalid offset ") + << offsets[i] << " at dimension " << i; + } + sizes.push_back(resultDimSize); + } + + if (sizes[0] % shapePerCTATile[0] != 0 || + sizes[1] % shapePerCTATile[1] != 0) { + return emitError() << "sizes [" << sizes + << "] must be a multiple of shapePerCTATile [" + << shapePerCTATile << "]"; + } + + if (offsets[0] % shapePerCTATile[0] != 0 || + offsets[1] % shapePerCTATile[1] != 0) { + return emitError() << "offset [" << offsets + << "] must be a multiple of shapePerCTATile [" + << shapePerCTATile << "]"; + } + + return success(); +} +} // namespace mlir::triton::amdgpu diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt index e6da8f28777e..4aebabc0a275 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(TritonAMDGPUDialectToLLVM TritonAMDGPUToLLVMPatterns.cpp + ExtractSliceOpToLLVM.cpp DEPENDS TritonAMDGPUIR diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp new file mode 100644 index 000000000000..c0100812f299 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp @@ -0,0 +1,143 @@ +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +// clang-format off +/*** + # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # + # WO # W1 # | # + # # # | # + # # # # # | # + # W2 # W3 # .... | # + # # # | SkipElems # + # # # # # | # + # | # + # Slice | # + # . / \ | # + # . / \ | # + # . / \| # + # # # # # # # + # # W0 # W1 # # + # # # # # + # # # # # # tensorStride # + # # W2 # W3 # --------------------------------# + # # # # # + # # # # # # # + # tensorStride # W0 # W1 # # + # ---------------------------------- # # # # + # # # # # # # + # # W2 # W3 # # + # # # # # + # # # # # # ---> lastIdx # + # . # + # . # + # . # + # # + # # + # # + # # + # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +***/ +// clang-format on + +namespace { +struct ExtractSliceOpConversion + : public ConvertOpToLLVMPattern { + explicit ExtractSliceOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit) { + } + + LogicalResult processLayout(amdgpu::ExtractSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto srcTy = cast(op.getSource().getType()); + auto srcLayout = srcTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto resultTy = cast(op.getType()); + auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter); + auto elemsPerThread = triton::gpu::getElemsPerThread(srcTy); + auto sizePerThread = triton::gpu::getSizePerThread(srcLayout); + auto totalSizePerThread = product(sizePerThread); + auto order = triton::gpu::getOrder(srcLayout); + + // Calculate valid total number of workers in each dimension + auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout, srcShape); + shapePerCTATile[0] = + std::min(static_cast(srcShape[0]), shapePerCTATile[0]); + shapePerCTATile[1] = + std::min(static_cast(srcShape[1]), shapePerCTATile[1]); + + // Rank == 2 checked in the verifier + SmallVector sizes; + for (auto i = 0; i < 2; ++i) { + sizes.push_back(resultTy.getDimSize(i)); + } + + auto offsets = op.getStaticOffsets(); + + // Calculate offsets and sizes in terms of CTA units. + std::array CTAOffsets{offsets[0] / shapePerCTATile[0], + offsets[1] / shapePerCTATile[1]}; + std::array CTASizes{sizes[0] / shapePerCTATile[0], + sizes[1] / shapePerCTATile[1]}; + std::array CTAPerShape{srcShape[0] / shapePerCTATile[0], + srcShape[1] / shapePerCTATile[1]}; + + // The diagram above illustrates the graphical representation of the + // skipElems, tensorStride, and lastIdx variables. + auto skipElems = CTAOffsets[order[1]] * + (elemsPerThread[order[0]] * sizePerThread[order[1]]) + + CTAOffsets[order[0]] * totalSizePerThread; + auto tensorStride = + (CTAPerShape[order[0]] - CTASizes[order[0]]) * totalSizePerThread; + auto lastIdx = + (CTAOffsets[order[1]] + CTASizes[order[1]] - 1) * + elemsPerThread[order[0]] * sizePerThread[order[1]] + + (CTAOffsets[order[0]] + CTASizes[order[0]]) * totalSizePerThread; + + assert(lastIdx <= vals.size()); + + SmallVector resultVals; + for (int i = skipElems; i < lastIdx; i += tensorStride) { + for (int j = 0; j < totalSizePerThread * CTASizes[order[0]]; ++j, ++i) { + assert(i < lastIdx); + resultVals.push_back(vals[i]); + } + } + Value ret = packLLElements(loc, this->getTypeConverter(), resultVals, + rewriter, resultTy); + + rewriter.replaceOp(op, ret); + return success(); + } + + LogicalResult + matchAndRewrite(amdgpu::ExtractSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcTy = op.getSource().getType(); + if (isa( + op.getSource().getType().getEncoding())) { + return processLayout(op, adaptor, rewriter); + } + return failure(); + } +}; +} // namespace + +namespace mlir::triton::AMD { + +void populateExtractSliceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} +} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp index 5d172fea9cfa..c7c2f56d31de 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp @@ -1,9 +1,10 @@ +#include "third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" namespace mlir::triton::AMD { void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { - // TODO: Insert TrtionAMDGPU dialect patterns. + populateExtractSliceOpToLLVMPatterns(typeConverter, patterns, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/python/test/test_extract_slice.py b/third_party/amd/python/test/test_extract_slice.py new file mode 100644 index 000000000000..bca36038405d --- /dev/null +++ b/third_party/amd/python/test/test_extract_slice.py @@ -0,0 +1,115 @@ +import tempfile + +import numpy as np +import pytest +import torch + +import triton +import triton.language as tl + +from triton._internal_testing import is_hip + +num_ctas_list = [1] + +GPU_DIALECT = "triton_gpu" + +if is_hip(): + THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size +else: + THREADS_PER_WARP = 32 + + +class BlockedLayout: + + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): + self.sz_per_thread = size_per_thread + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +# ----------------------- +# test extract slice +# ----------------------- + +extract_layout = [ + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [64, 1], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] +blocked_layout = [ + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] + + +@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", + [[256, 256, 256, 32, 0, 32], [128, 128, 128, 64, 0, 64]]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("extract_layout", extract_layout) +@pytest.mark.parametrize("blocked_layout", blocked_layout) +def test_extract_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, + extract_layout, device='cuda'): + if not is_hip(): + pytest.skip("extract_slice is AMD specific instruction.") + + ir = f""" + #blocked = {blocked_layout} + #extract_layout = {extract_layout} + module attributes {{"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = {str(64)} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> + %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %43 = tt.expand_dims %42 {{axis = 1 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M_tile_size}x1xi32, #blocked> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked> + %44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{M}xi32, #blocked> + %7 = tt.broadcast %6 : tensor<1x{M}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked> + %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> + %34 = tt.splat %arg1 : !tt.ptr -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + %37 = tt.expand_dims %33 {{axis = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N_tile_size}xi32, #blocked> + %38 = tt.broadcast %37 : tensor<1x{N_tile_size}xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %39 = tt.broadcast %44 : tensor<{M_tile_size}x1xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + %12 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #extract_layout> + %13 = amdgpu.extract_slice %12 [{M_tile_offset}, {N_tile_offset}] : tensor<{M}x{N}xf16, #extract_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> + %14 = triton_gpu.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> + %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + x = torch.randn((M, N), device=device, dtype=torch.float16) + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + extract_slice = torch.empty((M_tile_size, N_tile_size), device=device, dtype=torch.float16) + + kernel[(1, 1, 1)](x.data_ptr(), extract_slice) + test_result = torch.equal(x[M_tile_offset:M_tile_size + M_tile_offset, N_tile_offset:N_tile_offset + N_tile_size], + extract_slice) + assert test_result