diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 5b4549d0d94f..40f3629fc368 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -316,80 +316,92 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } /** - * Returns a copy of this node where `f` has been applied to all the nodes children. + * Returns a copy of this node where `f` has been applied to all the nodes in `children`. */ def mapChildren(f: BaseType => BaseType): BaseType = { if (children.nonEmpty) { - var changed = false - def mapChild(child: Any): Any = child match { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = f(arg.asInstanceOf[BaseType]) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => - val newChild1 = if (containsChild(arg1)) { - f(arg1.asInstanceOf[BaseType]) - } else { - arg1.asInstanceOf[BaseType] - } + mapChildren(f, forceCopy = false) + } else { + this + } + } - val newChild2 = if (containsChild(arg2)) { - f(arg2.asInstanceOf[BaseType]) - } else { - arg2.asInstanceOf[BaseType] - } + /** + * Returns a copy of this node where `f` has been applied to all the nodes in `children`. + * @param f The transform function to be applied on applicable `TreeNode` elements. + * @param forceCopy Whether to force making a copy of the nodes even if no child has been changed. + */ + private def mapChildren( + f: BaseType => BaseType, + forceCopy: Boolean): BaseType = { + var changed = false - if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { - changed = true - (newChild1, newChild2) - } else { - tuple - } - case other => other - } + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = f(arg.asInstanceOf[BaseType]) + if (forceCopy || !(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case tuple @ (arg1: TreeNode[_], arg2: TreeNode[_]) => + val newChild1 = if (containsChild(arg1)) { + f(arg1.asInstanceOf[BaseType]) + } else { + arg1.asInstanceOf[BaseType] + } + + val newChild2 = if (containsChild(arg2)) { + f(arg2.asInstanceOf[BaseType]) + } else { + arg2.asInstanceOf[BaseType] + } + + if (forceCopy || !(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { + changed = true + (newChild1, newChild2) + } else { + tuple + } + case other => other + } - val newArgs = mapProductIterator { + val newArgs = mapProductIterator { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = f(arg.asInstanceOf[BaseType]) + if (forceCopy || !(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case Some(arg: TreeNode[_]) if containsChild(arg) => + val newChild = f(arg.asInstanceOf[BaseType]) + if (forceCopy || !(newChild fastEquals arg)) { + changed = true + Some(newChild) + } else { + Some(arg) + } + case m: Map[_, _] => m.mapValues { case arg: TreeNode[_] if containsChild(arg) => val newChild = f(arg.asInstanceOf[BaseType]) - if (!(newChild fastEquals arg)) { + if (forceCopy || !(newChild fastEquals arg)) { changed = true newChild } else { arg } - case Some(arg: TreeNode[_]) if containsChild(arg) => - val newChild = f(arg.asInstanceOf[BaseType]) - if (!(newChild fastEquals arg)) { - changed = true - Some(newChild) - } else { - Some(arg) - } - case m: Map[_, _] => m.mapValues { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = f(arg.asInstanceOf[BaseType]) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case other => other - }.view.force // `mapValues` is lazy and we need to force it to materialize - case d: DataType => d // Avoid unpacking Structs - case args: Stream[_] => args.map(mapChild).force // Force materialization on stream - case args: Iterable[_] => args.map(mapChild) - case nonChild: AnyRef => nonChild - case null => null - } - if (changed) makeCopy(newArgs) else this - } else { - this + case other => other + }.view.force // `mapValues` is lazy and we need to force it to materialize + case d: DataType => d // Avoid unpacking Structs + case args: Stream[_] => args.map(mapChild).force // Force materialization on stream + case args: Iterable[_] => args.map(mapChild) + case nonChild: AnyRef => nonChild + case null => null } + if (forceCopy || changed) makeCopy(newArgs, forceCopy) else this } /** @@ -405,9 +417,20 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * that are not present in the productIterator. * @param newArgs the new product arguments. */ - def makeCopy(newArgs: Array[AnyRef]): BaseType = attachTree(this, "makeCopy") { + def makeCopy(newArgs: Array[AnyRef]): BaseType = makeCopy(newArgs, allowEmptyArgs = false) + + /** + * Creates a copy of this type of tree node after a transformation. + * Must be overridden by child classes that have constructor arguments + * that are not present in the productIterator. + * @param newArgs the new product arguments. + * @param allowEmptyArgs whether to allow argument list to be empty. + */ + private def makeCopy( + newArgs: Array[AnyRef], + allowEmptyArgs: Boolean): BaseType = attachTree(this, "makeCopy") { // Skip no-arg constructors that are just there for kryo. - val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0) + val ctors = getClass.getConstructors.filter(allowEmptyArgs || _.getParameterTypes.size != 0) if (ctors.isEmpty) { sys.error(s"No valid constructor for $nodeName") } @@ -450,6 +473,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } + override def clone(): BaseType = { + mapChildren(_.clone(), forceCopy = true) + } + /** * Returns the name of this type of TreeNode. Defaults to the class name. * Note that we remove the "Exec" suffix for physical operators here. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 744d522b1b5d..fbaa5527a705 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions.DslString import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin, SQLHelper} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -82,6 +82,11 @@ case class SelfReferenceUDF( def apply(key: String): Boolean = config.contains(key) } +case class FakeLeafPlan(child: LogicalPlan) + extends org.apache.spark.sql.catalyst.plans.logical.LeafNode { + override def output: Seq[Attribute] = child.output +} + class TreeNodeSuite extends SparkFunSuite with SQLHelper { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } @@ -673,4 +678,34 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { }) } } + + test("clone") { + def assertDifferentInstance(before: AnyRef, after: AnyRef): Unit = { + assert(before.ne(after) && before == after) + before.asInstanceOf[TreeNode[_]].children.zip( + after.asInstanceOf[TreeNode[_]].children).foreach { + case (beforeChild: AnyRef, afterChild: AnyRef) => + assertDifferentInstance(beforeChild, afterChild) + } + } + + // Empty constructor + val rowNumber = RowNumber() + assertDifferentInstance(rowNumber, rowNumber.clone()) + + // Overridden `makeCopy` + val oneRowRelation = OneRowRelation() + assertDifferentInstance(oneRowRelation, oneRowRelation.clone()) + + // Multi-way operators + val intersect = + Intersect(oneRowRelation, Union(Seq(oneRowRelation, oneRowRelation)), isAll = false) + assertDifferentInstance(intersect, intersect.clone()) + + // Leaf node with an inner child + val leaf = FakeLeafPlan(intersect) + val leafCloned = leaf.clone() + assertDifferentInstance(leaf, leafCloned) + assert(leaf.child.eq(leafCloned.asInstanceOf[FakeLeafPlan].child)) + } }