Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/TritonTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TritonTypeDef<string name, string _mnemonic, list<Trait> traits = []>
}

// Floating-point Type
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : RankedTensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;

Expand Down
6 changes: 4 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ bool supportMFMATypes(Type a, Type b) {
return false;

auto F8E5M2 = TypeID::get<Float8E5M2Type>();
auto F8E4M3FN = TypeID::get<Float8E4M3FNType>();
auto F8E4M3FNUZ = TypeID::get<Float8E4M3FNUZType>();
auto F8E5M2FNUZ = TypeID::get<Float8E5M2FNUZType>();
auto F16 = TypeID::get<Float16Type>();
Expand All @@ -437,6 +438,7 @@ bool supportMFMATypes(Type a, Type b) {
{F16, F16},
{BF16, BF16},
{F8E5M2, F8E5M2},
{F8E4M3FN, F8E4M3FN},
{F8E4M3FNUZ, F8E4M3FNUZ},
{F8E4M3FNUZ, F8E5M2FNUZ},
{F8E5M2FNUZ, F8E4M3FNUZ},
Expand Down Expand Up @@ -495,14 +497,14 @@ bool supportMMA(triton::DotOp op, int version) {
return false;
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
retShapePerCTA[rank - 1] % 8 == 0 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() ||
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
aElemTy.isF32()))) {
return false;
}
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
if (op.getMaxNumImpreciseAcc() < 32 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
return false;
}
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E4M3FNType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
NvidiaMmaEncodingAttr mmaLayout =
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
if (mmaLayout) {
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ();
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
// promote operands for sm < 89 since fp8 mma is not natively supported
// promote operands for sm >= 90 when mma is not v3
if (!isNativeFP8 ||
Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
SmallVector<unsigned> validN;

// MMAv3 with larger instruction shape is preferred.
if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FNUZ() ||
eltType.isF16() || eltType.isBF16() || eltType.isF32()) {
if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() ||
eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() ||
eltType.isF32()) {
validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176,
168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88,
80, 72, 64, 56, 48, 40, 32, 24, 16, 8});
Expand Down
4 changes: 1 addition & 3 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,10 +745,8 @@ void init_triton_ir(py::module &&m) {
return self.getBuilder().getI64Type();
})
.def("get_fp8e4nv_ty",
// TODO: fp8e4nv is using Float8E4M3FNUZType, which
// does not seem right. It should use FloatE4M3FNType
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getType<Float8E4M3FNUZType>();
return self.getBuilder().getType<Float8E4M3FNType>();
})
.def("get_fp8e4b8_ty",
[](TritonOpBuilder &self) -> Type {
Expand Down
8 changes: 4 additions & 4 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_fp8_to_f16_conversion
tt.func @test_fp8_to_f16_conversion(
%in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FNUZ, #blocked>,
%in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FN, #blocked>,
%in2: tensor<128xf16, #blocked>, %in3: tensor<128xf32, #blocked>) {
// CHECK-COUNT-2: cvt.rn.f16x2.e5m2x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16>
%out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked>
// CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16>
%out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked>
%out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FN, #blocked> -> tensor<128xf16, #blocked>
// CHECK-COUNT-2: mul.rn.bf16x2
%out2 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xbf16, #blocked>

// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
%out3 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked>
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
%out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
%out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FN, #blocked>

// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
%out5 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked>
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
%out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
%out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FN, #blocked>
tt.return
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ enum class TensorCoreType : uint8_t {
FP32_TF32_TF32_FP32,
FP16_FP16_FP16_FP16,
FP32_FP8E5M2_FP8E5M2_FP32,
FP32_FP8E5M2_FP8E4M3FNUZ_FP32,
FP32_FP8E4M3FNUZ_FP8E5M2_FP32,
FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32,
FP32_FP8E5M2_FP8E4M3FN_FP32,
FP32_FP8E4M3FN_FP8E5M2_FP32,
FP32_FP8E4M3FN_FP8E4M3FN_FP32,
// integer tensor core instr
INT32_INT1_INT1_INT32, // Not implemented
INT32_INT4_INT4_INT32, // Not implemented
Expand Down Expand Up @@ -112,9 +112,9 @@ Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) {
case TensorCoreType::FP16_FP16_FP16_FP16:
return fp16x2Pack2Ty;
case TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32:
case TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32:
case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32:
case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32:
case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32:
case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32:
case TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32:
return fp32x4Ty;
case TensorCoreType::INT32_INT8_INT8_INT32:
return i32x4Ty;
Expand All @@ -140,14 +140,14 @@ TensorCoreType getMmaType(triton::DotOp op) {
bTy.getElementType().isFloat8E5M2())
return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32;
if (aTy.getElementType().isFloat8E5M2() &&
bTy.getElementType().isFloat8E4M3FNUZ())
return TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32;
if (aTy.getElementType().isFloat8E4M3FNUZ() &&
bTy.getElementType().isFloat8E4M3FN())
return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32;
if (aTy.getElementType().isFloat8E4M3FN() &&
bTy.getElementType().isFloat8E5M2())
return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32;
if (aTy.getElementType().isFloat8E4M3FNUZ() &&
bTy.getElementType().isFloat8E4M3FNUZ())
return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32;
return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32;
if (aTy.getElementType().isFloat8E4M3FN() &&
bTy.getElementType().isFloat8E4M3FN())
return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32;
if (aTy.getElementType().isF32() && bTy.getElementType().isF32() &&
op.getInputPrecision() == InputPrecision::TF32)
return TensorCoreType::FP32_TF32_TF32_FP32;
Expand Down Expand Up @@ -193,11 +193,11 @@ inline static const std::map<TensorCoreType, std::string> mmaInstrPtxAmpere = {

{TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32,
"mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32"},
{TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32,
{TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32,
"mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32"},
{TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32,
{TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32,
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32"},
{TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32,
{TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32,
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32"},
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) {
return triton::nvgpu::WGMMAEltType::s8;
} else if (aTy.isFloat8E5M2()) {
return triton::nvgpu::WGMMAEltType::e5m2;
} else if (aTy.isFloat8E4M3FNUZ()) {
} else if (aTy.isFloat8E4M3FN()) {
return triton::nvgpu::WGMMAEltType::e4m3;
} else {
llvm::report_fatal_error("Unsupported mma operand type found");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ struct FpToFpOpConversion
std::pair<ConverterT, size_t>
getConversionFunc(Type srcTy, Type dstTy,
std::optional<RoundingMode> roundingMode) const {
auto F8E4M3TyID = TypeID::get<Float8E4M3FNUZType>();
auto F8E4M3TyID = TypeID::get<Float8E4M3FNType>();
auto F8E5M2TyID = TypeID::get<Float8E5M2Type>();
auto F16TyID = TypeID::get<Float16Type>();
auto BF16TyID = TypeID::get<BFloat16Type>();
Expand Down Expand Up @@ -445,7 +445,7 @@ struct FpToFpOpConversion
llvm::report_fatal_error("Unsupported rounding mode for conversion.");
}
if (computeCapability < 89 &&
(srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) {
(srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) {
llvm::errs() << "Conversion from/to f8e4m3nv is only supported on "
"compute capability >= 89"
<< "\n";
Expand All @@ -467,7 +467,7 @@ struct FpToFpOpConversion
auto dstElementType = getElementType(op.getResult());
auto roundingMode = op.getRounding();

if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FNUZ()) {
if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) {
assert(roundingMode.has_value() &&
"Rounding mode must be specified for convertsions to fp8");

Expand Down Expand Up @@ -504,7 +504,7 @@ struct FpToFpOpConversion

bool useFP16IntermediateSrc =
srcElementType.isF32() &&
(!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FNUZ() ||
(!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() ||
dstElementType.isFloat8E5M2())) ||
roundingMode.value() == RoundingMode::RTZ);
bool isDstFP32 = dstElementType.isF32();
Expand Down