From 2ccb721659e31b5ff3323a1ae2ff90aefb341b03 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 11 Jun 2014 12:35:46 -0700 Subject: [PATCH 1/3] Add support for transformation of optional children. --- .../spark/sql/catalyst/trees/TreeNode.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 0369129393a0..9829b4a54ac0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -187,6 +187,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } + case Some(arg: TreeNode[_]) if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + if (!(newChild fastEquals arg)) { + changed = true + Some(newChild) + } else { + Some(arg) + } case m: Map[_,_] => m case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => @@ -231,6 +239,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } + case Some(arg: TreeNode[_]) if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformUp(rule) + if (!(newChild fastEquals arg)) { + changed = true + Some(newChild) + } else { + Some(arg) + } case m: Map[_,_] => m case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => From ab78420c4c2722e85c41cd5e0292583e26f6999c Mon Sep 17 00:00:00 2001 From: Zongheng Yang Date: Thu, 12 Jun 2014 22:17:49 -0700 Subject: [PATCH 2/3] Add a test. --- .../spark/sql/catalyst/trees/TreeNode.scala | 3 ++- .../sql/catalyst/trees/TreeNodeSuite.scala | 26 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 9829b4a54ac0..cd04bdf02cf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -289,7 +289,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } catch { case e: java.lang.IllegalArgumentException => throw new TreeNodeException( - this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName?") + this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? " + + s"Exception message: ${e.getMessage}.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 1ddc41a731ff..0b40a2dcea7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{StringType, NullType} class TreeNodeSuite extends FunSuite { test("top node changed") { @@ -75,4 +76,29 @@ class TreeNodeSuite extends FunSuite { assert(expected === actual) } + + test("transform works on nodes with Option children") { + case class Dummy(optKey: Option[Expression]) extends Expression { + def children = optKey.toSeq + def references = Set.empty[Attribute] + def nullable = true + def dataType = NullType + override lazy val resolved = true + type EvaluatedType = Any + def eval(input: Row) = null.asInstanceOf[Any] + } + val dummy1 = Dummy(Some(Literal("1", StringType))) + val dummy2 = Dummy(None) + val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } + + var actual = dummy1 transformDown toZero + assert(actual === Dummy(Some(Literal(0)))) + + actual = dummy1 transformUp toZero + assert(actual === Dummy(Some(Literal(0)))) + + actual = dummy2 transform toZero + assert(actual === Dummy(None)) + } + } From 73133c2cd634bcd7ba6d28308f668a69e30418b1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 13 Jun 2014 19:37:46 -0700 Subject: [PATCH 3/3] TreeNodes can't be inner classes. --- .../sql/catalyst/trees/TreeNodeSuite.scala | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 0b40a2dcea7c..6344874538d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -24,6 +24,16 @@ import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{StringType, NullType} +case class Dummy(optKey: Option[Expression]) extends Expression { + def children = optKey.toSeq + def references = Set.empty[Attribute] + def nullable = true + def dataType = NullType + override lazy val resolved = true + type EvaluatedType = Any + def eval(input: Row) = null.asInstanceOf[Any] +} + class TreeNodeSuite extends FunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } @@ -78,15 +88,6 @@ class TreeNodeSuite extends FunSuite { } test("transform works on nodes with Option children") { - case class Dummy(optKey: Option[Expression]) extends Expression { - def children = optKey.toSeq - def references = Set.empty[Attribute] - def nullable = true - def dataType = NullType - override lazy val resolved = true - type EvaluatedType = Any - def eval(input: Row) = null.asInstanceOf[Any] - } val dummy1 = Dummy(Some(Literal("1", StringType))) val dummy2 = Dummy(None) val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) }