Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 0 additions & 67 deletions mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,73 +46,6 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
{i32Type, i64Type, i64Type}, symbolTables);
}

/// Given two operands of vector type and vector result type (with the same
/// shape), call the given function for each pair of scalar operands and
/// package the result into a vector. If the given operands and result type are
/// not vectors, call the function directly. The second operand is optional.
template <typename Fn, typename... Values>
static Value forEachScalarValue(RewriterBase &rewriter, Location loc,
Value operand1, Value operand2, Type resultType,
Fn fn) {
auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
if (operand2) {
// Sanity check: Operand types must match.
assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
"expected same vector types");
}
if (!vecTy1) {
// Not a vector. Call the function directly.
return fn(operand1, operand2, resultType);
}

// Prepare scalar operands.
ResultRange sclars1 =
vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
SmallVector<Value> scalars2;
if (!operand2) {
// No second operand. Create a vector of empty values.
scalars2.assign(vecTy1.getNumElements(), Value());
} else {
llvm::append_range(
scalars2,
vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
}

// Call the function for each pair of scalar operands.
auto resultVecType = cast<VectorType>(resultType);
SmallVector<Value> results;
for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
Value result = fn(scalar1, scalar2, resultVecType.getElementType());
results.push_back(result);
}

// Package the results into a vector.
return vector::FromElementsOp::create(
rewriter, loc,
vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
results);
}

/// Check preconditions for the conversion:
/// 1. All operands / results must be integers or floats (or vectors thereof).
/// 2. The bitwidth of the operands / results must be <= 64.
static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
Type type = value.getType();
if (auto vecTy = dyn_cast<VectorType>(type)) {
type = vecTy.getElementType();
}
if (!type.isIntOrFloat()) {
return rewriter.notifyMatchFailure(
op, "only integers and floats (or vectors thereof) are supported");
}
if (type.getIntOrFloatBitWidth() > 64)
return rewriter.notifyMatchFailure(op,
"bitwidth > 64 bits is not supported");
}
return success();
}

/// Rewrite a binary arithmetic operation to an APFloat function call.
template <typename OpTy>
struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
Expand Down
188 changes: 110 additions & 78 deletions mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,8 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {

LogicalResult matchAndRewrite(math::AbsFOp op,
PatternRewriter &rewriter) const override {
// Cast operands to 64-bit integers.
auto operand = op.getOperand();
auto floatTy = dyn_cast<FloatType>(operand.getType());
if (!floatTy)
return rewriter.notifyMatchFailure(op,
"only scalar FloatTypes supported");
if (floatTy.getIntOrFloatBitWidth() > 64) {
return rewriter.notifyMatchFailure(op,
"bitwidth > 64 bits is not supported");
}
if (failed(checkPreconditions(rewriter, op)))
return failure();
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
Expand All @@ -52,23 +44,30 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
return fn;
Location loc = op.getLoc();
rewriter.setInsertionPoint(op);
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
Value operandBits = arith::ExtUIOp::create(
rewriter, loc, i64Type,
arith::BitcastOp::create(rewriter, loc, intWType, operand));

// Call APFloat function.
Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, operandBits};
Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
SymbolRefAttr::get(*fn), params)
->getResult(0);

// Truncate result to the original width.
Value truncatedBits =
arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
rewriter.replaceOp(
op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
// Scalarize and convert to APFloat runtime calls.
Value repl = forEachScalarValue(
rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
[&](Value operand, Value, Type resultType) {
auto floatTy = cast<FloatType>(operand.getType());
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
Value operandBits = arith::ExtUIOp::create(
rewriter, loc, i64Type,
arith::BitcastOp::create(rewriter, loc, intWType, operand));
// Call APFloat function.
Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, operandBits};
Value negatedBits =
func::CallOp::create(rewriter, loc, TypeRange(i64Type),
SymbolRefAttr::get(*fn), params)
->getResult(0);
// Truncate result to the original width.
auto truncatedBits =
arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
return arith::BitcastOp::create(rewriter, loc, floatTy,
truncatedBits);
});

rewriter.replaceOp(op, repl);
return success();
}

Expand All @@ -85,16 +84,8 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {

LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Cast operands to 64-bit integers.
auto operand = op.getOperand();
auto floatTy = dyn_cast<FloatType>(operand.getType());
if (!floatTy)
return rewriter.notifyMatchFailure(op,
"only scalar FloatTypes supported");
if (floatTy.getIntOrFloatBitWidth() > 64) {
return rewriter.notifyMatchFailure(op,
"bitwidth > 64 bits is not supported");
}
if (failed(checkPreconditions(rewriter, op)))
return failure();
// Get APFloat function from runtime library.
auto i1 = IntegerType::get(symTable->getContext(), 1);
auto i32Type = IntegerType::get(symTable->getContext(), 32);
Expand All @@ -107,16 +98,24 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
return fn;
Location loc = op.getLoc();
rewriter.setInsertionPoint(op);
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
Value operandBits = arith::ExtUIOp::create(
rewriter, loc, i64Type,
arith::BitcastOp::create(rewriter, loc, intWType, operand));

// Call APFloat function.
Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, operandBits};
rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(i1),
SymbolRefAttr::get(*fn), params);
// Scalarize and convert to APFloat runtime calls.
Value repl = forEachScalarValue(
rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
[&](Value operand, Value, Type resultType) {
auto floatTy = cast<FloatType>(operand.getType());
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
Value operandBits = arith::ExtUIOp::create(
rewriter, loc, i64Type,
arith::BitcastOp::create(rewriter, loc, intWType, operand));

// Call APFloat function.
Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
Value params[] = {semValue, operandBits};
return func::CallOp::create(rewriter, loc, TypeRange(i1),
SymbolRefAttr::get(*fn), params)
.getResult(0);
});
rewriter.replaceOp(op, repl);
return success();
}

