diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 0ddf1a7df19fc..de0e90285f541 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, QuaternaryLike, SQLQueryContext, TernaryLike, TreeNode, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, SQLQueryContext, TernaryLike, TreeNode, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors} @@ -47,8 +47,6 @@ import org.apache.spark.sql.types._ * There are a few important traits or abstract classes: * * - [[Nondeterministic]]: an expression that is not deterministic. - * - [[Stateful]]: an expression that contains mutable state. For example, MonotonicallyIncreasingID - * and Rand. A stateful expression is always non-deterministic. * - [[Unevaluable]]: an expression that is not supposed to be evaluated. * - [[CodegenFallback]]: an expression that does not have code gen implemented and falls back to * interpreted mode. @@ -127,6 +125,54 @@ abstract class Expression extends TreeNode[Expression] { def references: AttributeSet = _references + /** + * Returns true if the expression contains mutable state. + * + * A stateful expression should never be evaluated multiple times for a single row. This should + * only be a problem for interpreted execution. This can be prevented by creating fresh copies + * of the stateful expression before execution. A common example to trigger this issue: + * {{{ + * val rand = functions.rand() + * df.select(rand, rand) // These 2 rand should not share a state. + * }}} + */ + def stateful: Boolean = false + + /** + * Returns a copy of this expression where all stateful expressions are replaced with fresh + * uninitialized copies. If the expression contains no stateful expressions then the original + * expression is returned. + */ + def freshCopyIfContainsStatefulExpression(): Expression = { + val childrenIndexedSeq: IndexedSeq[Expression] = children match { + case types: IndexedSeq[Expression] => types + case other => other.toIndexedSeq + } + val newChildren = childrenIndexedSeq.map(_.freshCopyIfContainsStatefulExpression()) + // A more efficient version of `children.zip(newChildren).exists(_ ne _)` + val anyChildChanged = { + val size = newChildren.length + var i = 0 + var res: Boolean = false + while (!res && i < size) { + res |= (childrenIndexedSeq(i) ne newChildren(i)) + i += 1 + } + res + } + // If the children contain stateful expressions and get copied, or this expression is stateful, + // copy this expression with the new children. + if (anyChildChanged || stateful) { + CurrentOrigin.withOrigin(origin) { + val res = withNewChildrenInternal(newChildren) + res.copyTagsFrom(this) + res + } + } else { + this + } + } + /** Returns the result of evaluating this expression on a given input Row */ def eval(input: InternalRow = null): Any @@ -472,33 +518,6 @@ trait ConditionalExpression extends Expression { def branchGroups: Seq[Seq[Expression]] } -/** - * An expression that contains mutable state. A stateful expression is always non-deterministic - * because the results it produces during evaluation are not only dependent on the given input - * but also on its internal state. - * - * The state of the expressions is generally not exposed in the parameter list and this makes - * comparing stateful expressions problematic because similar stateful expressions (with the same - * parameter list) but with different internal state will be considered equal. This is especially - * problematic during tree transformations. In order to counter this the `fastEquals` method for - * stateful expressions only returns `true` for the same reference. - * - * A stateful expression should never be evaluated multiple times for a single row. This should - * only be a problem for interpreted execution. This can be prevented by creating fresh copies - * of the stateful expression before execution, these can be made using the `freshCopy` function. - */ -trait Stateful extends Nondeterministic { - /** - * Return a fresh uninitialized copy of the stateful expression. - */ - def freshCopy(): Stateful - - /** - * Only the same reference is considered equal. - */ - override def fastEquals(other: TreeNode[_]): Boolean = this eq other -} - /** * A leaf expression, i.e. one without any child expressions. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala new file mode 100644 index 0000000000000..dcbc6926cd335 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.internal.SQLConf + +// A helper class to evaluate expressions. +trait ExpressionsEvaluator { + protected lazy val runtime = + new SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries) + + protected def prepareExpressions( + exprs: Seq[Expression], + subExprEliminationEnabled: Boolean): Seq[Expression] = { + // We need to make sure that we do not reuse stateful expressions. + val cleanedExpressions = exprs.map(_.freshCopyIfContainsStatefulExpression()) + if (subExprEliminationEnabled) { + runtime.proxyExpressions(cleanedExpressions) + } else { + cleanedExpressions + } + } + + /** + * Initializes internal states given the current partition index. + * This is used by nondeterministic expressions to set initial states. + * The default implementation does nothing. + */ + def initialize(partitionIndex: Int): Unit = {} +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala index 5d95ac71be8d0..682604b9bf72c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala @@ -36,18 +36,12 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable this(bindReferences(expressions, inputSchema)) private[this] val subExprEliminationEnabled = SQLConf.get.subexpressionEliminationEnabled - private[this] lazy val runtime = - new SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries) - private[this] val exprs = if (subExprEliminationEnabled) { - runtime.proxyExpressions(expressions) - } else { - expressions - } + private[this] val exprs = prepareExpressions(expressions, subExprEliminationEnabled) private[this] val buffer = new Array[Any](expressions.size) override def initialize(partitionIndex: Int): Unit = { - expressions.foreach(_.foreach { + exprs.foreach(_.foreach { case n: Nondeterministic => n.initialize(partitionIndex) case _ => }) @@ -117,10 +111,6 @@ object InterpretedMutableProjection { * Returns a [[MutableProjection]] for given sequence of bound Expressions. */ def createProjection(exprs: Seq[Expression]): MutableProjection = { - // We need to make sure that we do not reuse stateful expressions. - val cleanedExpressions = exprs.map(_.transform { - case s: Stateful => s.freshCopy() - }) - new InterpretedMutableProjection(cleanedExpressions) + new InterpretedMutableProjection(exprs) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala index 0e71892db666b..84263d97f5da7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala @@ -32,13 +32,7 @@ import org.apache.spark.sql.types._ class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection { private[this] val subExprEliminationEnabled = SQLConf.get.subexpressionEliminationEnabled - private[this] lazy val runtime = - new SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries) - private[this] val exprs = if (subExprEliminationEnabled) { - runtime.proxyExpressions(expressions) - } else { - expressions - } + private[this] val exprs = prepareExpressions(expressions, subExprEliminationEnabled) private[this] val mutableRow = new SpecificInternalRow(expressions.map(_.dataType)) @@ -106,6 +100,13 @@ class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection case _ => identity } + override def initialize(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + }) + } + override def apply(row: InternalRow): InternalRow = { if (subExprEliminationEnabled) { runtime.setInput(row) @@ -130,10 +131,6 @@ object InterpretedSafeProjection { * Returns an [[SafeProjection]] for given sequence of bound Expressions. */ def createProjection(exprs: Seq[Expression]): Projection = { - // We need to make sure that we do not reuse stateful expressions. - val cleanedExpressions = exprs.map(_.transform { - case s: Stateful => s.freshCopy() - }) - new InterpretedSafeProjection(cleanedExpressions) + new InterpretedSafeProjection(exprs) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 9a9a41b1f18d1..d87c0c006cf24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -35,13 +35,7 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe import InterpretedUnsafeProjection._ private[this] val subExprEliminationEnabled = SQLConf.get.subexpressionEliminationEnabled - private[this] lazy val runtime = - new SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries) - private[this] val exprs = if (subExprEliminationEnabled) { - runtime.proxyExpressions(expressions) - } else { - expressions.toSeq - } + private[this] val exprs = prepareExpressions(expressions, subExprEliminationEnabled) /** Number of (top level) fields in the resulting row. */ private[this] val numFields = expressions.length @@ -106,11 +100,7 @@ object InterpretedUnsafeProjection { * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. */ def createProjection(exprs: Seq[Expression]): UnsafeProjection = { - // We need to make sure that we do not reuse stateful expressions. - val cleanedExpressions = exprs.map(_.transform { - case s: Stateful => s.freshCopy() - }) - new InterpretedUnsafeProjection(cleanedExpressions.toArray) + new InterpretedUnsafeProjection(exprs.toArray) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index ecf254f65f5a1..8dc1ba4846adb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.types.{DataType, LongType} """, since = "1.4.0", group = "misc_funcs") -case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { +case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic { /** * Record ID within each partition. By being transient, count's value is reset to 0 every time @@ -58,11 +58,17 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { @transient private[this] var partitionMask: Long = _ + override def stateful: Boolean = true + override protected def initializeInternal(partitionIndex: Int): Unit = { count = 0L partitionMask = partitionIndex.toLong << 33 } + override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { + MonotonicallyIncreasingID() + } + override def nullable: Boolean = false override def dataType: DataType = LongType @@ -88,6 +94,4 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { override def nodeName: String = "monotonically_increasing_id" override def sql: String = s"$prettyName()" - - override def freshCopy(): MonotonicallyIncreasingID = MonotonicallyIncreasingID() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index b4a85e3e50bec..20969fa584a87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -33,16 +33,20 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(bindReferences(expressions, inputSchema)) + // null check is required for when Kryo invokes the no-arg constructor. + protected val exprArray = if (expressions != null) { + prepareExpressions(expressions, subExprEliminationEnabled = false).toArray + } else { + null + } + override def initialize(partitionIndex: Int): Unit = { - expressions.foreach(_.foreach { + exprArray.foreach(_.foreach { case n: Nondeterministic => n.initialize(partitionIndex) case _ => }) } - // null check is required for when Kryo invokes the no-arg constructor. - protected val exprArray = if (expressions != null) expressions.toArray else null - def apply(input: InternalRow): InternalRow = { val outputArray = new Array[Any](exprArray.length) var i = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index f8ff5f583f602..137a8976a40ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -58,6 +58,11 @@ case class ScalaUDF( override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) + // `ScalaUDF` uses `ExpressionEncoder` to convert the function result to Catalyst internal format. + // `ExpressionEncoder` is stateful as it reuses the `UnsafeRow` instance, thus `ScalaUDF` is + // stateful as well. + override def stateful: Boolean = true + final override val nodePatterns: Seq[TreePattern] = Seq(SCALA_UDF) override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 61f888f17b1f9..12103ceef6ee0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1271,8 +1271,11 @@ class CodegenContext extends Logging { def generateExpressions( expressions: Seq[Expression], doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { - if (doSubexpressionElimination) subexpressionElimination(expressions) - expressions.map(e => e.genCode(this)) + // We need to make sure that we do not reuse stateful expressions. This is needed for codegen + // as well because some expressions may implement `CodegenFallback`. + val cleanedExpressions = expressions.map(_.freshCopyIfContainsStatefulExpression()) + if (doSubexpressionElimination) subexpressionElimination(cleanedExpressions) + cleanedExpressions.map(e => e.genCode(this)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 229987fc0c89b..22584a64f7de3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1125,11 +1125,13 @@ case class SortArray(base: Expression, ascendingOrder: Expression) """, group = "array_funcs", since = "2.4.0") -case class Shuffle(child: Expression, randomSeed: Option[Long] = None) - extends UnaryExpression with ExpectsInputTypes with Stateful with ExpressionWithRandomSeed { +case class Shuffle(child: Expression, randomSeed: Option[Long] = None) extends UnaryExpression + with ExpectsInputTypes with Nondeterministic with ExpressionWithRandomSeed { def this(child: Expression) = this(child, None) + override def stateful: Boolean = true + override def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(UnresolvedSeed) override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed)) @@ -1195,8 +1197,6 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) """.stripMargin } - override def freshCopy(): Shuffle = Shuffle(child, randomSeed) - override def withNewChildInternal(newChild: Expression): Shuffle = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index eb21bd555db7d..bf9dd700dfabd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -201,7 +201,7 @@ case class CurrentCatalog() extends LeafExpression with Unevaluable { since = "2.3.0", group = "misc_funcs") // scalastyle:on line.size.limit -case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful +case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Nondeterministic with ExpressionWithRandomSeed { def this() = this(None) @@ -216,6 +216,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta override def dataType: DataType = StringType + override def stateful: Boolean = true + @transient private[this] var randomGenerator: RandomUUIDGenerator = _ override protected def initializeInternal(partitionIndex: Int): Unit = @@ -235,8 +237,6 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", isNull = FalseLiteral) } - - override def freshCopy(): Uuid = Uuid(randomSeed) } // scalastyle:off line.size.limit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index ededac3d91706..44813ac7b614e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -68,15 +68,7 @@ package object expressions { * column of the new row. If the schema of the input row is specified, then the given expression * will be bound to that schema. */ - abstract class Projection extends (InternalRow => InternalRow) { - - /** - * Initializes internal states given the current partition index. - * This is used by nondeterministic expressions to set initial states. - * The default implementation does nothing. - */ - def initialize(partitionIndex: Int): Unit = {} - } + abstract class Projection extends (InternalRow => InternalRow) with ExpressionsEvaluator /** * An identity projection. This returns the input row. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f51c9fd5ef367..4e4ac6ee49265 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -36,26 +36,13 @@ import org.apache.spark.sql.types._ /** * A base class for generated/interpreted predicate */ -abstract class BasePredicate { +abstract class BasePredicate extends ExpressionsEvaluator { def eval(r: InternalRow): Boolean - - /** - * Initializes internal states given the current partition index. - * This is used by nondeterministic expressions to set initial states. - * The default implementation does nothing. - */ - def initialize(partitionIndex: Int): Unit = {} } case class InterpretedPredicate(expression: Expression) extends BasePredicate { private[this] val subExprEliminationEnabled = SQLConf.get.subexpressionEliminationEnabled - private[this] lazy val runtime = - new SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries) - private[this] val expr = if (subExprEliminationEnabled) { - runtime.proxyExpressions(Seq(expression)).head - } else { - expression - } + private[this] val expr = prepareExpressions(Seq(expression), subExprEliminationEnabled).head override def eval(r: InternalRow): Boolean = { if (subExprEliminationEnabled) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index e2eb7fb1643b4..db78415a0cc54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.random.XORShiftRandom * * Since this expression is stateful, it cannot be a case object. */ -abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful +abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic with ExpressionWithRandomSeed { /** * Record ID within each partition. By being transient, the Random Number Generator is @@ -40,6 +40,8 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful */ @transient protected var rng: XORShiftRandom = _ + override def stateful: Boolean = true + override protected def initializeInternal(partitionIndex: Int): Unit = { rng = new XORShiftRandom(seed + partitionIndex) } @@ -108,8 +110,6 @@ case class Rand(child: Expression, hideSeed: Boolean = false) extends RDG { isNull = FalseLiteral) } - override def freshCopy(): Rand = Rand(child, hideSeed) - override def flatArguments: Iterator[Any] = Iterator(child) override def sql: String = { s"rand(${if (hideSeed) "" else child.sql})" @@ -161,8 +161,6 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG { isNull = FalseLiteral) } - override def freshCopy(): Randn = Randn(child, hideSeed) - override def flatArguments: Iterator[Any] = Iterator(child) override def sql: String = { s"randn(${if (hideSeed) "" else child.sql})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index a7573fc1bd9c4..9510aa4d9e707 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -1113,7 +1113,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre trait LeafLike[T <: TreeNode[T]] { self: TreeNode[T] => override final def children: Seq[T] = Nil override final def mapChildren(f: T => T): T = this.asInstanceOf[T] - override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = this.asInstanceOf[T] + // Stateful expressions should override this method to return a new instance. + override def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = this.asInstanceOf[T] } trait UnaryLike[T <: TreeNode[T]] { self: TreeNode[T] => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index f6c529ec4ced1..32b3840760f91 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -2103,12 +2103,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper evaluateWithMutableProjection(Shuffle(ai0, seed2))) assert(evaluateWithUnsafeProjection(Shuffle(ai0, seed1)) !== evaluateWithUnsafeProjection(Shuffle(ai0, seed2))) - - val shuffle = Shuffle(ai0, seed1) - assert(shuffle.fastEquals(shuffle)) - assert(!shuffle.fastEquals(Shuffle(ai0, seed1))) - assert(!shuffle.fastEquals(shuffle.freshCopy())) - assert(!shuffle.fastEquals(Shuffle(ai0, seed2))) } test("Array Except") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 15a0695943b5b..d449de3defb2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -70,12 +70,6 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { evaluateWithMutableProjection(Uuid(seed2))) assert(evaluateWithUnsafeProjection(Uuid(seed1)) !== evaluateWithUnsafeProjection(Uuid(seed2))) - - val uuid = Uuid(seed1) - assert(uuid.fastEquals(uuid)) - assert(!uuid.fastEquals(Uuid(seed1))) - assert(!uuid.fastEquals(uuid.freshCopy())) - assert(!uuid.fastEquals(Uuid(seed2))) } test("PrintToStderr") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index bef88a7c0a356..286d3dddae6ea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -820,6 +820,50 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { assert(leaf.child.eq(leafCloned.asInstanceOf[FakeLeafPlan].child)) } + test("Expression.freshCopyIfContainsStatefulExpression()") { + val tag = TreeNodeTag[String]("test") + + def makeExprWithPositionAndTag(block: => Expression): Expression = { + CurrentOrigin.setPosition(1, 1) + val expr = block + CurrentOrigin.reset() + expr.setTagValue(tag, "tagValue") + expr + } + + // Test generic assertions which should always hold for any value returned + // from freshCopyIfContainsStatefulExpression() + def genericAssertions(before: Expression, after: Expression): Unit = { + assert(before == after) + assert(before.origin == after.origin) + assert(before.getTagValue(tag) == after.getTagValue(tag)) + } + + // Doesn't transform for non-stateful expressions: + val onePlusOneBefore = makeExprWithPositionAndTag(Add(Literal(1), Literal(1))) + val onePlusOneAfter = onePlusOneBefore.freshCopyIfContainsStatefulExpression() + genericAssertions(onePlusOneBefore, onePlusOneAfter) + assert(onePlusOneBefore eq onePlusOneAfter) + + // Transforms stateful expressions with no nesting: + val statefulExprBefore = makeExprWithPositionAndTag(Rand(Literal(1))) + val statefulExprAfter = statefulExprBefore.freshCopyIfContainsStatefulExpression() + genericAssertions(statefulExprBefore, statefulExprAfter) + assert(statefulExprBefore ne statefulExprAfter) + + // Transforms expressions nested three levels deep: + val withNestedStatefulBefore = makeExprWithPositionAndTag( + Add(Literal(1), Add(Literal(1), Rand(Literal(1)))) + ) + val withNestedStatefulAfter = withNestedStatefulBefore.freshCopyIfContainsStatefulExpression() + genericAssertions(withNestedStatefulBefore, withNestedStatefulAfter) + assert(withNestedStatefulBefore ne withNestedStatefulAfter) + def getStateful(e: Expression): Expression = { + e.collect { case e if e.stateful => e }.head + } + assert(getStateful(withNestedStatefulBefore) ne getStateful(withNestedStatefulAfter)) + } + object MalformedClassObject extends Serializable { case class MalformedNameExpression(child: Expression) extends TaggingExpression { override protected def withNewChildInternal(newChild: Expression): Expression = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4269aaea0dfcd..a7bb0a2d1bd8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -3567,6 +3567,17 @@ class DataFrameSuite extends QueryTest }.isEmpty) } } + + test("SPARK-41049: stateful expression should be copied correctly") { + val df = spark.sparkContext.parallelize(1 to 5).toDF("x") + val v1 = (rand() * 10000).cast(IntegerType) + val v2 = to_csv(struct(v1.as("a"))) // to_csv is CodegenFallback + df.select(v1, v1, v2, v2).collect.foreach { row => + assert(row.getInt(0) == row.getInt(1)) + assert(row.getInt(0).toString == row.getString(2)) + assert(row.getInt(0).toString == row.getString(3)) + } + } } case class GroupByKey(a: Int, b: Int)