Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
case Some(value) => Some(recursiveTransform(value))
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
case stream: Stream[_] => stream.map(recursiveTransform).force
case seq: Traversable[_] => seq.map(recursiveTransform)
case other: AnyRef => other
case null => null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,44 +199,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
var changed = false
val remainingNewChildren = newChildren.toBuffer
val remainingOldChildren = children.toBuffer
def mapTreeNode(node: TreeNode[_]): TreeNode[_] = {
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
oldChild
} else {
changed = true
newChild
}
}
def mapChild(child: Any): Any = child match {
case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg)
case nonChild: AnyRef => nonChild
case null => null
}
val newArgs = mapProductIterator {
case s: StructType => s // Don't convert struct types to some other type of Seq[StructField]
// Handle Seq[TreeNode] in TreeNode parameters.
case s: Seq[_] => s.map {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
oldChild
} else {
changed = true
newChild
}
case nonChild: AnyRef => nonChild
case null => null
}
case m: Map[_, _] => m.mapValues {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
oldChild
} else {
changed = true
newChild
}
case nonChild: AnyRef => nonChild
case null => null
}.view.force // `mapValues` is lazy and we need to force it to materialize
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
oldChild
} else {
changed = true
newChild
}
case s: Stream[_] =>
// Stream is lazy so we need to force materialization
s.map(mapChild).force
case s: Seq[_] =>
s.map(mapChild)
case m: Map[_, _] =>
// `mapValues` is lazy and we need to force it to materialize
m.mapValues(mapChild).view.force
case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg)
case nonChild: AnyRef => nonChild
case null => null
}
Expand Down Expand Up @@ -301,6 +290,37 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
def mapChildren(f: BaseType => BaseType): BaseType = {
if (children.nonEmpty) {
var changed = false
def mapChild(child: Any): Any = child match {
case arg: TreeNode[_] if containsChild(arg) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we reuse these code in L326?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so was about to do that but then I noticed that they handle different cases, mapChild handles TreeNode and (TreeNode, TreeNode), whereas L326-L349 handles TreeNode, Option[TreeNode] and Map[_, TreeNode]. I am not sure if combining them is useful, and if it is then I'd rather do it in a different PR.

val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = if (containsChild(arg1)) {
f(arg1.asInstanceOf[BaseType])
} else {
arg1.asInstanceOf[BaseType]
}

val newChild2 = if (containsChild(arg2)) {
f(arg2.asInstanceOf[BaseType])
} else {
arg2.asInstanceOf[BaseType]
}

if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
} else {
tuple
}
case other => other
}

val newArgs = mapProductIterator {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
Expand Down Expand Up @@ -330,36 +350,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case other => other
}.view.force // `mapValues` is lazy and we need to force it to materialize
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = if (containsChild(arg1)) {
f(arg1.asInstanceOf[BaseType])
} else {
arg1.asInstanceOf[BaseType]
}

val newChild2 = if (containsChild(arg2)) {
f(arg2.asInstanceOf[BaseType])
} else {
arg2.asInstanceOf[BaseType]
}

if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
} else {
tuple
}
case other => other
}
case args: Stream[_] => args.map(mapChild).force // Force materialization on stream
case args: Traversable[_] => args.map(mapChild)
case nonChild: AnyRef => nonChild
case null => null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Coalesce, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType

Expand Down Expand Up @@ -101,4 +101,22 @@ class LogicalPlanSuite extends SparkFunSuite {
assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)
}

test("transformExpressions works with a Stream") {
val id1 = NamedExpression.newExprId
val id2 = NamedExpression.newExprId
val plan = Project(Stream(
Alias(Literal(1), "a")(exprId = id1),
Alias(Literal(2), "b")(exprId = id2)),
OneRowRelation())
val result = plan.transformExpressions {
case Literal(v: Int, IntegerType) if v != 1 =>
Literal(v + 1, IntegerType)
}
val expected = Project(Stream(
Alias(Literal(1), "a")(exprId = id1),
Alias(Literal(3), "b")(exprId = id2)),
OneRowRelation())
assert(result.sameResult(expected))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.dsl.expressions.DslString
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union}
import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel

case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback {
Expand Down Expand Up @@ -574,4 +574,25 @@ class TreeNodeSuite extends SparkFunSuite {
val right = JsonMethods.parse(rightJson)
assert(left == right)
}

test("transform works on stream of children") {
val before = Coalesce(Stream(Literal(1), Literal(2)))
// Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the
// situation in which the TreeNode.mapChildren function's change detection is not triggered. A
// stream's first element is typically materialized, so in order to not trip the TreeNode change
// detection logic, we should not change the first element in the sequence.
val result = before.transform {
case Literal(v: Int, IntegerType) if v != 1 =>
Literal(v + 1, IntegerType)
}
val expected = Coalesce(Stream(Literal(1), Literal(3)))
assert(result === expected)
}

test("withNewChildren on stream of children") {
val before = Coalesce(Stream(Literal(1), Literal(2)))
val result = before.withNewChildren(Stream(Literal(1), Literal(3)))
val expected = Coalesce(Stream(Literal(1), Literal(3)))
assert(result === expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
Expand Down Expand Up @@ -679,6 +679,13 @@ class PlannerSuite extends SharedSQLContext {
}
assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0))
}

test("SPARK-24500: create union with stream of children") {
val df = Union(Stream(
Range(1, 1, 1, 1),
Range(1, 2, 1, 1)))
df.queryExecution.executedPlan.execute()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it throw exception before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it would throw an UnsupportedOperationException before.

}
}

// Used for unit-testing EnsureRequirements
Expand Down