Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 272 additions & 0 deletions third_party/triton/temporary/fp8_fix.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
This patch can be removed as part of the next integrate.
The corresponding import patch has already been added.

==== triton/include/triton/Dialect/Triton/IR/TritonTypes.td#13 - triton/include/triton/Dialect/Triton/IR/TritonTypes.td ====
# action=edit type=text
--- triton/include/triton/Dialect/Triton/IR/TritonTypes.td 2024-06-07 05:28:31.000000000 -0700
+++ triton/include/triton/Dialect/Triton/IR/TritonTypes.td 2024-08-20 06:34:55.000000000 -0700
@@ -15,7 +15,7 @@
}

// Floating-point Type
-def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
+def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : RankedTensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;

==== triton/lib/Analysis/Utility.cpp#42 - triton/lib/Analysis/Utility.cpp ====
# action=edit type=text
--- triton/lib/Analysis/Utility.cpp 2024-08-14 09:36:23.000000000 -0700
+++ triton/lib/Analysis/Utility.cpp 2024-08-20 06:34:55.000000000 -0700
@@ -425,6 +425,7 @@
if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth())
return false;

+ auto F8E4M3FN = TypeID::get<Float8E4M3FNType>();
auto F8E4M3FNUZ = TypeID::get<Float8E4M3FNUZType>();
auto F8E5M2FNUZ = TypeID::get<Float8E5M2FNUZType>();
auto F16 = TypeID::get<Float16Type>();
@@ -435,6 +436,7 @@
{F32, F32},
{F16, F16},
{BF16, BF16},
+ {F8E4M3FN, F8E4M3FN},
{F8E4M3FNUZ, F8E4M3FNUZ},
{F8E4M3FNUZ, F8E5M2FNUZ},
{F8E5M2FNUZ, F8E4M3FNUZ},
@@ -493,14 +495,14 @@
return false;
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
retShapePerCTA[rank - 1] % 8 == 0 &&
- (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() ||
+ (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
aElemTy.isF32()))) {
return false;
}
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
if (op.getMaxNumImpreciseAcc() < 32 &&
- (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) &&
+ (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
return false;
}
==== triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp#20 - triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp ====
# action=edit type=text
--- triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp 2024-06-07 05:28:31.000000000 -0700
+++ triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp 2024-08-20 06:34:55.000000000 -0700
@@ -34,6 +34,9 @@
addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
+ addConversion([&](mlir::Float8E4M3FNType type) -> std::optional<Type> {
+ return IntegerType::get(type.getContext(), 8);
+ });
addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
==== triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp#44 - triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp ====
# action=edit type=text
--- triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp 2024-07-31 01:05:00.000000000 -0700
+++ triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp 2024-08-20 06:40:32.000000000 -0700
@@ -382,7 +382,7 @@
NvidiaMmaEncodingAttr mmaLayout =
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
if (mmaLayout) {
- bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ();
+ bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
// promote operands for sm < 89 since fp8 mma is not natively supported
// promote operands for sm >= 90 when mma is not v3
if (!isNativeFP8 ||
==== triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp#39 - triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp ====
# action=edit type=text
--- triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2024-08-14 09:36:23.000000000 -0700
+++ triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2024-08-20 06:34:55.000000000 -0700
@@ -45,8 +45,9 @@
SmallVector<unsigned> validN;

// MMAv3 with larger instruction shape is preferred.
- if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FNUZ() ||
- eltType.isF16() || eltType.isBF16() || eltType.isF32()) {
+ if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() ||
+ eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() ||
+ eltType.isF32()) {
validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176,
168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88,
80, 72, 64, 56, 48, 40, 32, 24, 16, 8});
==== triton/patches/public/fp8_fix.patch#None - triton/patches/public/fp8_fix.patch ====
# action=add type=text
--- /dev/null 1969-12-31 16:00:00.000000000 -0800
+++ triton/patches/public/fp8_fix.patch 2024-08-21 01:51:13.000000000 -0700
@@ -0,0 +1,2 @@
+triton/patches/public/fp8_fix.patch#1 - opened for add
+triton/patches/public/fp8_fix.patch - empty, assuming text.
==== triton/python/src/ir.cc#24 - triton/python/src/ir.cc ====
# action=edit type=text
--- triton/python/src/ir.cc 2024-08-12 00:24:31.000000000 -0700
+++ triton/python/src/ir.cc 2024-08-21 01:46:02.000000000 -0700
@@ -745,10 +745,8 @@
return self.getBuilder().getI64Type();
})
.def("get_fp8e4nv_ty",
- // TODO: fp8e4nv is using Float8E4M3FNUZType, which
- // does not seem right. It should use FloatE4M3FNType
[](TritonOpBuilder &self) -> Type {
- return self.getBuilder().getType<Float8E4M3FNUZType>();
+ return self.getBuilder().getType<Float8E4M3FNType>();
})
.def("get_fp8e4b8_ty",
[](TritonOpBuilder &self) -> Type {
==== triton/test/Conversion/tritongpu_to_llvm_hopper.mlir#25 - triton/test/Conversion/tritongpu_to_llvm_hopper.mlir ====
# action=edit type=text
--- triton/test/Conversion/tritongpu_to_llvm_hopper.mlir 2024-07-03 07:14:55.000000000 -0700
+++ triton/test/Conversion/tritongpu_to_llvm_hopper.mlir 2024-08-20 06:34:55.000000000 -0700
@@ -129,24 +129,24 @@
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_fp8_to_f16_conversion
tt.func @test_fp8_to_f16_conversion(
- %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FNUZ, #blocked>,
+ %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FN, #blocked>,
%in2: tensor<128xf16, #blocked>, %in3: tensor<128xf32, #blocked>) {
// CHECK-COUNT-2: cvt.rn.f16x2.e5m2x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16>
%out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked>
// CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16>
- %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked>
+ %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FN, #blocked> -> tensor<128xf16, #blocked>
// CHECK-COUNT-2: mul.rn.bf16x2
%out2 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xbf16, #blocked>

// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
%out3 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked>
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8>
- %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
+ %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FN, #blocked>

// CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
%out5 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked>
// CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8>
- %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked>
+ %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FN, #blocked>
tt.return
}
}
==== triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp#4 - triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp ====
# action=edit type=text
--- triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp 2024-05-14 06:33:36.000000000 -0700
+++ triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp 2024-08-20 06:34:55.000000000 -0700
@@ -81,9 +81,9 @@
FP32_TF32_TF32_FP32,
FP16_FP16_FP16_FP16,
FP32_FP8E5M2_FP8E5M2_FP32,
- FP32_FP8E5M2_FP8E4M3FNUZ_FP32,
- FP32_FP8E4M3FNUZ_FP8E5M2_FP32,
- FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32,
+ FP32_FP8E5M2_FP8E4M3FN_FP32,
+ FP32_FP8E4M3FN_FP8E5M2_FP32,
+ FP32_FP8E4M3FN_FP8E4M3FN_FP32,
// integer tensor core instr
INT32_INT1_INT1_INT32, // Not implemented
INT32_INT4_INT4_INT32, // Not implemented
@@ -112,9 +112,9 @@
case TensorCoreType::FP16_FP16_FP16_FP16:
return fp16x2Pack2Ty;
case TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32:
- case TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32:
- case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32:
- case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32:
+ case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32:
+ case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32:
+ case TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32:
return fp32x4Ty;
case TensorCoreType::INT32_INT8_INT8_INT32:
return i32x4Ty;
@@ -140,14 +140,14 @@
bTy.getElementType().isFloat8E5M2())
return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32;
if (aTy.getElementType().isFloat8E5M2() &&
- bTy.getElementType().isFloat8E4M3FNUZ())
- return TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32;
- if (aTy.getElementType().isFloat8E4M3FNUZ() &&
+ bTy.getElementType().isFloat8E4M3FN())
+ return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32;
+ if (aTy.getElementType().isFloat8E4M3FN() &&
bTy.getElementType().isFloat8E5M2())
- return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32;
- if (aTy.getElementType().isFloat8E4M3FNUZ() &&
- bTy.getElementType().isFloat8E4M3FNUZ())
- return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32;
+ return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32;
+ if (aTy.getElementType().isFloat8E4M3FN() &&
+ bTy.getElementType().isFloat8E4M3FN())
+ return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32;
if (aTy.getElementType().isF32() && bTy.getElementType().isF32() &&
op.getInputPrecision() == InputPrecision::TF32)
return TensorCoreType::FP32_TF32_TF32_FP32;
@@ -193,11 +193,11 @@

{TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32,
"mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32"},
- {TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32,
+ {TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32,
"mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32"},
- {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32,
+ {TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32,
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32"},
- {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32,
+ {TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32,
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32"},
};

==== triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp#9 - triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp ====
# action=edit type=text
--- triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2024-06-07 05:28:31.000000000 -0700
+++ triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2024-08-20 06:34:55.000000000 -0700
@@ -58,7 +58,7 @@
return triton::nvgpu::WGMMAEltType::s8;
} else if (aTy.isFloat8E5M2()) {
return triton::nvgpu::WGMMAEltType::e5m2;
- } else if (aTy.isFloat8E4M3FNUZ()) {
+ } else if (aTy.isFloat8E4M3FN()) {
return triton::nvgpu::WGMMAEltType::e4m3;
} else {
llvm::report_fatal_error("Unsupported mma operand type found");
==== triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp#9 - triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp ====
# action=edit type=text
--- triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp 2024-07-17 02:05:59.000000000 -0700
+++ triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp 2024-08-20 06:34:55.000000000 -0700
@@ -386,7 +386,7 @@
std::pair<ConverterT, size_t>
getConversionFunc(Type srcTy, Type dstTy,
std::optional<RoundingMode> roundingMode) const {
- auto F8E4M3TyID = TypeID::get<Float8E4M3FNUZType>();
+ auto F8E4M3TyID = TypeID::get<Float8E4M3FNType>();
auto F8E5M2TyID = TypeID::get<Float8E5M2Type>();
auto F16TyID = TypeID::get<Float16Type>();
auto BF16TyID = TypeID::get<BFloat16Type>();
@@ -430,7 +430,7 @@
llvm::report_fatal_error("Unsupported rounding mode for conversion.");
}
if (computeCapability < 89 &&
- (srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) {
+ (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) {
llvm::errs() << "Conversion from/to f8e4m3nv is only supported on "
"compute capability >= 89"
<< "\n";
@@ -452,7 +452,7 @@
auto dstElementType = getElementType(op.getResult());
auto roundingMode = op.getRounding();

- if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FNUZ()) {
+ if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) {
assert(roundingMode.has_value() &&
"Rounding mode must be specified for convertsions to fp8");

@@ -489,7 +489,7 @@

bool useFP16IntermediateSrc =
srcElementType.isF32() &&
- (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FNUZ() ||
+ (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() ||
dstElementType.isFloat8E5M2())) ||
roundingMode.value() == RoundingMode::RTZ);
bool isDstFP32 = dstElementType.isF32();
1 change: 1 addition & 0 deletions third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ those to this list.

temporary_patch_list = [
"//third_party/triton:temporary/highestPowOf2Divisor-underflow-fix.patch",
"//third_party/triton:temporary/fp8_fix.patch",
# Add new patches just above this line
]
5 changes: 1 addition & 4 deletions xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,7 @@ absl::StatusOr<Type> TritonType(mlir::OpBuilder b, PrimitiveType t) {
case F8E5M2:
return b.getFloat8E5M2Type();
case F8E4M3FN:
// TODO(b/345700241) Note that we return UZ type as Triton mistakenly uses
// this type for F8E4M3FN. The mapping must be changed when it's fixed in
// Triton.
return b.getFloat8E4M3FNUZType();
return b.getFloat8E4M3FNType();
default:
return absl::UnimplementedError(
absl::StrCat("This type is not supported yet: ",
Expand Down
8 changes: 4 additions & 4 deletions xla/service/gpu/tests/fp8_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ module attributes {

// CHECK-LABEL: e4m3_mapping
tt.func @e4m3_mapping(
%arg0: tensor<16x256xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>,
%arg1: tensor<256x16xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%arg0: tensor<16x256xf8E4M3FN, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>,
%arg1: tensor<256x16xf8E4M3FN, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
) {
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
// CHECK: mma.{{.*}}.e4m3.e4m3.f32
%res = tt.dot %arg0, %arg1, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32}
: tensor<16x256xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> *
tensor<256x16xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
: tensor<16x256xf8E4M3FN, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> *
tensor<256x16xf8E4M3FN, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
-> tensor<16x16xf32, #mma>
tt.return
}
Expand Down