Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, QuaternaryLike, SQLQueryContext, TernaryLike, TreeNode, UnaryLike}
import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, SQLQueryContext, TernaryLike, TreeNode, UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
Expand All @@ -47,8 +47,6 @@ import org.apache.spark.sql.types._
* There are a few important traits or abstract classes:
*
* - [[Nondeterministic]]: an expression that is not deterministic.
* - [[Stateful]]: an expression that contains mutable state. For example, MonotonicallyIncreasingID
* and Rand. A stateful expression is always non-deterministic.
* - [[Unevaluable]]: an expression that is not supposed to be evaluated.
* - [[CodegenFallback]]: an expression that does not have code gen implemented and falls back to
* interpreted mode.
Expand Down Expand Up @@ -127,6 +125,54 @@ abstract class Expression extends TreeNode[Expression] {

def references: AttributeSet = _references

/**
* Returns true if the expression contains mutable state.
*
* A stateful expression should never be evaluated multiple times for a single row. This should
* only be a problem for interpreted execution. This can be prevented by creating fresh copies
* of the stateful expression before execution. A common example to trigger this issue:
* {{{
* val rand = functions.rand()
* df.select(rand, rand) // These 2 rand should not share a state.
* }}}
*/
def stateful: Boolean = false

/**
* Returns a copy of this expression where all stateful expressions are replaced with fresh
* uninitialized copies. If the expression contains no stateful expressions then the original
* expression is returned.
*/
def freshCopyIfContainsStatefulExpression(): Expression = {
val childrenIndexedSeq: IndexedSeq[Expression] = children match {
case types: IndexedSeq[Expression] => types
case other => other.toIndexedSeq
}
val newChildren = childrenIndexedSeq.map(_.freshCopyIfContainsStatefulExpression())
// A more efficient version of `children.zip(newChildren).exists(_ ne _)`
val anyChildChanged = {
val size = newChildren.length
var i = 0
var res: Boolean = false
while (!res && i < size) {
res |= (childrenIndexedSeq(i) ne newChildren(i))
i += 1
}
res
}
// If the children contain stateful expressions and get copied, or this expression is stateful,
// copy this expression with the new children.
if (anyChildChanged || stateful) {
CurrentOrigin.withOrigin(origin) {
val res = withNewChildrenInternal(newChildren)
res.copyTagsFrom(this)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The keys in tags still refer to original tree node. Is it okay?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this possible? keys are basically strings, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the code here basically follows TreeNode.withNewChildren

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, you're correct. It is just a string for node name.

res
}
} else {
this
}
}

/** Returns the result of evaluating this expression on a given input Row */
def eval(input: InternalRow = null): Any

Expand Down Expand Up @@ -472,33 +518,6 @@ trait ConditionalExpression extends Expression {
def branchGroups: Seq[Seq[Expression]]
}

/**
* An expression that contains mutable state. A stateful expression is always non-deterministic
* because the results it produces during evaluation are not only dependent on the given input
* but also on its internal state.
*
* The state of the expressions is generally not exposed in the parameter list and this makes
* comparing stateful expressions problematic because similar stateful expressions (with the same
* parameter list) but with different internal state will be considered equal. This is especially
* problematic during tree transformations. In order to counter this the `fastEquals` method for
* stateful expressions only returns `true` for the same reference.
*
* A stateful expression should never be evaluated multiple times for a single row. This should
* only be a problem for interpreted execution. This can be prevented by creating fresh copies
* of the stateful expression before execution, these can be made using the `freshCopy` function.
*/
trait Stateful extends Nondeterministic {
/**
* Return a fresh uninitialized copy of the stateful expression.
*/
def freshCopy(): Stateful

/**
* Only the same reference is considered equal.
*/
override def fastEquals(other: TreeNode[_]): Boolean = this eq other
}

/**
* A leaf expression, i.e. one without any child expressions.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.internal.SQLConf

// A helper class to evaluate expressions.
trait ExpressionsEvaluator {
protected lazy val runtime =
new SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries)

protected def prepareExpressions(
exprs: Seq[Expression],
subExprEliminationEnabled: Boolean): Seq[Expression] = {
// We need to make sure that we do not reuse stateful expressions.
val cleanedExpressions = exprs.map(_.freshCopyIfContainsStatefulExpression())
if (subExprEliminationEnabled) {
runtime.proxyExpressions(cleanedExpressions)
Comment on lines +32 to +33
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, will sub-expr elimination extract common stateful expressions as common expr and break the rule (not reusing)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 stateful but deterministic expressions always produce the same result given the same input sequence. So it's OK to apply sub-expr elimination.

} else {
cleanedExpressions
}
}

/**
* Initializes internal states given the current partition index.
* This is used by nondeterministic expressions to set initial states.
* The default implementation does nothing.
*/
def initialize(partitionIndex: Int): Unit = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be called after prepareExpressions is finished, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or more specifically, the implementation should initialize the final expression that it actually uses.

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,12 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable
this(bindReferences(expressions, inputSchema))

private[this] val subExprEliminationEnabled = SQLConf.get.subexpressionEliminationEnabled
private[this] lazy val runtime =
new SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries)
private[this] val exprs = if (subExprEliminationEnabled) {
runtime.proxyExpressions(expressions)
} else {
expressions
}
private[this] val exprs = prepareExpressions(expressions, subExprEliminationEnabled)

