Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions third_party/intel/include/Analysis/DPAS.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#ifndef TRITON_INTEL_ANALYSIS_DPAS_H
#define TRITON_INTEL_ANALYSIS_DPAS_H

#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

namespace mlir::triton::gpu::intel {

//===----------------------------------------------------------------------===//
// Intel DPAS Analysis
//===----------------------------------------------------------------------===//

class DPASAnalysis {
public:
DPASAnalysis(FunctionOpInterface func);

enum class Result { True, False, Maybe };

enum class DPASEngineType : uint8_t {
// data types for operands D,C,A,B.
FP32_FP32_FP16_FP16 = 0, // default
FP32_FP32_BF16_BF16,
FP32_FP32_TF32_TF32,
FP16_FP16_FP16_FP16,
BF16_BF16_BF16_BF16,
U32_U32_U8_U8,
S32_S32_S8_S8,
NOT_APPLICABLE
};

/// Analyze the dpasMap and return:
/// - Result::True if the function associated with this analysis contains
Comment thread
whitneywhtsang marked this conversation as resolved.
/// DotOp operations that can be lowered to DPAS instructions,
/// - Result::False if it contains DotOp operations that cannot be lowered
/// to DPAS instructions, and
/// - Result::Maybe if it contains DotOp operations that could be lowered to
/// DPAS instructions if the module was executed with a different subgroup
/// (aka threads per warp) size.
Result canUseDPAS() const;

/// Return the threads per warp (aka subgroup size) supported by the DPAS
/// instruction on the given device architecture.
static unsigned supportedThreadsPerWarp(DeviceArch arch);

/// Given a DotOp operation, return the DPAS engine type.
static DPASEngineType getDPASType(DotOp op);

private:
/// The module enclosing the function associated with the analysis.
mlir::ModuleOp mod;

/// The map of DotOp to DPAS type.
std::map<DotOp, DPASEngineType> dpasMap;
};

} // namespace mlir::triton::gpu::intel

#endif // TRITON_INTEL_ANALYSIS_DPAS_H
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@

