diff --git a/third_party/intel/include/Analysis/DPAS.h b/third_party/intel/include/Analysis/DPAS.h new file mode 100644 index 0000000000..98a3390ec5 --- /dev/null +++ b/third_party/intel/include/Analysis/DPAS.h @@ -0,0 +1,59 @@ +#ifndef TRITON_INTEL_ANALYSIS_DPAS_H +#define TRITON_INTEL_ANALYSIS_DPAS_H + +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton::gpu::intel { + +//===----------------------------------------------------------------------===// +// Intel DPAS Analysis +//===----------------------------------------------------------------------===// + +class DPASAnalysis { +public: + DPASAnalysis(FunctionOpInterface func); + + enum class Result { True, False, Maybe }; + + enum class DPASEngineType : uint8_t { + // data types for operands D,C,A,B. + FP32_FP32_FP16_FP16 = 0, // default + FP32_FP32_BF16_BF16, + FP32_FP32_TF32_TF32, + FP16_FP16_FP16_FP16, + BF16_BF16_BF16_BF16, + U32_U32_U8_U8, + S32_S32_S8_S8, + NOT_APPLICABLE + }; + + /// Analyze the dpasMap and return: + /// - Result::True if the function associated with this analysis contains + /// DotOp operations that can be lowered to DPAS instructions, + /// - Result::False if it contains DotOp operations that cannot be lowered + /// to DPAS instructions, and + /// - Result::Maybe if it contains DotOp operations that could be lowered to + /// DPAS instructions if the module was executed with a different subgroup + /// (aka threads per warp) size. + Result canUseDPAS() const; + + /// Return the threads per warp (aka subgroup size) supported by the DPAS + /// instruction on the given device architecture. + static unsigned supportedThreadsPerWarp(DeviceArch arch); + + /// Given a DotOp operation, return the DPAS engine type. + static DPASEngineType getDPASType(DotOp op); + +private: + /// The module enclosing the function associated with the analysis. + mlir::ModuleOp mod; + + /// The map of DotOp to DPAS type. + std::map dpasMap; +}; + +} // namespace mlir::triton::gpu::intel + +#endif // TRITON_INTEL_ANALYSIS_DPAS_H diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h index 28954785da..c4963b6412 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h @@ -15,12 +15,6 @@ namespace mlir::triton::gpu::intel { -enum class DeviceArch { - UNKNOWN = 0, - ATS, - PVC, -}; - #define GEN_PASS_DECL #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h index 30b4f3b7ce..d0dcaaccc1 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h @@ -9,9 +9,8 @@ #ifndef TRITON_DIALECT_TRITONINTELGPU_TRANSFORMS_UTILITY_H #define TRITON_DIALECT_TRITONINTELGPU_TRANSFORMS_UTILITY_H -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" namespace mlir { class ConversionPatternRewriter; @@ -19,23 +18,7 @@ class ConversionPatternRewriter; namespace mlir::triton::gpu::intel { -// data type for D_C_A_B. -enum class DPASEngineType : uint8_t { - // floating-point XMX engine instr - FP32_FP32_FP16_FP16 = 0, // default - FP32_FP32_BF16_BF16, - FP32_FP32_TF32_TF32, - FP16_FP16_FP16_FP16, - BF16_BF16_BF16_BF16, - // integer XMX engine instr - U32_U32_U8_U8, - S32_S32_S8_S8, - // - NOT_APPLICABLE, -}; - -bool supportDPAS(DotOp op, DeviceArch arch); -DPASEngineType getDPASType(DotOp op); +enum class DeviceArch { UNKNOWN = 0, ATS, PVC }; // Infers the encoding of the source of op given the result encoding. std::optional inferSrcEncoding(Operation *op, Attribute encoding); @@ -69,6 +52,7 @@ LLVM::CallOp createSPIRVBuiltinCall(Location loc, LLVM::LLVMFuncOp func, ValueRange args); DeviceArch getDeviceArch(Operation *module); + } // namespace mlir::triton::gpu::intel #endif // TRITON_DIALECT_TRITONINTELGPU_TRANSFORMS_UTILITY_H diff --git a/third_party/intel/lib/Analysis/CMakeLists.txt b/third_party/intel/lib/Analysis/CMakeLists.txt new file mode 100644 index 0000000000..163cb22029 --- /dev/null +++ b/third_party/intel/lib/Analysis/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(TritonIntelAnalysis + DPAS.cpp + + DEPENDS + TritonTableGen + TritonGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + TritonIR +) diff --git a/third_party/intel/lib/Analysis/DPAS.cpp b/third_party/intel/lib/Analysis/DPAS.cpp new file mode 100644 index 0000000000..f79cf2348a --- /dev/null +++ b/third_party/intel/lib/Analysis/DPAS.cpp @@ -0,0 +1,115 @@ +#include "intel/include/Analysis/DPAS.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; + +namespace mlir::triton::gpu::intel { + +DPASAnalysis::DPASAnalysis(FunctionOpInterface func) + : mod(func->getParentOfType()) { + DeviceArch arch = getDeviceArch(mod); + + // Populate the DPAS map. + func.walk([&](DotOp dotOp) { + DPASEngineType dpasEngineType = + (mod->hasAttr("triton_gpu.is_lts") || arch == DeviceArch::UNKNOWN) + ? DPASEngineType::NOT_APPLICABLE + : DPASAnalysis::getDPASType(dotOp); + dpasMap[dotOp] = dpasEngineType; + + // Only PVC supports TF32. + if (dpasEngineType == DPASEngineType::FP32_FP32_TF32_TF32) { + if (arch != DeviceArch::PVC || + dotOp.getInputPrecision() != InputPrecision::TF32) + dpasMap[dotOp] = DPASEngineType::NOT_APPLICABLE; + } + }); +} + +DPASAnalysis::Result DPASAnalysis::canUseDPAS() const { + if (dpasMap.empty()) + return Result::False; + + // Ensure all dot operations can be lowered to DPAS instructions. + if (llvm::any_of(dpasMap, [](const auto &entry) { + return entry.second == DPASEngineType::NOT_APPLICABLE; + })) + return Result::False; + + // Verify whether the module has the correct number of threads per warp. + // Note: if the module doesn't have the warp size attribute, return + // Result::Maybe to allow the caller to set warp size. + Attribute threadsPerWarpAttr = + mod->getDiscardableAttr(TritonGPUDialect::getThreadsPerWarpAttrName()); + if (!threadsPerWarpAttr) + return Result::Maybe; + + unsigned threadsPerWarp = cast(threadsPerWarpAttr).getInt(); + DeviceArch arch = getDeviceArch(mod); + if (threadsPerWarp == supportedThreadsPerWarp(arch)) + return Result::True; + + return Result::False; +} + +unsigned DPASAnalysis::supportedThreadsPerWarp(DeviceArch arch) { + switch (arch) { + case DeviceArch::PVC: + return 16; + case DeviceArch::ATS: + return 8; + default: + llvm_unreachable("Unexpected target architecture"); + } +} + +DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType(DotOp op) { + // d = a * b + c + auto aTy = cast(op.getA().getType()); + auto bTy = cast(op.getB().getType()); + auto cTy = cast(op.getC().getType()); + auto dTy = cast(op.getD().getType()); + Type aElemTy = aTy.getElementType(); + Type bElemTy = bTy.getElementType(); + Type cElemTy = cTy.getElementType(); + Type dElemTy = dTy.getElementType(); + + assert(cElemTy == dElemTy && "Unexpected element type mismatch"); + + if (aElemTy != bElemTy) + return DPASEngineType::NOT_APPLICABLE; + + if (dElemTy.isIntOrIndex()) { + if (dElemTy.getIntOrFloatBitWidth() == 32 && + aElemTy.getIntOrFloatBitWidth() == 8) + return dElemTy.isSignedInteger() ? DPASEngineType::S32_S32_S8_S8 + : DPASEngineType::U32_U32_U8_U8; + return DPASEngineType::NOT_APPLICABLE; + } + + if (isa(dElemTy)) { + if (dElemTy.isF32()) { + if (aElemTy.isF16()) + return DPASEngineType::FP32_FP32_FP16_FP16; + if (aElemTy.isBF16()) + return DPASEngineType::FP32_FP32_BF16_BF16; + if (aElemTy.isF32() && op.getInputPrecision() == InputPrecision::TF32) + return DPASEngineType::FP32_FP32_TF32_TF32; + // For FP8XFP8->FP32, upcast to FP16 + if (aElemTy.isFloat8E5M2()) + return DPASEngineType::FP32_FP32_FP16_FP16; + if (aElemTy.isFloat8E4M3FNUZ()) + return DPASEngineType::FP32_FP32_FP16_FP16; + } else if (dElemTy.isF16()) { + if (aElemTy.isF16()) + return DPASEngineType::FP16_FP16_FP16_FP16; + } else if (dElemTy.isBF16()) { + if (aElemTy.isBF16()) + return DPASEngineType::BF16_BF16_BF16_BF16; + } + } + + return DPASEngineType::NOT_APPLICABLE; +} + +} // namespace mlir::triton::gpu::intel diff --git a/third_party/intel/lib/CMakeLists.txt b/third_party/intel/lib/CMakeLists.txt index bd58f09bdb..5cbf44a3b6 100644 --- a/third_party/intel/lib/CMakeLists.txt +++ b/third_party/intel/lib/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Analysis) add_subdirectory(Dialect) add_subdirectory(GPUToTritonGEN) add_subdirectory(Target) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp index bbb5a526dd..158076fd2c 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp @@ -2,8 +2,10 @@ #include "../Utility.h" #include "mlir/IR/BuiltinTypes.h" +#include "intel/include/Analysis/DPAS.h" #include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" + #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include @@ -28,7 +30,8 @@ class DotOpDPASConversionHelper { typeConverter(typeConverter), loc(loc), ctx(dpasLayout.getContext()) {} std::tuple static getDPASOperandsType( - DPASEngineType dpasType, MLIRContext *ctx, DpasEncodingAttr layout) { + DPASAnalysis::DPASEngineType dpasType, MLIRContext *ctx, + DpasEncodingAttr layout) { Type fp32Ty = type::f32Ty(ctx); Type fp16Ty = type::f16Ty(ctx); Type bf16Ty = type::bf16Ty(ctx); @@ -44,6 +47,8 @@ class DotOpDPASConversionHelper { unsigned elemNumA = product(shapeA) / threadsPerWarp; SmallVector shapeB = layout.getShapeB(); unsigned elemNumB = product(shapeB) / threadsPerWarp; + + using DPASEngineType = DPASAnalysis::DPASEngineType; switch (dpasType) { case DPASEngineType::FP32_FP32_FP16_FP16: { Type cTy = vec_ty(fp32Ty, elemNumC); @@ -138,7 +143,7 @@ class DotOpDPASConversionHelper { unsigned repM = repA[0], repN = repB[1], repK = repA[1]; - auto dpasType = getDPASType(op); + auto dpasType = DPASAnalysis::getDPASType(op); auto dpasEncoding = cast(DTensorTy.getEncoding()); Type aTy, bTy, cTy, dTy; std::tie(dTy, cTy, aTy, bTy) = diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index 5d107ec6b5..71394c65c2 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -1,21 +1,18 @@ #include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "intel/include/Analysis/DPAS.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" -#include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; -namespace tt = mlir::triton; -namespace ttg = mlir::triton::gpu; -namespace ttgi = mlir::triton::gpu::intel; +using namespace mlir::triton; +using namespace mlir::triton::gpu; namespace mlir::triton::gpu::intel { #define GEN_PASS_DEF_TRITONINTELGPUACCELERATEMATMUL @@ -23,11 +20,6 @@ namespace mlir::triton::gpu::intel { } // namespace mlir::triton::gpu::intel namespace { -using tt::DotOp; -using ttg::ConvertLayoutOp; -using ttg::DotOperandEncodingAttr; -using ttgi::DeviceArch; -using ttgi::DpasEncodingAttr; struct IntelDPASCapability { uint32_t systolicDepth; @@ -36,12 +28,12 @@ struct IntelDPASCapability { uint32_t opsChanBitWidths; }; -IntelDPASCapability getDPASCapability(DeviceArch arch) { +IntelDPASCapability getDPASCapability(intel::DeviceArch arch) { switch (arch) { - case DeviceArch::UNKNOWN: + case intel::DeviceArch::UNKNOWN: return IntelDPASCapability(); - case DeviceArch::ATS: { + case intel::DeviceArch::ATS: { IntelDPASCapability cap; cap.systolicDepth = 8; cap.repeatCount = 8; @@ -50,7 +42,7 @@ IntelDPASCapability getDPASCapability(DeviceArch arch) { return cap; } - case DeviceArch::PVC: { + case intel::DeviceArch::PVC: { IntelDPASCapability cap; cap.systolicDepth = 8; cap.repeatCount = 8; @@ -64,14 +56,15 @@ IntelDPASCapability getDPASCapability(DeviceArch arch) { } } -SmallVector getWarpsPerTile(tt::DotOp dotOp, +SmallVector getWarpsPerTile(DotOp dotOp, struct IntelDPASCapability dpasCap, const ArrayRef shape, unsigned numWarps) { auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); }; - auto slices = mlir::getSlice(dotOp, {filter}); + + SetVector slices = getSlice(dotOp, {filter}); // TODO: revisit this in flash attention. for (Operation *op : slices) if (isa(op) && (op != dotOp)) @@ -88,49 +81,51 @@ SmallVector getWarpsPerTile(tt::DotOp dotOp, ceil(dpasCap.repeatCount, dpasCap.executionSize); uint32_t colRowRatio = ceil(dpasCap.executionSize, dpasCap.repeatCount); + do { if (ret[0] * ret[1] >= numWarps) break; if (shape[0] / (shapePerWarp[0] * colRowRatio) / ret[0] >= shape[1] / (shapePerWarp[1] * rowColRatio) / ret[1]) { - if (ret[0] < shape[0] / shapePerWarp[0]) { + if (ret[0] < shape[0] / shapePerWarp[0]) ret[0] *= 2; - } else + else ret[1] *= 2; } else { ret[1] *= 2; } } while (true); + return ret; } -class BlockedToDPAS : public mlir::RewritePattern { - DeviceArch arch; +class BlockedToDPAS : public RewritePattern { + intel::DeviceArch arch; public: - BlockedToDPAS(mlir::MLIRContext *context, DeviceArch arch) - : mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context), - arch(arch) {} + BlockedToDPAS(MLIRContext *context, intel::DeviceArch arch) + : RewritePattern(DotOp::getOperationName(), 2, context), arch(arch) {} - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { DotOp dotOp = cast(op); RankedTensorType oldRetType = cast(dotOp.getResult().getType()); if (!oldRetType.getEncoding() || - isa(oldRetType.getEncoding())) + isa(oldRetType.getEncoding())) return failure(); - if (!supportDPAS(dotOp, arch)) + using Result = intel::DPASAnalysis::Result; + auto funcOp = op->getParentOfType(); + Result canUseDPAS = intel::DPASAnalysis(funcOp).canUseDPAS(); + if (canUseDPAS != Result::True) return failure(); // Create DPAS encoding for the given number of warps ArrayRef retShape = oldRetType.getShape(); - ModuleOp mod = op->getParentOfType(); - unsigned numWarps = ttg::TritonGPUDialect::getNumWarps(mod); + ModuleOp mod = funcOp->getParentOfType(); + unsigned numWarps = TritonGPUDialect::getNumWarps(mod); - // operands Value a = dotOp.getA(); Value b = dotOp.getB(); RankedTensorType oldAType = cast(a.getType()); @@ -139,32 +134,33 @@ class BlockedToDPAS : public mlir::RewritePattern { IntelDPASCapability dpasCap = getDPASCapability(arch); unsigned dpasElemBitWidths = oldAType.getElementType().getIntOrFloatBitWidth(); + // We are upcasting FP8 to FP16 if (oldAType.getElementType().isFloat8E5M2() || oldAType.getElementType().isFloat8E4M3FNUZ()) dpasElemBitWidths = 2 * dpasElemBitWidths; - unsigned opsPerChan = dpasCap.opsChanBitWidths / dpasElemBitWidths; + unsigned opsPerChan = dpasCap.opsChanBitWidths / dpasElemBitWidths; SmallVector warpsPerTile = getWarpsPerTile(dotOp, dpasCap, retShape, numWarps); - unsigned threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); - DpasEncodingAttr dpasEnc = - DpasEncodingAttr::get(oldRetType.getContext(), dpasCap.repeatCount, - dpasCap.systolicDepth, dpasCap.executionSize, - opsPerChan, warpsPerTile, {1, 1}, threadsPerWarp); + unsigned threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + auto dpasEnc = intel::DpasEncodingAttr::get( + oldRetType.getContext(), dpasCap.repeatCount, dpasCap.systolicDepth, + dpasCap.executionSize, opsPerChan, warpsPerTile, {1, 1}, + threadsPerWarp); RankedTensorType newRetType = RankedTensorType::get(retShape, oldRetType.getElementType(), dpasEnc); // convert accumulator Value oldAcc = dotOp.getOperand(2); - ConvertLayoutOp newAcc = rewriter.create( - oldAcc.getLoc(), newRetType, oldAcc); + ConvertLayoutOp newAcc = + rewriter.create(oldAcc.getLoc(), newRetType, oldAcc); - DotOperandEncodingAttr newAEncoding = ttg::DotOperandEncodingAttr::get( + DotOperandEncodingAttr newAEncoding = DotOperandEncodingAttr::get( oldAType.getContext(), 0, newRetType.getEncoding(), opsPerChan); - DotOperandEncodingAttr newBEncoding = ttg::DotOperandEncodingAttr::get( + DotOperandEncodingAttr newBEncoding = DotOperandEncodingAttr::get( oldBType.getContext(), 1, newRetType.getEncoding(), opsPerChan); RankedTensorType newAType = RankedTensorType::get( @@ -172,17 +168,18 @@ class BlockedToDPAS : public mlir::RewritePattern { RankedTensorType newBType = RankedTensorType::get( oldBType.getShape(), oldBType.getElementType(), newBEncoding); - a = rewriter.create(a.getLoc(), newAType, a); - b = rewriter.create(b.getLoc(), newBType, b); + a = rewriter.create(a.getLoc(), newAType, a); + b = rewriter.create(b.getLoc(), newBType, b); DotOp newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); - rewriter.replaceOpWithNewOp(op, oldRetType, - newDot.getResult()); + rewriter.replaceOpWithNewOp(op, oldRetType, + newDot.getResult()); return success(); } }; + } // namespace static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, @@ -190,9 +187,10 @@ static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, auto tensorPromotedType = cast(operand.getType()) .cloneWith(std::nullopt, promotedType); Type elemType = tensorPromotedType.getElementType(); + return llvm::TypeSwitch(elemType) .Case([&](auto) { - return builder.create(loc, tensorPromotedType, operand); + return builder.create(loc, tensorPromotedType, operand); }) .Case([&](auto) { unsigned tgtBitWidth = elemType.getIntOrFloatBitWidth(), @@ -210,13 +208,14 @@ static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, // promote operands of dot op if the existing combination is not natively // supported. static void decomposeMixedModeDotOp(ModuleOp mod) { - mod.walk([](tt::DotOp dotOp) -> void { + mod.walk([](DotOp dotOp) -> void { auto D = dotOp.getD(); OpBuilder builder(dotOp); Type AElType = dotOp.getA().getType().getElementType(); + auto dpasLayout = + dyn_cast(D.getType().getEncoding()); + Type promoteType; - DpasEncodingAttr dpasLayout = - dyn_cast(D.getType().getEncoding()); if (dpasLayout) { bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ(); // promote operands for fp8 since fp8 DPAS is not natively supported @@ -231,6 +230,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { return; promoteType = DElType; } + Location loc = dotOp.getLoc(); Value promotedA = promoteOperand(builder, loc, dotOp.getA(), promoteType); Value promotedB = promoteOperand(builder, loc, dotOp.getB(), promoteType); @@ -249,13 +249,13 @@ class TritonIntelGPUAccelerateMatmulPass void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); - auto deviceArch = ttgi::getDeviceArch(m); + auto deviceArch = intel::getDeviceArch(m); - mlir::RewritePatternSet patterns(context); - patterns.add<::BlockedToDPAS>(context, deviceArch); - if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { + RewritePatternSet patterns(context); + patterns.add(context, deviceArch); + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); - } + // now that we pick the scalar type decompose dot that are not natively // supported. decomposeMixedModeDotOp(m); diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index 4af4b3e840..2b3e19d717 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -16,7 +16,7 @@ add_triton_library(TritonIntelGPUTransforms MLIRSCFTransforms MLIRTransforms MLIRTransformUtils - TritonAnalysis + TritonIntelAnalysis TritonIR TritonGENIR TritonGPUIR diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index 30c425eb73..865f838185 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -24,87 +24,6 @@ namespace ttgi = mlir::triton::gpu::intel; namespace mlir::triton::gpu::intel { -bool supportDPAS(DotOp op, DeviceArch arch) { - if (op->getParentOfType()->hasAttr("triton_gpu.is_lts")) - return false; - - if (arch == DeviceArch::UNKNOWN) - return false; - - auto mod = op->getParentOfType(); - int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); - - if (arch == DeviceArch::PVC && threadsPerWarp != 16) { - // Only support threadsPerWarp 16 for PVC now. - return false; - } - - if (arch == DeviceArch::ATS && threadsPerWarp != 8) { - // Only support threadsPerWarp 8 for ATS now. - return false; - } - - DPASEngineType dpasType = getDPASType(op); - - if (dpasType == DPASEngineType::FP32_FP32_TF32_TF32) { - // Only PVC support TF32. - return op.getInputPrecision() == InputPrecision::TF32 && - arch == DeviceArch::PVC; - } - - return dpasType != DPASEngineType::NOT_APPLICABLE; -} - -DPASEngineType getDPASType(DotOp op) { - // d = a * b + c - auto aTy = cast(op.getA().getType()); - auto bTy = cast(op.getB().getType()); - auto cTy = cast(op.getC().getType()); - auto dTy = cast(op.getD().getType()); - - if (aTy.getElementType() != bTy.getElementType() || - cTy.getElementType() != dTy.getElementType()) - return DPASEngineType::NOT_APPLICABLE; - - // TODO: add more dpas supported data type. - if (dTy.getElementType().isIntOrIndex()) { - // Integer - if (dTy.getElementType().getIntOrFloatBitWidth() == 32) { - if (aTy.getElementType().getIntOrFloatBitWidth() == 8 && - bTy.getElementType().getIntOrFloatBitWidth() == 8) - return dTy.getElementType().isSignedInteger() - ? DPASEngineType::S32_S32_S8_S8 - : DPASEngineType::U32_U32_U8_U8; - } - } else { - // floating. - if (dTy.getElementType().isF32()) { - if (aTy.getElementType().isF16() && bTy.getElementType().isF16()) - return DPASEngineType::FP32_FP32_FP16_FP16; - if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) - return DPASEngineType::FP32_FP32_BF16_BF16; - if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && - op.getInputPrecision() == InputPrecision::TF32) - return DPASEngineType::FP32_FP32_TF32_TF32; - // For FP8XFP8->FP32, upcast to FP16 - if (aTy.getElementType().isFloat8E5M2() && - bTy.getElementType().isFloat8E5M2()) - return DPASEngineType::FP32_FP32_FP16_FP16; - if (aTy.getElementType().isFloat8E4M3FNUZ() && - bTy.getElementType().isFloat8E4M3FNUZ()) - return DPASEngineType::FP32_FP32_FP16_FP16; - } else if (dTy.getElementType().isF16()) { - if (aTy.getElementType().isF16() && bTy.getElementType().isF16()) - return DPASEngineType::FP16_FP16_FP16_FP16; - } else if (dTy.getElementType().isBF16()) { - if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) - return DPASEngineType::BF16_BF16_BF16_BF16; - } - } - - return DPASEngineType::NOT_APPLICABLE; -} - std::optional inferSrcEncoding(Operation *op, Attribute encoding) { if (auto makeTensorPtrOp = dyn_cast(op)) return encoding; diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 655c48adf3..fc3cef3b49 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -7,6 +7,7 @@ #include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" #include "intel/include/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.h" #include "intel/include/Target/LLVMIR/PostProcess.h" #include "intel/include/TritonIntelGPUToLLVM/Passes.h"