private[this] val buffer = new Array[Any](expressions.size)

override def initialize(partitionIndex: Int): Unit = {
expressions.foreach(_.foreach {
exprs.foreach(_.foreach {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pre-existing bug. The final expressions we use is exprs, not expressions

case n: Nondeterministic => n.initialize(partitionIndex)
case _ =>
})
Expand Down Expand Up @@ -117,10 +111,6 @@ object InterpretedMutableProjection {
* Returns a [[MutableProjection]] for given sequence of bound Expressions.
*/
def createProjection(exprs: Seq[Expression]): MutableProjection = {
// We need to make sure that we do not reuse stateful expressions.
val cleanedExpressions = exprs.map(_.transform {
case s: Stateful => s.freshCopy()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old usage of .transform here contained a subtle bug related to how fastEquals works.

Let's say that we have a tree which looks like this:

Outer(Middle(Stateful()))

where Outer and Middle are non-Stateful expressions.

When the .transform is applied to Stateful() and .freshCopy() is called, the returned value will be == to the original Stateful expression but will have a different object identity (because it's a fresh object). Internally, .transform will use fastEquals to check whether the transformation modified the node. Stateful overrides fastEquals so that it only considers object identity, so the transform will return the freshCopy() result.

At the next level up, Middle will check whether any of its children have been changed in the recursive bottom-up transformation (see childrenFastEquals() in withNewChildren(), which is called from mapChildren()). It will detect that its children have changed, so the transform will return a new Middle node.

Finally, at the top level, Outer will perform the same check to see if any of its children have changed. This time, however, it will be calling Middle.fastEquals instead of Stateful.fastEquals. Middle's fastEquals method is the regular implementation which also considers object equality. Both the original and new Middle nodes will be ==, so fastEquals will be true and Outer will conclude that its children have not been changed by the transformation and the original Outer will be returned (losing the copy of the stateful expression).

In other words, the old .transform and copying logic here was incorrect if the Stateful expression was nested more than a single level deep.

In this PR I chose to fix this by adding a freshCopyIfContainsStatefulExpression() method to Expression which implements a custom tree traversal considers only object identity when determining whether the transform has changed a node or a node's children.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is a nice catch!

})
new InterpretedMutableProjection(cleanedExpressions)
new InterpretedMutableProjection(exprs)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,7 @@ import org.apache.spark.sql.types._
class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection {

private[this] val subExprEliminationEnabled = SQLConf.get.subexpressionEliminationEnabled
private[this] lazy val runtime =
new SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries)
private[this] val exprs = if (subExprEliminationEnabled) {
runtime.proxyExpressions(expressions)
} else {
expressions
}
private[this] val exprs = prepareExpressions(expressions, subExprEliminationEnabled)

private[this] val mutableRow = new SpecificInternalRow(expressions.map(_.dataType))

Expand Down Expand Up @@ -106,6 +100,13 @@ class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection
case _ => identity
}

override def initialize(partitionIndex: Int): Unit = {
expressions.foreach(_.foreach {
case n: Nondeterministic => n.initialize(partitionIndex)
case _ =>
})
}

override def apply(row: InternalRow): InternalRow = {
if (subExprEliminationEnabled) {
runtime.setInput(row)
Expand All @@ -130,10 +131,6 @@ object InterpretedSafeProjection {
* Returns an [[SafeProjection]] for given sequence of bound Expressions.
*/
def createProjection(exprs: Seq[Expression]): Projection = {
// We need to make sure that we do not reuse stateful expressions.
val cleanedExpressions = exprs.map(_.transform {
case s: Stateful => s.freshCopy()
})
new InterpretedSafeProjection(cleanedExpressions)
new InterpretedSafeProjection(exprs)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,7 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe
import InterpretedUnsafeProjection._

private[this] val subExprEliminationEnabled = SQLConf.get.subexpressionEliminationEnabled
private[this] lazy val runtime =
new SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries)
private[this] val exprs = if (subExprEliminationEnabled) {
runtime.proxyExpressions(expressions)
} else {
expressions.toSeq
}
private[this] val exprs = prepareExpressions(expressions, subExprEliminationEnabled)

/** Number of (top level) fields in the resulting row. */
private[this] val numFields = expressions.length
Expand Down Expand Up @@ -106,11 +100,7 @@ object InterpretedUnsafeProjection {
* Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
*/
def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
// We need to make sure that we do not reuse stateful expressions.
val cleanedExpressions = exprs.map(_.transform {
case s: Stateful => s.freshCopy()
})
new InterpretedUnsafeProjection(cleanedExpressions.toArray)
new InterpretedUnsafeProjection(exprs.toArray)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ import org.apache.spark.sql.types.{DataType, LongType}
""",
since = "1.4.0",
group = "misc_funcs")
case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic {

/**
* Record ID within each partition. By being transient, count's value is reset to 0 every time
Expand All @@ -58,11 +58,17 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {

@transient private[this] var partitionMask: Long = _

override def stateful: Boolean = true

override protected def initializeInternal(partitionIndex: Int): Unit = {
count = 0L
partitionMask = partitionIndex.toLong << 33
}

override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
MonotonicallyIncreasingID()
}

override def nullable: Boolean = false

override def dataType: DataType = LongType
Expand All @@ -88,6 +94,4 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
override def nodeName: String = "monotonically_increasing_id"

override def sql: String = s"$prettyName()"

override def freshCopy(): MonotonicallyIncreasingID = MonotonicallyIncreasingID()
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,20 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(bindReferences(expressions, inputSchema))

// null check is required for when Kryo invokes the no-arg constructor.
protected val exprArray = if (expressions != null) {
prepareExpressions(expressions, subExprEliminationEnabled = false).toArray
} else {
null
}

override def initialize(partitionIndex: Int): Unit = {
expressions.foreach(_.foreach {
exprArray.foreach(_.foreach {
case n: Nondeterministic => n.initialize(partitionIndex)
case _ =>
})
}

// null check is required for when Kryo invokes the no-arg constructor.
protected val exprArray = if (expressions != null) expressions.toArray else null

def apply(input: InternalRow): InternalRow = {
val outputArray = new Array[Any](exprArray.length)
var i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ case class ScalaUDF(

override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)

// `ScalaUDF` uses `ExpressionEncoder` to convert the function result to Catalyst internal format.
// `ExpressionEncoder` is stateful as it reuses the `UnsafeRow` instance, thus `ScalaUDF` is
// stateful as well.
override def stateful: Boolean = true

final override val nodePatterns: Seq[TreePattern] = Seq(SCALA_UDF)

override def toString: String = s"$name(${children.mkString(", ")})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1271,8 +1271,11 @@ class CodegenContext extends Logging {
def generateExpressions(
expressions: Seq[Expression],
doSubexpressionElimination: Boolean = false): Seq[ExprCode] = {
if (doSubexpressionElimination) subexpressionElimination(expressions)
expressions.map(e => e.genCode(this))
// We need to make sure that we do not reuse stateful expressions. This is needed for codegen
// as well because some expressions may implement `CodegenFallback`.
val cleanedExpressions = expressions.map(_.freshCopyIfContainsStatefulExpression())
if (doSubexpressionElimination) subexpressionElimination(cleanedExpressions)
cleanedExpressions.map(e => e.genCode(this))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1125,11 +1125,13 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
""",
group = "array_funcs",
since = "2.4.0")
case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
extends UnaryExpression with ExpectsInputTypes with Stateful with ExpressionWithRandomSeed {
case class Shuffle(child: Expression, randomSeed: Option[Long] = None) extends UnaryExpression
with ExpectsInputTypes with Nondeterministic with ExpressionWithRandomSeed {

def this(child: Expression) = this(child, None)

override def stateful: Boolean = true

override def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(UnresolvedSeed)

override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed))
Expand Down Expand Up @@ -1195,8 +1197,6 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
""".stripMargin
}

override def freshCopy(): Shuffle = Shuffle(child, randomSeed)

override def withNewChildInternal(newChild: Expression): Shuffle = copy(child = newChild)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ case class CurrentCatalog() extends LeafExpression with Unevaluable {
since = "2.3.0",
group = "misc_funcs")
// scalastyle:on line.size.limit
case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful
case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Nondeterministic
with ExpressionWithRandomSeed {

def this() = this(None)
Expand All @@ -216,6 +216,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta

override def dataType: DataType = StringType

override def stateful: Boolean = true

@transient private[this] var randomGenerator: RandomUUIDGenerator = _

override protected def initializeInternal(partitionIndex: Int): Unit =
Expand All @@ -235,8 +237,6 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta
ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();",
isNull = FalseLiteral)
}

override def freshCopy(): Uuid = Uuid(randomSeed)
}

// scalastyle:off line.size.limit
Expand Down
Loading