namespace mlir::triton::gpu::intel {

enum class DeviceArch {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: moved to Utility.h

UNKNOWN = 0,
ATS,
PVC,
};

#define GEN_PASS_DECL
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,16 @@
#ifndef TRITON_DIALECT_TRITONINTELGPU_TRANSFORMS_UTILITY_H
#define TRITON_DIALECT_TRITONINTELGPU_TRANSFORMS_UTILITY_H

#include "triton/Dialect/Triton/IR/Dialect.h"

#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"

namespace mlir {
class ConversionPatternRewriter;
}

namespace mlir::triton::gpu::intel {

// data type for D_C_A_B.
Comment thread
etiotto marked this conversation as resolved.
enum class DPASEngineType : uint8_t {
// floating-point XMX engine instr
FP32_FP32_FP16_FP16 = 0, // default
FP32_FP32_BF16_BF16,
FP32_FP32_TF32_TF32,
FP16_FP16_FP16_FP16,
BF16_BF16_BF16_BF16,
// integer XMX engine instr
U32_U32_U8_U8,
S32_S32_S8_S8,
//
NOT_APPLICABLE,
};

bool supportDPAS(DotOp op, DeviceArch arch);
DPASEngineType getDPASType(DotOp op);
enum class DeviceArch { UNKNOWN = 0, ATS, PVC };

// Infers the encoding of the source of op given the result encoding.
std::optional<Attribute> inferSrcEncoding(Operation *op, Attribute encoding);
Expand Down Expand Up @@ -69,6 +52,7 @@ LLVM::CallOp createSPIRVBuiltinCall(Location loc,
LLVM::LLVMFuncOp func, ValueRange args);

DeviceArch getDeviceArch(Operation *module);

} // namespace mlir::triton::gpu::intel

#endif // TRITON_DIALECT_TRITONINTELGPU_TRANSFORMS_UTILITY_H
10 changes: 10 additions & 0 deletions third_party/intel/lib/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
add_triton_library(TritonIntelAnalysis
DPAS.cpp

DEPENDS
TritonTableGen
TritonGPUAttrDefsIncGen

LINK_LIBS PUBLIC
TritonIR
)
115 changes: 115 additions & 0 deletions third_party/intel/lib/Analysis/DPAS.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#include "intel/include/Analysis/DPAS.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"

using namespace mlir;

namespace mlir::triton::gpu::intel {

DPASAnalysis::DPASAnalysis(FunctionOpInterface func)
: mod(func->getParentOfType<mlir::ModuleOp>()) {
DeviceArch arch = getDeviceArch(mod);

Comment thread
etiotto marked this conversation as resolved.
// Populate the DPAS map.
func.walk([&](DotOp dotOp) {
DPASEngineType dpasEngineType =
(mod->hasAttr("triton_gpu.is_lts") || arch == DeviceArch::UNKNOWN)
Comment thread
etiotto marked this conversation as resolved.
Comment thread
etiotto marked this conversation as resolved.
Comment thread
etiotto marked this conversation as resolved.
? DPASEngineType::NOT_APPLICABLE
: DPASAnalysis::getDPASType(dotOp);
dpasMap[dotOp] = dpasEngineType;

// Only PVC supports TF32.
if (dpasEngineType == DPASEngineType::FP32_FP32_TF32_TF32) {
if (arch != DeviceArch::PVC ||
dotOp.getInputPrecision() != InputPrecision::TF32)
dpasMap[dotOp] = DPASEngineType::NOT_APPLICABLE;
}
});
}

DPASAnalysis::Result DPASAnalysis::canUseDPAS() const {
if (dpasMap.empty())
return Result::False;

// Ensure all dot operations can be lowered to DPAS instructions.
if (llvm::any_of(dpasMap, [](const auto &entry) {
return entry.second == DPASEngineType::NOT_APPLICABLE;
}))
return Result::False;

// Verify whether the module has the correct number of threads per warp.
// Note: if the module doesn't have the warp size attribute, return
// Result::Maybe to allow the caller to set warp size.
Attribute threadsPerWarpAttr =
mod->getDiscardableAttr(TritonGPUDialect::getThreadsPerWarpAttrName());
if (!threadsPerWarpAttr)
return Result::Maybe;

unsigned threadsPerWarp = cast<IntegerAttr>(threadsPerWarpAttr).getInt();
DeviceArch arch = getDeviceArch(mod);
if (threadsPerWarp == supportedThreadsPerWarp(arch))
return Result::True;

return Result::False;
}

unsigned DPASAnalysis::supportedThreadsPerWarp(DeviceArch arch) {
Comment thread
etiotto marked this conversation as resolved.
switch (arch) {
case DeviceArch::PVC:
return 16;
case DeviceArch::ATS:
return 8;
default:
llvm_unreachable("Unexpected target architecture");
}
}

DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType(DotOp op) {
// d = a * b + c
auto aTy = cast<RankedTensorType>(op.getA().getType());
auto bTy = cast<RankedTensorType>(op.getB().getType());
auto cTy = cast<RankedTensorType>(op.getC().getType());
auto dTy = cast<RankedTensorType>(op.getD().getType());
Type aElemTy = aTy.getElementType();
Type bElemTy = bTy.getElementType();
Type cElemTy = cTy.getElementType();
Type dElemTy = dTy.getElementType();

assert(cElemTy == dElemTy && "Unexpected element type mismatch");

if (aElemTy != bElemTy)
return DPASEngineType::NOT_APPLICABLE;

if (dElemTy.isIntOrIndex()) {
if (dElemTy.getIntOrFloatBitWidth() == 32 &&
aElemTy.getIntOrFloatBitWidth() == 8)
return dElemTy.isSignedInteger() ? DPASEngineType::S32_S32_S8_S8
: DPASEngineType::U32_U32_U8_U8;
return DPASEngineType::NOT_APPLICABLE;
}

if (isa<FloatType>(dElemTy)) {
if (dElemTy.isF32()) {
if (aElemTy.isF16())
return DPASEngineType::FP32_FP32_FP16_FP16;
if (aElemTy.isBF16())
return DPASEngineType::FP32_FP32_BF16_BF16;
if (aElemTy.isF32() && op.getInputPrecision() == InputPrecision::TF32)
return DPASEngineType::FP32_FP32_TF32_TF32;
// For FP8XFP8->FP32, upcast to FP16
if (aElemTy.isFloat8E5M2())
return DPASEngineType::FP32_FP32_FP16_FP16;
if (aElemTy.isFloat8E4M3FNUZ())
return DPASEngineType::FP32_FP32_FP16_FP16;
} else if (dElemTy.isF16()) {
if (aElemTy.isF16())
return DPASEngineType::FP16_FP16_FP16_FP16;
} else if (dElemTy.isBF16()) {
if (aElemTy.isBF16())
return DPASEngineType::BF16_BF16_BF16_BF16;
}
}

return DPASEngineType::NOT_APPLICABLE;
}

} // namespace mlir::triton::gpu::intel
1 change: 1 addition & 0 deletions third_party/intel/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(Analysis)
add_subdirectory(Dialect)
add_subdirectory(GPUToTritonGEN)
add_subdirectory(Target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
#include "../Utility.h"
#include "mlir/IR/BuiltinTypes.h"

#include "intel/include/Analysis/DPAS.h"
#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"

#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
Expand All @@ -28,7 +30,8 @@ class DotOpDPASConversionHelper {
typeConverter(typeConverter), loc(loc), ctx(dpasLayout.getContext()) {}

std::tuple<Type, Type, Type, Type> static getDPASOperandsType(
DPASEngineType dpasType, MLIRContext *ctx, DpasEncodingAttr layout) {
DPASAnalysis::DPASEngineType dpasType, MLIRContext *ctx,
DpasEncodingAttr layout) {
Type fp32Ty = type::f32Ty(ctx);
Type fp16Ty = type::f16Ty(ctx);
Type bf16Ty = type::bf16Ty(ctx);
Expand All @@ -44,6 +47,8 @@ class DotOpDPASConversionHelper {
unsigned elemNumA = product<unsigned>(shapeA) / threadsPerWarp;
SmallVector<unsigned> shapeB = layout.getShapeB();
unsigned elemNumB = product<unsigned>(shapeB) / threadsPerWarp;

using DPASEngineType = DPASAnalysis::DPASEngineType;
switch (dpasType) {
case DPASEngineType::FP32_FP32_FP16_FP16: {
Type cTy = vec_ty(fp32Ty, elemNumC);
Expand Down Expand Up @@ -138,7 +143,7 @@ class DotOpDPASConversionHelper {

unsigned repM = repA[0], repN = repB[1], repK = repA[1];

auto dpasType = getDPASType(op);
auto dpasType = DPASAnalysis::getDPASType(op);
auto dpasEncoding = cast<DpasEncodingAttr>(DTensorTy.getEncoding());
Type aTy, bTy, cTy, dTy;
std::tie(dTy, cTy, aTy, bTy) =
Expand Down
Loading