Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
case e: Expression => transformExpressionDown(e)
case Some(e: Expression) => Some(transformExpressionDown(e))
case m: Map[_,_] => m
case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map {
case e: Expression => transformExpressionDown(e)
case other => other
Expand Down Expand Up @@ -114,6 +115,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
case e: Expression => transformExpressionUp(e)
case Some(e: Expression) => Some(transformExpressionUp(e))
case m: Map[_,_] => m
case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map {
case e: Expression => transformExpressionUp(e)
case other => other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.trees

import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.types.DataType

/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
private class MutableInt(var i: Int)
Expand Down Expand Up @@ -220,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
Some(arg)
}
case m: Map[_,_] => m
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
Expand Down Expand Up @@ -276,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
Some(arg)
}
case m: Map[_,_] => m
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
Expand Down Expand Up @@ -307,10 +310,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
* @param newArgs the new product arguments.
*/
def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
val defaultCtor =
getClass.getConstructors
.find(_.getParameterTypes.size != 0)
.headOption
.getOrElse(sys.error(s"No valid constructor for $nodeName"))

try {
CurrentOrigin.withOrigin(origin) {
// Skip no-arg constructors that are just there for kryo.
val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head
if (otherCopyArgs.isEmpty) {
defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
} else {
Expand All @@ -320,8 +328,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
} catch {
case e: java.lang.IllegalArgumentException =>
throw new TreeNodeException(
this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? "
+ s"Exception message: ${e.getMessage}.")
this,
s"""
|Failed to copy node.
|Is otherCopyArgs specified correctly for $nodeName.
|Exception message: ${e.getMessage}
|ctor: $defaultCtor?
|args: ${newArgs.mkString(", ")}
""".stripMargin)
}
}

Expand Down
6 changes: 6 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,10 @@ class UDFSuite extends QueryTest {
.select($"ret.f1").head().getString(0)
assert(result === "test")
}

test("udf that is transformed") {
udf.register("makeStruct", (x: Int, y: Int) => (x, y))
// 1 + 1 is constant folded causing a transformation.
assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
}
}