diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index a37ea5e87a316..55b9c3dc0a355 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -112,6 +112,8 @@ static FailureOr getOperatorPrecedence(Operation *operation) { .Default([](auto op) { return op->emitError("unsupported operation"); }); } +static bool shouldBeInlined(Operation *op); + namespace { /// Emitter that uses dialect specific emitters to emit C++ code. struct CppEmitter { @@ -255,24 +257,19 @@ struct CppEmitter { } /// Is expression currently being emitted. - bool isEmittingExpression() { return emittedExpression; } + bool isEmittingExpression() { return !emittedExpressionPrecedence.empty(); } /// Determine whether given value is part of the expression potentially being /// emitted. bool isPartOfCurrentExpression(Value value) { - if (!emittedExpression) - return false; Operation *def = value.getDefiningOp(); - if (!def) - return false; - return isPartOfCurrentExpression(def); + return def ? isPartOfCurrentExpression(def) : false; } /// Determine whether given operation is part of the expression potentially /// being emitted. bool isPartOfCurrentExpression(Operation *def) { - auto operandExpression = dyn_cast(def->getParentOp()); - return operandExpression && operandExpression == emittedExpression; + return isEmittingExpression() && shouldBeInlined(def); }; // Resets the value counter to 0. @@ -319,7 +316,6 @@ struct CppEmitter { unsigned int valueCount{0}; /// State of the current expression being emitted. - ExpressionOp emittedExpression; SmallVector emittedExpressionPrecedence; void pushExpressionPrecedence(int precedence) { @@ -342,12 +338,22 @@ static bool hasDeferredEmission(Operation *op) { emitc::GetFieldOp>(op); } -/// Determine whether expression \p expressionOp should be emitted inline, i.e. +/// Determine whether operation \p op should be emitted inline, i.e. /// as part of its user. This function recommends inlining of any expressions /// that can be inlined unless it is used by another expression, under the /// assumption that any expression fusion/re-materialization was taken care of /// by transformations run by the backend. -static bool shouldBeInlined(ExpressionOp expressionOp) { +static bool shouldBeInlined(Operation *op) { + // CExpression operations are inlined if and only if they reside within an + // ExpressionOp. + if (isa(op)) + return isa(op->getParentOp()); + + // Only other inlinable operation is ExpressionOp itself. + ExpressionOp expressionOp = dyn_cast(op); + if (!expressionOp) + return false; + // Do not inline if expression is marked as such. if (expressionOp.getDoNotInline()) return false; @@ -1585,7 +1591,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { "Expected precedence stack to be empty"); Operation *rootOp = expressionOp.getRootOp(); - emittedExpression = expressionOp; FailureOr precedence = getOperatorPrecedence(rootOp); if (failed(precedence)) return failure(); @@ -1597,7 +1602,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { popExpressionPrecedence(); assert(emittedExpressionPrecedence.empty() && "Expected precedence stack to be empty"); - emittedExpression = nullptr; return success(); } @@ -1638,14 +1642,8 @@ LogicalResult CppEmitter::emitOperand(Value value, bool isInBrackets) { // If this operand is a block argument of an expression, emit instead the // matching expression parameter. Operation *argOp = arg.getParentBlock()->getParentOp(); - if (auto expressionOp = dyn_cast(argOp)) { - // This scenario is only expected when one of the operations within the - // expression being emitted references one of the expression's block - // arguments. - assert(expressionOp == emittedExpression && - "Expected expression being emitted"); - value = expressionOp->getOperand(arg.getArgNumber()); - } + if (auto expressionOp = dyn_cast(argOp)) + return emitOperand(expressionOp->getOperand(arg.getArgNumber())); } os << getOrCreateName(value);