Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -316,80 +316,92 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}

/**
* Returns a copy of this node where `f` has been applied to all the nodes children.
* Returns a copy of this node where `f` has been applied to all the nodes in `children`.
*/
def mapChildren(f: BaseType => BaseType): BaseType = {
if (children.nonEmpty) {
var changed = false
def mapChild(child: Any): Any = child match {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = if (containsChild(arg1)) {
f(arg1.asInstanceOf[BaseType])
} else {
arg1.asInstanceOf[BaseType]
}
mapChildren(f, forceCopy = false)
} else {
this
}
}

val newChild2 = if (containsChild(arg2)) {
f(arg2.asInstanceOf[BaseType])
} else {
arg2.asInstanceOf[BaseType]
}
/**
* Returns a copy of this node where `f` has been applied to all the nodes in `children`.
* @param f The transform function to be applied on applicable `TreeNode` elements.
* @param forceCopy Whether to force making a copy of the nodes even if no child has been changed.
*/
private def mapChildren(
f: BaseType => BaseType,
forceCopy: Boolean): BaseType = {
var changed = false

if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
} else {
tuple
}
case other => other
}
def mapChild(child: Any): Any = child match {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (forceCopy || !(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case tuple @ (arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = if (containsChild(arg1)) {
f(arg1.asInstanceOf[BaseType])
} else {
arg1.asInstanceOf[BaseType]
}

val newChild2 = if (containsChild(arg2)) {
f(arg2.asInstanceOf[BaseType])
} else {
arg2.asInstanceOf[BaseType]
}

if (forceCopy || !(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
} else {
tuple
}
case other => other
}

val newArgs = mapProductIterator {
val newArgs = mapProductIterator {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (forceCopy || !(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case Some(arg: TreeNode[_]) if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (forceCopy || !(newChild fastEquals arg)) {
changed = true
Some(newChild)
} else {
Some(arg)
}
case m: Map[_, _] => m.mapValues {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
if (forceCopy || !(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case Some(arg: TreeNode[_]) if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
changed = true
Some(newChild)
} else {
Some(arg)
}
case m: Map[_, _] => m.mapValues {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case other => other
}.view.force // `mapValues` is lazy and we need to force it to materialize
case d: DataType => d // Avoid unpacking Structs
case args: Stream[_] => args.map(mapChild).force // Force materialization on stream
case args: Iterable[_] => args.map(mapChild)
case nonChild: AnyRef => nonChild
case null => null
}
if (changed) makeCopy(newArgs) else this
} else {
this
case other => other
}.view.force // `mapValues` is lazy and we need to force it to materialize
case d: DataType => d // Avoid unpacking Structs
case args: Stream[_] => args.map(mapChild).force // Force materialization on stream
case args: Iterable[_] => args.map(mapChild)
case nonChild: AnyRef => nonChild
case null => null
}
if (forceCopy || changed) makeCopy(newArgs, forceCopy) else this
}

/**
Expand All @@ -405,9 +417,20 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* that are not present in the productIterator.
* @param newArgs the new product arguments.
*/
def makeCopy(newArgs: Array[AnyRef]): BaseType = attachTree(this, "makeCopy") {
def makeCopy(newArgs: Array[AnyRef]): BaseType = makeCopy(newArgs, allowEmptyArgs = false)

/**
* Creates a copy of this type of tree node after a transformation.
* Must be overridden by child classes that have constructor arguments
* that are not present in the productIterator.
* @param newArgs the new product arguments.
* @param allowEmptyArgs whether to allow argument list to be empty.
*/
private def makeCopy(
newArgs: Array[AnyRef],
allowEmptyArgs: Boolean): BaseType = attachTree(this, "makeCopy") {
// Skip no-arg constructors that are just there for kryo.
val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0)
val ctors = getClass.getConstructors.filter(allowEmptyArgs || _.getParameterTypes.size != 0)
if (ctors.isEmpty) {
sys.error(s"No valid constructor for $nodeName")
}
Expand Down Expand Up @@ -450,6 +473,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
}

override def clone(): BaseType = {
mapChildren(_.clone(), forceCopy = true)
}

/**
* Returns the name of this type of TreeNode. Defaults to the class name.
* Note that we remove the "Exec" suffix for physical operators here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions.DslString
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin, SQLHelper}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -82,6 +82,11 @@ case class SelfReferenceUDF(
def apply(key: String): Boolean = config.contains(key)
}

case class FakeLeafPlan(child: LogicalPlan)
extends org.apache.spark.sql.catalyst.plans.logical.LeafNode {
override def output: Seq[Attribute] = child.output
}

class TreeNodeSuite extends SparkFunSuite with SQLHelper {
test("top node changed") {
val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
Expand Down Expand Up @@ -673,4 +678,34 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
})
}
}

test("clone") {
def assertDifferentInstance(before: AnyRef, after: AnyRef): Unit = {
assert(before.ne(after) && before == after)
before.asInstanceOf[TreeNode[_]].children.zip(
after.asInstanceOf[TreeNode[_]].children).foreach {
case (beforeChild: AnyRef, afterChild: AnyRef) =>
assertDifferentInstance(beforeChild, afterChild)
}
}

// Empty constructor
val rowNumber = RowNumber()
assertDifferentInstance(rowNumber, rowNumber.clone())

// Overridden `makeCopy`
val oneRowRelation = OneRowRelation()
assertDifferentInstance(oneRowRelation, oneRowRelation.clone())

// Multi-way operators
val intersect =
Intersect(oneRowRelation, Union(Seq(oneRowRelation, oneRowRelation)), isAll = false)
assertDifferentInstance(intersect, intersect.clone())

// Leaf node with an inner child
val leaf = FakeLeafPlan(intersect)
val leafCloned = leaf.clone()
assertDifferentInstance(leaf, leafCloned)
assert(leaf.child.eq(leafCloned.asInstanceOf[FakeLeafPlan].child))
}
}