Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 82 additions & 9 deletions polygeist/tools/cgeist/Lib/CGExpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ using namespace mlir;

extern llvm::cl::opt<bool> GenerateAllSYCLFuncs;

static llvm::cl::opt<bool>
OmitFPContract("omit-fp-contract", llvm::cl::init(false),
llvm::cl::desc("Do not contract FP operations"));

ValueCategory
MLIRScanner::VisitExtVectorElementExpr(clang::ExtVectorElementExpr *Expr) {
auto Base = Visit(Expr->getBase());
Expand Down Expand Up @@ -2101,23 +2105,79 @@ ValueCategory MLIRScanner::VisitBinAssign(BinaryOperator *E) {
class BinOpInfo {
public:
BinOpInfo(ValueCategory LHS, ValueCategory RHS, QualType Ty,
BinaryOperator::Opcode Opcode, const Expr *Expr)
: LHS(LHS), RHS(RHS), Ty(Ty), Opcode(Opcode), E(Expr) {}
BinaryOperator::Opcode Opcode, FPOptions FPFeatures,
const Expr *Expr)
: LHS(LHS), RHS(RHS), Ty(Ty), Opcode(Opcode), FPFeatures(FPFeatures),
E(Expr) {}

ValueCategory getLHS() const { return LHS; }
ValueCategory getRHS() const { return RHS; }
constexpr QualType getType() const { return Ty; }
constexpr BinaryOperator::Opcode getOpcode() const { return Opcode; }
FPOptions getFPFeatures() const { return FPFeatures; }
constexpr const Expr *getExpr() const { return E; }

private:
const ValueCategory LHS;
const ValueCategory RHS;
const QualType Ty; // Computation Type.
const BinaryOperator::Opcode Opcode; // Opcode of BinOp to perform
FPOptions FPFeatures;
const Expr *E;
};

// Check whether it would be legal to emit a `math.fma` operation to represent
// op and if so, build the fmuladd.
//
// Checks that (a) the operation is fusable, and (b) -ffp-contract=on.
static std::optional<ValueCategory> tryEmitFMulAdd(const BinOpInfo &Op,
OpBuilder &Builder,
Location Loc,
bool IsSub = false) {
const BinaryOperator::Opcode Opcode = Op.getOpcode();

assert((Opcode == BO_Add || Opcode == BO_AddAssign || Opcode == BO_Sub ||
Opcode == BO_SubAssign) &&
"Only fadd/fsub can be the root of an fmuladd.");

// Check whether this op is marked as fusable and fusion is allowed.
if (OmitFPContract || !Op.getFPFeatures().allowFPContractWithinStatement())
return {};

// Peek through fneg to look for fmul. Make sure fneg has no users, and that
// it is the only use of its operand.
constexpr auto isNegOperand =
[](ValueCategory Val) -> std::pair<ValueCategory, bool> {
auto Op = Val.val.getDefiningOp<arith::NegFOp>();
if (Op && Val.val.use_empty()) {
Value Operand = Op.getOperand();
if (Operand.hasOneUse())
return {ValueCategory(Operand, /*IsReference=*/false), true};
}
return {Val, false};
};
auto [LHS, NegLHS] = isNegOperand(Op.getLHS());
auto [RHS, NegRHS] = isNegOperand(Op.getRHS());

// We have a potentially fusable op. Look for a mul on one of the operands.
// Also, make sure that the mul result isn't used directly. In that case,
// there's no point creating a muladd operation.
constexpr auto tryEmit = [](OpBuilder &Builder, Location Loc,
ValueCategory Original, ValueCategory LHS,
ValueCategory RHS, bool NegLHS, bool NegMul,
bool NegAdd) -> std::optional<ValueCategory> {
auto LHSOp = LHS.val.getDefiningOp<arith::MulFOp>();
if (LHSOp && (LHS.val.use_empty() || NegLHS))
return LHS.FMA(Builder, Loc, RHS, NegMul, NegAdd);
return {};
};
std::optional<ValueCategory> Res =
tryEmit(Builder, Loc, Op.getLHS(), LHS, RHS, NegLHS, NegLHS, IsSub);
return Res ? Res
: tryEmit(Builder, Loc, Op.getRHS(), RHS, LHS, NegRHS,
IsSub ^ NegRHS, false);
}

ValueCategory MLIRScanner::EmitPromoted(Expr *E, QualType PromotionType) {
assert(E && "Invalid input expression.");
E = E->IgnoreParens();
Expand Down Expand Up @@ -2612,7 +2672,9 @@ std::pair<ValueCategory, ValueCategory> MLIRScanner::EmitCompoundAssignLValue(
LHS = EmitScalarConversion(LHS, LHSTy, E->getComputationLHSType(), Loc);

// Expand the binary operator.
ValueCategory Result = (this->*Func)({LHS, RHS, Ty, OpCode, E});
ValueCategory Result =
(this->*Func)({LHS, RHS, Ty, OpCode,
E->getFPFeaturesInEffect(Glob.getCGM().getLangOpts()), E});
// Convert the result back to the LHS type,
// potentially with Implicit Conversion sanitizer check.
Result = EmitScalarConversion(Result, PromotionTypeCR, LHSTy, Loc);
Expand Down Expand Up @@ -2669,7 +2731,12 @@ BinOpInfo MLIRScanner::EmitBinOps(BinaryOperator *E, QualType PromotionType) {
const ValueCategory RHS = EmitPromotedScalarExpr(E->getRHS(), PromotionType);
const QualType Ty = !PromotionType.isNull() ? PromotionType : E->getType();
const BinaryOperator::Opcode Opcode = E->getOpcode();
return {LHS, RHS, Ty, Opcode, E};
return {LHS,
RHS,
Ty,
Opcode,
E->getFPFeaturesInEffect(Glob.getCGM().getLangOpts()),
E};
}

static void informNoOverflowCheck(LangOptions::SignedOverflowBehaviorTy SOB,
Expand Down Expand Up @@ -2900,8 +2967,10 @@ ValueCategory MLIRScanner::EmitBinAdd(const BinOpInfo &Info) {

assert(!Info.getType()->isConstantMatrixType() && "Not yet implemented");

if (mlirclang::isFPOrFPVectorTy(LHS.val.getType()))
return LHS.FAdd(Builder, Loc, RHS.val);
if (mlirclang::isFPOrFPVectorTy(LHS.val.getType())) {
std::optional<ValueCategory> FMulAdd = tryEmitFMulAdd(Info, Builder, Loc);
return FMulAdd ? *FMulAdd : LHS.FAdd(Builder, Loc, RHS.val);
}

return LHS.Add(Builder, Loc, RHS.val);
}
Expand All @@ -2919,8 +2988,11 @@ ValueCategory MLIRScanner::EmitBinSub(const BinOpInfo &Info) {
return LHS.Sub(Builder, Loc, RHS.val);
}
assert(!Info.getType()->isConstantMatrixType() && "Not yet implemented");
if (mlirclang::isFPOrFPVectorTy(LHS.val.getType()))
return LHS.FSub(Builder, Loc, RHS.val);
if (mlirclang::isFPOrFPVectorTy(LHS.val.getType())) {
std::optional<ValueCategory> FMulAdd =
tryEmitFMulAdd(Info, Builder, Loc, /*IsSub=*/true);
return FMulAdd ? *FMulAdd : LHS.FSub(Builder, Loc, RHS.val);
}
return LHS.Sub(Builder, Loc, RHS.val);
}

Expand Down Expand Up @@ -3103,7 +3175,8 @@ ValueCategory MLIRScanner::VisitMinus(UnaryOperator *E,
const ValueCategory Zero =
ValueCategory::getNullValue(Builder, Loc, Op.val.getType());
return EmitBinSub(
BinOpInfo{Zero, Op, E->getType(), BinaryOperator::Opcode::BO_Sub, E});
BinOpInfo{Zero, Op, E->getType(), BinaryOperator::Opcode::BO_Sub,
E->getFPFeaturesInEffect(Glob.getCGM().getLangOpts()), E});
}

ValueCategory MLIRScanner::VisitImag(UnaryOperator *E, QualType PromotionType) {
Expand Down
21 changes: 21 additions & 0 deletions polygeist/tools/cgeist/Lib/ValueCategory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "ValueCategory.h"
#include "Lib/TypeUtils.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Polygeist/IR/PolygeistOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Support/LLVM.h"
Expand Down Expand Up @@ -804,6 +805,26 @@ ValueCategory ValueCategory::FAdd(OpBuilder &Builder, Location Loc,
return FPBinOp<arith::AddFOp>(Builder, Loc, val, RHS);
}

ValueCategory ValueCategory::FMA(OpBuilder &Builder, Location Loc,
ValueCategory Addend, bool NegMul,
bool NegAdd) const {
auto MulOp = val.getDefiningOp<arith::MulFOp>();

assert(MulOp && "Expecting arith.mul operation");

Value MulOp0 = MulOp.getLhs();
Value MulOp1 = MulOp.getRhs();
if (NegMul)
MulOp0 =
ValueCategory(MulOp0, /*IsReference=*/false).FNeg(Builder, Loc).val;
if (NegAdd)
Addend = Addend.FNeg(Builder, Loc);

return ValueCategory(
Builder.create<math::FmaOp>(Loc, MulOp0, MulOp1, Addend.val),
/*IsReference=*/false);
}

ValueCategory ValueCategory::CAdd(OpBuilder &Builder, Location Loc,
Value RHS) const {
assert(isComplexRepresentation(val.getType()) &&
Expand Down
2 changes: 2 additions & 0 deletions polygeist/tools/cgeist/Lib/ValueCategory.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ class ValueCategory {
bool HasNSW = false) const;
ValueCategory FAdd(mlir::OpBuilder &Builder, mlir::Location Loc,
mlir::Value RHS) const;
ValueCategory FMA(mlir::OpBuilder &Builder, mlir::Location Loc,
ValueCategory Addend, bool NegMul, bool NegAdd) const;
ValueCategory Sub(mlir::OpBuilder &Builder, mlir::Location Loc,
mlir::Value RHS, bool HasNUW = false,
bool HasNSW = false) const;
Expand Down
171 changes: 171 additions & 0 deletions polygeist/tools/cgeist/Test/Verification/fmuladd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// RUN: cgeist -w -o - -S --function=* %s | FileCheck %s
// RUN: cgeist -omit-fp-contract -w -o - -S --function=* %s | FileCheck %s --check-prefix=CHECK-OMIT

// COM: -omit-fp-contract should yield no 'math.fma' operations

// CHECK-OMIT-NOT: math.fma

using double2 = double __attribute__((ext_vector_type(2)));

template <typename T>
T test_simple_lhs(T a, T b, T c) {
return a * b + c;
}

template <typename T>
T test_negmul_lhs(T a, T b, T c) {
return -(a * b) + c;
}

template <typename T>
T test_negadd_lhs(T a, T b, T c) {
return a * b - c;
}

template <typename T>
T test_negmul_negadd_lhs(T a, T b, T c) {
return -(a * b) - c;
}

template <typename T>
T test_simple_rhs(T a, T b, T c) {
return a + b * c;
}

template <typename T>
T test_negmul_rhs(T a, T b, T c) {
return a + -(b * c);
}

template <typename T>
T test_negadd_rhs(T a, T b, T c) {
return a - b * c;
}

template <typename T>
T test_negmul_negadd_rhs(T a, T b, T c) {
return a - -(b * c);
}

#define TEST_TYPE(type) \
template type test_simple_lhs(type a, type b, type c); \
template type test_negmul_lhs(type a, type b, type c); \
template type test_negadd_lhs(type a, type b, type c); \
template type test_negmul_negadd_lhs(type a, type b, type c); \
template type test_simple_rhs(type a, type b, type c); \
template type test_negmul_rhs(type a, type b, type c); \
template type test_negadd_rhs(type a, type b, type c); \
template type test_negmul_negadd_rhs(type a, type b, type c);

TEST_TYPE(float)

// CHECK-LABEL: func.func @_Z15test_simple_lhsIfET_S0_S0_S0_(
// CHECK-SAME: %[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32,
// CHECK-SAME: %[[VAL_2:.*]]: f32) -> f32
// CHECK: %[[VAL_3:.*]] = math.fma %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : f32
// CHECK: return %[[VAL_3]] : f32
// CHECK: }

// CHECK-LABEL: func.func @_Z15test_negmul_lhsIfET_S0_S0_S0_(
// CHECK-SAME: %[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32) -> f32
// CHECK: %[[VAL_3:.*]] = arith.negf %[[VAL_0]] : f32
// CHECK: %[[VAL_4:.*]] = math.fma %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : f32
// CHECK: return %[[VAL_4]] : f32
// CHECK: }

// CHECK-LABEL: func.func @_Z15test_negadd_lhsIfET_S0_S0_S0_(
// CHECK-SAME: %[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32) -> f32
// CHECK: %[[VAL_3:.*]] = arith.negf %[[VAL_2]] : f32
// CHECK: %[[VAL_4:.*]] = math.fma %[[VAL_0]], %[[VAL_1]], %[[VAL_3]] : f32
// CHECK: return %[[VAL_4]] : f32
// CHECK: }

// CHECK-LABEL: func.func @_Z22test_negmul_negadd_lhsIfET_S0_S0_S0_(
// CHECK-SAME: %[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32) -> f32
// CHECK: %[[VAL_3:.*]] = arith.negf %[[VAL_0]] : f32
// CHECK: %[[VAL_4:.*]] = arith.negf %[[VAL_2]] : f32
// CHECK: %[[VAL_5:.*]] = math.fma %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : f32
// CHECK: return %[[VAL_5]] : f32
// CHECK: }

// CHECK-LABEL: func.func @_Z15test_simple_rhsIfET_S0_S0_S0_(
// CHECK-SAME: %[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32) -> f32
// CHECK: %[[VAL_3:.*]] = math.fma %[[VAL_1]], %[[VAL_2]], %[[VAL_0]] : f32
// CHECK: return %[[VAL_3]] : f32
// CHECK: }

// CHECK-LABEL: func.func @_Z15test_negmul_rhsIfET_S0_S0_S0_(
// CHECK-SAME: %[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32) -> f32
// CHECK: %[[VAL_3:.*]] = arith.negf %[[VAL_1]] : f32
// CHECK: %[[VAL_4:.*]] = math.fma %[[VAL_3]], %[[VAL_2]], %[[VAL_0]] : f32
// CHECK: return %[[VAL_4]] : f32
// CHECK: }

// CHECK-LABEL: func.func @_Z15test_negadd_rhsIfET_S0_S0_S0_(
// CHECK-SAME: %[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32) -> f32
// CHECK: %[[VAL_3:.*]] = arith.negf %[[VAL_1]] : f32
// CHECK: %[[VAL_4:.*]] = math.fma %[[VAL_3]], %[[VAL_2]], %[[VAL_0]] : f32
// CHECK: return %[[VAL_4]] : f32
// CHECK: }

// CHECK-LABEL: func.func @_Z22test_negmul_negadd_rhsIfET_S0_S0_S0_(
// CHECK-SAME: %[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32) -> f32
// CHECK: %[[VAL_3:.*]] = math.fma %[[VAL_1]], %[[VAL_2]], %[[VAL_0]] : f32
// CHECK: return %[[VAL_3]] : f32
// CHECK: }

TEST_TYPE(double2)

// CHECK-LABEL: func.func @_Z15test_simple_lhsIDv2_dET_S1_S1_S1_(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf64>, %[[VAL_1:.*]]: vector<2xf64>, %[[VAL_2:.*]]: vector<2xf64>) -> vector<2xf64>
// CHECK: %[[VAL_3:.*]] = math.fma %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : vector<2xf64>
// CHECK: return %[[VAL_3]] : vector<2xf64>
// CHECK: }

// CHECK-LABEL: func.func @_Z15test_negmul_lhsIDv2_dET_S1_S1_S1_(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf64>, %[[VAL_1:.*]]: vector<2xf64>, %[[VAL_2:.*]]: vector<2xf64>) -> vector<2xf64>
// CHECK: %[[VAL_3:.*]] = arith.negf %[[VAL_0]] : vector<2xf64>
// CHECK: %[[VAL_4:.*]] = math.fma %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : vector<2xf64>
// CHECK: return %[[VAL_4]] : vector<2xf64>
// CHECK: }

// CHECK-LABEL: func.func @_Z15test_negadd_lhsIDv2_dET_S1_S1_S1_(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf64>, %[[VAL_1:.*]]: vector<2xf64>, %[[VAL_2:.*]]: vector<2xf64>) -> vector<2xf64>
// CHECK: %[[VAL_3:.*]] = arith.negf %[[VAL_2]] : vector<2xf64>
// CHECK: %[[VAL_4:.*]] = math.fma %[[VAL_0]], %[[VAL_1]], %[[VAL_3]] : vector<2xf64>
// CHECK: return %[[VAL_4]] : vector<2xf64>
// CHECK: }

// CHECK-LABEL: func.func @_Z22test_negmul_negadd_lhsIDv2_dET_S1_S1_S1_(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf64>, %[[VAL_1:.*]]: vector<2xf64>, %[[VAL_2:.*]]: vector<2xf64>) -> vector<2xf64>
// CHECK: %[[VAL_3:.*]] = arith.negf %[[VAL_0]] : vector<2xf64>
// CHECK: %[[VAL_4:.*]] = arith.negf %[[VAL_2]] : vector<2xf64>
// CHECK: %[[VAL_5:.*]] = math.fma %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<2xf64>
// CHECK: return %[[VAL_5]] : vector<2xf64>
// CHECK: }

// CHECK-LABEL: func.func @_Z15test_simple_rhsIDv2_dET_S1_S1_S1_(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf64>, %[[VAL_1:.*]]: vector<2xf64>, %[[VAL_2:.*]]: vector<2xf64>) -> vector<2xf64>
// CHECK: %[[VAL_3:.*]] = math.fma %[[VAL_1]], %[[VAL_2]], %[[VAL_0]] : vector<2xf64>
// CHECK: return %[[VAL_3]] : vector<2xf64>
// CHECK: }

// CHECK-LABEL: func.func @_Z15test_negmul_rhsIDv2_dET_S1_S1_S1_(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf64>, %[[VAL_1:.*]]: vector<2xf64>, %[[VAL_2:.*]]: vector<2xf64>) -> vector<2xf64>
// CHECK: %[[VAL_3:.*]] = arith.negf %[[VAL_1]] : vector<2xf64>
// CHECK: %[[VAL_4:.*]] = math.fma %[[VAL_3]], %[[VAL_2]], %[[VAL_0]] : vector<2xf64>
// CHECK: return %[[VAL_4]] : vector<2xf64>
// CHECK: }

// CHECK-LABEL: func.func @_Z15test_negadd_rhsIDv2_dET_S1_S1_S1_(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf64>, %[[VAL_1:.*]]: vector<2xf64>, %[[VAL_2:.*]]: vector<2xf64>) -> vector<2xf64>
// CHECK: %[[VAL_3:.*]] = arith.negf %[[VAL_1]] : vector<2xf64>
// CHECK: %[[VAL_4:.*]] = math.fma %[[VAL_3]], %[[VAL_2]], %[[VAL_0]] : vector<2xf64>
// CHECK: return %[[VAL_4]] : vector<2xf64>
// CHECK: }

// CHECK-LABEL: func.func @_Z22test_negmul_negadd_rhsIDv2_dET_S1_S1_S1_(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf64>, %[[VAL_1:.*]]: vector<2xf64>, %[[VAL_2:.*]]: vector<2xf64>) -> vector<2xf64>
// CHECK: %[[VAL_3:.*]] = math.fma %[[VAL_1]], %[[VAL_2]], %[[VAL_0]] : vector<2xf64>
// CHECK: return %[[VAL_3]] : vector<2xf64>
// CHECK: }
5 changes: 2 additions & 3 deletions polygeist/tools/cgeist/Test/Verification/gettimeofday.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ double alloc() {
// CHECK-NEXT: %[[VAL_8:.*]] = llvm.getelementptr inbounds %[[VAL_2]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i64, i64)>
// CHECK-NEXT: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr -> i64
// CHECK-NEXT: %[[VAL_10:.*]] = arith.sitofp %[[VAL_9]] : i64 to f64
// CHECK-NEXT: %[[VAL_11:.*]] = arith.mulf %[[VAL_10]], %[[VAL_0]] : f64
// CHECK-NEXT: %[[VAL_12:.*]] = arith.addf %[[VAL_7]], %[[VAL_11]] : f64
// CHECK-NEXT: return %[[VAL_12]] : f64
// CHECK-NEXT: %[[VAL_11:.*]] = math.fma %[[VAL_10]], %[[VAL_0]], %[[VAL_7]] : f64
// CHECK-NEXT: return %[[VAL_11]] : f64
// CHECK-NEXT: }
Loading