diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index ecd0fb3b94b6..415091a1000e 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -59,6 +59,7 @@ Linear Algebra Ops :nosignatures: dot + dot_scaled Memory/Pointer Ops diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index f3159338bd0a..04e4c25fd6d8 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -119,15 +119,16 @@ def TT_InputPrecisionAttr : I32EnumAttr< let cppNamespace = "::mlir::triton"; } -// Type for F8F6F4 kind of floats. -def TT_F8F6F4TypeAttr : I32EnumAttr< - "F8F6F4Type", "", +// Type for ScaleDotElemType kind of floats. +def TT_ScaleDotElemTypeAttr : I32EnumAttr< + "ScaleDotElemType", "", [ I32EnumAttrCase<"E4M3", 0, "e4m3">, I32EnumAttrCase<"E5M2", 1, "e5m2">, I32EnumAttrCase<"E2M3", 2, "e2m3">, I32EnumAttrCase<"E3M2", 3, "e3m2">, - I32EnumAttrCase<"E2M1", 4, "e2m1"> + I32EnumAttrCase<"E2M1", 4, "e2m1">, + I32EnumAttrCase<"BF16", 5, "bf16"> ]>{ let cppNamespace = "::mlir::triton"; diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index d3bb95ca959c..2c3a1bf71442 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -685,15 +685,15 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, let arguments = ( ins - // inputs are integer types as they are packed types and we currently - // don't have a representation for those. - TT_IntTensor:$lhs, - TT_IntTensor:$rhs, + // inputs are floats if we have a type for them, otherwise (fp4), + // they are packed in pairs in an I8Tensor + RankedTensorOf<[TT_Float,I8]>:$lhs, + RankedTensorOf<[TT_Float,I8]>:$rhs, TT_FloatTensor:$c, - TT_IntTensor:$lhs_scale, - Optional:$rhs_scale, - TT_F8F6F4TypeAttr:$lhs_type, - TT_F8F6F4TypeAttr:$rhs_type + RankedTensorOf<[I8]>:$lhs_scale, + Optional>:$rhs_scale, + TT_ScaleDotElemTypeAttr:$lhs_type, + TT_ScaleDotElemTypeAttr:$rhs_type ); let results = (outs TT_FloatTensor:$d); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index a290cb20310a..6299ee6ed43d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -268,7 +268,7 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods(encoding); auto newVEncoding = DotOperandEncodingAttr::get( ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(), diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index a2d4012bf23e..3ddab364d7c6 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -415,22 +415,12 @@ class ScaledBlockedToMMAv2 auto aType = dotOp.getLhsType(); auto bType = dotOp.getRhsType(); - auto enumToType = [&rewriter](F8F6F4Type type) { - switch (type) { - case F8F6F4Type::E4M3: - return rewriter.getFloat8E4M3FNType(); - case F8F6F4Type::E5M2: - return rewriter.getFloat8E5M2Type(); - default: - llvm_unreachable("unexpected type"); - } - }; - - assert((aType == F8F6F4Type::E4M3 || aType == F8F6F4Type::E5M2 || - aType == F8F6F4Type::E2M1) && + assert((aType == ScaleDotElemType::E4M3 || + aType == ScaleDotElemType::E5M2 || + aType == ScaleDotElemType::E2M1) && "NYI: lhs supports fp4 or fp8"); - assert(bType == F8F6F4Type::E4M3 || - bType == F8F6F4Type::E5M2 && "NYI: rhs supports fp8"); + assert(bType == ScaleDotElemType::E4M3 || bType == ScaleDotElemType::E5M2 || + bType == ScaleDotElemType::BF16 && "NYI: rhs supports fp8 and bf16"); // TODO run accelerate matmul on A and B first to choose their layouts // Set return type @@ -454,11 +444,12 @@ class ScaledBlockedToMMAv2 auto newAcc = rewriter.create(oldAcc.getLoc(), newRetType, oldAcc); - auto toMMABf16 = [&newRetType, &rewriter, &ctx, &enumToType]( - TypedValue v, int idx, - F8F6F4Type type) -> TypedValue { + auto toMMABf16 = + [&newRetType, &rewriter, + &ctx](TypedValue v, int idx, + ScaleDotElemType type) -> TypedValue { auto vType = v.getType(); - if (type == F8F6F4Type::E2M1) { + if (type == ScaleDotElemType::E2M1) { // A bit too dynamically typed... // perhaps return ints in both cases? @@ -469,23 +460,23 @@ class ScaledBlockedToMMAv2 vType.getShape(), vType.getElementType(), newVEncoding); return rewriter.create(v.getLoc(), newVType, v); } else { - assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3); + assert(type == ScaleDotElemType::E5M2 || + type == ScaleDotElemType::E4M3 || + type == ScaleDotElemType::BF16); auto newVEncoding = DotOperandEncodingAttr::get( ctx, idx, newRetType.getEncoding(), /*kWidth=*/8); auto newVType = RankedTensorType::get( vType.getShape(), vType.getElementType(), newVEncoding); v = rewriter.create(v.getLoc(), newVType, v); - // Bitcast - auto vTypeFp8 = RankedTensorType::get(vType.getShape(), - enumToType(type), newVEncoding); - v = cast>( - rewriter.create(v.getLoc(), vTypeFp8, v).getResult()); - - // Convert to bf16 - auto vTypeBf16 = RankedTensorType::get( - vType.getShape(), rewriter.getBF16Type(), newVEncoding); - return rewriter.create(v.getLoc(), vTypeBf16, v); + if (type == ScaleDotElemType::BF16) { + return v; + } else { + // Convert to bf16 + auto vTypeBf16 = RankedTensorType::get( + vType.getShape(), rewriter.getBF16Type(), newVEncoding); + return rewriter.create(v.getLoc(), vTypeBf16, v); + } } }; a = toMMABf16(a, 0, aType); @@ -515,11 +506,11 @@ class ScaledBlockedToMMAv2 auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( ctx, {1, 1}, threadsPerWarp, warpsPerCTA, {1, 0}, CTALayout); - auto newScaleType = RankedTensorType::get(scale.getType().getShape(), - scale.getType().getElementType(), - newScaleEncoding); - scale = - rewriter.create(scale.getLoc(), newScaleType, scale); + auto newScaleDotElemType = RankedTensorType::get( + scale.getType().getShape(), scale.getType().getElementType(), + newScaleEncoding); + scale = rewriter.create(scale.getLoc(), + newScaleDotElemType, scale); auto scaledA = rewriter.create( dotOp.getLoc(), a, scale, dotOp.getLhsType()); diff --git a/python/src/ir.cc b/python/src/ir.cc index 9945c6188294..cce7c87e8d87 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -205,12 +205,13 @@ void init_triton_ir(py::module &&m) { .value("IEEE", InputPrecision::IEEE) .export_values(); - py::enum_(m, "F8F6F4TY", py::module_local()) - .value("E4M3", F8F6F4Type::E4M3) - .value("E5M2", F8F6F4Type::E5M2) - .value("E2M3", F8F6F4Type::E2M3) - .value("E3M2", F8F6F4Type::E3M2) - .value("E2M1", F8F6F4Type::E2M1) + py::enum_(m, "ScaleDotElemTypeTY", py::module_local()) + .value("E4M3", ScaleDotElemType::E4M3) + .value("E5M2", ScaleDotElemType::E5M2) + .value("E2M3", ScaleDotElemType::E2M3) + .value("E3M2", ScaleDotElemType::E3M2) + .value("E2M1", ScaleDotElemType::E2M1) + .value("BF16", ScaleDotElemType::BF16) .export_values(); py::class_(m, "context", py::module_local()) @@ -1423,9 +1424,9 @@ void init_triton_ir(py::module &&m) { }) .def("create_dot_scaled", [](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &lhs_scale, - F8F6F4Type lhs_format, mlir::Value &rhs, - std::optional &rhs_scale, F8F6F4Type rhs_format, - mlir::Value &c) -> mlir::Value { + ScaleDotElemType lhs_format, mlir::Value &rhs, + std::optional &rhs_scale, + ScaleDotElemType rhs_format, mlir::Value &c) -> mlir::Value { return self.create( c.getType(), lhs, rhs, c, lhs_scale, rhs_scale.value_or(Value()), lhs_format, rhs_format); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 1cebd2577969..703768b2b942 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3327,7 +3327,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) for col_a, col_b in itertools.product([True, False], repeat=2) for type_a in ["e2m1", "e4m3", "e5m2"] - for type_b in ["e4m3", "e5m2"] + for type_b in ["e4m3", "e5m2", "bf16"] for mma in ([32, 16] if is_hip() else [16]) for kpack in ([1, 2] if is_hip() else [1])]) def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack, device): @@ -3345,7 +3345,7 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr, type_b: tl.constexpr): - tl.static_assert(type_b == "e4m3" or type_b == "e5m2", "type_b must be fp8") + tl.static_assert((type_b == "e4m3" or type_b == "e5m2") or type_b == "bf16", "type_b must be fp8 or bf16") IS_FP8: tl.constexpr = type_a == "e4m3" or type_a == "e5m2" DIV_FACTOR: tl.constexpr = 1 if IS_FP8 else 2 PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR @@ -3436,7 +3436,7 @@ def mxfp_to_bf16_kernel( def dot_scale_ref(x, scale, y, type_x, type_y): e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x] - type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y] + type_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2, "bf16": torch.bfloat16}[type_y] comp_dtype = torch.bfloat16 @@ -3449,7 +3449,7 @@ def dot_scale_ref(x, scale, y, type_x, type_y): mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=num_warps) assert x_upcast.isfinite().all() - y_upcast = y.view(type_fp8_y).to(comp_dtype) + y_upcast = y.view(type_y).to(comp_dtype) class AccumulateInFp32: @@ -3461,28 +3461,30 @@ def __exit__(self, exc_type, exc_val, exc_tb): torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value with AccumulateInFp32(): - return torch.matmul(x_upcast.to(comp_dtype), y_upcast.to(comp_dtype)) + return torch.matmul(x_upcast, y_upcast) torch.manual_seed(0) - def create_uint8(shape, col_major=False, max_val=255): + def make_arg(shape, ty, col_major=False, max_val=255): if col_major: shape = shape[:-2] + (shape[-1], shape[-2]) - ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) + if ty == "bf16": + ret = torch.randn(shape, dtype=torch.bfloat16, device=device) + # Clamp to avoid relative error issues + ret.clamp_(-2**15, 2**15 - 1) + else: + ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) if col_major: ret = ret.mT return ret DIV_FACTOR = 2 if type_a == "e2m1" else 1 - x = create_uint8((M, K // DIV_FACTOR), col_major=col_a) - y = create_uint8((K, N), col_major=col_b) + x = make_arg((M, K // DIV_FACTOR), type_a, col_major=col_a) + y = make_arg((K, N), type_b, col_major=col_b) # sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright) - # We substract a reasonably high number (64) so that the sum of all the mxfp elements does not overflow - m_bytes = int(type_a[1]) - bias_type_a = 1 << (m_bytes - 1) - 1 - max_exponent_type_a = (1 << m_bytes) - 1 - bias_type_a - scale_x = create_uint8((M, K // 32), max_val=255 - max_exponent_type_a - 64) + # Max scale= 2**15 + scale_x = make_arg((M, K // 32), "e8m0", max_val=127 + 15) def make_finite(x, dtype): # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and @@ -3507,7 +3509,6 @@ def make_finite(x, dtype): z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b) - # generous rtol as we are sampling the whole range of floats torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2) # make sure ld/st are vectorized diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e2c57b388bb0..856b537c5103 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1555,15 +1555,17 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, lhs and rhs use microscaling formats described here: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf :param lhs: The first tensor to be multiplied. - :type lhs: 2D tensor of f8, f6 or f4 format packed in int32 format. + :type lhs: 2D tensor representing fp4 or fp8 elements packed into uint8 for fp4 inputs, or in uint8 or the corresponding fp8 type for fp8 inputs. :param lhs_scale: Scale factor for lhs tensor. - :type lhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). - :param lhs_format: format of the lhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. + :type lhs_scale: e8m0 type represented as an uint8 tensor. + :param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code: `e5m2`}. + :type lhs_format: str :param rhs: The second tensor to be multiplied. - :type rhs: 2D tensor of f8, f6 or f4 format packed in int32 format. + :type rhs: 2D tensor representing fp8 or bf16 elements in uint8 or the corresponding fp8 type for fp8 inputs or bf16 for bf16 inputs. :param rhs_scale: Scale factor for rhs tensor. - :type rhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). - :param rhs_format: format of the rhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. + :type rhs_scale: e8m0 type represented as an uint8 tensor. + :param rhs_format: format of the rhs tensor. Available formats: {:code:`e4m3`, :code: `e5m2`, :code:`bf16`}. + :type rhs_format: str :param acc: The accumulator tensor. If not None, the result is added to this tensor. """ out_dtype = _constexpr_to_value(out_dtype) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index be157c5b4609..a9af8c8d808b 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1527,33 +1527,48 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona ret_ty) -def _str_to_fp_type(float_format: Optional[str]): - if float_format == 'e4m3': - return ir.F8F6F4TY.E4M3 - if float_format == 'e5m2': - return ir.F8F6F4TY.E5M2 - if float_format == 'e2m3': - return ir.F8F6F4TY.E2M3 - if float_format == 'e3m2': - return ir.F8F6F4TY.E3M2 - if float_format == 'e2m1': - return ir.F8F6F4TY.E2M1 - raise ValueError(f"Invalid float format: {float_format}.") - - -def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, rhs_scale: Optional[tl.tensor], - rhs_format, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: +def _str_to_fp_type(float_format: str): + ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None) + if ty_enum is None: + raise ValueError(f"Invalid float format: {float_format}.") + return ty_enum + + +def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder): + """ + If float_format is subbyte, make sure it's packed as uint8 and return it. + Otherwise, return a tensor (perhaps bitcasting) of the specified float format. + """ + triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16}.get(float_format) + if triton_ty is None: + assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}" + assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}" + return val + if val.dtype == triton_ty: + return val + else: + unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16}[float_format] + assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}" + return bitcast(val, triton_ty, builder) + + +def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor], + rhs_format: str, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() #TODO: validate types. lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + lhs_format: str = lhs_format.value + rhs_format: str = rhs_format.value lhs_format_enum = _str_to_fp_type(lhs_format) rhs_format_enum = _str_to_fp_type(rhs_format) assert lhs_format in ("e2m1", "e4m3", "e5m2"), f"NYI: lhs_format {lhs_format}" - assert rhs_format in ("e4m3", "e5m2"), f"NYI: rhs_format {rhs_format}" + assert rhs_format in ("e4m3", "e5m2", "bf16"), f"NYI: rhs_format {rhs_format}" rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None assert rhs_scale_is_none, "NYI: rhs_scale not supported" + lhs = _bitcast_to_fp_type(lhs, lhs_format, builder) + rhs = _bitcast_to_fp_type(rhs, rhs_format, builder) M = lhs.type.shape[-2] K, N = rhs.type.shape[-2:] diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 85b37f3ed3a9..420a9d5c2cbf 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -164,21 +164,21 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -// Verify that dot_scaled (mxfp8 x fp8) decomposes as expected +// Verify that dot_scaled (mxfp4 x bf16) decomposes as expected #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: dot_scaled tt.func @dot_scaled( - %a: tensor<128x64xi8, #blocked2>, + %a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, - %b: tensor<64x128xi8, #blocked>) + %b: tensor<64x128xbf16, #blocked>) -> tensor<128x128xf32, #blocked> { // CHECK: triton_gpu.upcast_mxfp // CHECK: tt.dot %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - %result = tt.dot_scaled %a, %scale, %b, %cst lhs = e4m3 rhs = e4m3 : tensor<128x64xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xi8, #blocked> -> tensor<128x128xf32, #blocked> + %result = tt.dot_scaled %a, %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> tt.return %result : tensor<128x128xf32, #blocked> } } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp index 289ceb61a51b..b35e28272c05 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -43,7 +43,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto fpType = op.getFpType(); - if (!(fpType == F8F6F4Type::E4M3 || fpType == F8F6F4Type::E5M2)) + if (!(fpType == ScaleDotElemType::E4M3 || fpType == ScaleDotElemType::E5M2)) return rewriter.notifyMatchFailure(op, "NYI: non-mxfp8 cases"); Location loc = op.getLoc(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 3aa009c3639a..201a7b0212fe 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -504,12 +504,14 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { TensorValue a = dotOp.getLhs(); TensorValue b = dotOp.getRhs(); TensorValue aScale = dotOp.getLhsScale(); - F8F6F4Type aElemType = dotOp.getLhsType(); - F8F6F4Type bElemType = dotOp.getRhsType(); + ScaleDotElemType aElemType = dotOp.getLhsType(); + ScaleDotElemType bElemType = dotOp.getRhsType(); - if (!(aElemType == F8F6F4Type::E4M3 || aElemType == F8F6F4Type::E5M2)) + if (!(aElemType == ScaleDotElemType::E4M3 || + aElemType == ScaleDotElemType::E5M2)) return rewriter.notifyMatchFailure(dotOp, "NYI: non-mxfp8 LHS"); - if (!(bElemType == F8F6F4Type::E4M3 || bElemType == F8F6F4Type::E5M2)) + if (!(bElemType == ScaleDotElemType::E4M3 || + bElemType == ScaleDotElemType::E5M2)) return rewriter.notifyMatchFailure(dotOp, "NYI: non-fp8 RHS"); MLIRContext *ctx = dotOp.getContext(); @@ -553,11 +555,11 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { // OCP mxfp8 requires implementations to follow OCP fp8 elements. We are // doing software emulation using bf16 here, so we map to OCP fp8 f8E4M3FN // and f8E5M2. - auto enumToType = [&rewriter](F8F6F4Type type) { + auto enumToType = [&rewriter](ScaleDotElemType type) { switch (type) { - case F8F6F4Type::E4M3: + case ScaleDotElemType::E4M3: return rewriter.getFloat8E4M3FNType(); - case F8F6F4Type::E5M2: + case ScaleDotElemType::E5M2: return rewriter.getFloat8E5M2Type(); default: llvm_unreachable("unexpected fp type"); @@ -565,8 +567,8 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { }; auto toMMABf16 = [&](TensorValue v, int idx, - F8F6F4Type type) -> TensorValue { - assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3); + ScaleDotElemType type) -> TensorValue { + assert(type == ScaleDotElemType::E5M2 || type == ScaleDotElemType::E4M3); auto vType = v.getType(); auto newVEnc = DotOperandEncodingAttr::get( ctx, idx, newRetType.getEncoding(), kWdith); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index 722bf56cd015..136b69613216 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -103,7 +103,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { Value warpId = udiv(tid, warpSize); Value laneId = urem(tid, warpSize); - if (fpType == F8F6F4Type::E2M1) { + if (fpType == ScaleDotElemType::E2M1) { xVals = unpackFP4Elements(loc, rewriter, xVals, laneId); }