diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td index 6ceb4bc47665..4c709cd4420b 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypes.td +++ b/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -15,7 +15,7 @@ class TritonTypeDef traits = []> } // Floating-point Type -def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; def TT_FloatTensor : RankedTensorOf<[TT_Float]>; def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 933f062d8191..7d4088b16f2b 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -426,6 +426,7 @@ bool supportMFMATypes(Type a, Type b) { return false; auto F8E5M2 = TypeID::get(); + auto F8E4M3FN = TypeID::get(); auto F8E4M3FNUZ = TypeID::get(); auto F8E5M2FNUZ = TypeID::get(); auto F16 = TypeID::get(); @@ -437,6 +438,7 @@ bool supportMFMATypes(Type a, Type b) { {F16, F16}, {BF16, BF16}, {F8E5M2, F8E5M2}, + {F8E4M3FN, F8E4M3FN}, {F8E4M3FNUZ, F8E4M3FNUZ}, {F8E4M3FNUZ, F8E5M2FNUZ}, {F8E5M2FNUZ, F8E4M3FNUZ}, @@ -495,14 +497,14 @@ bool supportMMA(triton::DotOp op, int version) { return false; if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && retShapePerCTA[rank - 1] % 8 == 0 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() || + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() || aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32()))) { return false; } // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. if (op.getMaxNumImpreciseAcc() < 32 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) && + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) && cast(op.getType()).getElementType().isF32()) { return false; } diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 53705c3b78b9..c0371cfe1d1b 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -34,6 +34,9 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional { return IntegerType::get(type.getContext(), 8); }); + addConversion([&](mlir::Float8E4M3FNType type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); addConversion([&](mlir::Float8E5M2Type type) -> std::optional { return IntegerType::get(type.getContext(), 8); }); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 39c043695bc6..6c15ce06979f 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -357,7 +357,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { NvidiaMmaEncodingAttr mmaLayout = dyn_cast(D.getType().getEncoding()); if (mmaLayout) { - bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ(); + bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN(); // promote operands for sm < 89 since fp8 mma is not natively supported // promote operands for sm >= 90 when mma is not v3 if (!isNativeFP8 || diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index c764ef1b6957..db980c5fcaf8 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -45,8 +45,9 @@ SmallVector mmaVersionToInstrShape(int version, SmallVector validN; // MMAv3 with larger instruction shape is preferred. - if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FNUZ() || - eltType.isF16() || eltType.isBF16() || eltType.isF32()) { + if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() || + eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() || + eltType.isF32()) { validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); diff --git a/python/src/ir.cc b/python/src/ir.cc index 46095dcc6653..34e4feb78c05 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -745,10 +745,8 @@ void init_triton_ir(py::module &&m) { return self.getBuilder().getI64Type(); }) .def("get_fp8e4nv_ty", - // TODO: fp8e4nv is using Float8E4M3FNUZType, which - // does not seem right. It should use FloatE4M3FNType [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getType(); + return self.getBuilder().getType(); }) .def("get_fp8e4b8_ty", [](TritonOpBuilder &self) -> Type { diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 7ecee2eba11b..3c16ea0260ef 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -129,24 +129,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: test_fp8_to_f16_conversion tt.func @test_fp8_to_f16_conversion( - %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FNUZ, #blocked>, + %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FN, #blocked>, %in2: tensor<128xf16, #blocked>, %in3: tensor<128xf32, #blocked>) { // CHECK-COUNT-2: cvt.rn.f16x2.e5m2x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> %out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked> // CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> - %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked> + %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FN, #blocked> -> tensor<128xf16, #blocked> // CHECK-COUNT-2: mul.rn.bf16x2 %out2 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xbf16, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> %out3 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> - %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> + %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FN, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> %out5 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> - %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> + %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FN, #blocked> tt.return } } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 2d16dc19b3b3..af897ef546dd 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -81,9 +81,9 @@ enum class TensorCoreType : uint8_t { FP32_TF32_TF32_FP32, FP16_FP16_FP16_FP16, FP32_FP8E5M2_FP8E5M2_FP32, - FP32_FP8E5M2_FP8E4M3FNUZ_FP32, - FP32_FP8E4M3FNUZ_FP8E5M2_FP32, - FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32, + FP32_FP8E5M2_FP8E4M3FN_FP32, + FP32_FP8E4M3FN_FP8E5M2_FP32, + FP32_FP8E4M3FN_FP8E4M3FN_FP32, // integer tensor core instr INT32_INT1_INT1_INT32, // Not implemented INT32_INT4_INT4_INT32, // Not implemented @@ -112,9 +112,9 @@ Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) { case TensorCoreType::FP16_FP16_FP16_FP16: return fp16x2Pack2Ty; case TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32: - case TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32: - case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32: - case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32: + case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32: + case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32: + case TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32: return fp32x4Ty; case TensorCoreType::INT32_INT8_INT8_INT32: return i32x4Ty; @@ -140,14 +140,14 @@ TensorCoreType getMmaType(triton::DotOp op) { bTy.getElementType().isFloat8E5M2()) return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32; if (aTy.getElementType().isFloat8E5M2() && - bTy.getElementType().isFloat8E4M3FNUZ()) - return TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32; - if (aTy.getElementType().isFloat8E4M3FNUZ() && + bTy.getElementType().isFloat8E4M3FN()) + return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32; + if (aTy.getElementType().isFloat8E4M3FN() && bTy.getElementType().isFloat8E5M2()) - return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32; - if (aTy.getElementType().isFloat8E4M3FNUZ() && - bTy.getElementType().isFloat8E4M3FNUZ()) - return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32; + return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32; + if (aTy.getElementType().isFloat8E4M3FN() && + bTy.getElementType().isFloat8E4M3FN()) + return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32; if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && op.getInputPrecision() == InputPrecision::TF32) return TensorCoreType::FP32_TF32_TF32_FP32; @@ -193,11 +193,11 @@ inline static const std::map mmaInstrPtxAmpere = { {TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32, "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32"}, - {TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32, + {TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32, "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32"}, - {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32, + {TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32, "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32"}, - {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32, + {TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32, "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32"}, }; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index baed96a29704..41e36503f593 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -58,7 +58,7 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { return triton::nvgpu::WGMMAEltType::s8; } else if (aTy.isFloat8E5M2()) { return triton::nvgpu::WGMMAEltType::e5m2; - } else if (aTy.isFloat8E4M3FNUZ()) { + } else if (aTy.isFloat8E4M3FN()) { return triton::nvgpu::WGMMAEltType::e4m3; } else { llvm::report_fatal_error("Unsupported mma operand type found"); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index 91e9a4bbf888..0b663a875422 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -401,7 +401,7 @@ struct FpToFpOpConversion std::pair getConversionFunc(Type srcTy, Type dstTy, std::optional roundingMode) const { - auto F8E4M3TyID = TypeID::get(); + auto F8E4M3TyID = TypeID::get(); auto F8E5M2TyID = TypeID::get(); auto F16TyID = TypeID::get(); auto BF16TyID = TypeID::get(); @@ -445,7 +445,7 @@ struct FpToFpOpConversion llvm::report_fatal_error("Unsupported rounding mode for conversion."); } if (computeCapability < 89 && - (srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) { + (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) { llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " "compute capability >= 89" << "\n"; @@ -467,7 +467,7 @@ struct FpToFpOpConversion auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); - if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FNUZ()) { + if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) { assert(roundingMode.has_value() && "Rounding mode must be specified for convertsions to fp8"); @@ -504,7 +504,7 @@ struct FpToFpOpConversion bool useFP16IntermediateSrc = srcElementType.isF32() && - (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FNUZ() || + (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() || dstElementType.isFloat8E5M2())) || roundingMode.value() == RoundingMode::RTZ); bool isDstFP32 = dstElementType.isF32();