From 380f84880f30f227994efa2bd31ac70fdbd65cd3 Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Fri, 20 Sep 2024 23:00:26 +0000 Subject: [PATCH 01/15] [AMD] Adds Support For ViewSlice Operation Introduces a new operation for amdgpus to slice a tensor in memory - Adds new TritonAMDGPUDialect operation ViewSliceOp - Adds verifier for ViewSliceOp - Adds conversion of the operation to llvm --- .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 58 +++++++++++ .../PatternTritonAMDGPUToLLVM.h | 18 ++++ .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 49 +++++++++- .../TritonAMDGPUDialectToLLVM/CMakeLists.txt | 1 + .../TritonAMDGPUToLLVMPatterns.cpp | 4 +- .../ViewSliceOpToLLVM.cpp | 98 +++++++++++++++++++ 6 files changed, 226 insertions(+), 2 deletions(-) create mode 100644 third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h create mode 100644 third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 68c50d48635b..b291c956cff7 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -34,6 +34,9 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "TritonAMDGPUDialect.td" include "TritonAMDGPUAttrDefs.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/ViewLikeInterface.td" // OffsetSizeAndStrideOpInterface + class TT_AMDGPU_Op traits = []> : Op { @@ -43,6 +46,61 @@ class TT_AMDGPU_Op traits = []> : // Interfaces // def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +def TritonAMDGPU_ViewSliceOp : TritonAMDGPU_Op<"view_slice", + [AttrSizedOperandSegments, + Pure, + OffsetSizeAndStrideOpInterface + ]> { + let summary = "view slice operation"; + let description = [{ + Represents view of the slice of the tensor in registers. Syntax of the operation is the same + as for extract_slice op. However, unlike 'extract_slice' which slices in shared memory, + 'view_slice' specifically slices within registers. + Slice of the tensor is required to have the same layout as the original tensor. + In a way, semantics of the 'view_slice' operation is a combination of the 'extract_slice' and 'view' operations semantics. + }]; + + let arguments = (ins + AnyRankedTensor:$source, + Variadic:$offsets, + Variadic:$sizes, + Variadic:$strides, + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides + ); + let results = (outs AnyRankedTensor:$result); + + let builders = [ + // Build a ViewSliceOp with mixed static and dynamic entries and the same + // result type + OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source, + "ArrayRef":$offsets, "ArrayRef":$sizes, + "ArrayRef":$strides, + CArg<"ArrayRef", "{}">:$attrs)>, + ]; + + let extraClassDeclaration = [{ + /// Return the number of leading operands before the `offsets`, `sizes` and + /// and `strides` operands. + static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } + + std::array getArrayAttrMaxRanks() { + unsigned rank = getSource().getType().getRank(); + return {rank, rank, rank}; + } + }]; + + let assemblyFormat = [{ + $source `` + custom($offsets, $static_offsets) + custom($sizes, $static_sizes) + custom($strides, $static_strides) + attr-dict `:` type($source) `to` type($result) + }]; + + let hasVerifier = 1; +} def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { let summary = "A placeholder op for instruction scheduling hints within a basic block"; diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h new file mode 100644 index 000000000000..ecdce2f47b61 --- /dev/null +++ b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h @@ -0,0 +1,18 @@ +#ifndef THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H +#define THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Analysis/AxisInfo.h" + +using namespace mlir; +// using namespace mlir::triton; + +namespace mlir::triton::AMD { + +void populateViewSliceOpTritonAMDGPUToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit); + +} + +#endif \ No newline at end of file diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 1e429fdc39a9..4fd913848b87 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -25,9 +25,10 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" - #include "llvm/ADT/TypeSwitch.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + // clang-format off #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "Dialect/TritonAMDGPU/IR/Dialect.cpp.inc" @@ -53,3 +54,49 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { #define GET_OP_CLASSES #include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc" + +namespace mlir::triton::amdgpu { + +LogicalResult ViewSliceOp::verify() { + auto srcTy = getSource().getType(); + auto srcLayout = srcTy.getEncoding(); + auto srcElementType = dyn_cast(srcTy).getElementType(); + auto resultTy = getResult().getType(); + auto resultLayout = resultTy.getEncoding(); + auto resultElementType = + dyn_cast(resultTy).getElementType(); + + if (srcElementType != resultElementType) { + return emitError("result type must match source type"); + } + + if (srcLayout != resultLayout) + return emitError("result layout must match source layout"); + + auto srcShape = srcTy.getShape(); + auto shapePerCTA = mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape); + shapePerCTA[0] = std::min(srcShape[0], (long)shapePerCTA[0]); + shapePerCTA[1] = std::min(srcShape[1], (long)shapePerCTA[1]); + + auto offsets = getStaticOffsets(); + auto sizes = getStaticSizes(); + + // ViewSlice only supports slicing where offsets and sizes are multiples of + // shapePerCTA. This condition ensures that slice has the same layout as the + // original tensor. + + if (offsets[0] % shapePerCTA[0] != 0 || offsets[1] % shapePerCTA[1] != 0) { + return emitError("incorrect static offset"); + } + + if (sizes[0] % shapePerCTA[0] != 0 || sizes[1] % shapePerCTA[1] != 0) { + return emitError("incorrect static size"); + } + + if (!hasUnitStride()) { + return emitError("unsupported stride"); + } + + return success(); +} +} // namespace mlir::triton::amdgpu \ No newline at end of file diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt index e6da8f28777e..aab42744aed2 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(TritonAMDGPUDialectToLLVM TritonAMDGPUToLLVMPatterns.cpp + ViewSliceOpToLLVM.cpp DEPENDS TritonAMDGPUIR diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp index 5d172fea9cfa..786043d1078f 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp @@ -1,9 +1,11 @@ +#include "third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" namespace mlir::triton::AMD { void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { - // TODO: Insert TrtionAMDGPU dialect patterns. + populateViewSliceOpTritonAMDGPUToLLVMPatterns(typeConverter, patterns, + benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp new file mode 100644 index 000000000000..44f9f1496c2e --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp @@ -0,0 +1,98 @@ +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; +namespace tta = mlir::triton::amdgpu; + +namespace { +struct ViewSliceOpConversion : public ConvertOpToLLVMPattern { + explicit ViewSliceOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + + LogicalResult processLayout(tta::ViewSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto srcTy = dyn_cast(op.getSource().getType()); + auto srcLayout = srcTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto resultTy = cast(op.getType()); + auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter); + auto elemsPerThread = mlir::triton::gpu::getElemsPerThread(srcTy); + auto sizePerThread = getSizePerThread(srcLayout); + auto totalSizePerThread = sizePerThread[0] * sizePerThread[1]; + auto order = getOrder(srcLayout); + auto shapePerCTA = getShapePerCTATile(srcLayout, srcShape); + shapePerCTA[0] = std::min(srcShape[0], (long)shapePerCTA[0]); + shapePerCTA[1] = std::min(srcShape[1], (long)shapePerCTA[1]); + + auto offsets = op.getStaticOffsets(); + auto sizes = op.getStaticSizes(); + + // Calculate offsets and sizes in terms of CTA units. + std::vector CTAOffsets{offsets[0] / shapePerCTA[0], + offsets[1] / shapePerCTA[1]}; + std::vector CTASizes{sizes[0] / shapePerCTA[0], + sizes[1] / shapePerCTA[1]}; + std::vector CTAPerShape{srcShape[0] / shapePerCTA[0], + srcShape[1] / shapePerCTA[1]}; + + // The diagram above illustrates the graphical representation of the + // skipElems, tensorStride, and lastIdx variables. + auto skipElems = CTAOffsets[order[1]] * + (elemsPerThread[order[0]] * sizePerThread[order[1]]) + + CTAOffsets[order[0]] * totalSizePerThread; + auto tensorStride = + (CTAPerShape[order[0]] - CTASizes[order[0]]) * totalSizePerThread; + auto lastIdx = + (CTAOffsets[order[1]] + CTASizes[order[1]] - 1) * + elemsPerThread[order[0]] * sizePerThread[order[1]] + + (CTAOffsets[order[0]] + CTASizes[order[0]]) * totalSizePerThread; + + assert(lastIdx <= vals.size()); + + SmallVector resultVals; + for (int i = skipElems; i < lastIdx; i += tensorStride) { + for (int j = 0; j < totalSizePerThread * CTASizes[order[0]]; ++j, ++i) { + assert(i < lastIdx); + resultVals.push_back(vals[i]); + } + } + Value ret = packLLElements(loc, this->getTypeConverter(), resultVals, + rewriter, resultTy); + + rewriter.replaceOp(op, ret); + return success(); + } + + LogicalResult + matchAndRewrite(tta::ViewSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcTy = op.getSource().getType(); + if (isa(op.getSource().getType().getEncoding()) || + isa(op.getSource().getType().getEncoding())) { + return processLayout(op, adaptor, rewriter); + } else { + assert(false && "Unsupported layout in viewSlice."); + return failure(); + } + } +}; +} // namespace + +namespace mlir::triton::AMD { + +void populateViewSliceOpTritonAMDGPUToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} +} // namespace mlir::triton::AMD From 7769656dd70dbb58bb5437e23754d21b7c0b3823 Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Mon, 23 Sep 2024 20:21:07 +0000 Subject: [PATCH 02/15] Adds lit test --- test/TritonGPU/amd/amd-viewslice-op.mlir | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 test/TritonGPU/amd/amd-viewslice-op.mlir diff --git a/test/TritonGPU/amd/amd-viewslice-op.mlir b/test/TritonGPU/amd/amd-viewslice-op.mlir new file mode 100644 index 000000000000..9c8b5fa167d3 --- /dev/null +++ b/test/TritonGPU/amd/amd-viewslice-op.mlir @@ -0,0 +1,16 @@ +// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s + +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32, triton_gpu.target = "hip:gfx942"} { + tt.func @basic_insert_slice_async_1d(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // CHECK: llvm.func @basic_insert_slice_async_1d + // CHECK: %0 = llvm.extractvalue %arg0[0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-63: %{{[0-9]*}} = llvm.extractvalue %arg0{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %64 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-8: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + %72 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return + } +} + From 2943fc07a7501db53719901e5d84db009165a6a5 Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Tue, 24 Sep 2024 21:28:48 +0000 Subject: [PATCH 03/15] Adds comments and formatting changes --- test/TritonGPU/amd/amd-viewslice-op.mlir | 10 ++--- .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 4 +- .../PatternTritonAMDGPUToLLVM.h | 3 +- .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 2 +- .../ViewSliceOpToLLVM.cpp | 38 +++++++++++++++++++ 5 files changed, 45 insertions(+), 12 deletions(-) diff --git a/test/TritonGPU/amd/amd-viewslice-op.mlir b/test/TritonGPU/amd/amd-viewslice-op.mlir index 9c8b5fa167d3..fdaa1ccbdba3 100644 --- a/test/TritonGPU/amd/amd-viewslice-op.mlir +++ b/test/TritonGPU/amd/amd-viewslice-op.mlir @@ -2,15 +2,13 @@ #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32, triton_gpu.target = "hip:gfx942"} { +module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { tt.func @basic_insert_slice_async_1d(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // CHECK: llvm.func @basic_insert_slice_async_1d - // CHECK: %0 = llvm.extractvalue %arg0[0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> - // CHECK-COUNT-63: %{{[0-9]*}} = llvm.extractvalue %arg0{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-64: %{{[0-9]*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: %64 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> - // CHECK-COUNT-8: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> - %72 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + // CHECK-COUNT-8: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + %72 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> tt.return } } - diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index b291c956cff7..02fcc1971ae9 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -53,11 +53,9 @@ def TritonAMDGPU_ViewSliceOp : TritonAMDGPU_Op<"view_slice", ]> { let summary = "view slice operation"; let description = [{ - Represents view of the slice of the tensor in registers. Syntax of the operation is the same - as for extract_slice op. However, unlike 'extract_slice' which slices in shared memory, + Represents view of the slice of the tensor in registers. However, unlike 'memdesc_subview' which provides a view in shared memory, 'view_slice' specifically slices within registers. Slice of the tensor is required to have the same layout as the original tensor. - In a way, semantics of the 'view_slice' operation is a combination of the 'extract_slice' and 'view' operations semantics. }]; let arguments = (ins diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h index ecdce2f47b61..4e3321af8554 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h @@ -5,7 +5,6 @@ #include "triton/Analysis/AxisInfo.h" using namespace mlir; -// using namespace mlir::triton; namespace mlir::triton::AMD { @@ -15,4 +14,4 @@ void populateViewSliceOpTritonAMDGPUToLLVMPatterns( } -#endif \ No newline at end of file +#endif diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 4fd913848b87..e27bff296d69 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -99,4 +99,4 @@ LogicalResult ViewSliceOp::verify() { return success(); } -} // namespace mlir::triton::amdgpu \ No newline at end of file +} // namespace mlir::triton::amdgpu diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp index 44f9f1496c2e..4238f5349c71 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp @@ -12,6 +12,44 @@ using namespace mlir::triton; using namespace mlir::triton::gpu; namespace tta = mlir::triton::amdgpu; +// clang-format off +/*** + # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # + # WO # W1 # | # + # # # | # + # # # # # | # + # W2 # W3 # .... | # + # # # | SkipElems # + # # # # # | # + # | # + # Slice | # + # . / \ | # + # . / \ | # + # . / \| # + # # # # # # # + # # W0 # W1 # # + # # # # # + # # # # # # tensorStride # + # # W2 # W3 # --------------------------------# + # # # # # + # # # # # # # + # tensorStride # W0 # W1 # # + # ---------------------------------- # # # # + # # # # # # # + # # W2 # W3 # # + # # # # # + # # # # # # ---> lastIdx # + # . # + # . # + # . # + # # + # # + # # + # # + # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +***/ +// clang-format on + namespace { struct ViewSliceOpConversion : public ConvertOpToLLVMPattern { explicit ViewSliceOpConversion(LLVMTypeConverter &typeConverter, From d7d05a86380caa774a4618a0d917b0200a5c6e45 Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Wed, 25 Sep 2024 14:41:34 +0000 Subject: [PATCH 04/15] Changes casting --- third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 4 ++-- .../amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index e27bff296d69..fdc42ccf82d9 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -75,8 +75,8 @@ LogicalResult ViewSliceOp::verify() { auto srcShape = srcTy.getShape(); auto shapePerCTA = mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape); - shapePerCTA[0] = std::min(srcShape[0], (long)shapePerCTA[0]); - shapePerCTA[1] = std::min(srcShape[1], (long)shapePerCTA[1]); + shapePerCTA[0] = std::min(static_cast(srcShape[0]), shapePerCTA[0]); + shapePerCTA[1] = std::min(static_cast(srcShape[1]), shapePerCTA[1]); auto offsets = getStaticOffsets(); auto sizes = getStaticSizes(); diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp index 4238f5349c71..bbdbefb28ee3 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp @@ -69,8 +69,10 @@ struct ViewSliceOpConversion : public ConvertOpToLLVMPattern { auto totalSizePerThread = sizePerThread[0] * sizePerThread[1]; auto order = getOrder(srcLayout); auto shapePerCTA = getShapePerCTATile(srcLayout, srcShape); - shapePerCTA[0] = std::min(srcShape[0], (long)shapePerCTA[0]); - shapePerCTA[1] = std::min(srcShape[1], (long)shapePerCTA[1]); + shapePerCTA[0] = + std::min(static_cast(srcShape[0]), shapePerCTA[0]); + shapePerCTA[1] = + std::min(static_cast(srcShape[1]), shapePerCTA[1]); auto offsets = op.getStaticOffsets(); auto sizes = op.getStaticSizes(); From b835923c7570437dfc9b3bcb8c4745b505580966 Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Wed, 25 Sep 2024 15:57:18 +0000 Subject: [PATCH 05/15] Adds pytest --- python/test/unit/language/test_core.py | 648 +++++++++++-------------- 1 file changed, 276 insertions(+), 372 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a499dc232146..79ce102d2377 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5,7 +5,7 @@ from typing import Optional import math import textwrap -import pathlib +import tempfile import numpy as np import pytest @@ -23,16 +23,11 @@ int_dtypes, uint_dtypes, float_dtypes, - float_dtypes_with_bfloat16, dtypes, dtypes_with_bfloat16, is_cuda, is_interpreter, is_hip, - is_hip_cdna, - is_hip_mi200, - is_hip_mi300, - is_xpu, get_arch, torch_float8_dtypes, torch_dtypes, @@ -71,11 +66,6 @@ def _bitwidth(dtype: str) -> int: return int(re.search(r'(\d+)$', dtype).group(1)) -def _dtype(dtype: str) -> str: - # ex.: "int64" -> "int" - return re.match(r'([a-zA-Z]+)', dtype).group(0) - - def patch_kernel(template, to_replace): if is_interpreter(): local_namespace = {} @@ -149,17 +139,6 @@ def __str__(self): return f"#{GPU_DIALECT}.nvidia_mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" -class DotOperandLayout: - - def __init__(self, parent, op_idx, k_width): - self.parent = parent - self.op_idx = op_idx - self.k_width = k_width - - def __str__(self): - return f"#{GPU_DIALECT}.dot_op<{{parent={self.parent}, opIdx={self.op_idx}, kWidth={self.k_width}}}>" - - class BlockedLayout: def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): @@ -288,8 +267,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, - x_low=None, x_high=None, y_low=None, y_high=None, filter_y=None, test_broadcast=True, - test_scalar=True): + y_low=None, y_high=None, filter_y=None, test_broadcast=True, test_scalar=True): check_type_supported(dtype_x, device) # early return if dtype_x is not supported check_type_supported(dtype_y, device) SIZE = 128 @@ -334,7 +312,7 @@ def kernel_scalar_rhs(Z, X, y: tl.constexpr, SIZE: tl.constexpr): # inputs rs = RandomState(17) - x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs, low=x_low, high=x_high) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high) if filter_y: y[filter_y(y)] = 1 @@ -368,7 +346,7 @@ def do_test(x, y, kernel_fn): z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) err_msg = f"{expr}, {kernel_fn.__name__}" - np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=7e-3, rtol=0.01) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=3e-3, rtol=0.01) def get_scalar(x, dtype, low, high, filter): # If dtype is int, don't choose a huge number for the scalar @@ -402,32 +380,28 @@ def get_scalar(x, dtype, low, high, filter): do_test(x, y[:1].reshape(()), kernel_broadcast_rhs) -def _min_max_integral_mod_value(dtype_x, dtype_y) -> Optional[int]: - """ - Limit min/max values for integral types for mod values. Leads to - overflow/underflow when casting large integral types to floats. - """ - x_bitwidth = _bitwidth(dtype_x) - y_bitwidth = _bitwidth(dtype_y) - - # hard cap max value bit-width to 32 if 64 bit-width types - min_bitwidth = min(x_bitwidth, y_bitwidth, 32) - - # Limit max value bit-width to be one integral type less than the min bit-width - # For example: - # int64, float32 -> int16 - # uint16, float16 -> uint8 - x_dtype = _dtype(dtype_x) - max_bitwidth = max(min_bitwidth >> 1, 8) - dtype_max = x_dtype + str(max_bitwidth) - - max_info = np.iinfo(getattr(np, dtype_max)) - - # Still need to limit values here for uints - if max_bitwidth >= 16 and dtype_max in uint_dtypes: - return max_info.min, max_info.max // 4 - else: - return max_info.min, max_info.max +def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: + # FIXME For large x, we are casting x to a floating point where it does not fit + # For small y, we are computing floor(div(float(x), y)) which may not fit + return (dtype_x, dtype_y) in [ + ('int32', 'bfloat16'), + ('int32', 'float16'), + ('int32', 'float32'), + ('int64', 'bfloat16'), + ('int64', 'float16'), + ('int64', 'float32'), + ('int64', 'float64'), + ('uint16', 'bfloat16'), + ('uint16', 'float16'), + ('uint16', 'float32'), + ('uint32', 'bfloat16'), + ('uint32', 'float16'), + ('uint32', 'float32'), + ('uint64', 'bfloat16'), + ('uint64', 'float16'), + ('uint64', 'float32'), + ('uint64', 'float64'), + ] def test_dtype_codegen(): @@ -451,35 +425,35 @@ def test_dtype_codegen(): @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): expr = f'x {op} y' - np_expr_gen = (lambda x, y: f'{x} {op} {y}') if op != '%' else (lambda x, y: f'np.fmod({x}, {y})') - - # Triton promotes 16-bit floating-point / and % to 32-bit because there - # are no native div or FRem operations on float16. Since we have to - # convert anyway, we may as well take the accuracy bump. - def promote_to_fp32(dtype_x, dtype_y): - return dtype_x in ('float16', 'bfloat16') and dtype_y not in ('float32', 'float64') - - if op in ('/', '%') and (promote_to_fp32(dtype_x, dtype_y) or promote_to_fp32(dtype_y, dtype_x)): - numpy_expr = np_expr_gen('x.astype(np.float32)', 'y.astype(np.float32)') + if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + numpy_expr = 'np.fmod(x, y)' + elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', + 'bfloat16'): + # Triton promotes 16-bit floating-point / and % to 32-bit because there + # are no native div or FRem operations on float16. Since we have to + # convert anyway, we may as well take the accuracy bump. + numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)' elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): - numpy_expr = np_expr_gen(f'x.astype(np.{dtype_x})', f'y.astype(np.{dtype_x})') + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): - numpy_expr = np_expr_gen(f'x.astype(np.{dtype_y})', f'y.astype(np.{dtype_y})') - elif op == '%': - # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. - numpy_expr = np_expr_gen('x', 'y') + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' else: numpy_expr = None - - if (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or - (dtype_x in uint_dtypes and dtype_y in int_dtypes))): + if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): + with pytest.raises(AssertionError, match="Not equal to tolerance"): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + elif (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or + (dtype_x in uint_dtypes and dtype_y in int_dtypes))): with pytest.raises(triton.TritonError, match='Cannot use .* because they have different signedness'): _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) else: # skip when bfloat16, as NumPy's ref performs the computation in float32 # while Triton performs it in bfloat16 + # We also skip mod when it is ill-conditioned skip_scalar_test = ((dtype_x == "bfloat16" and "float" in dtype_y) - or (op in ('/', '%') and dtype_x in ("float16", "bfloat16"))) + or (expr == "x % y" and dtype_x in int_dtypes + uint_dtypes and dtype_y in float_dtypes + and _mod_operation_ill_conditioned(dtype_x, "float32"))) # can't divide by zero not_zero = op in ('/', '%') and dtype_x in integral_dtypes and dtype_y in integral_dtypes # can't represent -int(max) @@ -488,17 +462,11 @@ def promote_to_fp32(dtype_x, dtype_y): filter_y = lambda y: not_zero * (y == 0) | not_minus_one * (y == -1) else: filter_y = None - - if op == "%" and dtype_x in integral_dtypes and dtype_y in float_dtypes_with_bfloat16: - x_low, x_high = _min_max_integral_mod_value(dtype_x, dtype_y) - else: - x_low, x_high = None, None - _test_binary( dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, # fails with values where fmod(x, y) is roughly zero, but happens to # pass with the random values chosen for non-broadcast tests - test_broadcast=(op != "%"), x_low=x_low, x_high=x_high, filter_y=filter_y, test_scalar=not skip_scalar_test) + test_broadcast=(op != "%"), filter_y=filter_y, test_scalar=not skip_scalar_test) @pytest.mark.interpreter @@ -1143,9 +1111,6 @@ def kernel(): a = tl.arange(0, 64).reshape(2, 4, 8).trans(2, 1, 0) tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) - a = tl.arange(0, 64).reshape(2, 4, 8).trans((2, 1, 0)) - tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) - a = tl.arange(0, 64).view(2, 4, 8) tl.static_assert(a.shape == [tl.constexpr(2), tl.constexpr(4), tl.constexpr(8)]) @@ -1488,27 +1453,17 @@ def kernel(X): for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] for axis in [0, 1] for num_ctas in num_ctas_list - for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']]) + for dtype_x_str in ['float32', 'uint64', 'int64', 'float64']]) def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device): shape0, shape1 = shape # triton kernel @triton.jit - def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr, DTYPE: tl.constexpr): + def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): off0 = tl.arange(0, SHAPE0) off1 = tl.arange(0, SHAPE1) x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) - - if DTYPE == tl.float16: - # sum can have bad numerics when accumulating in float16. - # if we're dealing with float16, do the sum in float32. - x = x.to(tl.float32) - z = tl.sum(x, axis=AXIS) - - if DTYPE == tl.float16: - z = z.to(DTYPE) - if AXIS == 1: old = tl.atomic_add(Z + off0, z) tl.store(OLD + off0, old) @@ -1522,23 +1477,13 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs) old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str)) # reference results - if x.dtype == np.float16: - # do the sum in float32 to reduce numerical variation - z_ref = z + np.sum(x.astype(np.float32), axis=axis, keepdims=False).astype(x.dtype) - else: - z_ref = z + np.sum(x, axis=axis, keepdims=False) + z_ref = z + np.sum(x, axis=axis, keepdims=False) old_ref = np.copy(z) # triton result x_tri = to_triton(x, device=device) z_tri = to_triton(z, device=device) old_tri = to_triton(old, device=device) - - def torch_to_triton_dtype(t): - if t == torch.float16: - return tl.float16 - return None - - kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), num_ctas=num_ctas) + kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) np.testing.assert_equal(old_ref, to_numpy(old_tri)) @@ -1754,33 +1699,47 @@ def kernel(X, Y, Z, N: tl.constexpr): @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", list(torch_dtypes)) -@pytest.mark.parametrize("constant_field", ["value", "mask"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) -def test_store_constant(num_ctas, dtype_str, constant_field, device): +def test_store_constant(dtype_str, num_ctas, device): check_type_supported(dtype_str, device) + """Tests that boolean True is stored as 1""" @triton.jit - def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, CONSTANT_FIELD: tl.constexpr): + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - if CONSTANT_FIELD == "value": - value = 1 - output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype) - mask = offsets < n_elements - elif CONSTANT_FIELD == "mask": - output = offsets < n_elements - mask = False + mask = offsets < n_elements + output = GENERATE_TEST_HERE tl.store(output_ptr + offsets, output, mask=mask) + triton_dtype_str = 'uint8' if dtype_str == 'bool' else dtype_str + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.zeros([BLOCK_SIZE], dtype=tl.{triton_dtype_str}) + 1'}) block_size = 128 ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device) output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) - kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas, CONSTANT_FIELD=constant_field) + assert torch.all(output == ref) - if constant_field == "value": - assert torch.all(output == ref) - else: - assert torch.all(output == 0) + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_store_constant_default_dtype(num_ctas, device): + """Tests that boolean True is stored as 1""" + + @triton.jit + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + value = 1 + output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype) + tl.store(output_ptr + offsets, output, mask=mask) + + block_size = 128 + ref = torch.ones([block_size], dtype=getattr(torch, 'int32'), device=device) + output = torch.zeros([block_size], dtype=getattr(torch, 'int32'), device=device) + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) + + assert torch.all(output == ref) def test_load_store_same_ptr(device): @@ -2489,9 +2448,6 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): offset2 = tl.arange(0, N) x = tl.load(x_ptr + offset1) z = tl.histogram(x, N) - bias = tl.full([M, N], 1, dtype=tl.int32) - # check that histogram produces object compatible with broadcasting - biased = z + bias tl.store(z_ptr + offset2, z) torch.manual_seed(17) @@ -2561,22 +2517,7 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl. @pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) @pytest.mark.parametrize("src_layout", scan_layouts) @pytest.mark.parametrize("axis", [0, 1]) -@pytest.mark.parametrize("add_overflow_check", [False, True]) -def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_path: pathlib.Path): - if add_overflow_check is True and is_hip(): - pytest.skip("overflow check disabled on HIP while fixing issues") - - overflow_check = """ - %17 = arith.extsi %arg2 : i32 to i64 - %18 = arith.extsi %arg3 : i32 to i64 - %19 = arith.addi %17, %18 : i64 - %i32.min = arith.constant -2147483648: i64 - %i32.max = arith.constant 2147483647: i64 - %20 = arith.cmpi slt, %19, %i32.max : i64 - %21 = arith.cmpi sge, %19, %i32.min : i64 - %22 = arith.andi %20, %21 : i1 - tt.assert %22, "overflow detected" : i1 - """ +def test_scan_layouts(M, N, src_layout, axis, device): ir = f""" #blocked = {src_layout} @@ -2596,7 +2537,7 @@ def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_pa %10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr, #blocked> %11 = "tt.scan"(%10) <{{axis = {axis} : i32, reverse = false}}> ({{ ^bb0(%arg2: i32, %arg3: i32): - %16 = arith.addi %arg2, %arg3 : i32{overflow_check if add_overflow_check else ""} + %16 = arith.addi %arg2, %arg3 : i32 tt.scan.return %16 : i32 }}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> %12 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> @@ -2609,10 +2550,10 @@ def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_pa }} """ - temp_file = tmp_path / "test_scan_layouts.ttgir" - temp_file.write_text(ir) - kernel = triton.compile(str(temp_file)) - + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) rs = RandomState(17) x = rs.randint(-100, 100, (M, N)).astype('int32') @@ -2658,36 +2599,20 @@ def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_pa @pytest.mark.parametrize("src_layout", filter_layouts(layouts)) @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d']) -@pytest.mark.parametrize("dtype_str,add_overflow_check", [("int32", False), ("int32", True), ("float32", False), - ("float16", False)]) +@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"]) @pytest.mark.parametrize("reduce_op", ["sum", "max"]) -def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_overflow_check, reduce_op, device, - tmp_path: pathlib.Path): +def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device): if isinstance(src_layout, (MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]): pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape") if is_hip() and isinstance(src_layout, MfmaLayout) and ((M, N) == (128, 128)): pytest.skip("Skipping test because it runs out of shared memory") - if add_overflow_check is True and is_hip(): - pytest.skip("overflow check disabled on HIP while fixing issues") if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024: pytest.skip("Skipping sum reduction on float16 due to accuracy issues") if isinstance(src_layout, MmaLayout) and src_layout.version == 3: src_layout[2] = 16 if dtype_str == "float16" else 8 - overflow_check = """ - %18 = arith.extsi %arg3 : i32 to i64 - %19 = arith.extsi %arg4 : i32 to i64 - %20 = arith.addi %18, %19 : i64 - %i32.min = arith.constant -2147483648: i64 - %i32.max = arith.constant 2147483647: i64 - %21 = arith.cmpi slt, %20, %i32.max : i64 - %22 = arith.cmpi sge, %20, %i32.min : i64 - %23 = arith.andi %21, %22 : i1 - tt.assert %23, "overflow detected" : i1 - """ - ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] arith_op = { "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, # @@ -2720,7 +2645,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov f""" %14 = "tt.reduce"(%13) ({{ ^bb0(%arg3: {ty}, %arg4: {ty}): - %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} + %17 = {arith_op} %arg3, %arg4 : {ty} tt.reduce.return %17 : {ty} }}) {{axis = 0 : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> {ty} tt.store %arg2, %14 : !tt.ptr<{ty}> @@ -2732,7 +2657,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov %14 = tt.expand_dims %13 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{expanded_shape}x{ty}, #src> %15 = "tt.reduce"(%14) ({{ ^bb0(%arg3: {ty}, %arg4: {ty}): - %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} + %17 = {arith_op} %arg3, %arg4 : {ty} tt.reduce.return %17 : {ty} }}) {{axis = {other_axis} : i32}} : (tensor<{expanded_shape}x{ty}, #src>) -> (tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>>) %16 = triton_gpu.convert_layout %15 : tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>> -> tensor<1x{ty}, #one_d_layout> @@ -2765,14 +2690,15 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov %12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src> %13 = "tt.reduce"(%12) ({{ ^bb0(%arg3: {ty}, %arg4: {ty}): - %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} + %17 = {arith_op} %arg3, %arg4 : {ty} tt.reduce.return %17 : {ty} }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> """ + epilogue - temp_file = tmp_path / "test_reduce_layouts.ttgir" - temp_file.write_text(ir) - kernel = triton.compile(str(temp_file)) + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) rs = RandomState(17) x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) @@ -2802,7 +2728,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov @pytest.mark.parametrize("M", [32, 64, 128, 256]) @pytest.mark.parametrize("src_layout", layouts) -def test_store_op(M, src_layout, device, tmp_path: pathlib.Path): +def test_store_op(M, src_layout, device): ir = f""" #src = {src_layout} @@ -2823,9 +2749,10 @@ def test_store_op(M, src_layout, device, tmp_path: pathlib.Path): }} """ - temp_file = tmp_path / "test_store_op.ttgir" - temp_file.write_text(ir) - store_kernel = triton.compile(str(temp_file)) + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + store_kernel = triton.compile(f.name) rs = RandomState(17) x = rs.randint(0, 4, (M, 1)).astype('float32') @@ -2852,7 +2779,7 @@ def test_store_op(M, src_layout, device, tmp_path: pathlib.Path): @pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) @pytest.mark.parametrize("src_dim", [0, 1]) @pytest.mark.parametrize("dst_dim", [0, 1]) -def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp_path: pathlib.Path): +def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): ir = f""" #dst = {dst_layout} @@ -2872,9 +2799,10 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp_path }} }} """ - temp_file = tmp_path / "test_convert1d.ttgir" - temp_file.write_text(ir) - kernel = triton.compile(str(temp_file)) + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) rs = RandomState(17) x = rs.randint(0, 4, (M, )).astype('int32') @@ -2912,7 +2840,7 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("op", ["sum", "max"]) @pytest.mark.parametrize("first_axis", [0, 1]) -def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathlib.Path): +def test_chain_reduce(M, N, src_layout, op, device, first_axis): op_str = "" if op == "sum": @@ -2953,9 +2881,10 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathli }} }} """ - temp_file = tmp_path / "test_chain_reduce.ttgir" - temp_file.write_text(ir) - kernel = triton.compile(str(temp_file)) + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) rs = RandomState(17) x = rs.randint(0, 4, (M, N)).astype('int32') @@ -3371,53 +3300,41 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx -@pytest.mark.parametrize("M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, num_warps, mma, kpack", - [(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, 4, mma, kpack) +@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps", + [(M, N, K, col_a, col_b, type_a, type_b, 4) for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) for col_a, col_b in itertools.product([True, False], repeat=2) - for rhs_scale in [False, True] - for normal_type in ["e2m1", "e4m3", "e5m2"] - for mxfp_type in ["e4m3", "e5m2", "bf16"] - for mma in ([32, 16] if is_hip() else [16]) - for kpack in ([1, 2] if is_hip() else [1])]) -def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, num_warps, mma, kpack, device): - if is_cuda(): + for type_a in ["e2m1", "e4m3", "e5m2"] + for type_b in ["e4m3", "e5m2"]]) +def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, device): + if not is_cuda(): + pytest.skip("scaled_dot only supported on CUDA") + else: cc = torch.cuda.get_device_capability() if cc < (8, 9): pytest.skip("float8e4nv not supported on CUDA < 8.9") - if is_hip(): - if not is_hip_cdna(): - pytest.skip("scaled_dot only implemented for HIP CDNA") - if "e4m3" in (normal_type, mxfp_type) and not is_hip_mi300(): - pytest.skip(f"scaled_dot({normal_type}, {mxfp_type}) only implemented for MI300") - if mma == 16 and K == 64: - pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot") @triton.jit - def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out, + def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr, type_b: tl.constexpr): - DIV_FACTOR_A: tl.constexpr = 2 if type_a == "e2m1" else 1 - DIV_FACTOR_B: tl.constexpr = 2 if type_b == "e2m1" else 1 - PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR_A - PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K // DIV_FACTOR_B + tl.static_assert(type_b == "e4m3" or type_b == "e5m2", "type_b must be fp8") + IS_FP8: tl.constexpr = type_a == "e4m3" or type_a == "e5m2" + DIV_FACTOR: tl.constexpr = 1 if IS_FP8 else 2 + PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR + PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0, PACKED_BLOCK_K_A)[None, :] * stride_a1 b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0, BLOCK_N)[None, :] * stride_b1 + SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 + scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] + a = tl.load(a_ptr) b = tl.load(b_ptr) - SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 - if a_scale is not None: - scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, - SCALE_BLOCK_K)[None, :] - a_scale = tl.load(scale_a_ptr) - if b_scale is not None: - scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, - SCALE_BLOCK_K)[None, :] - b_scale = tl.load(scale_b_ptr) - c = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b) + a_scale = tl.load(scale_a_ptr) + c = tl.dot_scaled(a, a_scale, type_a, b, None, type_b) out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] tl.store(out_ptr, c.to(tl.bfloat16)) @@ -3490,31 +3407,22 @@ def mxfp_to_bf16_kernel( offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) tl.store(mxfp_ptr + offsets, tl.ravel(mxfp), mask=offsets < N * 32) - def dot_scale_ref(x, scale_x, y, scale_y, type_x, type_y): - - def upcast(v, scale, type, transposed): - comp_dtype = torch.bfloat16 - if scale is None: - type = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2, "bf16": torch.bfloat16}[type] - return v.view(type).to(comp_dtype) - e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type] - # Packing is always on the K dimension so we transpose before upcasting then transpose back. - if transposed: - v = v.mT.contiguous() - v = v.contiguous() - v_upcast = v.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype) - N = v_upcast.numel() - BLOCK_SIZE = 512 - grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) - mxfp_to_bf16_kernel[grid](v, scale, v_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, - num_warps=num_warps) - assert v_upcast.isfinite().all() - if transposed: - v_upcast = v_upcast.mT - return v_upcast - - x_upcast = upcast(x, scale_x, type_x, False) - y_upcast = upcast(y, scale_y, type_y, True) + def dot_scale_ref(x, scale, y, type_x, type_y): + e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x] + type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y] + + comp_dtype = torch.bfloat16 + + x = x.contiguous() + x_upcast = x.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype) + + N = x_upcast.numel() + BLOCK_SIZE = 512 + grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) + mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=num_warps) + assert x_upcast.isfinite().all() + + y_upcast = y.view(type_fp8_y).to(comp_dtype) class AccumulateInFp32: @@ -3526,39 +3434,28 @@ def __exit__(self, exc_type, exc_val, exc_tb): torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value with AccumulateInFp32(): - return torch.matmul(x_upcast, y_upcast) + return torch.matmul(x_upcast.to(comp_dtype), y_upcast.to(comp_dtype)) torch.manual_seed(0) - def make_arg(shape, ty, col_major=False, max_val=255): + def create_uint8(shape, col_major=False, max_val=255): if col_major: shape = shape[:-2] + (shape[-1], shape[-2]) - if ty == "bf16": - ret = torch.randn(shape, dtype=torch.bfloat16, device=device) - # Clamp to avoid relative error issues - ret.clamp_(-2**15, 2**15 - 1) - else: - ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) + ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) if col_major: ret = ret.mT return ret - type_a = normal_type if not rhs_scale else mxfp_type - type_b = mxfp_type if not rhs_scale else normal_type - - DIV_FACTOR_A = 2 if type_a == "e2m1" else 1 - DIV_FACTOR_B = 2 if type_b == "e2m1" else 1 - x = make_arg((M, K // DIV_FACTOR_A), type_a, col_major=col_a) - y = make_arg((K // DIV_FACTOR_B, N), type_b, col_major=col_b) + DIV_FACTOR = 2 if type_a == "e2m1" else 1 + x = create_uint8((M, K // DIV_FACTOR), col_major=col_a) + y = create_uint8((K, N), col_major=col_b) # sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright) - # Max scale= 2**15 - scale_x = make_arg((M, K // 32), "e8m0", max_val=127 + 15) - scale_y = make_arg((N, K // 32), "e8m0", max_val=127 + 15) - if rhs_scale: - scale_x = None - else: - scale_y = None + # We substract a reasonably high number (64) so that the sum of all the mxfp elements does not overflow + m_bytes = int(type_a[1]) + bias_type_a = 1 << (m_bytes - 1) - 1 + max_exponent_type_a = (1 << m_bytes) - 1 - bias_type_a + scale_x = create_uint8((M, K // 32), max_val=255 - max_exponent_type_a - 64) def make_finite(x, dtype): # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and @@ -3573,30 +3470,23 @@ def make_finite(x, dtype): x = make_finite(x, type_a) y = make_finite(y, type_b) - kernel_kwargs = {"num_warps": num_warps} - if is_hip(): - kernel_kwargs["kpack"] = kpack - kernel_kwargs["matrix_instr_nonkdim"] = mma + z = x.new_empty((M, N), dtype=torch.bfloat16) - pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b, - **kernel_kwargs) - z_ref = dot_scale_ref(x, scale_x, y, scale_y, type_a, type_b) - # Bigger tolerance for AMD MI200 devices. - # MI200 devices use reduced precision fp16 and bf16 and flush input and output denormal values - # to zero. Detailed info is at: - # https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices - atol = 2e-4 if is_hip_mi200() else 1e-5 - rtol = 2e-2 if is_hip_mi200() else 1e-2 - torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) + pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, + num_warps=num_warps) + + z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b) + + # generous rtol as we are sampling the whole range of floats + torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2) # make sure ld/st are vectorized - if is_cuda(): - ptx = pgm.asm['ptx'] - if (max(M, N) * K) // (num_warps * 32) >= 4: - assert 'ld.global.v4' in ptx - if M * N // (num_warps * 32) >= 4: - assert 'st.global.v4' in ptx - assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) + ptx = pgm.asm['ptx'] + if (max(M, N) * K) // (num_warps * 32) >= 4: + assert 'ld.global.v4' in ptx + if M * N // (num_warps * 32) >= 4: + assert 'st.global.v4' in ptx + assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) @pytest.mark.interpreter @@ -4088,14 +3978,14 @@ def _kernel(dst, src, CACHE: tl.constexpr): amdgcn = pgm.asm['amdgcn'] cg_cache_modifier_str = 'nt' cv_cache_modifier_str = 'sc0 sc1' - buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line] global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line] + flat_load_line = [line for line in amdgcn.splitlines() if "flat_load" in line] if cache == '' or cache == '.ca': - assert cg_cache_modifier_str not in (global_load_line[0] if global_load_line else buffer_load_line[0]) + assert cg_cache_modifier_str not in global_load_line[0] if cache == '.cg': assert cg_cache_modifier_str in global_load_line[0] if cache == '.cv': - assert cv_cache_modifier_str in global_load_line[0] + assert cv_cache_modifier_str in flat_load_line[0] if is_cuda(): ptx = pgm.asm['ptx'] @@ -5182,9 +5072,7 @@ def kernel(Out): a = torch.empty((), device=device, dtype=torch.int32) h = kernel[(1, )](a) assert "ub.poison" in h.asm["ttir"], h.asm["ttir"] - # xpu uses llvm.store, which in this case is removed by the optimizer - if not is_xpu(): - assert "poison" in h.asm["llir"], h.asm["llir"] + assert "poison" in h.asm["llir"], h.asm["llir"] # ----------------------- @@ -5258,14 +5146,6 @@ def kernel(Out): BlockedLayout([4, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([4, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), - DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), - DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), - DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), - DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), - DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), - DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), - DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), ] @@ -5300,13 +5180,9 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape): @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("interm_layout", intermediate_layouts) @pytest.mark.parametrize("dst_layout", layouts) -def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path): +def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): if str(src_layout) == str(dst_layout): pytest.skip() - if (isinstance(src_layout, DotOperandLayout) - and isinstance(interm_layout, SharedLayout)) or (isinstance(dst_layout, DotOperandLayout) - and isinstance(interm_layout, SharedLayout)): - pytest.skip("DotOperandLayout <-> SharedLayout conversion is not completely supported") if is_hip(): try: scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N)) @@ -5369,10 +5245,10 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x, device=device) - temp_file = tmp_path / "test_convert2d.ttgir" - temp_file.write_text(ir) - kernel = triton.compile(str(temp_file)) - + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) assert torch.equal(z, x) @@ -5425,7 +5301,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t @pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]]) @pytest.mark.parametrize("dtype", ['float16']) @pytest.mark.parametrize("mma_pair", mma_pairs) -def test_convertmma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path): +def test_convertmma2mma(M, N, mma_pair, dtype, device): if is_hip(): pytest.skip("test_mma2mma is not supported in HIP") @@ -5482,10 +5358,10 @@ def do_test(src_layout, dst_layout): x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x) - temp_file = tmp_path / "test_convertmma2mma.ttgir" - temp_file.write_text(ir) - kernel = triton.compile(str(temp_file)) - + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) assert torch.equal(z, x) @@ -5581,7 +5457,7 @@ def matmul_kernel( # stride_cm, stride_cn, # BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # low_precision_acc: tl.constexpr, # - num_stages: tl.constexpr = 3 # + num_pipeline_stages: tl.constexpr = 3 # ): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -5593,7 +5469,7 @@ def matmul_kernel( # a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_stages): + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_pipeline_stages): a = tl.load(a_ptrs) b = tl.load(b_ptrs) accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc) @@ -5632,7 +5508,7 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps, - num_stages=num_stages) + num_pipeline_stages=num_stages) torch_a = torch.from_numpy(A).to(device=device) th_a = f8_to_f16(torch_a, in_type_str) torch_b = torch.from_numpy(B).to(device=device) @@ -5824,7 +5700,7 @@ def test_tl_range(device): pgm = matmul_kernel[ 1, ](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N, - BLOCK_K, 0, num_stages=5) + BLOCK_K, 0, num_pipeline_stages=5) ref_out = torch.matmul(a, b).to(torch.float32) if is_interpreter(): # GPU invokes tensor core for float16 matmul, which is not supported in interpreter. @@ -5850,8 +5726,8 @@ def maxnreg_noinline2(X): tl.store(X, 0) -@pytest.mark.interpreter def test_maxnreg(device): + assert not is_interpreter(), "this test won't work with the interpreter" if not is_cuda(): pytest.skip('maxnreg only works on CUDA') @@ -5865,15 +5741,14 @@ def kernel(X): X = torch.empty(1, dtype=torch.int32, device=device) k = kernel[(1, )](X, maxnreg=42) - if not is_interpreter(): - # Ensure that .maxnreg is set on the kernel function (marked with .entry) - # and not on either of the noinline functions (marked with .func). - try: - assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) - assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) - except AssertionError: - print("Failing ptx:\n", k.asm["ptx"]) - raise + # Ensure that .maxnreg is set on the kernel function (marked with .entry) + # and not on either of the noinline functions (marked with .func). + try: + assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) + assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) + except AssertionError: + print("Failing ptx:\n", k.asm["ptx"]) + raise @pytest.mark.interpreter @@ -5967,6 +5842,7 @@ def kernel( # ----------------------- +<<<<<<< HEAD # test loop unrolling # ----------------------- @@ -6021,30 +5897,6 @@ def sanitize_sum_kernel(Z, X, BLOCK: tl.constexpr): torch.testing.assert_close(Z, X.sum().to(torch.int32)) -@pytest.mark.parametrize("reduce_dim", [0, 1]) -def test_side_effectful_reduction_2d(device, reduce_dim): - if device != "cuda": - pytest.skip() - - @triton.jit(debug=True) - def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, reduce_dim: tl.constexpr, - NON_REDUCE_DIM: tl.constexpr): - offsets = tl.arange(0, BLOCK_0)[:, None] * BLOCK_1 + tl.arange(0, BLOCK_1)[None, :] - vals = tl.load(X + offsets) - z = tl.reduce(vals, reduce_dim, sanitize_add) - tl.store(Z + tl.arange(0, NON_REDUCE_DIM), z) - - BLOCK_0 = 16 - BLOCK_1 = 32 - NON_REDUCE_DIM = BLOCK_1 if reduce_dim == 0 else BLOCK_0 - torch.manual_seed(42) - X = torch.randint(0, 10, [BLOCK_0, BLOCK_1], device="cuda", dtype=torch.int32) - Z = torch.zeros([NON_REDUCE_DIM], device="cuda", dtype=torch.int32) - sanitize_sum_2d_kernel[(1, )](Z, X, BLOCK_0=BLOCK_0, BLOCK_1=BLOCK_1, reduce_dim=reduce_dim, - NON_REDUCE_DIM=NON_REDUCE_DIM) - torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32)) - - def test_side_effectful_scan(device): if device != "cuda": pytest.skip() @@ -6063,33 +5915,85 @@ def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr): Z = torch.zeros_like(X) sanitize_cumsum_kernel[(1, )](Z, X, BLOCK=BLOCK) torch.testing.assert_close(Z, X.cumsum(0).to(torch.int32)) +======= +# test view slice +# ----------------------- +view_layout = [ + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] +blocked_layout = [ + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] -# stress test slice layout usages in reductions. -@pytest.mark.parametrize("in_shape, perm, red_dims", [ - ((4, 32, 32, 4, 2), [2, 1, 0, 3, 4], [3, 1, 0]), - ((8, 2, 32, 4, 16), [4, 0, 1, 3, 2], [0, 2, 0]), -]) -def test_chained_reductions(in_shape, perm, red_dims, device): - @triton.jit - def kernel(In, Out, # - dim_0: tl.constexpr, dim_1: tl.constexpr, dim_2: tl.constexpr, dim_3: tl.constexpr, dim_4: tl.constexpr, - perm_0: tl.constexpr, perm_1: tl.constexpr, perm_2: tl.constexpr, perm_3: tl.constexpr, - perm_4: tl.constexpr, red_dim_0: tl.constexpr, red_dim_1: tl.constexpr, red_dim_2: tl.constexpr): - idx = tl.arange(0, dim_0 * dim_1 * dim_2 * dim_3 * dim_4) - idx = idx.reshape(dim_0, dim_1, dim_2, dim_3, dim_4) - vals = tl.load(In + idx) - vals = tl.permute(vals, [perm_0, perm_1, perm_2, perm_3, perm_4]) - r = tl.sum(tl.sum(tl.sum(vals, red_dim_0), red_dim_1), red_dim_2) - st_idx = tl.arange(0, r.shape[0] * r.shape[1]).reshape(r.shape) - tl.store(Out + st_idx, r) - - input = torch.randint(0, 1000, in_shape, device=device, dtype=torch.int32) - temp = torch.permute(input, perm).contiguous() - ref = torch.sum(torch.sum(torch.sum(temp, dim=red_dims[0]), dim=red_dims[1]), dim=red_dims[2]) - result = torch.empty_like(ref) - kernel[(1, )](input, result, input.shape[0], input.shape[1], input.shape[2], input.shape[3], input.shape[4], - perm[0], perm[1], perm[2], perm[3], perm[4], red_dims[0], red_dims[1], red_dims[2]) - - assert torch.all(ref == result) +@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", + [[256, 256, 256, 32, 0, 32], [128, 128, 128, 64, 0, 64]]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("view_layout", view_layout) +@pytest.mark.parametrize("blocked_layout", blocked_layout) +def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, view_layout, + device): + if not is_hip(): + pytest.skip('view_slice is AMD specific instruction.') + + ir = f""" + #blocked = {blocked_layout} + #view_layout = {view_layout} + """ + """ + module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = """ + str( + 64) + f""" : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> + %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %43 = tt.expand_dims %42 {{axis = 1 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M_tile_size}x1xi32, #blocked> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked> + %44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{M}xi32, #blocked> + %7 = tt.broadcast %6 : tensor<1x{M}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked> + %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> + %34 = tt.splat %arg1 : !tt.ptr -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + %37 = tt.expand_dims %33 {{axis = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N_tile_size}xi32, #blocked> + %38 = tt.broadcast %37 : tensor<1x{N_tile_size}xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %39 = tt.broadcast %44 : tensor<{M_tile_size}x1xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + %12 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #view_layout> + %13 = amdgpu.view_slice %12[{M_tile_offset}, {N_tile_offset}] [{M_tile_size}, {N_tile_size}] [1, 1] : tensor<{M}x{N}xf16, #view_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout> + %14 = triton_gpu.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> + %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + x = torch.randn((M, N), device=device, dtype=torch.float16) + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + view = torch.empty((M_tile_size, N_tile_size), device=device, dtype=torch.float16) + + kernel[(1, 1, 1)](x.data_ptr(), view) + test_result = torch.eq(x[M_tile_offset:M_tile_size + M_tile_offset, N_tile_offset:N_tile_offset + N_tile_size], + view).all() + assert test_result +>>>>>>> 1b17054b (Adds pytest) From 3709f35a88811cfb9201f25c4e686ac27ce38191 Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Wed, 2 Oct 2024 21:53:32 +0000 Subject: [PATCH 06/15] Addresses review comments --- python/test/unit/language/test_core.py | 2 +- test/TritonGPU/amd/amd-viewslice-op.mlir | 12 +- .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 70 ++++++--- .../PatternTritonAMDGPUToLLVM.h | 13 +- .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 28 ++-- .../TritonAMDGPUToLLVMPatterns.cpp | 3 +- .../ViewSliceOpToLLVM.cpp | 45 +++--- third_party/amd/python/test/test_core.py | 146 ++++++++++++++++++ 8 files changed, 251 insertions(+), 68 deletions(-) create mode 100644 third_party/amd/python/test/test_core.py diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 79ce102d2377..957ed30190bb 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5839,6 +5839,7 @@ def kernel( kernel[(1, )](x_tri, y_tri, shape[0], BLOCK_SIZE=shape[0]) # compare np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) +<<<<<<< HEAD # ----------------------- @@ -5996,4 +5997,3 @@ def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile test_result = torch.eq(x[M_tile_offset:M_tile_size + M_tile_offset, N_tile_offset:N_tile_offset + N_tile_size], view).all() assert test_result ->>>>>>> 1b17054b (Adds pytest) diff --git a/test/TritonGPU/amd/amd-viewslice-op.mlir b/test/TritonGPU/amd/amd-viewslice-op.mlir index fdaa1ccbdba3..fa3ca934405b 100644 --- a/test/TritonGPU/amd/amd-viewslice-op.mlir +++ b/test/TritonGPU/amd/amd-viewslice-op.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> @@ -12,3 +12,13 @@ module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ct tt.return } } + +module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @basic_insert_slice_async_1d(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // CHECK: llvm.func @basic_insert_slice_async_1d + // CHECK: error: sizes [256, 2] must be a multiple of shapePerCTA [256, 16] + // XFAIL: * + %72 = amdgpu.view_slice %arg0[0,0] [256, 2] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return + } +} diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 02fcc1971ae9..b48e02555a8f 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -21,7 +21,6 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ - #ifndef TRITON_AMDGPU_OPS #define TRITON_AMDGPU_OPS @@ -31,11 +30,16 @@ include "mlir/IR/EnumAttr.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" +<<<<<<< HEAD include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "TritonAMDGPUDialect.td" include "TritonAMDGPUAttrDefs.td" +======= +>>>>>>> a1aaf2f1 (Addresses review comments) include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "mlir/Interfaces/ViewLikeInterface.td" // OffsetSizeAndStrideOpInterface +include "TritonAMDGPUDialect.td" +include "TritonAMDGPUAttrDefs.td" class TT_AMDGPU_Op traits = []> : @@ -46,36 +50,54 @@ class TT_AMDGPU_Op traits = []> : // Interfaces // def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; -def TritonAMDGPU_ViewSliceOp : TritonAMDGPU_Op<"view_slice", - [AttrSizedOperandSegments, - Pure, - OffsetSizeAndStrideOpInterface - ]> { + +//===----------------------------------------------------------------------===// +// ViewSliceOp +//===----------------------------------------------------------------------===// + +def TritonAMDGPU_ViewSliceOp + : TritonAMDGPU_Op<"view_slice", [AttrSizedOperandSegments, + OffsetSizeAndStrideOpInterface, Pure]> { let summary = "view slice operation"; let description = [{ - Represents view of the slice of the tensor in registers. However, unlike 'memdesc_subview' which provides a view in shared memory, - 'view_slice' specifically slices within registers. - Slice of the tensor is required to have the same layout as the original tensor. + The "view_slice" operation enables "viewing" a slice of a tensor in + registers without data exchange. + + The "view_slice" operation supports the following arguments: + + * source: the base tensor on which to create a "view" tensor + * offsets: offsets into the base tensor at which to create the "view" + * size: size of the result "view" tensor + * strides: the number of strides for each dimension + + Currently only 2D tensors are supported. + + Example 1: + + ```mlir + %1 = triton_gpu.convert_layout %0 : tensor<128x128x!tt.ptr, #blocked> + -> tensor<128x128x!tt.ptr, #blocked2> + // create a slice of base tensor %1 with + // static offsets and sizes for each dimension + %2 = amdgpu.view_slice %0[0, 0] [128, 8] [1, 1] : + tensor<128x128x!tt.ptr, #blocked2> to + tensor<128x8x!tt.ptr, #blocked2> + ``` }]; - let arguments = (ins - AnyRankedTensor:$source, - Variadic:$offsets, - Variadic:$sizes, - Variadic:$strides, - DenseI64ArrayAttr:$static_offsets, - DenseI64ArrayAttr:$static_sizes, - DenseI64ArrayAttr:$static_strides - ); + let arguments = (ins AnyRankedTensor:$source, Variadic:$offsets, + Variadic:$sizes, Variadic:$strides, + DenseI64ArrayAttr:$static_offsets, DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides); let results = (outs AnyRankedTensor:$result); let builders = [ - // Build a ViewSliceOp with mixed static and dynamic entries and the same - // result type - OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source, - "ArrayRef":$offsets, "ArrayRef":$sizes, - "ArrayRef":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, + // Build a ViewSliceOp with mixed static and dynamic entries and the same + // result type + OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source, + "ArrayRef":$offsets, "ArrayRef":$sizes, + "ArrayRef":$strides, + CArg<"ArrayRef", "{}">:$attrs)>, ]; let extraClassDeclaration = [{ diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h index 4e3321af8554..a7ae33f8bff9 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h @@ -1,16 +1,13 @@ -#ifndef THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H -#define THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H +#ifndef AMD_INCLUDE_TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H +#define AMD_INCLUDE_TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H #include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "triton/Analysis/AxisInfo.h" - -using namespace mlir; namespace mlir::triton::AMD { -void populateViewSliceOpTritonAMDGPUToLLVMPatterns( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit); +void populateViewSliceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, + mlir::RewritePatternSet &patterns, + mlir::PatternBenefit benefit); } diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index fdc42ccf82d9..8ba61b0e9646 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -60,18 +60,23 @@ namespace mlir::triton::amdgpu { LogicalResult ViewSliceOp::verify() { auto srcTy = getSource().getType(); auto srcLayout = srcTy.getEncoding(); - auto srcElementType = dyn_cast(srcTy).getElementType(); + auto srcElementType = getElementTypeOrSelf(srcTy); auto resultTy = getResult().getType(); auto resultLayout = resultTy.getEncoding(); - auto resultElementType = - dyn_cast(resultTy).getElementType(); + auto resultElementType = getElementTypeOrSelf(resultTy); if (srcElementType != resultElementType) { - return emitError("result type must match source type"); + return emitError("result element type must match source element type"); } - - if (srcLayout != resultLayout) + if (srcLayout != resultLayout) { return emitError("result layout must match source layout"); + } + if (srcTy.getRank() != resultTy.getRank()) { + return emitError("result rank must be equal to source rank"); + } + if (srcTy.getRank() != 2) { + return emitError("currently only 2D tensors are supported"); + } auto srcShape = srcTy.getShape(); auto shapePerCTA = mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape); @@ -86,15 +91,20 @@ LogicalResult ViewSliceOp::verify() { // original tensor. if (offsets[0] % shapePerCTA[0] != 0 || offsets[1] % shapePerCTA[1] != 0) { - return emitError("incorrect static offset"); + return emitError() << "offset [" << offsets + << "] must be a multiple of shapePerCTA [" << shapePerCTA + << "]"; } if (sizes[0] % shapePerCTA[0] != 0 || sizes[1] % shapePerCTA[1] != 0) { - return emitError("incorrect static size"); + return emitError() << "sizes [" << sizes + << "] must be a multiple of shapePerCTA [" << shapePerCTA + << "]"; } if (!hasUnitStride()) { - return emitError("unsupported stride"); + return emitError("expected unit strides but found unsupported stride [") + << getStrides() << "]"; } return success(); diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp index 786043d1078f..2dc6a476e5c4 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp @@ -5,7 +5,6 @@ namespace mlir::triton::AMD { void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { - populateViewSliceOpTritonAMDGPUToLLVMPatterns(typeConverter, patterns, - benefit); + populateViewSliceOpToLLVMPatterns(typeConverter, patterns, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp index bbdbefb28ee3..65ef2d1577dc 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp @@ -9,8 +9,6 @@ using namespace mlir; using namespace mlir::triton; -using namespace mlir::triton::gpu; -namespace tta = mlir::triton::amdgpu; // clang-format off /*** @@ -51,24 +49,27 @@ namespace tta = mlir::triton::amdgpu; // clang-format on namespace { -struct ViewSliceOpConversion : public ConvertOpToLLVMPattern { +struct ViewSliceOpConversion + : public ConvertOpToLLVMPattern { explicit ViewSliceOpConversion(LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) - : ConvertOpToLLVMPattern(typeConverter, benefit) {} + : ConvertOpToLLVMPattern(typeConverter, benefit) {} - LogicalResult processLayout(tta::ViewSliceOp op, OpAdaptor adaptor, + LogicalResult processLayout(amdgpu::ViewSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); - auto srcTy = dyn_cast(op.getSource().getType()); + auto srcTy = cast(op.getSource().getType()); auto srcLayout = srcTy.getEncoding(); auto srcShape = srcTy.getShape(); auto resultTy = cast(op.getType()); auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter); - auto elemsPerThread = mlir::triton::gpu::getElemsPerThread(srcTy); - auto sizePerThread = getSizePerThread(srcLayout); + auto elemsPerThread = triton::gpu::getElemsPerThread(srcTy); + auto sizePerThread = triton::gpu::getSizePerThread(srcLayout); auto totalSizePerThread = sizePerThread[0] * sizePerThread[1]; - auto order = getOrder(srcLayout); - auto shapePerCTA = getShapePerCTATile(srcLayout, srcShape); + auto order = triton::gpu::getOrder(srcLayout); + + // Calculate valid total number of workers in each dimension + auto shapePerCTA = triton::gpu::getShapePerCTATile(srcLayout, srcShape); shapePerCTA[0] = std::min(static_cast(srcShape[0]), shapePerCTA[0]); shapePerCTA[1] = @@ -78,12 +79,12 @@ struct ViewSliceOpConversion : public ConvertOpToLLVMPattern { auto sizes = op.getStaticSizes(); // Calculate offsets and sizes in terms of CTA units. - std::vector CTAOffsets{offsets[0] / shapePerCTA[0], - offsets[1] / shapePerCTA[1]}; - std::vector CTASizes{sizes[0] / shapePerCTA[0], - sizes[1] / shapePerCTA[1]}; - std::vector CTAPerShape{srcShape[0] / shapePerCTA[0], - srcShape[1] / shapePerCTA[1]}; + std::vector CTAOffsets{offsets[0] / shapePerCTA[0], + offsets[1] / shapePerCTA[1]}; + std::vector CTASizes{sizes[0] / shapePerCTA[0], + sizes[1] / shapePerCTA[1]}; + std::vector CTAPerShape{srcShape[0] / shapePerCTA[0], + srcShape[1] / shapePerCTA[1]}; // The diagram above illustrates the graphical representation of the // skipElems, tensorStride, and lastIdx variables. @@ -114,25 +115,23 @@ struct ViewSliceOpConversion : public ConvertOpToLLVMPattern { } LogicalResult - matchAndRewrite(tta::ViewSliceOp op, OpAdaptor adaptor, + matchAndRewrite(amdgpu::ViewSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcTy = op.getSource().getType(); if (isa(op.getSource().getType().getEncoding()) || isa(op.getSource().getType().getEncoding())) { return processLayout(op, adaptor, rewriter); - } else { - assert(false && "Unsupported layout in viewSlice."); - return failure(); } + return failure(); } }; } // namespace namespace mlir::triton::AMD { -void populateViewSliceOpTritonAMDGPUToLLVMPatterns( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { +void populateViewSliceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit) { patterns.add(typeConverter, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/python/test/test_core.py b/third_party/amd/python/test/test_core.py new file mode 100644 index 000000000000..7270be7e5913 --- /dev/null +++ b/third_party/amd/python/test/test_core.py @@ -0,0 +1,146 @@ +# flake8: noqa: F821,F841 +import contextlib +import itertools +import re +from typing import Optional +import math +import textwrap +import tempfile + +import numpy as np +import pytest +import torch +import os +import inspect +from numpy.random import RandomState + +import triton +import triton.language as tl +from triton.language.extra import libdevice + +from triton._internal_testing import ( + is_interpreter, + is_hip, + get_arch, + torch_float8_dtypes, + torch_dtypes, +) + + +@contextlib.contextmanager +def promotion_numpy_2_0(): + state = np._get_promotion_state() + np._set_promotion_state("weak") + try: + yield + finally: + np._set_promotion_state(state) + + +# TODO: enable multiple cta cluster testing. +# num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1] +num_ctas_list = [1] + +GPU_DIALECT = "triton_gpu" +if is_interpreter(): + THREADS_PER_WARP = 1 +elif is_hip(): + THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size +else: + THREADS_PER_WARP = 32 + + +class BlockedLayout: + + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): + self.sz_per_thread = size_per_thread + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +# ----------------------- +# test view slice +# ----------------------- + +view_layout = [ + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] +blocked_layout = [ + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] + + +@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", + [[256, 256, 256, 32, 0, 32], [128, 128, 128, 64, 0, 64]]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("view_layout", view_layout) +@pytest.mark.parametrize("blocked_layout", blocked_layout) +def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, view_layout, + device='cuda'): + if torch.version.hip is None: + pytest.skip("view_slice is AMD specific instruction.") + + ir = f""" + #blocked = {blocked_layout} + #view_layout = {view_layout} + module attributes {{"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = {str(64)} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> + %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %43 = tt.expand_dims %42 {{axis = 1 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M_tile_size}x1xi32, #blocked> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked> + %44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{M}xi32, #blocked> + %7 = tt.broadcast %6 : tensor<1x{M}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked> + %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> + %34 = tt.splat %arg1 : !tt.ptr -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + %37 = tt.expand_dims %33 {{axis = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N_tile_size}xi32, #blocked> + %38 = tt.broadcast %37 : tensor<1x{N_tile_size}xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %39 = tt.broadcast %44 : tensor<{M_tile_size}x1xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + %12 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #view_layout> + %13 = amdgpu.view_slice %12[{M_tile_offset}, {N_tile_offset}] [{M_tile_size}, {N_tile_size}] [1, 1] : tensor<{M}x{N}xf16, #view_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout> + %14 = triton_gpu.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> + %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + x = torch.randn((M, N), device=device, dtype=torch.float16) + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + view = torch.empty((M_tile_size, N_tile_size), device=device, dtype=torch.float16) + + kernel[(1, 1, 1)](x.data_ptr(), view) + test_result = torch.equal(x[M_tile_offset:M_tile_size + M_tile_offset, N_tile_offset:N_tile_offset + N_tile_size], + view) + assert test_result From 5e36ba3a80c76004311ee7a55c79a7232a79abbf Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Mon, 7 Oct 2024 21:23:58 +0000 Subject: [PATCH 07/15] Adds changes to address review comments --- .../amd/invalid_viewslice_to_llvm.mlir | 70 +++++++++++++++++++ test/TritonGPU/amd/amd-viewslice-op.mlir | 16 +---- .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 47 +++++++++---- .../PatternTritonAMDGPUToLLVM.h | 4 +- .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 2 +- third_party/amd/python/test/test_core.py | 44 ++---------- 6 files changed, 115 insertions(+), 68 deletions(-) create mode 100644 test/Conversion/amd/invalid_viewslice_to_llvm.mlir diff --git a/test/Conversion/amd/invalid_viewslice_to_llvm.mlir b/test/Conversion/amd/invalid_viewslice_to_llvm.mlir new file mode 100644 index 000000000000..e26817cd026e --- /dev/null +++ b/test/Conversion/amd/invalid_viewslice_to_llvm.mlir @@ -0,0 +1,70 @@ +// RUN: triton-opt -split-input-file %s --convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics + +// Invalid size +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{sizes [256, 2] must be a multiple of shapePerCTA [256, 16]}} + %1 = amdgpu.view_slice %arg0[0,0] [256, 2] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid offset +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTA [256, 16]}} + %1 = amdgpu.view_slice %arg0[0,5] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid result layout +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_layout(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result layout must match source layout}} + %1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked2> + tt.return +} + +// ----- + +// Invalid result element type +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_element_type(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result element type must match source element type}} + %1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi64, #blocked1> + tt.return +} + +// ----- + +// Invalid result rank +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result rank must be equal to source rank}} + %1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid rank +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_rank(%arg0: tensor<256x128x2xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{currently only 2D tensors are supported}} + %1 = amdgpu.view_slice %arg0[0,0,0] [256,16,2] [1,1,1] : tensor<256x128x2xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid stride +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_stride(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{expected unit strides but found unsupported stride [1, 2]}} + %1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,2] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} diff --git a/test/TritonGPU/amd/amd-viewslice-op.mlir b/test/TritonGPU/amd/amd-viewslice-op.mlir index fa3ca934405b..ff967e88686e 100644 --- a/test/TritonGPU/amd/amd-viewslice-op.mlir +++ b/test/TritonGPU/amd/amd-viewslice-op.mlir @@ -1,10 +1,10 @@ -// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s +// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func @basic_insert_slice_async_1d(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { - // CHECK: llvm.func @basic_insert_slice_async_1d + tt.func @basic_insert_slice(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // CHECK: llvm.func @basic_insert_slice // CHECK-COUNT-64: %{{[0-9]*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: %64 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK-COUNT-8: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> @@ -12,13 +12,3 @@ module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ct tt.return } } - -module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func @basic_insert_slice_async_1d(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { - // CHECK: llvm.func @basic_insert_slice_async_1d - // CHECK: error: sizes [256, 2] must be a multiple of shapePerCTA [256, 16] - // XFAIL: * - %72 = amdgpu.view_slice %arg0[0,0] [256, 2] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> - tt.return - } -} diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index b48e02555a8f..bfe549a03277 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -60,42 +60,59 @@ def TritonAMDGPU_ViewSliceOp OffsetSizeAndStrideOpInterface, Pure]> { let summary = "view slice operation"; let description = [{ - The "view_slice" operation enables "viewing" a slice of a tensor in + The "view_slice" operation enables viewing a slice of a tensor in registers without data exchange. The "view_slice" operation supports the following arguments: - * source: the base tensor on which to create a "view" tensor - * offsets: offsets into the base tensor at which to create the "view" + * source: the base tensor on which to create a view tensor + * offsets: offsets into the base tensor at which to create the view * size: size of the result "view" tensor * strides: the number of strides for each dimension - Currently only 2D tensors are supported. - Example 1: ```mlir - %1 = triton_gpu.convert_layout %0 : tensor<128x128x!tt.ptr, #blocked> - -> tensor<128x128x!tt.ptr, #blocked2> + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], + threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [0, 1]}> + #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], + threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}> + %1 = triton_gpu.convert_layout %0 : tensor<128x128xf16, #blocked> + -> tensor<128x128xf16, #blocked1> // create a slice of base tensor %1 with // static offsets and sizes for each dimension - %2 = amdgpu.view_slice %0[0, 0] [128, 8] [1, 1] : - tensor<128x128x!tt.ptr, #blocked2> to - tensor<128x8x!tt.ptr, #blocked2> + %2 = amdgpu.view_slice %0[0, 0] [128, 32] [1, 1] : + tensor<128x128xf16, #blocked1> to tensor<128x32xf16, #blocked1> ``` + + Example 1 shows how "view_slice" operation may be used. In this example a + new view of 128x32 is created. "view_slice" works on tensors with layout + where the desired slice has the same layout as the source tensor. + "%0" cannot be sliced directly as the resulting slice cannot have the same + layout as "%0". Therefore it needs to be converted to a layout suitable + for slicing. "#blocked1" layout is appropriate for this as it keeps the + sizePerThread the same thus keeping coalescing properties the same. + In order to utilize all threads in a warp, "threadsPerWarp" is set to + [16,4] for this new layout. This layout conversion carried out before + using "view_slice" ensures slicing still uses all threads efficiently. }]; - let arguments = (ins AnyRankedTensor:$source, Variadic:$offsets, - Variadic:$sizes, Variadic:$strides, - DenseI64ArrayAttr:$static_offsets, DenseI64ArrayAttr:$static_sizes, + let arguments = (ins AnyRankedTensor:$source, + Variadic:$offsets, + Variadic:$sizes, + Variadic:$strides, + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, DenseI64ArrayAttr:$static_strides); let results = (outs AnyRankedTensor:$result); let builders = [ // Build a ViewSliceOp with mixed static and dynamic entries and the same // result type - OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source, - "ArrayRef":$offsets, "ArrayRef":$sizes, + OpBuilder<(ins "RankedTensorType":$resultType, + "Value":$source, + "ArrayRef":$offsets, + "ArrayRef":$sizes, "ArrayRef":$strides, CArg<"ArrayRef", "{}">:$attrs)>, ]; diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h index a7ae33f8bff9..bc0b03d00461 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h @@ -1,5 +1,5 @@ -#ifndef AMD_INCLUDE_TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H -#define AMD_INCLUDE_TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H +#ifndef TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H +#define TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H #include "mlir/Conversion/LLVMCommon/TypeConverter.h" diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 8ba61b0e9646..9ac1058bb6bb 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -104,7 +104,7 @@ LogicalResult ViewSliceOp::verify() { if (!hasUnitStride()) { return emitError("expected unit strides but found unsupported stride [") - << getStrides() << "]"; + << getStaticStrides() << "]"; } return success(); diff --git a/third_party/amd/python/test/test_core.py b/third_party/amd/python/test/test_core.py index 7270be7e5913..08bca4541681 100644 --- a/third_party/amd/python/test/test_core.py +++ b/third_party/amd/python/test/test_core.py @@ -1,50 +1,20 @@ # flake8: noqa: F821,F841 -import contextlib -import itertools -import re -from typing import Optional -import math -import textwrap import tempfile import numpy as np import pytest import torch -import os -import inspect -from numpy.random import RandomState import triton import triton.language as tl -from triton.language.extra import libdevice -from triton._internal_testing import ( - is_interpreter, - is_hip, - get_arch, - torch_float8_dtypes, - torch_dtypes, -) +from triton._internal_testing import is_hip - -@contextlib.contextmanager -def promotion_numpy_2_0(): - state = np._get_promotion_state() - np._set_promotion_state("weak") - try: - yield - finally: - np._set_promotion_state(state) - - -# TODO: enable multiple cta cluster testing. -# num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1] num_ctas_list = [1] GPU_DIALECT = "triton_gpu" -if is_interpreter(): - THREADS_PER_WARP = 1 -elif is_hip(): + +if is_hip(): THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size else: THREADS_PER_WARP = 32 @@ -71,8 +41,8 @@ def __str__(self): view_layout = [ BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [64, 1], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), ] @@ -92,14 +62,14 @@ def __str__(self): @pytest.mark.parametrize("blocked_layout", blocked_layout) def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, view_layout, device='cuda'): - if torch.version.hip is None: + if not is_hip(): pytest.skip("view_slice is AMD specific instruction.") ir = f""" #blocked = {blocked_layout} #view_layout = {view_layout} module attributes {{"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = {str(64)} : i32}} {{ - tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> From 4ff61273acad8b7f6bdb19b5adedf58142bf7c23 Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Mon, 14 Oct 2024 16:37:13 +0000 Subject: [PATCH 08/15] Moves pytest, adds pytest to CI, verifies for static input args to view_slice --- .github/workflows/integration-tests.yml | 1 + .github/workflows/integration-tests.yml.in | 1 + .../amd/invalid_viewslice_to_llvm.mlir | 30 +++++++++++++++++++ .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 11 ++----- .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 14 +++++++++ .../test/{test_core.py => test_view_slice.py} | 1 - 6 files changed, 49 insertions(+), 9 deletions(-) rename third_party/amd/python/test/{test_core.py => test_view_slice.py} (99%) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index f7bcd24d5403..bf21797370a6 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -404,6 +404,7 @@ jobs: echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py + pytest --capture=tee-sys -rfs third_party/amd/python/test/test_view_slice.py cd python/test/unit pytest --capture=tee-sys -rfs -n 16 language runtime \ --ignore=language/test_line_info.py \ diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index c587cfc27ae8..d189586f2171 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -402,6 +402,7 @@ jobs: echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py + pytest --capture=tee-sys -rfs third_party/amd/python/test/test_view_slice.py cd python/test/unit pytest --capture=tee-sys -rfs -n 16 language runtime \ --ignore=language/test_line_info.py \ diff --git a/test/Conversion/amd/invalid_viewslice_to_llvm.mlir b/test/Conversion/amd/invalid_viewslice_to_llvm.mlir index e26817cd026e..b0df9aa5b9c4 100644 --- a/test/Conversion/amd/invalid_viewslice_to_llvm.mlir +++ b/test/Conversion/amd/invalid_viewslice_to_llvm.mlir @@ -68,3 +68,33 @@ tt.func @invalid_stride(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = %1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,2] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> tt.return } + +// ----- + +// Invalid non static offset +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_non_static_offset(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}, %arg1: i32) { + // expected-error @+1 {{currently only static offsets are supported}} + %2 = amdgpu.view_slice %arg0[0,%arg1] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid non static size +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_non_static_size(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}, %arg1: i32) { + // expected-error @+1 {{currently only static sizes are supported}} + %2 = amdgpu.view_slice %arg0[0,0] [256, %arg1] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid non static stride +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_non_static_stride(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}, %arg1: i32) { + // expected-error @+1 {{currently only static strides are supported}} + %2 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,%arg1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index bfe549a03277..d793c351eee5 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -30,14 +30,9 @@ include "mlir/IR/EnumAttr.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" -<<<<<<< HEAD -include "triton/Dialect/Triton/IR/TritonInterfaces.td" -include "TritonAMDGPUDialect.td" -include "TritonAMDGPUAttrDefs.td" -======= ->>>>>>> a1aaf2f1 (Addresses review comments) include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "mlir/Interfaces/ViewLikeInterface.td" // OffsetSizeAndStrideOpInterface +include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "TritonAMDGPUDialect.td" include "TritonAMDGPUAttrDefs.td" @@ -55,8 +50,8 @@ def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; // ViewSliceOp //===----------------------------------------------------------------------===// -def TritonAMDGPU_ViewSliceOp - : TritonAMDGPU_Op<"view_slice", [AttrSizedOperandSegments, +def ViewSliceOp + : TT_AMDGPU_Op<"view_slice", [AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface, Pure]> { let summary = "view slice operation"; let description = [{ diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 9ac1058bb6bb..60e8476684ef 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -83,6 +83,20 @@ LogicalResult ViewSliceOp::verify() { shapePerCTA[0] = std::min(static_cast(srcShape[0]), shapePerCTA[0]); shapePerCTA[1] = std::min(static_cast(srcShape[1]), shapePerCTA[1]); + auto checkForConstInts = [](Value val) { + return isa(val.getDefiningOp()); + }; + + if (!llvm::all_of(getOffsets(), checkForConstInts)) { + return emitError("currently only static offsets are supported"); + } + if (!llvm::all_of(getSizes(), checkForConstInts)) { + return emitError("currently only static sizes are supported"); + } + if (!llvm::all_of(getStrides(), checkForConstInts)) { + return emitError("currently only static strides are supported"); + } + auto offsets = getStaticOffsets(); auto sizes = getStaticSizes(); diff --git a/third_party/amd/python/test/test_core.py b/third_party/amd/python/test/test_view_slice.py similarity index 99% rename from third_party/amd/python/test/test_core.py rename to third_party/amd/python/test/test_view_slice.py index 08bca4541681..e691039c2f20 100644 --- a/third_party/amd/python/test/test_core.py +++ b/third_party/amd/python/test/test_view_slice.py @@ -1,4 +1,3 @@ -# flake8: noqa: F821,F841 import tempfile import numpy as np From 2ce2e441f9054b902a11be36c858ea9f20da7d14 Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Tue, 15 Oct 2024 15:28:08 +0000 Subject: [PATCH 09/15] Fixes non static check to handle both the attributes and input args --- .../amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 60e8476684ef..660e50b97b75 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -83,17 +83,17 @@ LogicalResult ViewSliceOp::verify() { shapePerCTA[0] = std::min(static_cast(srcShape[0]), shapePerCTA[0]); shapePerCTA[1] = std::min(static_cast(srcShape[1]), shapePerCTA[1]); - auto checkForConstInts = [](Value val) { - return isa(val.getDefiningOp()); + auto checkForConstInts = [](OpFoldResult ofr) { + return getConstantIntValue(ofr).has_value(); }; - if (!llvm::all_of(getOffsets(), checkForConstInts)) { + if (!llvm::all_of(getMixedOffsets(), checkForConstInts)) { return emitError("currently only static offsets are supported"); } - if (!llvm::all_of(getSizes(), checkForConstInts)) { + if (!llvm::all_of(getMixedSizes(), checkForConstInts)) { return emitError("currently only static sizes are supported"); } - if (!llvm::all_of(getStrides(), checkForConstInts)) { + if (!llvm::all_of(getMixedStrides(), checkForConstInts)) { return emitError("currently only static strides are supported"); } From 09cc6dcee624f26935414d15bdfad17e3c1c0a19 Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Tue, 15 Oct 2024 15:28:08 +0000 Subject: [PATCH 10/15] Fixes non static check to handle both the attributes and input args --- python/test/unit/language/test_core.py | 83 ------------------- .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 2 +- 2 files changed, 1 insertion(+), 84 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 957ed30190bb..3013bbf53177 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5839,11 +5839,9 @@ def kernel( kernel[(1, )](x_tri, y_tri, shape[0], BLOCK_SIZE=shape[0]) # compare np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) -<<<<<<< HEAD # ----------------------- -<<<<<<< HEAD # test loop unrolling # ----------------------- @@ -5916,84 +5914,3 @@ def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr): Z = torch.zeros_like(X) sanitize_cumsum_kernel[(1, )](Z, X, BLOCK=BLOCK) torch.testing.assert_close(Z, X.cumsum(0).to(torch.int32)) -======= -# test view slice -# ----------------------- - -view_layout = [ - BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), -] -blocked_layout = [ - BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), -] - - -@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", - [[256, 256, 256, 32, 0, 32], [128, 128, 128, 64, 0, 64]]) -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("view_layout", view_layout) -@pytest.mark.parametrize("blocked_layout", blocked_layout) -def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, view_layout, - device): - if not is_hip(): - pytest.skip('view_slice is AMD specific instruction.') - - ir = f""" - #blocked = {blocked_layout} - #view_layout = {view_layout} - """ + """ - module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = """ + str( - 64) + f""" : i32}} {{ - tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ - %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> - %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> - %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> - %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> - %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> - %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> - %43 = tt.expand_dims %42 {{axis = 1 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M_tile_size}x1xi32, #blocked> - %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked> - %44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked> - %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{M}xi32, #blocked> - %7 = tt.broadcast %6 : tensor<1x{M}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> - %8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> - %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked> - %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> - %34 = tt.splat %arg1 : !tt.ptr -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> - %37 = tt.expand_dims %33 {{axis = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N_tile_size}xi32, #blocked> - %38 = tt.broadcast %37 : tensor<1x{N_tile_size}xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> - %39 = tt.broadcast %44 : tensor<{M_tile_size}x1xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> - %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> - %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> - %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> - %12 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #view_layout> - %13 = amdgpu.view_slice %12[{M_tile_offset}, {N_tile_offset}] [{M_tile_size}, {N_tile_size}] [1, 1] : tensor<{M}x{N}xf16, #view_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout> - %14 = triton_gpu.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> - %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> - tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> - tt.return - }} - }} - """ - x = torch.randn((M, N), device=device, dtype=torch.float16) - import tempfile - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) - - view = torch.empty((M_tile_size, N_tile_size), device=device, dtype=torch.float16) - - kernel[(1, 1, 1)](x.data_ptr(), view) - test_result = torch.eq(x[M_tile_offset:M_tile_size + M_tile_offset, N_tile_offset:N_tile_offset + N_tile_size], - view).all() - assert test_result diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 660e50b97b75..26af80d620a1 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -23,8 +23,8 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OpImplementation.h" #include "llvm/ADT/TypeSwitch.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" From 3beb781d75cf891f0b46183e9c11a40674cc4b5b Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Mon, 21 Oct 2024 21:48:52 +0000 Subject: [PATCH 11/15] changes operation name and assembly format, modifies tests to reflect new format --- .github/workflows/integration-tests.yml | 2 +- .github/workflows/integration-tests.yml.in | 2 +- .../amd/invalid_viewslice_to_llvm.mlir | 51 ++++++----------- test/TritonGPU/amd/amd-viewslice-op.mlir | 2 +- .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 56 ++++++------------- .../PatternTritonAMDGPUToLLVM.h | 6 +- .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 43 ++++++-------- .../TritonAMDGPUDialectToLLVM/CMakeLists.txt | 2 +- ...eOpToLLVM.cpp => ExtractSliceOpToLLVM.cpp} | 30 ++++++---- .../TritonAMDGPUToLLVMPatterns.cpp | 2 +- ...st_view_slice.py => test_extract_slice.py} | 12 ++-- 11 files changed, 82 insertions(+), 126 deletions(-) rename third_party/amd/lib/TritonAMDGPUDialectToLLVM/{ViewSliceOpToLLVM.cpp => ExtractSliceOpToLLVM.cpp} (87%) rename third_party/amd/python/test/{test_view_slice.py => test_extract_slice.py} (92%) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index bf21797370a6..cfba6d7225b8 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -404,7 +404,7 @@ jobs: echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py - pytest --capture=tee-sys -rfs third_party/amd/python/test/test_view_slice.py + pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py cd python/test/unit pytest --capture=tee-sys -rfs -n 16 language runtime \ --ignore=language/test_line_info.py \ diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index d189586f2171..7da4aa079327 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -402,7 +402,7 @@ jobs: echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py - pytest --capture=tee-sys -rfs third_party/amd/python/test/test_view_slice.py + pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py cd python/test/unit pytest --capture=tee-sys -rfs -n 16 language runtime \ --ignore=language/test_line_info.py \ diff --git a/test/Conversion/amd/invalid_viewslice_to_llvm.mlir b/test/Conversion/amd/invalid_viewslice_to_llvm.mlir index b0df9aa5b9c4..0cacabf12a17 100644 --- a/test/Conversion/amd/invalid_viewslice_to_llvm.mlir +++ b/test/Conversion/amd/invalid_viewslice_to_llvm.mlir @@ -4,7 +4,7 @@ #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{sizes [256, 2] must be a multiple of shapePerCTA [256, 16]}} - %1 = amdgpu.view_slice %arg0[0,0] [256, 2] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x2xi32, #blocked1> tt.return } @@ -14,7 +14,7 @@ tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibili #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTA [256, 16]}} - %1 = amdgpu.view_slice %arg0[0,5] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + %1 = amdgpu.extract_slice %arg0 [0,5] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> tt.return } @@ -25,7 +25,7 @@ tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibi #blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_result_layout(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{result layout must match source layout}} - %1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked2> + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked2> tt.return } @@ -35,7 +35,7 @@ tt.func @invalid_result_layout(%arg0: tensor<256x128xi32, #blocked1> {tt.divisib #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_result_element_type(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{result element type must match source element type}} - %1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi64, #blocked1> + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi64, #blocked1> tt.return } @@ -45,27 +45,27 @@ tt.func @invalid_result_element_type(%arg0: tensor<256x128xi32, #blocked1> {tt.d #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{result rank must be equal to source rank}} - %1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> tt.return } // ----- -// Invalid rank +// Invalid result shape #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -tt.func @invalid_rank(%arg0: tensor<256x128x2xi32, #blocked1> {tt.divisibility = 16 : i32}) { - // expected-error @+1 {{currently only 2D tensors are supported}} - %1 = amdgpu.view_slice %arg0[0,0,0] [256,16,2] [1,1,1] : tensor<256x128x2xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> +tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result shape cannot be larger than input shape at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x256xi32, #blocked1> tt.return } // ----- -// Invalid stride +// Invalid rank #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -tt.func @invalid_stride(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { - // expected-error @+1 {{expected unit strides but found unsupported stride [1, 2]}} - %1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,2] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> +tt.func @invalid_rank(%arg0: tensor<256x128x2xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{currently only 2D tensors are supported}} + %1 = amdgpu.extract_slice %arg0 [0,0,0] : tensor<256x128x2xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> tt.return } @@ -74,27 +74,8 @@ tt.func @invalid_stride(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = // Invalid non static offset #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_non_static_offset(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}, %arg1: i32) { - // expected-error @+1 {{currently only static offsets are supported}} - %2 = amdgpu.view_slice %arg0[0,%arg1] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> - tt.return -} - -// ----- - -// Invalid non static size -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -tt.func @invalid_non_static_size(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}, %arg1: i32) { - // expected-error @+1 {{currently only static sizes are supported}} - %2 = amdgpu.view_slice %arg0[0,0] [256, %arg1] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> - tt.return -} - -// ----- - -// Invalid non static stride -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -tt.func @invalid_non_static_stride(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}, %arg1: i32) { - // expected-error @+1 {{currently only static strides are supported}} - %2 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,%arg1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + // expected-error @+2 {{expected ']'}} + // expected-error @+1 {{expected integer value}} + %2 = amdgpu.extract_slice %arg0 [%arg1, 0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> tt.return } diff --git a/test/TritonGPU/amd/amd-viewslice-op.mlir b/test/TritonGPU/amd/amd-viewslice-op.mlir index ff967e88686e..ef47a9f9b434 100644 --- a/test/TritonGPU/amd/amd-viewslice-op.mlir +++ b/test/TritonGPU/amd/amd-viewslice-op.mlir @@ -8,7 +8,7 @@ module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ct // CHECK-COUNT-64: %{{[0-9]*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: %64 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK-COUNT-8: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> - %72 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + %72 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> tt.return } } diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index d793c351eee5..0b865cd1b8af 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -21,6 +21,7 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ + #ifndef TRITON_AMDGPU_OPS #define TRITON_AMDGPU_OPS @@ -31,7 +32,6 @@ include "triton/Dialect/Triton/IR/TritonTypes.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Pure -include "mlir/Interfaces/ViewLikeInterface.td" // OffsetSizeAndStrideOpInterface include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "TritonAMDGPUDialect.td" include "TritonAMDGPUAttrDefs.td" @@ -47,23 +47,20 @@ class TT_AMDGPU_Op traits = []> : def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; //===----------------------------------------------------------------------===// -// ViewSliceOp +// ExtractSliceOp //===----------------------------------------------------------------------===// -def ViewSliceOp - : TT_AMDGPU_Op<"view_slice", [AttrSizedOperandSegments, - OffsetSizeAndStrideOpInterface, Pure]> { - let summary = "view slice operation"; +def ExtractSliceOp + : TT_AMDGPU_Op<"extract_slice", [Pure]> { + let summary = "extract slice operation"; let description = [{ - The "view_slice" operation enables viewing a slice of a tensor in - registers without data exchange. + The "extract_slice" operation enables extracting a slice of a tensor in + registers. - The "view_slice" operation supports the following arguments: + The "extract_slice" operation supports the following arguments: * source: the base tensor on which to create a view tensor * offsets: offsets into the base tensor at which to create the view - * size: size of the result "view" tensor - * strides: the number of strides for each dimension Example 1: @@ -74,14 +71,13 @@ def ViewSliceOp threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}> %1 = triton_gpu.convert_layout %0 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #blocked1> - // create a slice of base tensor %1 with - // static offsets and sizes for each dimension - %2 = amdgpu.view_slice %0[0, 0] [128, 32] [1, 1] : + // create a slice of base tensor %1 with static offsets + %2 = amdgpu.extract_slice %0 [0, 0] : tensor<128x128xf16, #blocked1> to tensor<128x32xf16, #blocked1> ``` - Example 1 shows how "view_slice" operation may be used. In this example a - new view of 128x32 is created. "view_slice" works on tensors with layout + Example 1 shows how "extract_slice" operation may be used. In this example a + new slice of 128x32 is created. "extract_slice" works on tensors with layout where the desired slice has the same layout as the source tensor. "%0" cannot be sliced directly as the resulting slice cannot have the same layout as "%0". Therefore it needs to be converted to a layout suitable @@ -89,34 +85,22 @@ def ViewSliceOp sizePerThread the same thus keeping coalescing properties the same. In order to utilize all threads in a warp, "threadsPerWarp" is set to [16,4] for this new layout. This layout conversion carried out before - using "view_slice" ensures slicing still uses all threads efficiently. + using "extract_slice" ensures slicing still uses all threads efficiently. The + size of the slice is determined by the result type. }]; let arguments = (ins AnyRankedTensor:$source, - Variadic:$offsets, - Variadic:$sizes, - Variadic:$strides, - DenseI64ArrayAttr:$static_offsets, - DenseI64ArrayAttr:$static_sizes, - DenseI64ArrayAttr:$static_strides); + DenseI64ArrayAttr:$static_offsets); let results = (outs AnyRankedTensor:$result); let builders = [ - // Build a ViewSliceOp with mixed static and dynamic entries and the same - // result type + // Build a ExtractSliceOp with static offsets and the same result type OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source, - "ArrayRef":$offsets, - "ArrayRef":$sizes, - "ArrayRef":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, + "ArrayRef": $static_offsets)>, ]; let extraClassDeclaration = [{ - /// Return the number of leading operands before the `offsets`, `sizes` and - /// and `strides` operands. - static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } - std::array getArrayAttrMaxRanks() { unsigned rank = getSource().getType().getRank(); return {rank, rank, rank}; @@ -124,11 +108,7 @@ def ViewSliceOp }]; let assemblyFormat = [{ - $source `` - custom($offsets, $static_offsets) - custom($sizes, $static_sizes) - custom($strides, $static_strides) - attr-dict `:` type($source) `to` type($result) + $source $static_offsets attr-dict `:` type($source) `to` type($result) }]; let hasVerifier = 1; diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h index bc0b03d00461..90922e802988 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h @@ -5,9 +5,9 @@ namespace mlir::triton::AMD { -void populateViewSliceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, - mlir::RewritePatternSet &patterns, - mlir::PatternBenefit benefit); +void populateExtractSliceOpToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns, + mlir::PatternBenefit benefit); } diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 26af80d620a1..2f5dd3535699 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -23,7 +23,6 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/OperationSupport.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/TypeSwitch.h" @@ -57,7 +56,7 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { namespace mlir::triton::amdgpu { -LogicalResult ViewSliceOp::verify() { +LogicalResult ExtractSliceOp::verify() { auto srcTy = getSource().getType(); auto srcLayout = srcTy.getEncoding(); auto srcElementType = getElementTypeOrSelf(srcTy); @@ -83,31 +82,18 @@ LogicalResult ViewSliceOp::verify() { shapePerCTA[0] = std::min(static_cast(srcShape[0]), shapePerCTA[0]); shapePerCTA[1] = std::min(static_cast(srcShape[1]), shapePerCTA[1]); - auto checkForConstInts = [](OpFoldResult ofr) { - return getConstantIntValue(ofr).has_value(); - }; - - if (!llvm::all_of(getMixedOffsets(), checkForConstInts)) { - return emitError("currently only static offsets are supported"); - } - if (!llvm::all_of(getMixedSizes(), checkForConstInts)) { - return emitError("currently only static sizes are supported"); - } - if (!llvm::all_of(getMixedStrides(), checkForConstInts)) { - return emitError("currently only static strides are supported"); - } - - auto offsets = getStaticOffsets(); - auto sizes = getStaticSizes(); - - // ViewSlice only supports slicing where offsets and sizes are multiples of + // ExtractSlice only supports slicing where offsets and sizes are multiples of // shapePerCTA. This condition ensures that slice has the same layout as the // original tensor. - if (offsets[0] % shapePerCTA[0] != 0 || offsets[1] % shapePerCTA[1] != 0) { - return emitError() << "offset [" << offsets - << "] must be a multiple of shapePerCTA [" << shapePerCTA - << "]"; + SmallVector sizes; + for (auto i = 0; i < 2; ++i) { + if (resultTy.getDimSize(i) > srcTy.getDimSize(i)) { + return emitError( + "result shape cannot be larger than input shape at dimension ") + << i; + } + sizes.push_back(resultTy.getDimSize(i)); } if (sizes[0] % shapePerCTA[0] != 0 || sizes[1] % shapePerCTA[1] != 0) { @@ -116,9 +102,12 @@ LogicalResult ViewSliceOp::verify() { << "]"; } - if (!hasUnitStride()) { - return emitError("expected unit strides but found unsupported stride [") - << getStaticStrides() << "]"; + auto offsets = getStaticOffsets(); + + if (offsets[0] % shapePerCTA[0] != 0 || offsets[1] % shapePerCTA[1] != 0) { + return emitError() << "offset [" << offsets + << "] must be a multiple of shapePerCTA [" << shapePerCTA + << "]"; } return success(); diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt index aab42744aed2..4aebabc0a275 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt @@ -1,6 +1,6 @@ add_triton_library(TritonAMDGPUDialectToLLVM TritonAMDGPUToLLVMPatterns.cpp - ViewSliceOpToLLVM.cpp + ExtractSliceOpToLLVM.cpp DEPENDS TritonAMDGPUIR diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp similarity index 87% rename from third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp rename to third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp index 65ef2d1577dc..1af72316d1dc 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp @@ -49,13 +49,14 @@ using namespace mlir::triton; // clang-format on namespace { -struct ViewSliceOpConversion - : public ConvertOpToLLVMPattern { - explicit ViewSliceOpConversion(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : ConvertOpToLLVMPattern(typeConverter, benefit) {} +struct ExtractSliceOpConversion + : public ConvertOpToLLVMPattern { + explicit ExtractSliceOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit) { + } - LogicalResult processLayout(amdgpu::ViewSliceOp op, OpAdaptor adaptor, + LogicalResult processLayout(amdgpu::ExtractSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); auto srcTy = cast(op.getSource().getType()); @@ -75,8 +76,13 @@ struct ViewSliceOpConversion shapePerCTA[1] = std::min(static_cast(srcShape[1]), shapePerCTA[1]); + // Rank == 2 checked in the verifier + SmallVector sizes; + for (auto i = 0; i < 2; ++i) { + sizes.push_back(resultTy.getDimSize(i)); + } + auto offsets = op.getStaticOffsets(); - auto sizes = op.getStaticSizes(); // Calculate offsets and sizes in terms of CTA units. std::vector CTAOffsets{offsets[0] / shapePerCTA[0], @@ -115,7 +121,7 @@ struct ViewSliceOpConversion } LogicalResult - matchAndRewrite(amdgpu::ViewSliceOp op, OpAdaptor adaptor, + matchAndRewrite(amdgpu::ExtractSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcTy = op.getSource().getType(); if (isa(op.getSource().getType().getEncoding()) || @@ -129,9 +135,9 @@ struct ViewSliceOpConversion namespace mlir::triton::AMD { -void populateViewSliceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - PatternBenefit benefit) { - patterns.add(typeConverter, benefit); +void populateExtractSliceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp index 2dc6a476e5c4..c7c2f56d31de 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp @@ -5,6 +5,6 @@ namespace mlir::triton::AMD { void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { - populateViewSliceOpToLLVMPatterns(typeConverter, patterns, benefit); + populateExtractSliceOpToLLVMPatterns(typeConverter, patterns, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/python/test/test_view_slice.py b/third_party/amd/python/test/test_extract_slice.py similarity index 92% rename from third_party/amd/python/test/test_view_slice.py rename to third_party/amd/python/test/test_extract_slice.py index e691039c2f20..d9f388c985bf 100644 --- a/third_party/amd/python/test/test_view_slice.py +++ b/third_party/amd/python/test/test_extract_slice.py @@ -38,7 +38,7 @@ def __str__(self): # test view slice # ----------------------- -view_layout = [ +extract_layout = [ BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([2, 2], [64, 1], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([2, 2], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), @@ -57,12 +57,12 @@ def __str__(self): @pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", [[256, 256, 256, 32, 0, 32], [128, 128, 128, 64, 0, 64]]) @pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("view_layout", view_layout) +@pytest.mark.parametrize("view_layout", extract_layout) @pytest.mark.parametrize("blocked_layout", blocked_layout) -def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, view_layout, - device='cuda'): +def test_extract_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, view_layout, + device='cuda'): if not is_hip(): - pytest.skip("view_slice is AMD specific instruction.") + pytest.skip("extract_slice is AMD specific instruction.") ir = f""" #blocked = {blocked_layout} @@ -92,7 +92,7 @@ def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> %12 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #view_layout> - %13 = amdgpu.view_slice %12[{M_tile_offset}, {N_tile_offset}] [{M_tile_size}, {N_tile_size}] [1, 1] : tensor<{M}x{N}xf16, #view_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout> + %13 = amdgpu.extract_slice %12 [{M_tile_offset}, {N_tile_offset}] : tensor<{M}x{N}xf16, #view_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout> %14 = triton_gpu.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> From e4729bac9d7a5b2e361dcecd685a9f6b8f37a156 Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Fri, 25 Oct 2024 15:56:43 +0000 Subject: [PATCH 12/15] Adds bound checks for each dimension and renames files to reflect extract slice --- ...m.mlir => invalid_extractslice_to_llvm.mlir} | 12 +++++++++++- ...ewslice-op.mlir => amd-extractslice-op.mlir} | 0 .../amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 17 +++++++++++++---- .../amd/python/test/test_extract_slice.py | 8 ++++---- 4 files changed, 28 insertions(+), 9 deletions(-) rename test/Conversion/amd/{invalid_viewslice_to_llvm.mlir => invalid_extractslice_to_llvm.mlir} (88%) rename test/TritonGPU/amd/{amd-viewslice-op.mlir => amd-extractslice-op.mlir} (100%) diff --git a/test/Conversion/amd/invalid_viewslice_to_llvm.mlir b/test/Conversion/amd/invalid_extractslice_to_llvm.mlir similarity index 88% rename from test/Conversion/amd/invalid_viewslice_to_llvm.mlir rename to test/Conversion/amd/invalid_extractslice_to_llvm.mlir index 0cacabf12a17..ce2a1c1c164b 100644 --- a/test/Conversion/amd/invalid_viewslice_to_llvm.mlir +++ b/test/Conversion/amd/invalid_extractslice_to_llvm.mlir @@ -10,7 +10,7 @@ tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibili // ----- -// Invalid offset +// Invalid offset, not multiple of shapePerTile #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTA [256, 16]}} @@ -20,6 +20,16 @@ tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibi // ----- +// Invalid offset, out of bounds for dimension +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{invalid offset 128 at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,128] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + // Invalid result layout #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> diff --git a/test/TritonGPU/amd/amd-viewslice-op.mlir b/test/TritonGPU/amd/amd-extractslice-op.mlir similarity index 100% rename from test/TritonGPU/amd/amd-viewslice-op.mlir rename to test/TritonGPU/amd/amd-extractslice-op.mlir diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 2f5dd3535699..8dd4517036d0 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -86,14 +86,25 @@ LogicalResult ExtractSliceOp::verify() { // shapePerCTA. This condition ensures that slice has the same layout as the // original tensor. + auto offsets = getStaticOffsets(); + if (offsets.size() != 2) { + return emitError("invalid offset shape ") << offsets; + } + SmallVector sizes; for (auto i = 0; i < 2; ++i) { - if (resultTy.getDimSize(i) > srcTy.getDimSize(i)) { + auto resultDimSize = resultTy.getDimSize(i); + auto srcDimSize = srcTy.getDimSize(i); + if (resultDimSize > srcDimSize) { return emitError( "result shape cannot be larger than input shape at dimension ") << i; } - sizes.push_back(resultTy.getDimSize(i)); + if (offsets[i] + resultDimSize > srcDimSize) { + return emitError("invalid offset ") + << offsets[i] << " at dimension " << i; + } + sizes.push_back(resultDimSize); } if (sizes[0] % shapePerCTA[0] != 0 || sizes[1] % shapePerCTA[1] != 0) { @@ -102,8 +113,6 @@ LogicalResult ExtractSliceOp::verify() { << "]"; } - auto offsets = getStaticOffsets(); - if (offsets[0] % shapePerCTA[0] != 0 || offsets[1] % shapePerCTA[1] != 0) { return emitError() << "offset [" << offsets << "] must be a multiple of shapePerCTA [" << shapePerCTA diff --git a/third_party/amd/python/test/test_extract_slice.py b/third_party/amd/python/test/test_extract_slice.py index d9f388c985bf..59dd56cb5f95 100644 --- a/third_party/amd/python/test/test_extract_slice.py +++ b/third_party/amd/python/test/test_extract_slice.py @@ -35,7 +35,7 @@ def __str__(self): # ----------------------- -# test view slice +# test extract slice # ----------------------- extract_layout = [ @@ -107,9 +107,9 @@ def test_extract_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_t f.flush() kernel = triton.compile(f.name) - view = torch.empty((M_tile_size, N_tile_size), device=device, dtype=torch.float16) + extract_slice = torch.empty((M_tile_size, N_tile_size), device=device, dtype=torch.float16) - kernel[(1, 1, 1)](x.data_ptr(), view) + kernel[(1, 1, 1)](x.data_ptr(), extract_slice) test_result = torch.equal(x[M_tile_offset:M_tile_size + M_tile_offset, N_tile_offset:N_tile_offset + N_tile_size], - view) + extract_slice) assert test_result From b4c14eb310b7f362d9139dcb4cb235a44446373a Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Thu, 14 Nov 2024 18:53:26 +0000 Subject: [PATCH 13/15] Adds zero dimension check and related tests --- .../amd/invalid_extractslice_to_llvm.mlir | 20 +++++++++++++++++++ .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 6 ++++++ .../ExtractSliceOpToLLVM.cpp | 8 ++++---- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/test/Conversion/amd/invalid_extractslice_to_llvm.mlir b/test/Conversion/amd/invalid_extractslice_to_llvm.mlir index ce2a1c1c164b..e98d11cd48bc 100644 --- a/test/Conversion/amd/invalid_extractslice_to_llvm.mlir +++ b/test/Conversion/amd/invalid_extractslice_to_llvm.mlir @@ -10,6 +10,26 @@ tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibili // ----- +// Invalid zero source dimension +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_size_input(%arg0: tensor<256x0xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{source tensor dimension size zero at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x0xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid zero result dimension +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result tensor dimension size zero at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x0xi32, #blocked1> + tt.return +} + +// ----- + // Invalid offset, not multiple of shapePerTile #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 8dd4517036d0..6404828db016 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -95,6 +95,12 @@ LogicalResult ExtractSliceOp::verify() { for (auto i = 0; i < 2; ++i) { auto resultDimSize = resultTy.getDimSize(i); auto srcDimSize = srcTy.getDimSize(i); + if (resultDimSize == 0) { + return emitError("result tensor dimension size zero at dimension ") << i; + } + if (srcDimSize == 0) { + return emitError("source tensor dimension size zero at dimension ") << i; + } if (resultDimSize > srcDimSize) { return emitError( "result shape cannot be larger than input shape at dimension ") diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp index 1af72316d1dc..47112b2f1e6a 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp @@ -66,7 +66,7 @@ struct ExtractSliceOpConversion auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter); auto elemsPerThread = triton::gpu::getElemsPerThread(srcTy); auto sizePerThread = triton::gpu::getSizePerThread(srcLayout); - auto totalSizePerThread = sizePerThread[0] * sizePerThread[1]; + auto totalSizePerThread = product(sizePerThread); auto order = triton::gpu::getOrder(srcLayout); // Calculate valid total number of workers in each dimension @@ -85,11 +85,11 @@ struct ExtractSliceOpConversion auto offsets = op.getStaticOffsets(); // Calculate offsets and sizes in terms of CTA units. - std::vector CTAOffsets{offsets[0] / shapePerCTA[0], + std::array CTAOffsets{offsets[0] / shapePerCTA[0], offsets[1] / shapePerCTA[1]}; - std::vector CTASizes{sizes[0] / shapePerCTA[0], + std::array CTASizes{sizes[0] / shapePerCTA[0], sizes[1] / shapePerCTA[1]}; - std::vector CTAPerShape{srcShape[0] / shapePerCTA[0], + std::array CTAPerShape{srcShape[0] / shapePerCTA[0], srcShape[1] / shapePerCTA[1]}; // The diagram above illustrates the graphical representation of the From 5454ef366565a4156009053e7cc3ac1b540d0a2b Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Tue, 19 Nov 2024 03:52:29 +0000 Subject: [PATCH 14/15] Refactors code (shapePerCTATile, isa<>) --- .../amd/invalid_extractslice_to_llvm.mlir | 4 +-- .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 27 +++++++++++-------- .../ExtractSliceOpToLLVM.cpp | 26 +++++++++--------- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/test/Conversion/amd/invalid_extractslice_to_llvm.mlir b/test/Conversion/amd/invalid_extractslice_to_llvm.mlir index e98d11cd48bc..e561dfb26905 100644 --- a/test/Conversion/amd/invalid_extractslice_to_llvm.mlir +++ b/test/Conversion/amd/invalid_extractslice_to_llvm.mlir @@ -3,7 +3,7 @@ // Invalid size #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { - // expected-error @+1 {{sizes [256, 2] must be a multiple of shapePerCTA [256, 16]}} + // expected-error @+1 {{sizes [256, 2] must be a multiple of shapePerCTATile [256, 16]}} %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x2xi32, #blocked1> tt.return } @@ -33,7 +33,7 @@ tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibili // Invalid offset, not multiple of shapePerTile #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { - // expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTA [256, 16]}} + // expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTATile [256, 16]}} %1 = amdgpu.extract_slice %arg0 [0,5] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> tt.return } diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 6404828db016..7c2473dbe56f 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -78,13 +78,16 @@ LogicalResult ExtractSliceOp::verify() { } auto srcShape = srcTy.getShape(); - auto shapePerCTA = mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape); - shapePerCTA[0] = std::min(static_cast(srcShape[0]), shapePerCTA[0]); - shapePerCTA[1] = std::min(static_cast(srcShape[1]), shapePerCTA[1]); + auto shapePerCTATile = + mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape); + shapePerCTATile[0] = + std::min(static_cast(srcShape[0]), shapePerCTATile[0]); + shapePerCTATile[1] = + std::min(static_cast(srcShape[1]), shapePerCTATile[1]); // ExtractSlice only supports slicing where offsets and sizes are multiples of - // shapePerCTA. This condition ensures that slice has the same layout as the - // original tensor. + // shapePerCTATile. This condition ensures that slice has the same layout as + // the original tensor. auto offsets = getStaticOffsets(); if (offsets.size() != 2) { @@ -113,16 +116,18 @@ LogicalResult ExtractSliceOp::verify() { sizes.push_back(resultDimSize); } - if (sizes[0] % shapePerCTA[0] != 0 || sizes[1] % shapePerCTA[1] != 0) { + if (sizes[0] % shapePerCTATile[0] != 0 || + sizes[1] % shapePerCTATile[1] != 0) { return emitError() << "sizes [" << sizes - << "] must be a multiple of shapePerCTA [" << shapePerCTA - << "]"; + << "] must be a multiple of shapePerCTATile [" + << shapePerCTATile << "]"; } - if (offsets[0] % shapePerCTA[0] != 0 || offsets[1] % shapePerCTA[1] != 0) { + if (offsets[0] % shapePerCTATile[0] != 0 || + offsets[1] % shapePerCTATile[1] != 0) { return emitError() << "offset [" << offsets - << "] must be a multiple of shapePerCTA [" << shapePerCTA - << "]"; + << "] must be a multiple of shapePerCTATile [" + << shapePerCTATile << "]"; } return success(); diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp index 47112b2f1e6a..c0100812f299 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp @@ -70,11 +70,11 @@ struct ExtractSliceOpConversion auto order = triton::gpu::getOrder(srcLayout); // Calculate valid total number of workers in each dimension - auto shapePerCTA = triton::gpu::getShapePerCTATile(srcLayout, srcShape); - shapePerCTA[0] = - std::min(static_cast(srcShape[0]), shapePerCTA[0]); - shapePerCTA[1] = - std::min(static_cast(srcShape[1]), shapePerCTA[1]); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout, srcShape); + shapePerCTATile[0] = + std::min(static_cast(srcShape[0]), shapePerCTATile[0]); + shapePerCTATile[1] = + std::min(static_cast(srcShape[1]), shapePerCTATile[1]); // Rank == 2 checked in the verifier SmallVector sizes; @@ -85,12 +85,12 @@ struct ExtractSliceOpConversion auto offsets = op.getStaticOffsets(); // Calculate offsets and sizes in terms of CTA units. - std::array CTAOffsets{offsets[0] / shapePerCTA[0], - offsets[1] / shapePerCTA[1]}; - std::array CTASizes{sizes[0] / shapePerCTA[0], - sizes[1] / shapePerCTA[1]}; - std::array CTAPerShape{srcShape[0] / shapePerCTA[0], - srcShape[1] / shapePerCTA[1]}; + std::array CTAOffsets{offsets[0] / shapePerCTATile[0], + offsets[1] / shapePerCTATile[1]}; + std::array CTASizes{sizes[0] / shapePerCTATile[0], + sizes[1] / shapePerCTATile[1]}; + std::array CTAPerShape{srcShape[0] / shapePerCTATile[0], + srcShape[1] / shapePerCTATile[1]}; // The diagram above illustrates the graphical representation of the // skipElems, tensorStride, and lastIdx variables. @@ -124,8 +124,8 @@ struct ExtractSliceOpConversion matchAndRewrite(amdgpu::ExtractSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcTy = op.getSource().getType(); - if (isa(op.getSource().getType().getEncoding()) || - isa(op.getSource().getType().getEncoding())) { + if (isa( + op.getSource().getType().getEncoding())) { return processLayout(op, adaptor, rewriter); } return failure(); From 17e617eeb2e42fb0ff8d06e597d491e9a2d87da3 Mon Sep 17 00:00:00 2001 From: Hasitha Algewaththa Date: Tue, 19 Nov 2024 15:43:49 +0000 Subject: [PATCH 15/15] refactors test_extract_slice.py --- python/test/unit/language/test_core.py | 571 ++++++++++++------ .../amd/python/test/test_extract_slice.py | 14 +- 2 files changed, 382 insertions(+), 203 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3013bbf53177..a499dc232146 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5,7 +5,7 @@ from typing import Optional import math import textwrap -import tempfile +import pathlib import numpy as np import pytest @@ -23,11 +23,16 @@ int_dtypes, uint_dtypes, float_dtypes, + float_dtypes_with_bfloat16, dtypes, dtypes_with_bfloat16, is_cuda, is_interpreter, is_hip, + is_hip_cdna, + is_hip_mi200, + is_hip_mi300, + is_xpu, get_arch, torch_float8_dtypes, torch_dtypes, @@ -66,6 +71,11 @@ def _bitwidth(dtype: str) -> int: return int(re.search(r'(\d+)$', dtype).group(1)) +def _dtype(dtype: str) -> str: + # ex.: "int64" -> "int" + return re.match(r'([a-zA-Z]+)', dtype).group(0) + + def patch_kernel(template, to_replace): if is_interpreter(): local_namespace = {} @@ -139,6 +149,17 @@ def __str__(self): return f"#{GPU_DIALECT}.nvidia_mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" +class DotOperandLayout: + + def __init__(self, parent, op_idx, k_width): + self.parent = parent + self.op_idx = op_idx + self.k_width = k_width + + def __str__(self): + return f"#{GPU_DIALECT}.dot_op<{{parent={self.parent}, opIdx={self.op_idx}, kWidth={self.k_width}}}>" + + class BlockedLayout: def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): @@ -267,7 +288,8 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, - y_low=None, y_high=None, filter_y=None, test_broadcast=True, test_scalar=True): + x_low=None, x_high=None, y_low=None, y_high=None, filter_y=None, test_broadcast=True, + test_scalar=True): check_type_supported(dtype_x, device) # early return if dtype_x is not supported check_type_supported(dtype_y, device) SIZE = 128 @@ -312,7 +334,7 @@ def kernel_scalar_rhs(Z, X, y: tl.constexpr, SIZE: tl.constexpr): # inputs rs = RandomState(17) - x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs, low=x_low, high=x_high) y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high) if filter_y: y[filter_y(y)] = 1 @@ -346,7 +368,7 @@ def do_test(x, y, kernel_fn): z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) err_msg = f"{expr}, {kernel_fn.__name__}" - np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=3e-3, rtol=0.01) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=7e-3, rtol=0.01) def get_scalar(x, dtype, low, high, filter): # If dtype is int, don't choose a huge number for the scalar @@ -380,28 +402,32 @@ def get_scalar(x, dtype, low, high, filter): do_test(x, y[:1].reshape(()), kernel_broadcast_rhs) -def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: - # FIXME For large x, we are casting x to a floating point where it does not fit - # For small y, we are computing floor(div(float(x), y)) which may not fit - return (dtype_x, dtype_y) in [ - ('int32', 'bfloat16'), - ('int32', 'float16'), - ('int32', 'float32'), - ('int64', 'bfloat16'), - ('int64', 'float16'), - ('int64', 'float32'), - ('int64', 'float64'), - ('uint16', 'bfloat16'), - ('uint16', 'float16'), - ('uint16', 'float32'), - ('uint32', 'bfloat16'), - ('uint32', 'float16'), - ('uint32', 'float32'), - ('uint64', 'bfloat16'), - ('uint64', 'float16'), - ('uint64', 'float32'), - ('uint64', 'float64'), - ] +def _min_max_integral_mod_value(dtype_x, dtype_y) -> Optional[int]: + """ + Limit min/max values for integral types for mod values. Leads to + overflow/underflow when casting large integral types to floats. + """ + x_bitwidth = _bitwidth(dtype_x) + y_bitwidth = _bitwidth(dtype_y) + + # hard cap max value bit-width to 32 if 64 bit-width types + min_bitwidth = min(x_bitwidth, y_bitwidth, 32) + + # Limit max value bit-width to be one integral type less than the min bit-width + # For example: + # int64, float32 -> int16 + # uint16, float16 -> uint8 + x_dtype = _dtype(dtype_x) + max_bitwidth = max(min_bitwidth >> 1, 8) + dtype_max = x_dtype + str(max_bitwidth) + + max_info = np.iinfo(getattr(np, dtype_max)) + + # Still need to limit values here for uints + if max_bitwidth >= 16 and dtype_max in uint_dtypes: + return max_info.min, max_info.max // 4 + else: + return max_info.min, max_info.max def test_dtype_codegen(): @@ -425,35 +451,35 @@ def test_dtype_codegen(): @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): expr = f'x {op} y' - if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: - # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. - numpy_expr = 'np.fmod(x, y)' - elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', - 'bfloat16'): - # Triton promotes 16-bit floating-point / and % to 32-bit because there - # are no native div or FRem operations on float16. Since we have to - # convert anyway, we may as well take the accuracy bump. - numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)' + np_expr_gen = (lambda x, y: f'{x} {op} {y}') if op != '%' else (lambda x, y: f'np.fmod({x}, {y})') + + # Triton promotes 16-bit floating-point / and % to 32-bit because there + # are no native div or FRem operations on float16. Since we have to + # convert anyway, we may as well take the accuracy bump. + def promote_to_fp32(dtype_x, dtype_y): + return dtype_x in ('float16', 'bfloat16') and dtype_y not in ('float32', 'float64') + + if op in ('/', '%') and (promote_to_fp32(dtype_x, dtype_y) or promote_to_fp32(dtype_y, dtype_x)): + numpy_expr = np_expr_gen('x.astype(np.float32)', 'y.astype(np.float32)') elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): - numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + numpy_expr = np_expr_gen(f'x.astype(np.{dtype_x})', f'y.astype(np.{dtype_x})') elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): - numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + numpy_expr = np_expr_gen(f'x.astype(np.{dtype_y})', f'y.astype(np.{dtype_y})') + elif op == '%': + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + numpy_expr = np_expr_gen('x', 'y') else: numpy_expr = None - if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): - with pytest.raises(AssertionError, match="Not equal to tolerance"): - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) - elif (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or - (dtype_x in uint_dtypes and dtype_y in int_dtypes))): + + if (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or + (dtype_x in uint_dtypes and dtype_y in int_dtypes))): with pytest.raises(triton.TritonError, match='Cannot use .* because they have different signedness'): _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) else: # skip when bfloat16, as NumPy's ref performs the computation in float32 # while Triton performs it in bfloat16 - # We also skip mod when it is ill-conditioned skip_scalar_test = ((dtype_x == "bfloat16" and "float" in dtype_y) - or (expr == "x % y" and dtype_x in int_dtypes + uint_dtypes and dtype_y in float_dtypes - and _mod_operation_ill_conditioned(dtype_x, "float32"))) + or (op in ('/', '%') and dtype_x in ("float16", "bfloat16"))) # can't divide by zero not_zero = op in ('/', '%') and dtype_x in integral_dtypes and dtype_y in integral_dtypes # can't represent -int(max) @@ -462,11 +488,17 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): filter_y = lambda y: not_zero * (y == 0) | not_minus_one * (y == -1) else: filter_y = None + + if op == "%" and dtype_x in integral_dtypes and dtype_y in float_dtypes_with_bfloat16: + x_low, x_high = _min_max_integral_mod_value(dtype_x, dtype_y) + else: + x_low, x_high = None, None + _test_binary( dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, # fails with values where fmod(x, y) is roughly zero, but happens to # pass with the random values chosen for non-broadcast tests - test_broadcast=(op != "%"), filter_y=filter_y, test_scalar=not skip_scalar_test) + test_broadcast=(op != "%"), x_low=x_low, x_high=x_high, filter_y=filter_y, test_scalar=not skip_scalar_test) @pytest.mark.interpreter @@ -1111,6 +1143,9 @@ def kernel(): a = tl.arange(0, 64).reshape(2, 4, 8).trans(2, 1, 0) tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + a = tl.arange(0, 64).reshape(2, 4, 8).trans((2, 1, 0)) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + a = tl.arange(0, 64).view(2, 4, 8) tl.static_assert(a.shape == [tl.constexpr(2), tl.constexpr(4), tl.constexpr(8)]) @@ -1453,17 +1488,27 @@ def kernel(X): for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] for axis in [0, 1] for num_ctas in num_ctas_list - for dtype_x_str in ['float32', 'uint64', 'int64', 'float64']]) + for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']]) def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device): shape0, shape1 = shape # triton kernel @triton.jit - def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr, DTYPE: tl.constexpr): off0 = tl.arange(0, SHAPE0) off1 = tl.arange(0, SHAPE1) x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) + + if DTYPE == tl.float16: + # sum can have bad numerics when accumulating in float16. + # if we're dealing with float16, do the sum in float32. + x = x.to(tl.float32) + z = tl.sum(x, axis=AXIS) + + if DTYPE == tl.float16: + z = z.to(DTYPE) + if AXIS == 1: old = tl.atomic_add(Z + off0, z) tl.store(OLD + off0, old) @@ -1477,13 +1522,23 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs) old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str)) # reference results - z_ref = z + np.sum(x, axis=axis, keepdims=False) + if x.dtype == np.float16: + # do the sum in float32 to reduce numerical variation + z_ref = z + np.sum(x.astype(np.float32), axis=axis, keepdims=False).astype(x.dtype) + else: + z_ref = z + np.sum(x, axis=axis, keepdims=False) old_ref = np.copy(z) # triton result x_tri = to_triton(x, device=device) z_tri = to_triton(z, device=device) old_tri = to_triton(old, device=device) - kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas) + + def torch_to_triton_dtype(t): + if t == torch.float16: + return tl.float16 + return None + + kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), num_ctas=num_ctas) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) np.testing.assert_equal(old_ref, to_numpy(old_tri)) @@ -1699,47 +1754,33 @@ def kernel(X, Y, Z, N: tl.constexpr): @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", list(torch_dtypes)) +@pytest.mark.parametrize("constant_field", ["value", "mask"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) -def test_store_constant(dtype_str, num_ctas, device): +def test_store_constant(num_ctas, dtype_str, constant_field, device): check_type_supported(dtype_str, device) - """Tests that boolean True is stored as 1""" @triton.jit - def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, CONSTANT_FIELD: tl.constexpr): offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - output = GENERATE_TEST_HERE + if CONSTANT_FIELD == "value": + value = 1 + output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype) + mask = offsets < n_elements + elif CONSTANT_FIELD == "mask": + output = offsets < n_elements + mask = False tl.store(output_ptr + offsets, output, mask=mask) - triton_dtype_str = 'uint8' if dtype_str == 'bool' else dtype_str - kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.zeros([BLOCK_SIZE], dtype=tl.{triton_dtype_str}) + 1'}) block_size = 128 ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device) output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device) - kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) - - assert torch.all(output == ref) - - -@pytest.mark.interpreter -@pytest.mark.parametrize("num_ctas", num_ctas_list) -def test_store_constant_default_dtype(num_ctas, device): - """Tests that boolean True is stored as 1""" - @triton.jit - def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - value = 1 - output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype) - tl.store(output_ptr + offsets, output, mask=mask) - - block_size = 128 - ref = torch.ones([block_size], dtype=getattr(torch, 'int32'), device=device) - output = torch.zeros([block_size], dtype=getattr(torch, 'int32'), device=device) - kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas, CONSTANT_FIELD=constant_field) - assert torch.all(output == ref) + if constant_field == "value": + assert torch.all(output == ref) + else: + assert torch.all(output == 0) def test_load_store_same_ptr(device): @@ -2448,6 +2489,9 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): offset2 = tl.arange(0, N) x = tl.load(x_ptr + offset1) z = tl.histogram(x, N) + bias = tl.full([M, N], 1, dtype=tl.int32) + # check that histogram produces object compatible with broadcasting + biased = z + bias tl.store(z_ptr + offset2, z) torch.manual_seed(17) @@ -2517,7 +2561,22 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl. @pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) @pytest.mark.parametrize("src_layout", scan_layouts) @pytest.mark.parametrize("axis", [0, 1]) -def test_scan_layouts(M, N, src_layout, axis, device): +@pytest.mark.parametrize("add_overflow_check", [False, True]) +def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_path: pathlib.Path): + if add_overflow_check is True and is_hip(): + pytest.skip("overflow check disabled on HIP while fixing issues") + + overflow_check = """ + %17 = arith.extsi %arg2 : i32 to i64 + %18 = arith.extsi %arg3 : i32 to i64 + %19 = arith.addi %17, %18 : i64 + %i32.min = arith.constant -2147483648: i64 + %i32.max = arith.constant 2147483647: i64 + %20 = arith.cmpi slt, %19, %i32.max : i64 + %21 = arith.cmpi sge, %19, %i32.min : i64 + %22 = arith.andi %20, %21 : i1 + tt.assert %22, "overflow detected" : i1 + """ ir = f""" #blocked = {src_layout} @@ -2537,7 +2596,7 @@ def test_scan_layouts(M, N, src_layout, axis, device): %10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr, #blocked> %11 = "tt.scan"(%10) <{{axis = {axis} : i32, reverse = false}}> ({{ ^bb0(%arg2: i32, %arg3: i32): - %16 = arith.addi %arg2, %arg3 : i32 + %16 = arith.addi %arg2, %arg3 : i32{overflow_check if add_overflow_check else ""} tt.scan.return %16 : i32 }}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> %12 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> @@ -2550,10 +2609,10 @@ def test_scan_layouts(M, N, src_layout, axis, device): }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_scan_layouts.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + rs = RandomState(17) x = rs.randint(-100, 100, (M, N)).astype('int32') @@ -2599,20 +2658,36 @@ def test_scan_layouts(M, N, src_layout, axis, device): @pytest.mark.parametrize("src_layout", filter_layouts(layouts)) @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d']) -@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"]) +@pytest.mark.parametrize("dtype_str,add_overflow_check", [("int32", False), ("int32", True), ("float32", False), + ("float16", False)]) @pytest.mark.parametrize("reduce_op", ["sum", "max"]) -def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device): +def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_overflow_check, reduce_op, device, + tmp_path: pathlib.Path): if isinstance(src_layout, (MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]): pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape") if is_hip() and isinstance(src_layout, MfmaLayout) and ((M, N) == (128, 128)): pytest.skip("Skipping test because it runs out of shared memory") + if add_overflow_check is True and is_hip(): + pytest.skip("overflow check disabled on HIP while fixing issues") if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024: pytest.skip("Skipping sum reduction on float16 due to accuracy issues") if isinstance(src_layout, MmaLayout) and src_layout.version == 3: src_layout[2] = 16 if dtype_str == "float16" else 8 + overflow_check = """ + %18 = arith.extsi %arg3 : i32 to i64 + %19 = arith.extsi %arg4 : i32 to i64 + %20 = arith.addi %18, %19 : i64 + %i32.min = arith.constant -2147483648: i64 + %i32.max = arith.constant 2147483647: i64 + %21 = arith.cmpi slt, %20, %i32.max : i64 + %22 = arith.cmpi sge, %20, %i32.min : i64 + %23 = arith.andi %21, %22 : i1 + tt.assert %23, "overflow detected" : i1 + """ + ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] arith_op = { "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, # @@ -2645,7 +2720,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce f""" %14 = "tt.reduce"(%13) ({{ ^bb0(%arg3: {ty}, %arg4: {ty}): - %17 = {arith_op} %arg3, %arg4 : {ty} + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} tt.reduce.return %17 : {ty} }}) {{axis = 0 : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> {ty} tt.store %arg2, %14 : !tt.ptr<{ty}> @@ -2657,7 +2732,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce %14 = tt.expand_dims %13 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{expanded_shape}x{ty}, #src> %15 = "tt.reduce"(%14) ({{ ^bb0(%arg3: {ty}, %arg4: {ty}): - %17 = {arith_op} %arg3, %arg4 : {ty} + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} tt.reduce.return %17 : {ty} }}) {{axis = {other_axis} : i32}} : (tensor<{expanded_shape}x{ty}, #src>) -> (tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>>) %16 = triton_gpu.convert_layout %15 : tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>> -> tensor<1x{ty}, #one_d_layout> @@ -2690,15 +2765,14 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce %12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src> %13 = "tt.reduce"(%12) ({{ ^bb0(%arg3: {ty}, %arg4: {ty}): - %17 = {arith_op} %arg3, %arg4 : {ty} + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} tt.reduce.return %17 : {ty} }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> """ + epilogue - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_reduce_layouts.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) @@ -2728,7 +2802,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce @pytest.mark.parametrize("M", [32, 64, 128, 256]) @pytest.mark.parametrize("src_layout", layouts) -def test_store_op(M, src_layout, device): +def test_store_op(M, src_layout, device, tmp_path: pathlib.Path): ir = f""" #src = {src_layout} @@ -2749,10 +2823,9 @@ def test_store_op(M, src_layout, device): }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - store_kernel = triton.compile(f.name) + temp_file = tmp_path / "test_store_op.ttgir" + temp_file.write_text(ir) + store_kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = rs.randint(0, 4, (M, 1)).astype('float32') @@ -2779,7 +2852,7 @@ def test_store_op(M, src_layout, device): @pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) @pytest.mark.parametrize("src_dim", [0, 1]) @pytest.mark.parametrize("dst_dim", [0, 1]) -def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): +def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp_path: pathlib.Path): ir = f""" #dst = {dst_layout} @@ -2799,10 +2872,9 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): }} }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_convert1d.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = rs.randint(0, 4, (M, )).astype('int32') @@ -2840,7 +2912,7 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("op", ["sum", "max"]) @pytest.mark.parametrize("first_axis", [0, 1]) -def test_chain_reduce(M, N, src_layout, op, device, first_axis): +def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathlib.Path): op_str = "" if op == "sum": @@ -2881,10 +2953,9 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): }} }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_chain_reduce.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = rs.randint(0, 4, (M, N)).astype('int32') @@ -3300,41 +3371,53 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx -@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps", - [(M, N, K, col_a, col_b, type_a, type_b, 4) +@pytest.mark.parametrize("M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, num_warps, mma, kpack", + [(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, 4, mma, kpack) for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) for col_a, col_b in itertools.product([True, False], repeat=2) - for type_a in ["e2m1", "e4m3", "e5m2"] - for type_b in ["e4m3", "e5m2"]]) -def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, device): - if not is_cuda(): - pytest.skip("scaled_dot only supported on CUDA") - else: + for rhs_scale in [False, True] + for normal_type in ["e2m1", "e4m3", "e5m2"] + for mxfp_type in ["e4m3", "e5m2", "bf16"] + for mma in ([32, 16] if is_hip() else [16]) + for kpack in ([1, 2] if is_hip() else [1])]) +def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, num_warps, mma, kpack, device): + if is_cuda(): cc = torch.cuda.get_device_capability() if cc < (8, 9): pytest.skip("float8e4nv not supported on CUDA < 8.9") + if is_hip(): + if not is_hip_cdna(): + pytest.skip("scaled_dot only implemented for HIP CDNA") + if "e4m3" in (normal_type, mxfp_type) and not is_hip_mi300(): + pytest.skip(f"scaled_dot({normal_type}, {mxfp_type}) only implemented for MI300") + if mma == 16 and K == 64: + pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot") @triton.jit - def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out, + def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr, type_b: tl.constexpr): - tl.static_assert(type_b == "e4m3" or type_b == "e5m2", "type_b must be fp8") - IS_FP8: tl.constexpr = type_a == "e4m3" or type_a == "e5m2" - DIV_FACTOR: tl.constexpr = 1 if IS_FP8 else 2 - PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR - PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K + DIV_FACTOR_A: tl.constexpr = 2 if type_a == "e2m1" else 1 + DIV_FACTOR_B: tl.constexpr = 2 if type_b == "e2m1" else 1 + PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR_A + PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K // DIV_FACTOR_B a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0, PACKED_BLOCK_K_A)[None, :] * stride_a1 b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0, BLOCK_N)[None, :] * stride_b1 - SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 - scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] - a = tl.load(a_ptr) b = tl.load(b_ptr) - a_scale = tl.load(scale_a_ptr) - c = tl.dot_scaled(a, a_scale, type_a, b, None, type_b) + SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 + if a_scale is not None: + scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, + SCALE_BLOCK_K)[None, :] + a_scale = tl.load(scale_a_ptr) + if b_scale is not None: + scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, + SCALE_BLOCK_K)[None, :] + b_scale = tl.load(scale_b_ptr) + c = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b) out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] tl.store(out_ptr, c.to(tl.bfloat16)) @@ -3407,22 +3490,31 @@ def mxfp_to_bf16_kernel( offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) tl.store(mxfp_ptr + offsets, tl.ravel(mxfp), mask=offsets < N * 32) - def dot_scale_ref(x, scale, y, type_x, type_y): - e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x] - type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y] - - comp_dtype = torch.bfloat16 - - x = x.contiguous() - x_upcast = x.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype) - - N = x_upcast.numel() - BLOCK_SIZE = 512 - grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) - mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=num_warps) - assert x_upcast.isfinite().all() - - y_upcast = y.view(type_fp8_y).to(comp_dtype) + def dot_scale_ref(x, scale_x, y, scale_y, type_x, type_y): + + def upcast(v, scale, type, transposed): + comp_dtype = torch.bfloat16 + if scale is None: + type = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2, "bf16": torch.bfloat16}[type] + return v.view(type).to(comp_dtype) + e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type] + # Packing is always on the K dimension so we transpose before upcasting then transpose back. + if transposed: + v = v.mT.contiguous() + v = v.contiguous() + v_upcast = v.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype) + N = v_upcast.numel() + BLOCK_SIZE = 512 + grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) + mxfp_to_bf16_kernel[grid](v, scale, v_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, + num_warps=num_warps) + assert v_upcast.isfinite().all() + if transposed: + v_upcast = v_upcast.mT + return v_upcast + + x_upcast = upcast(x, scale_x, type_x, False) + y_upcast = upcast(y, scale_y, type_y, True) class AccumulateInFp32: @@ -3434,28 +3526,39 @@ def __exit__(self, exc_type, exc_val, exc_tb): torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value with AccumulateInFp32(): - return torch.matmul(x_upcast.to(comp_dtype), y_upcast.to(comp_dtype)) + return torch.matmul(x_upcast, y_upcast) torch.manual_seed(0) - def create_uint8(shape, col_major=False, max_val=255): + def make_arg(shape, ty, col_major=False, max_val=255): if col_major: shape = shape[:-2] + (shape[-1], shape[-2]) - ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) + if ty == "bf16": + ret = torch.randn(shape, dtype=torch.bfloat16, device=device) + # Clamp to avoid relative error issues + ret.clamp_(-2**15, 2**15 - 1) + else: + ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) if col_major: ret = ret.mT return ret - DIV_FACTOR = 2 if type_a == "e2m1" else 1 - x = create_uint8((M, K // DIV_FACTOR), col_major=col_a) - y = create_uint8((K, N), col_major=col_b) + type_a = normal_type if not rhs_scale else mxfp_type + type_b = mxfp_type if not rhs_scale else normal_type + + DIV_FACTOR_A = 2 if type_a == "e2m1" else 1 + DIV_FACTOR_B = 2 if type_b == "e2m1" else 1 + x = make_arg((M, K // DIV_FACTOR_A), type_a, col_major=col_a) + y = make_arg((K // DIV_FACTOR_B, N), type_b, col_major=col_b) # sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright) - # We substract a reasonably high number (64) so that the sum of all the mxfp elements does not overflow - m_bytes = int(type_a[1]) - bias_type_a = 1 << (m_bytes - 1) - 1 - max_exponent_type_a = (1 << m_bytes) - 1 - bias_type_a - scale_x = create_uint8((M, K // 32), max_val=255 - max_exponent_type_a - 64) + # Max scale= 2**15 + scale_x = make_arg((M, K // 32), "e8m0", max_val=127 + 15) + scale_y = make_arg((N, K // 32), "e8m0", max_val=127 + 15) + if rhs_scale: + scale_x = None + else: + scale_y = None def make_finite(x, dtype): # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and @@ -3470,23 +3573,30 @@ def make_finite(x, dtype): x = make_finite(x, type_a) y = make_finite(y, type_b) - + kernel_kwargs = {"num_warps": num_warps} + if is_hip(): + kernel_kwargs["kpack"] = kpack + kernel_kwargs["matrix_instr_nonkdim"] = mma z = x.new_empty((M, N), dtype=torch.bfloat16) - pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, - num_warps=num_warps) - - z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b) - - # generous rtol as we are sampling the whole range of floats - torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2) + pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b, + **kernel_kwargs) + z_ref = dot_scale_ref(x, scale_x, y, scale_y, type_a, type_b) + # Bigger tolerance for AMD MI200 devices. + # MI200 devices use reduced precision fp16 and bf16 and flush input and output denormal values + # to zero. Detailed info is at: + # https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + atol = 2e-4 if is_hip_mi200() else 1e-5 + rtol = 2e-2 if is_hip_mi200() else 1e-2 + torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) # make sure ld/st are vectorized - ptx = pgm.asm['ptx'] - if (max(M, N) * K) // (num_warps * 32) >= 4: - assert 'ld.global.v4' in ptx - if M * N // (num_warps * 32) >= 4: - assert 'st.global.v4' in ptx - assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) + if is_cuda(): + ptx = pgm.asm['ptx'] + if (max(M, N) * K) // (num_warps * 32) >= 4: + assert 'ld.global.v4' in ptx + if M * N // (num_warps * 32) >= 4: + assert 'st.global.v4' in ptx + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) @pytest.mark.interpreter @@ -3978,14 +4088,14 @@ def _kernel(dst, src, CACHE: tl.constexpr): amdgcn = pgm.asm['amdgcn'] cg_cache_modifier_str = 'nt' cv_cache_modifier_str = 'sc0 sc1' + buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line] global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line] - flat_load_line = [line for line in amdgcn.splitlines() if "flat_load" in line] if cache == '' or cache == '.ca': - assert cg_cache_modifier_str not in global_load_line[0] + assert cg_cache_modifier_str not in (global_load_line[0] if global_load_line else buffer_load_line[0]) if cache == '.cg': assert cg_cache_modifier_str in global_load_line[0] if cache == '.cv': - assert cv_cache_modifier_str in flat_load_line[0] + assert cv_cache_modifier_str in global_load_line[0] if is_cuda(): ptx = pgm.asm['ptx'] @@ -5072,7 +5182,9 @@ def kernel(Out): a = torch.empty((), device=device, dtype=torch.int32) h = kernel[(1, )](a) assert "ub.poison" in h.asm["ttir"], h.asm["ttir"] - assert "poison" in h.asm["llir"], h.asm["llir"] + # xpu uses llvm.store, which in this case is removed by the optimizer + if not is_xpu(): + assert "poison" in h.asm["llir"], h.asm["llir"] # ----------------------- @@ -5146,6 +5258,14 @@ def kernel(Out): BlockedLayout([4, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([4, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), ] @@ -5180,9 +5300,13 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape): @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("interm_layout", intermediate_layouts) @pytest.mark.parametrize("dst_layout", layouts) -def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): +def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path): if str(src_layout) == str(dst_layout): pytest.skip() + if (isinstance(src_layout, DotOperandLayout) + and isinstance(interm_layout, SharedLayout)) or (isinstance(dst_layout, DotOperandLayout) + and isinstance(interm_layout, SharedLayout)): + pytest.skip("DotOperandLayout <-> SharedLayout conversion is not completely supported") if is_hip(): try: scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N)) @@ -5245,10 +5369,10 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x, device=device) - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_convert2d.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) assert torch.equal(z, x) @@ -5301,7 +5425,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): @pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]]) @pytest.mark.parametrize("dtype", ['float16']) @pytest.mark.parametrize("mma_pair", mma_pairs) -def test_convertmma2mma(M, N, mma_pair, dtype, device): +def test_convertmma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path): if is_hip(): pytest.skip("test_mma2mma is not supported in HIP") @@ -5358,10 +5482,10 @@ def do_test(src_layout, dst_layout): x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x) - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_convertmma2mma.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) assert torch.equal(z, x) @@ -5457,7 +5581,7 @@ def matmul_kernel( # stride_cm, stride_cn, # BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # low_precision_acc: tl.constexpr, # - num_pipeline_stages: tl.constexpr = 3 # + num_stages: tl.constexpr = 3 # ): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -5469,7 +5593,7 @@ def matmul_kernel( # a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_pipeline_stages): + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_stages): a = tl.load(a_ptrs) b = tl.load(b_ptrs) accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc) @@ -5508,7 +5632,7 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps, - num_pipeline_stages=num_stages) + num_stages=num_stages) torch_a = torch.from_numpy(A).to(device=device) th_a = f8_to_f16(torch_a, in_type_str) torch_b = torch.from_numpy(B).to(device=device) @@ -5700,7 +5824,7 @@ def test_tl_range(device): pgm = matmul_kernel[ 1, ](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N, - BLOCK_K, 0, num_pipeline_stages=5) + BLOCK_K, 0, num_stages=5) ref_out = torch.matmul(a, b).to(torch.float32) if is_interpreter(): # GPU invokes tensor core for float16 matmul, which is not supported in interpreter. @@ -5726,8 +5850,8 @@ def maxnreg_noinline2(X): tl.store(X, 0) +@pytest.mark.interpreter def test_maxnreg(device): - assert not is_interpreter(), "this test won't work with the interpreter" if not is_cuda(): pytest.skip('maxnreg only works on CUDA') @@ -5741,14 +5865,15 @@ def kernel(X): X = torch.empty(1, dtype=torch.int32, device=device) k = kernel[(1, )](X, maxnreg=42) - # Ensure that .maxnreg is set on the kernel function (marked with .entry) - # and not on either of the noinline functions (marked with .func). - try: - assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) - assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) - except AssertionError: - print("Failing ptx:\n", k.asm["ptx"]) - raise + if not is_interpreter(): + # Ensure that .maxnreg is set on the kernel function (marked with .entry) + # and not on either of the noinline functions (marked with .func). + try: + assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) + assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) + except AssertionError: + print("Failing ptx:\n", k.asm["ptx"]) + raise @pytest.mark.interpreter @@ -5896,6 +6021,30 @@ def sanitize_sum_kernel(Z, X, BLOCK: tl.constexpr): torch.testing.assert_close(Z, X.sum().to(torch.int32)) +@pytest.mark.parametrize("reduce_dim", [0, 1]) +def test_side_effectful_reduction_2d(device, reduce_dim): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, reduce_dim: tl.constexpr, + NON_REDUCE_DIM: tl.constexpr): + offsets = tl.arange(0, BLOCK_0)[:, None] * BLOCK_1 + tl.arange(0, BLOCK_1)[None, :] + vals = tl.load(X + offsets) + z = tl.reduce(vals, reduce_dim, sanitize_add) + tl.store(Z + tl.arange(0, NON_REDUCE_DIM), z) + + BLOCK_0 = 16 + BLOCK_1 = 32 + NON_REDUCE_DIM = BLOCK_1 if reduce_dim == 0 else BLOCK_0 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK_0, BLOCK_1], device="cuda", dtype=torch.int32) + Z = torch.zeros([NON_REDUCE_DIM], device="cuda", dtype=torch.int32) + sanitize_sum_2d_kernel[(1, )](Z, X, BLOCK_0=BLOCK_0, BLOCK_1=BLOCK_1, reduce_dim=reduce_dim, + NON_REDUCE_DIM=NON_REDUCE_DIM) + torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32)) + + def test_side_effectful_scan(device): if device != "cuda": pytest.skip() @@ -5914,3 +6063,33 @@ def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr): Z = torch.zeros_like(X) sanitize_cumsum_kernel[(1, )](Z, X, BLOCK=BLOCK) torch.testing.assert_close(Z, X.cumsum(0).to(torch.int32)) + + +# stress test slice layout usages in reductions. +@pytest.mark.parametrize("in_shape, perm, red_dims", [ + ((4, 32, 32, 4, 2), [2, 1, 0, 3, 4], [3, 1, 0]), + ((8, 2, 32, 4, 16), [4, 0, 1, 3, 2], [0, 2, 0]), +]) +def test_chained_reductions(in_shape, perm, red_dims, device): + + @triton.jit + def kernel(In, Out, # + dim_0: tl.constexpr, dim_1: tl.constexpr, dim_2: tl.constexpr, dim_3: tl.constexpr, dim_4: tl.constexpr, + perm_0: tl.constexpr, perm_1: tl.constexpr, perm_2: tl.constexpr, perm_3: tl.constexpr, + perm_4: tl.constexpr, red_dim_0: tl.constexpr, red_dim_1: tl.constexpr, red_dim_2: tl.constexpr): + idx = tl.arange(0, dim_0 * dim_1 * dim_2 * dim_3 * dim_4) + idx = idx.reshape(dim_0, dim_1, dim_2, dim_3, dim_4) + vals = tl.load(In + idx) + vals = tl.permute(vals, [perm_0, perm_1, perm_2, perm_3, perm_4]) + r = tl.sum(tl.sum(tl.sum(vals, red_dim_0), red_dim_1), red_dim_2) + st_idx = tl.arange(0, r.shape[0] * r.shape[1]).reshape(r.shape) + tl.store(Out + st_idx, r) + + input = torch.randint(0, 1000, in_shape, device=device, dtype=torch.int32) + temp = torch.permute(input, perm).contiguous() + ref = torch.sum(torch.sum(torch.sum(temp, dim=red_dims[0]), dim=red_dims[1]), dim=red_dims[2]) + result = torch.empty_like(ref) + kernel[(1, )](input, result, input.shape[0], input.shape[1], input.shape[2], input.shape[3], input.shape[4], + perm[0], perm[1], perm[2], perm[3], perm[4], red_dims[0], red_dims[1], red_dims[2]) + + assert torch.all(ref == result) diff --git a/third_party/amd/python/test/test_extract_slice.py b/third_party/amd/python/test/test_extract_slice.py index 59dd56cb5f95..bca36038405d 100644 --- a/third_party/amd/python/test/test_extract_slice.py +++ b/third_party/amd/python/test/test_extract_slice.py @@ -57,16 +57,16 @@ def __str__(self): @pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", [[256, 256, 256, 32, 0, 32], [128, 128, 128, 64, 0, 64]]) @pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("view_layout", extract_layout) +@pytest.mark.parametrize("extract_layout", extract_layout) @pytest.mark.parametrize("blocked_layout", blocked_layout) -def test_extract_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, view_layout, - device='cuda'): +def test_extract_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, + extract_layout, device='cuda'): if not is_hip(): pytest.skip("extract_slice is AMD specific instruction.") ir = f""" #blocked = {blocked_layout} - #view_layout = {view_layout} + #extract_layout = {extract_layout} module attributes {{"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = {str(64)} : i32}} {{ tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> @@ -91,9 +91,9 @@ def test_extract_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_t %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> - %12 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #view_layout> - %13 = amdgpu.extract_slice %12 [{M_tile_offset}, {N_tile_offset}] : tensor<{M}x{N}xf16, #view_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout> - %14 = triton_gpu.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> + %12 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #extract_layout> + %13 = amdgpu.extract_slice %12 [{M_tile_offset}, {N_tile_offset}] : tensor<{M}x{N}xf16, #extract_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> + %14 = triton_gpu.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> tt.return