Expand All @@ -131,16 +130,15 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {

LogicalResult matchAndRewrite(math::FmaOp op,
PatternRewriter &rewriter) const override {
if (failed(checkPreconditions(rewriter, op)))
return failure();
// Cast operands to 64-bit integers.
auto floatTy = cast<FloatType>(op.getResult().getType());
if (!floatTy)
return rewriter.notifyMatchFailure(op,
"only scalar FloatTypes supported");
if (floatTy.getIntOrFloatBitWidth() > 64) {
return rewriter.notifyMatchFailure(op,
"bitwidth > 64 bits is not supported");
mlir::Type resType = op.getResult().getType();
auto floatTy = dyn_cast<FloatType>(resType);
if (!floatTy) {
auto vecTy1 = cast<VectorType>(resType);
floatTy = llvm::cast<FloatType>(vecTy1.getElementType());
}

auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
Expand All @@ -151,29 +149,63 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
Location loc = op.getLoc();
rewriter.setInsertionPoint(op);

auto intWType = rewriter.getIntegerType(floatTy.getWidth());
auto int64Type = rewriter.getI64Type();
Value operand = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getA()));
Value multiplicand = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getB()));
Value addend = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getC()));

// Call APFloat function.
Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, operand, multiplicand, addend};
auto resultOp =
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
SymbolRefAttr::get(*fn), params);

// Truncate result to the original width.
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
resultOp->getResult(0));
rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, floatTy, truncatedBits);
IntegerType intWType = rewriter.getIntegerType(floatTy.getWidth());
IntegerType int64Type = rewriter.getI64Type();

auto scalarFMA = [&rewriter, &loc, &floatTy, &fn, &intWType,
&int64Type](Value a, Value b, Value c) {
Value operand = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, a));
Value multiplicand = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, b));
Value addend = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, c));
// Call APFloat function.
Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, operand, multiplicand, addend};
auto resultOp =
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
SymbolRefAttr::get(*fn), params);

// Truncate result to the original width.
auto trunc = arith::TruncIOp::create(rewriter, loc, intWType,
resultOp->getResult(0));
return arith::BitcastOp::create(rewriter, loc, floatTy, trunc);
};

if (auto vecTy1 = dyn_cast<VectorType>(op.getA().getType())) {
// Sanity check: Operand types must match.
assert(vecTy1 == dyn_cast<VectorType>(op.getB().getType()) &&
"expected same vector types");
assert(vecTy1 == dyn_cast<VectorType>(op.getC().getType()) &&
"expected same vector types");
// Prepare scalar operands.
ResultRange scalarOperands =
vector::ToElementsOp::create(rewriter, loc, op.getA())->getResults();
ResultRange scalarMultiplicands =
vector::ToElementsOp::create(rewriter, loc, op.getB())->getResults();
ResultRange scalarAddends =
vector::ToElementsOp::create(rewriter, loc, op.getC())->getResults();
// Call the function for each pair of scalar operands.
SmallVector<Value> results;
for (auto [operand, multiplicand, addend] : llvm::zip_equal(
scalarOperands, scalarMultiplicands, scalarAddends)) {
results.push_back(scalarFMA(operand, multiplicand, addend));
}
// Package the results into a vector.
auto fromElements = vector::FromElementsOp::create(
rewriter, loc,
vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
results);
rewriter.replaceOp(op, fromElements);
return success();
}

Value repl = scalarFMA(op.getA(), op.getB(), op.getC());
rewriter.replaceOp(op, repl);
return success();
}

Expand Down
67 changes: 65 additions & 2 deletions mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,77 @@
#include "Utils.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"

mlir::Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc,
FloatType floatTy) {
using namespace mlir;

Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc,
FloatType floatTy) {
int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
return arith::ConstantOp::create(b, loc, b.getI32Type(),
b.getIntegerAttr(b.getI32Type(), sem));
}

Value mlir::forEachScalarValue(
mlir::RewriterBase &rewriter, Location loc, Value operand1, Value operand2,
Type resultType, llvm::function_ref<Value(Value, Value, Type)> fn) {
auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
if (operand2) {
// Sanity check: Operand types must match.
assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
"expected same vector types");
}
if (!vecTy1) {
// Not a vector. Call the function directly.
return fn(operand1, operand2, resultType);
}

// Prepare scalar operands.
ResultRange sclars1 =
vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
SmallVector<Value> scalars2;
if (!operand2) {
// No second operand. Create a vector of empty values.
scalars2.assign(vecTy1.getNumElements(), Value());
} else {
llvm::append_range(
scalars2,
vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
}

// Call the function for each pair of scalar operands.
auto resultVecType = cast<VectorType>(resultType);
SmallVector<Value> results;
for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
Value result = fn(scalar1, scalar2, resultVecType.getElementType());
results.push_back(result);
}

// Package the results into a vector.
return vector::FromElementsOp::create(
rewriter, loc,
vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
results);
}

LogicalResult mlir::checkPreconditions(RewriterBase &rewriter, Operation *op) {
for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
Type type = value.getType();
if (auto vecTy = dyn_cast<VectorType>(type)) {
type = vecTy.getElementType();
}
if (!type.isIntOrFloat()) {
return rewriter.notifyMatchFailure(
op, "only integers and floats (or vectors thereof) are supported");
}
if (type.getIntOrFloatBitWidth() > 64)
return rewriter.notifyMatchFailure(op,
"bitwidth > 64 bits is not supported");
}
return success();
}
Loading