diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp index 813a854f2fc97..98185697e4591 100644 --- a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp +++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp @@ -46,73 +46,6 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, {i32Type, i64Type, i64Type}, symbolTables); } -/// Given two operands of vector type and vector result type (with the same -/// shape), call the given function for each pair of scalar operands and -/// package the result into a vector. If the given operands and result type are -/// not vectors, call the function directly. The second operand is optional. -template -static Value forEachScalarValue(RewriterBase &rewriter, Location loc, - Value operand1, Value operand2, Type resultType, - Fn fn) { - auto vecTy1 = dyn_cast(operand1.getType()); - if (operand2) { - // Sanity check: Operand types must match. - assert(vecTy1 == dyn_cast(operand2.getType()) && - "expected same vector types"); - } - if (!vecTy1) { - // Not a vector. Call the function directly. - return fn(operand1, operand2, resultType); - } - - // Prepare scalar operands. - ResultRange sclars1 = - vector::ToElementsOp::create(rewriter, loc, operand1)->getResults(); - SmallVector scalars2; - if (!operand2) { - // No second operand. Create a vector of empty values. - scalars2.assign(vecTy1.getNumElements(), Value()); - } else { - llvm::append_range( - scalars2, - vector::ToElementsOp::create(rewriter, loc, operand2)->getResults()); - } - - // Call the function for each pair of scalar operands. - auto resultVecType = cast(resultType); - SmallVector results; - for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) { - Value result = fn(scalar1, scalar2, resultVecType.getElementType()); - results.push_back(result); - } - - // Package the results into a vector. - return vector::FromElementsOp::create( - rewriter, loc, - vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()), - results); -} - -/// Check preconditions for the conversion: -/// 1. All operands / results must be integers or floats (or vectors thereof). -/// 2. The bitwidth of the operands / results must be <= 64. -static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) { - for (Value value : llvm::concat(op->getOperands(), op->getResults())) { - Type type = value.getType(); - if (auto vecTy = dyn_cast(type)) { - type = vecTy.getElementType(); - } - if (!type.isIntOrFloat()) { - return rewriter.notifyMatchFailure( - op, "only integers and floats (or vectors thereof) are supported"); - } - if (type.getIntOrFloatBitWidth() > 64) - return rewriter.notifyMatchFailure(op, - "bitwidth > 64 bits is not supported"); - } - return success(); -} - /// Rewrite a binary arithmetic operation to an APFloat function call. template struct BinaryArithOpToAPFloatConversion final : OpRewritePattern { diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp index 784028f5cf2eb..af4a42aa308b3 100644 --- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp +++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp @@ -33,16 +33,8 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern { LogicalResult matchAndRewrite(math::AbsFOp op, PatternRewriter &rewriter) const override { - // Cast operands to 64-bit integers. - auto operand = op.getOperand(); - auto floatTy = dyn_cast(operand.getType()); - if (!floatTy) - return rewriter.notifyMatchFailure(op, - "only scalar FloatTypes supported"); - if (floatTy.getIntOrFloatBitWidth() > 64) { - return rewriter.notifyMatchFailure(op, - "bitwidth > 64 bits is not supported"); - } + if (failed(checkPreconditions(rewriter, op))) + return failure(); // Get APFloat function from runtime library. auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); @@ -52,23 +44,30 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern { return fn; Location loc = op.getLoc(); rewriter.setInsertionPoint(op); - auto intWType = rewriter.getIntegerType(floatTy.getWidth()); - Value operandBits = arith::ExtUIOp::create( - rewriter, loc, i64Type, - arith::BitcastOp::create(rewriter, loc, intWType, operand)); - - // Call APFloat function. - Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); - SmallVector params = {semValue, operandBits}; - Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type), - SymbolRefAttr::get(*fn), params) - ->getResult(0); - - // Truncate result to the original width. - Value truncatedBits = - arith::TruncIOp::create(rewriter, loc, intWType, negatedBits); - rewriter.replaceOp( - op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits)); + // Scalarize and convert to APFloat runtime calls. + Value repl = forEachScalarValue( + rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), + [&](Value operand, Value, Type resultType) { + auto floatTy = cast(operand.getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, operand)); + // Call APFloat function. + Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); + SmallVector params = {semValue, operandBits}; + Value negatedBits = + func::CallOp::create(rewriter, loc, TypeRange(i64Type), + SymbolRefAttr::get(*fn), params) + ->getResult(0); + // Truncate result to the original width. + auto truncatedBits = + arith::TruncIOp::create(rewriter, loc, intWType, negatedBits); + return arith::BitcastOp::create(rewriter, loc, floatTy, + truncatedBits); + }); + + rewriter.replaceOp(op, repl); return success(); } @@ -85,16 +84,8 @@ struct IsOpToAPFloatConversion final : OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - // Cast operands to 64-bit integers. - auto operand = op.getOperand(); - auto floatTy = dyn_cast(operand.getType()); - if (!floatTy) - return rewriter.notifyMatchFailure(op, - "only scalar FloatTypes supported"); - if (floatTy.getIntOrFloatBitWidth() > 64) { - return rewriter.notifyMatchFailure(op, - "bitwidth > 64 bits is not supported"); - } + if (failed(checkPreconditions(rewriter, op))) + return failure(); // Get APFloat function from runtime library. auto i1 = IntegerType::get(symTable->getContext(), 1); auto i32Type = IntegerType::get(symTable->getContext(), 32); @@ -107,16 +98,24 @@ struct IsOpToAPFloatConversion final : OpRewritePattern { return fn; Location loc = op.getLoc(); rewriter.setInsertionPoint(op); - auto intWType = rewriter.getIntegerType(floatTy.getWidth()); - Value operandBits = arith::ExtUIOp::create( - rewriter, loc, i64Type, - arith::BitcastOp::create(rewriter, loc, intWType, operand)); - - // Call APFloat function. - Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); - SmallVector params = {semValue, operandBits}; - rewriter.replaceOpWithNewOp(op, TypeRange(i1), - SymbolRefAttr::get(*fn), params); + // Scalarize and convert to APFloat runtime calls. + Value repl = forEachScalarValue( + rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), + [&](Value operand, Value, Type resultType) { + auto floatTy = cast(operand.getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, operand)); + + // Call APFloat function. + Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); + Value params[] = {semValue, operandBits}; + return func::CallOp::create(rewriter, loc, TypeRange(i1), + SymbolRefAttr::get(*fn), params) + .getResult(0); + }); + rewriter.replaceOp(op, repl); return success(); } @@ -131,16 +130,15 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern { LogicalResult matchAndRewrite(math::FmaOp op, PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); // Cast operands to 64-bit integers. - auto floatTy = cast(op.getResult().getType()); - if (!floatTy) - return rewriter.notifyMatchFailure(op, - "only scalar FloatTypes supported"); - if (floatTy.getIntOrFloatBitWidth() > 64) { - return rewriter.notifyMatchFailure(op, - "bitwidth > 64 bits is not supported"); + mlir::Type resType = op.getResult().getType(); + auto floatTy = dyn_cast(resType); + if (!floatTy) { + auto vecTy1 = cast(resType); + floatTy = llvm::cast(vecTy1.getElementType()); } - auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); FailureOr fn = lookupOrCreateFnDecl( @@ -151,29 +149,63 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern { Location loc = op.getLoc(); rewriter.setInsertionPoint(op); - auto intWType = rewriter.getIntegerType(floatTy.getWidth()); - auto int64Type = rewriter.getI64Type(); - Value operand = arith::ExtUIOp::create( - rewriter, loc, int64Type, - arith::BitcastOp::create(rewriter, loc, intWType, op.getA())); - Value multiplicand = arith::ExtUIOp::create( - rewriter, loc, int64Type, - arith::BitcastOp::create(rewriter, loc, intWType, op.getB())); - Value addend = arith::ExtUIOp::create( - rewriter, loc, int64Type, - arith::BitcastOp::create(rewriter, loc, intWType, op.getC())); - - // Call APFloat function. - Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); - SmallVector params = {semValue, operand, multiplicand, addend}; - auto resultOp = - func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), - SymbolRefAttr::get(*fn), params); - - // Truncate result to the original width. - Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType, - resultOp->getResult(0)); - rewriter.replaceOpWithNewOp(op, floatTy, truncatedBits); + IntegerType intWType = rewriter.getIntegerType(floatTy.getWidth()); + IntegerType int64Type = rewriter.getI64Type(); + + auto scalarFMA = [&rewriter, &loc, &floatTy, &fn, &intWType, + &int64Type](Value a, Value b, Value c) { + Value operand = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, a)); + Value multiplicand = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, b)); + Value addend = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, c)); + // Call APFloat function. + Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); + SmallVector params = {semValue, operand, multiplicand, addend}; + auto resultOp = + func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + auto trunc = arith::TruncIOp::create(rewriter, loc, intWType, + resultOp->getResult(0)); + return arith::BitcastOp::create(rewriter, loc, floatTy, trunc); + }; + + if (auto vecTy1 = dyn_cast(op.getA().getType())) { + // Sanity check: Operand types must match. + assert(vecTy1 == dyn_cast(op.getB().getType()) && + "expected same vector types"); + assert(vecTy1 == dyn_cast(op.getC().getType()) && + "expected same vector types"); + // Prepare scalar operands. + ResultRange scalarOperands = + vector::ToElementsOp::create(rewriter, loc, op.getA())->getResults(); + ResultRange scalarMultiplicands = + vector::ToElementsOp::create(rewriter, loc, op.getB())->getResults(); + ResultRange scalarAddends = + vector::ToElementsOp::create(rewriter, loc, op.getC())->getResults(); + // Call the function for each pair of scalar operands. + SmallVector results; + for (auto [operand, multiplicand, addend] : llvm::zip_equal( + scalarOperands, scalarMultiplicands, scalarAddends)) { + results.push_back(scalarFMA(operand, multiplicand, addend)); + } + // Package the results into a vector. + auto fromElements = vector::FromElementsOp::create( + rewriter, loc, + vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()), + results); + rewriter.replaceOp(op, fromElements); + return success(); + } + + Value repl = scalarFMA(op.getA(), op.getB(), op.getC()); + rewriter.replaceOp(op, repl); return success(); } diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp index 2b5857367dc40..01f55a8da15a3 100644 --- a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp +++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp @@ -9,14 +9,77 @@ #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" -mlir::Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc, - FloatType floatTy) { +using namespace mlir; + +Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc, + FloatType floatTy) { int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); return arith::ConstantOp::create(b, loc, b.getI32Type(), b.getIntegerAttr(b.getI32Type(), sem)); } + +Value mlir::forEachScalarValue( + mlir::RewriterBase &rewriter, Location loc, Value operand1, Value operand2, + Type resultType, llvm::function_ref fn) { + auto vecTy1 = dyn_cast(operand1.getType()); + if (operand2) { + // Sanity check: Operand types must match. + assert(vecTy1 == dyn_cast(operand2.getType()) && + "expected same vector types"); + } + if (!vecTy1) { + // Not a vector. Call the function directly. + return fn(operand1, operand2, resultType); + } + + // Prepare scalar operands. + ResultRange sclars1 = + vector::ToElementsOp::create(rewriter, loc, operand1)->getResults(); + SmallVector scalars2; + if (!operand2) { + // No second operand. Create a vector of empty values. + scalars2.assign(vecTy1.getNumElements(), Value()); + } else { + llvm::append_range( + scalars2, + vector::ToElementsOp::create(rewriter, loc, operand2)->getResults()); + } + + // Call the function for each pair of scalar operands. + auto resultVecType = cast(resultType); + SmallVector results; + for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) { + Value result = fn(scalar1, scalar2, resultVecType.getElementType()); + results.push_back(result); + } + + // Package the results into a vector. + return vector::FromElementsOp::create( + rewriter, loc, + vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()), + results); +} + +LogicalResult mlir::checkPreconditions(RewriterBase &rewriter, Operation *op) { + for (Value value : llvm::concat(op->getOperands(), op->getResults())) { + Type type = value.getType(); + if (auto vecTy = dyn_cast(type)) { + type = vecTy.getElementType(); + } + if (!type.isIntOrFloat()) { + return rewriter.notifyMatchFailure( + op, "only integers and floats (or vectors thereof) are supported"); + } + if (type.getIntOrFloatBitWidth() > 64) + return rewriter.notifyMatchFailure(op, + "bitwidth > 64 bits is not supported"); + } + return success(); +} diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h index 5f11d24261b43..dfadf9449b497 100644 --- a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h +++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h @@ -9,6 +9,9 @@ #ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_ #define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_ +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" + namespace mlir { class Value; class OpBuilder; @@ -16,6 +19,20 @@ class Location; class FloatType; Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy); + +/// Given two operands of vector type and vector result type (with the same +/// shape), call the given function for each pair of scalar operands and +/// package the result into a vector. If the given operands and result type are +/// not vectors, call the function directly. The second operand is optional. +Value forEachScalarValue(mlir::RewriterBase &rewriter, Location loc, + Value operand1, Value operand2, Type resultType, + llvm::function_ref fn); + +/// Check preconditions for the conversion: +/// 1. All operands / results must be integers or floats (or vectors thereof). +/// 2. The bitwidth of the operands / results must be <= 64. +LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op); + } // namespace mlir #endif // MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_ diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir index dfd9e7c4aaa14..cc773e60dda3e 100644 --- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir @@ -1,6 +1,4 @@ // REQUIRES: system-linux || system-darwin -// TODO: Run only on Linux until we figure out how to build -// mlir_apfloat_wrappers in a platform-independent way. // All floating-point arithmetics is lowered through APFloat. // RUN: mlir-opt %s --convert-arith-to-apfloat --convert-vector-to-scf \ diff --git a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation-vector.mlir b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation-vector.mlir new file mode 100644 index 0000000000000..c0b2d858c1fec --- /dev/null +++ b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation-vector.mlir @@ -0,0 +1,41 @@ +// REQUIRES: system-linux || system-darwin + +// All floating-point arithmetics is lowered through APFloat. +// RUN: mlir-opt %s --convert-math-to-apfloat --convert-vector-to-scf \ +// RUN: --convert-scf-to-cf --convert-to-llvm | \ +// RUN: mlir-runner -e entry --entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils \ +// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s + +func.func @entry() { + + %neg14fp8 = arith.constant dense<[-1.4, -1.4, -1.4, -1.4]> : vector<4xf8E4M3FN> + %absfp8 = math.absf %neg14fp8 : vector<4xf8E4M3FN> + // CHECK: ( 1.375, 1.375, 1.375, 1.375 ) + vector.print %absfp8 : vector<4xf8E4M3FN> + + %a1_vec = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : vector<4xf8E4M3FN> + %b1_vec = arith.constant dense<[4.0, 4.0, 4.0, 4.0]> : vector<4xf8E4M3FN> + %c1_vec = arith.constant dense<[8.0, 8.0, 8.0, 8.0]> : vector<4xf8E4M3FN> + %d1_vec = math.fma %a1_vec, %b1_vec, %c1_vec : vector<4xf8E4M3FN> // not supported by LLVM + // CHECK: ( 16, 16, 16, 16 ) + vector.print %d1_vec : vector<4xf8E4M3FN> + + // CHECK: ( 0, 0, 0, 0 ) + %isinffp8 = math.isinf %neg14fp8 : vector<4xf8E4M3FN> + vector.print %isinffp8 : vector<4xi1> + + %isnanfp8 = math.isnan %neg14fp8 : vector<4xf8E4M3FN> + // CHECK: ( 0, 0, 0, 0 ) + vector.print %isnanfp8 : vector<4xi1> + + %isnormalfp8 = math.isnormal %neg14fp8 : vector<4xf8E4M3FN> + // CHECK: ( 1, 1, 1, 1 ) + vector.print %isnormalfp8 : vector<4xi1> + + %isfinitefp8 = math.isfinite %neg14fp8 : vector<4xf8E4M3FN> + // CHECK: ( 1, 1, 1, 1 ) + vector.print %isfinitefp8 : vector<4xi1> + + return +} diff --git a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir index c890b470b563a..0cc3d3f2218f0 100644 --- a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir +++ b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir @@ -24,15 +24,18 @@ func.func @entry() { // CHECK: 16 vector.print %fmafp8 : f8E8M0FNU - // CHECK: 0 %isinffp8 = math.isinf %neg14fp8 : f8E4M3FN - vector.print %isinffp8 : i1 // CHECK: 0 + vector.print %isinffp8 : i1 + %isnanfp8 = math.isnan %neg14fp8 : f8E4M3FN + // CHECK: 0 vector.print %isnanfp8 : i1 + %isnormalfp8 = math.isnormal %neg14fp8 : f8E4M3FN // CHECK: 1 vector.print %isnormalfp8 : i1 + %isfinitefp8 = math.isfinite %neg14fp8 : f8E4M3FN // CHECK: 1 vector.print %isfinitefp8 : i1 @@ -51,15 +54,18 @@ func.func @entry() { // CHECK: 16 vector.print %fmafp32 : f32 - // CHECK: 0 %isinffp32 = math.isinf %neg14fp32 : f32 - vector.print %isinffp32 : i1 // CHECK: 0 + vector.print %isinffp32 : i1 + %isnanfp32 = math.isnan %neg14fp32 : f32 + // CHECK: 0 vector.print %isnanfp32 : i1 + %isnormalfp32 = math.isnormal %neg14fp32 : f32 // CHECK: 1 vector.print %isnormalfp32 : i1 + %isfinitefp32 = math.isfinite %neg14fp32 : f32 // CHECK: 1 vector.print %isfinitefp32 : i1