From 0acb7be0595d88a653af7fe7a9e05d48d8a3d254 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 18 Jun 2015 15:16:08 +0800 Subject: [PATCH 01/16] Add CreateNamedStruct in both DataFrame function API and FunctionRegistery --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/complexTypeCreator.scala | 43 ++++++++++++++ .../expressions/ComplexTypeSuite.scala | 54 ++++++++++++++++- .../org/apache/spark/sql/functions.scala | 15 +++++ .../spark/sql/DataFrameFunctionsSuite.scala | 58 +++++++++++++++++++ 5 files changed, 169 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6f04298d4711..65c372cb7df7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -96,6 +96,7 @@ object FunctionRegistry { expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), + expression[CreateNamedStruct]("named_struct"), expression[Sqrt]("sqrt"), // math functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67e7dc4ec8b1..ce59a6fdeb9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Returns an Array containing the evaluation of all children expressions. @@ -74,3 +76,44 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def prettyName: String = "struct" } + +/** + * Creates a struct with the given field names and values + * + * @param children Seq(name1, val1, name2, val2, ...) + */ +case class CreateNamedStruct(children: Seq[Expression]) extends Expression { + assert(children.size % 2 == 0, "NamedStruct expects an even number of arguments.") + + private val nameExprs = children.zipWithIndex.filter(_._2 % 2 == 0).map(_._1) + private val valExprs = children.zipWithIndex.filter(_._2 % 2 == 1).map(_._1) + + private lazy val names = nameExprs.map { case name => + name match { + case NonNullLiteral(str, StringType) => + str.asInstanceOf[UTF8String].toString + case _ => + throw new IllegalArgumentException("Expressions of odd index should be" + + s" Literal(_, StringType), get ${name.dataType} instead") + } + } + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + assert(resolved, + s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") + val fields = names.zip(valExprs).map { case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + InternalRow(valExprs.map(_.eval(input)): _*) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 3515d044b2f7..eb357fce9483 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -100,6 +100,38 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { assert(getStructField(nullStruct, "a").nullable === true) } + test("complex type") { + val row = create_row( + "^Ba*n", // 0 + null.asInstanceOf[UTF8String], // 1 + create_row("aa", "bb"), // 2 + Map("aa" -> "bb"), // 3 + Seq("aa", "bb") // 4 + ) + + val typeS = StructType( + StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil + ) + val typeMap = MapType(StringType, StringType) + val typeArray = ArrayType(StringType) + + checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), + Literal("aa")), "bb", row) + checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row) + checkEvaluation( + GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) + checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), + Literal.create(null, StringType)), null, row) + + checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), + Literal(1)), "bb", row) + checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row) + checkEvaluation( + GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) + checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), + Literal.create(null, IntegerType)), null, row) + } + test("GetArrayStructFields") { val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val arrayStruct = Literal.create(Seq(create_row(1)), typeAS) @@ -119,11 +151,29 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("CreateStruct") { val row = create_row(1, 2, 3) - val c1 = 'a.int.at(0).as("a") - val c3 = 'c.int.at(2).as("c") + val c1 = 'a.int.at(0) + val c3 = 'c.int.at(2) checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row) } + test("CreateNamedStruct") { + val row = InternalRow(1, 2, 3) + val c1 = 'a.int.at(0) + val c3 = 'c.int.at(2) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row) + } + + test("CreateNamedStruct with literal field") { + val row = InternalRow(1, 2, 3) + val c1 = 'a.int.at(0) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), InternalRow(1, "y"), row) + } + + test("CreateNamedStruct from all literal fields") { + checkEvaluation( + CreateNamedStruct(Seq("a", "x", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty) + } + test("test dsl for complex type") { def quickResolve(u: UnresolvedExtractValue): Expression = { ExtractValue(u.child, u.extraction, _ == _) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e6f623bdf39e..5e4c5c6cdfa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -762,6 +762,21 @@ object functions { struct((colName +: colNames).map(col) : _*) } + /** + * Creates a new struct column with given field names and columns. + * The input columns should be of length 2*n and follow (name1, col1, name2, col2), + * name* should be String Literal + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def named_struct(cols: Column*): Column = { + require(cols.length % 2 == 0, + s"named_struct expects an even number of arguments.") + CreateNamedStruct(cols.map(_.expr)) + } + /** * Converts a string expression to upper case. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7ae89bcb1b9c..5daca2d003eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -85,6 +85,64 @@ class DataFrameFunctionsSuite extends QueryTest { } } + test("named_struct with column expression") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val row = df.select( + named_struct(lit("x"), (col("a") * 2), lit("y"), col("b"))).take(2) + + val expectedType = StructType(Seq( + StructField("x", IntegerType, nullable = false), + StructField("y", StringType) + )) + + assert(row(0).schema(0).dataType === expectedType) + assert(row(0).getAs[Row](0) === Row(2, "str1")) + assert(row(1).getAs[Row](0) === Row(4, "str2")) + } + + test("named_struct with literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val row = df.select( + named_struct(lit("x"), (col("a") * 2), lit("y"), lit(5.0))).take(2) + + val expectedType = StructType(Seq( + StructField("x", IntegerType, nullable = false), + StructField("y", DoubleType, nullable = false) + )) + + assert(row(0).schema(0).dataType === expectedType) + assert(row(0).getAs[Row](0) === Row(2, 5.0)) + assert(row(1).getAs[Row](0) === Row(4, 5.0)) + } + + test("named_struct with all literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val row = df.select( + named_struct(lit("x"), lit("v"), lit("y"), lit(5.0))).take(2) + + val expectedType = StructType(Seq( + StructField("x", StringType, nullable = false), + StructField("y", DoubleType, nullable = false) + )) + + assert(row(0).schema(0).dataType === expectedType) + assert(row(0).getAs[Row](0) === Row("v", 5.0)) + assert(row(1).getAs[Row](0) === Row("v", 5.0)) + } + + test("named_struct with odd arguments") { + intercept[IllegalArgumentException] { + named_struct(col("x")) + } + } + + test("named_struct with non string literal names") { + val df = Seq((1, "str")).toDF("a", "b") + intercept[IllegalArgumentException] { + df.select(named_struct(lit(1), (col("a") * 2), lit("y"), lit(5.0))) + } + } + test("constant functions") { checkAnswer( ctx.sql("SELECT E()"), From 4bd75adbe5720aa6a980bf660b0ce1a8fee3965f Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 18 Jun 2015 17:13:00 +0800 Subject: [PATCH 02/16] loosen struct method in functions.scala to take Expression children --- .../org/apache/spark/sql/functions.scala | 4 +- .../spark/sql/DataFrameFunctionsSuite.scala | 45 +++++++++++++++++-- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5e4c5c6cdfa1..41a2c1e2609e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -747,9 +747,7 @@ object functions { */ @scala.annotation.varargs def struct(cols: Column*): Column = { - require(cols.forall(_.expr.isInstanceOf[NamedExpression]), - s"struct input columns must all be named or aliased ($cols)") - CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression])) + CreateStruct(cols.map(_.expr)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5daca2d003eb..b3a5dd44d9c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -79,12 +79,49 @@ class DataFrameFunctionsSuite extends QueryTest { assert(row.getAs[Row](0) === Row(2, "str")) } - test("struct: must use named column expression") { - intercept[IllegalArgumentException] { - struct(col("a") * 2) - } + test("struct with column expression to be automatically named") { + val df = Seq((1, "str")).toDF("a", "b") + val row = df.select(struct((col("a") * 2), col("b"))).first() + + val expectedType = StructType(Seq( + StructField("col1", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Row](0) === Row(2, "str")) + } + + test("struct with literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val row = df.select( + struct((col("a") * 2), lit(5.0))).take(2) + + val expectedType = StructType(Seq( + StructField("col1", IntegerType, nullable = false), + StructField("col2", DoubleType, nullable = false) + )) + + assert(row(0).schema(0).dataType === expectedType) + assert(row(0).getAs[Row](0) === Row(2, 5.0)) + assert(row(1).getAs[Row](0) === Row(4, 5.0)) + } + + test("struct with all literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val row = df.select( + struct(lit("v"), lit(5.0))).take(2) + + val expectedType = StructType(Seq( + StructField("col1", StringType, nullable = false), + StructField("col2", DoubleType, nullable = false) + )) + + assert(row(0).schema(0).dataType === expectedType) + assert(row(0).getAs[Row](0) === Row("v", 5.0)) + assert(row(1).getAs[Row](0) === Row("v", 5.0)) } + test("named_struct with column expression") { val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") val row = df.select( From 917e680b48cadd39674f12806746a60121cd6961 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 19 Jun 2015 12:34:33 +0800 Subject: [PATCH 03/16] Fix reviews --- .../expressions/complexTypeCreator.scala | 2 -- .../org/apache/spark/sql/functions.scala | 4 ++-- .../spark/sql/DataFrameFunctionsSuite.scala | 20 +++++++++---------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index ce59a6fdeb9d..5bf19da059d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -100,8 +100,6 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override lazy val resolved: Boolean = childrenResolved - override lazy val dataType: StructType = { assert(resolved, s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 41a2c1e2609e..3495d23950b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -769,9 +769,9 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def named_struct(cols: Column*): Column = { + def namedStruct(cols: Column*): Column = { require(cols.length % 2 == 0, - s"named_struct expects an even number of arguments.") + s"namedStruct expects an even number of arguments.") CreateNamedStruct(cols.map(_.expr)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b3a5dd44d9c6..2607c87a0472 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -122,10 +122,10 @@ class DataFrameFunctionsSuite extends QueryTest { } - test("named_struct with column expression") { + test("namedStruct with column expression") { val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") val row = df.select( - named_struct(lit("x"), (col("a") * 2), lit("y"), col("b"))).take(2) + namedStruct(lit("x"), (col("a") * 2), lit("y"), col("b"))).take(2) val expectedType = StructType(Seq( StructField("x", IntegerType, nullable = false), @@ -137,10 +137,10 @@ class DataFrameFunctionsSuite extends QueryTest { assert(row(1).getAs[Row](0) === Row(4, "str2")) } - test("named_struct with literal columns") { + test("namedStruct with literal columns") { val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") val row = df.select( - named_struct(lit("x"), (col("a") * 2), lit("y"), lit(5.0))).take(2) + namedStruct(lit("x"), (col("a") * 2), lit("y"), lit(5.0))).take(2) val expectedType = StructType(Seq( StructField("x", IntegerType, nullable = false), @@ -152,10 +152,10 @@ class DataFrameFunctionsSuite extends QueryTest { assert(row(1).getAs[Row](0) === Row(4, 5.0)) } - test("named_struct with all literal columns") { + test("namedStruct with all literal columns") { val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") val row = df.select( - named_struct(lit("x"), lit("v"), lit("y"), lit(5.0))).take(2) + namedStruct(lit("x"), lit("v"), lit("y"), lit(5.0))).take(2) val expectedType = StructType(Seq( StructField("x", StringType, nullable = false), @@ -167,16 +167,16 @@ class DataFrameFunctionsSuite extends QueryTest { assert(row(1).getAs[Row](0) === Row("v", 5.0)) } - test("named_struct with odd arguments") { + test("namedStruct with odd arguments") { intercept[IllegalArgumentException] { - named_struct(col("x")) + namedStruct(col("x")) } } - test("named_struct with non string literal names") { + test("namedStruct with non string literal names") { val df = Seq((1, "str")).toDF("a", "b") intercept[IllegalArgumentException] { - df.select(named_struct(lit(1), (col("a") * 2), lit("y"), lit(5.0))) + df.select(namedStruct(lit(1), (col("a") * 2), lit("y"), lit(5.0))) } } From 47da3323f2b4e87784e587b600efc7507e12d3d4 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 19 Jun 2015 13:56:14 +0800 Subject: [PATCH 04/16] remove nameStruct API from DataFrame --- .../org/apache/spark/sql/functions.scala | 15 ----- .../spark/sql/DataFrameFunctionsSuite.scala | 59 ------------------- 2 files changed, 74 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3495d23950b7..11e960d241cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -760,21 +760,6 @@ object functions { struct((colName +: colNames).map(col) : _*) } - /** - * Creates a new struct column with given field names and columns. - * The input columns should be of length 2*n and follow (name1, col1, name2, col2), - * name* should be String Literal - * - * @group normal_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def namedStruct(cols: Column*): Column = { - require(cols.length % 2 == 0, - s"namedStruct expects an even number of arguments.") - CreateNamedStruct(cols.map(_.expr)) - } - /** * Converts a string expression to upper case. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 2607c87a0472..116b7af09503 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -121,65 +121,6 @@ class DataFrameFunctionsSuite extends QueryTest { assert(row(1).getAs[Row](0) === Row("v", 5.0)) } - - test("namedStruct with column expression") { - val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") - val row = df.select( - namedStruct(lit("x"), (col("a") * 2), lit("y"), col("b"))).take(2) - - val expectedType = StructType(Seq( - StructField("x", IntegerType, nullable = false), - StructField("y", StringType) - )) - - assert(row(0).schema(0).dataType === expectedType) - assert(row(0).getAs[Row](0) === Row(2, "str1")) - assert(row(1).getAs[Row](0) === Row(4, "str2")) - } - - test("namedStruct with literal columns") { - val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") - val row = df.select( - namedStruct(lit("x"), (col("a") * 2), lit("y"), lit(5.0))).take(2) - - val expectedType = StructType(Seq( - StructField("x", IntegerType, nullable = false), - StructField("y", DoubleType, nullable = false) - )) - - assert(row(0).schema(0).dataType === expectedType) - assert(row(0).getAs[Row](0) === Row(2, 5.0)) - assert(row(1).getAs[Row](0) === Row(4, 5.0)) - } - - test("namedStruct with all literal columns") { - val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") - val row = df.select( - namedStruct(lit("x"), lit("v"), lit("y"), lit(5.0))).take(2) - - val expectedType = StructType(Seq( - StructField("x", StringType, nullable = false), - StructField("y", DoubleType, nullable = false) - )) - - assert(row(0).schema(0).dataType === expectedType) - assert(row(0).getAs[Row](0) === Row("v", 5.0)) - assert(row(1).getAs[Row](0) === Row("v", 5.0)) - } - - test("namedStruct with odd arguments") { - intercept[IllegalArgumentException] { - namedStruct(col("x")) - } - } - - test("namedStruct with non string literal names") { - val df = Seq((1, "str")).toDF("a", "b") - intercept[IllegalArgumentException] { - df.select(namedStruct(lit(1), (col("a") * 2), lit("y"), lit(5.0))) - } - } - test("constant functions") { checkAnswer( ctx.sql("SELECT E()"), From ccbbd86ec350917bc990b99c22636ca247fcad0e Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 19 Jun 2015 17:47:55 +0800 Subject: [PATCH 05/16] Fix reviews --- .../expressions/complexTypeCreator.scala | 56 +++++++++++++------ .../expressions/ComplexTypeSuite.scala | 18 ++++++ 2 files changed, 56 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 5bf19da059d8..be54265ca304 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String /** * Returns an Array containing the evaluation of all children expressions. @@ -56,6 +56,8 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) + override lazy val resolved: Boolean = childrenResolved + override lazy val dataType: StructType = { val fields = children.zipWithIndex.map { case (child, idx) => child match { @@ -83,35 +85,53 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { * @param children Seq(name1, val1, name2, val2, ...) */ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { - assert(children.size % 2 == 0, "NamedStruct expects an even number of arguments.") - - private val nameExprs = children.zipWithIndex.filter(_._2 % 2 == 0).map(_._1) - private val valExprs = children.zipWithIndex.filter(_._2 % 2 == 1).map(_._1) - - private lazy val names = nameExprs.map { case name => - name match { - case NonNullLiteral(str, StringType) => - str.asInstanceOf[UTF8String].toString - case _ => - throw new IllegalArgumentException("Expressions of odd index should be" + - s" Literal(_, StringType), get ${name.dataType} instead") - } - } - override def foldable: Boolean = children.forall(_.foldable) + private lazy val (nameExprs, valExprs) = + children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + + private lazy val names = nameExprs.map(_.asInstanceOf[Literal].value.toString) override lazy val dataType: StructType = { - assert(resolved, - s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") + require(resolved, resolveFailureMessage) val fields = names.zip(valExprs).map { case (name, valExpr) => StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) } StructType(fields) } + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = false + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size % 2 != 0) { + TypeCheckResult.TypeCheckFailure("CreateNamedStruct expects an even number of arguments.") + } else { + val invalidNames = nameExprs.filterNot { case name => + name match { + case NonNullLiteral(str, StringType) => true + case _ => false + } + } + if (invalidNames.size != 0) { + TypeCheckResult.TypeCheckFailure( + s"Non String Literal fields at odd position $invalidNames") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + } + override def eval(input: InternalRow): Any = { + require(resolved, resolveFailureMessage) InternalRow(valExprs.map(_.eval(input)): _*) } + + private def resolveFailureMessage(): String = { + if (!childrenResolved) { + s"CreateNamedStruct contains unresolvable children: ${children.filterNot(_.resolved)}." + } else { + checkInputDataTypes().asInstanceOf[TypeCheckFailure].message + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index eb357fce9483..aa65be5a3953 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import org.scalatest.exceptions.TestFailedException + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -100,6 +102,22 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { assert(getStructField(nullStruct, "a").nullable === true) } + test("CreateNamedStruct with odd number of parameters") { + val thrown = intercept[TestFailedException] { + checkEvaluation( + CreateNamedStruct(Seq("a", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty) + } + assert(thrown.getCause.getMessage.contains("even number of arguments")) + } + + test("CreateNamedStruct with non String Literal name") { + val thrown = intercept[TestFailedException] { + checkEvaluation( + CreateNamedStruct(Seq(1, "a", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty) + } + assert(thrown.getCause.getMessage.contains("Non String Literal fields")) + } + test("complex type") { val row = create_row( "^Ba*n", // 0 From 7a7125570cd4e8fe4a486bd4921e6be48993fbb0 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 19 Jun 2015 17:51:26 +0800 Subject: [PATCH 06/16] tiny fix --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index be54265ca304..cd4bf0b8fe2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -115,7 +115,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { } if (invalidNames.size != 0) { TypeCheckResult.TypeCheckFailure( - s"Non String Literal fields at odd position $invalidNames") + s"Non String Literal fields at odd position : ${invalidNames.mkString(",")}") } else { TypeCheckResult.TypeCheckSuccess } From fd3cd8ed9afeb99a3e946ae5f82566530e06d99b Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sun, 21 Jun 2015 00:37:29 +0800 Subject: [PATCH 07/16] remove type check from eval --- .../expressions/complexTypeCreator.scala | 3 +-- .../analysis/ExpressionTypeCheckingSuite.scala | 7 +++++++ .../catalyst/expressions/ComplexTypeSuite.scala | 16 ---------------- 3 files changed, 8 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index cd4bf0b8fe2d..8134b71b1050 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -92,7 +92,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val names = nameExprs.map(_.asInstanceOf[Literal].value.toString) override lazy val dataType: StructType = { - require(resolved, resolveFailureMessage) + assert(resolved, resolveFailureMessage) val fields = names.zip(valExprs).map { case (name, valExpr) => StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) } @@ -123,7 +123,6 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { } override def eval(input: InternalRow): Any = { - require(resolved, resolveFailureMessage) InternalRow(valExprs.map(_.eval(input)): _*) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index bc1537b0715b..bacbefb3a2a4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -160,4 +160,11 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Explode('intField), "input to function explode should be array or map type") } + + test("check types for CreateNamedStruct") { + assertError( + CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") + assertError( + CreateNamedStruct(Seq(1, "a", "b", 2.0)), "Non String Literal fields") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index aa65be5a3953..57b8d6b569a4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -102,22 +102,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { assert(getStructField(nullStruct, "a").nullable === true) } - test("CreateNamedStruct with odd number of parameters") { - val thrown = intercept[TestFailedException] { - checkEvaluation( - CreateNamedStruct(Seq("a", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty) - } - assert(thrown.getCause.getMessage.contains("even number of arguments")) - } - - test("CreateNamedStruct with non String Literal name") { - val thrown = intercept[TestFailedException] { - checkEvaluation( - CreateNamedStruct(Seq(1, "a", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty) - } - assert(thrown.getCause.getMessage.contains("Non String Literal fields")) - } - test("complex type") { val row = create_row( "^Ba*n", // 0 From 828d69491b7f08b0b797eaae2e23169c5f3be9fc Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sun, 21 Jun 2015 00:56:46 +0800 Subject: [PATCH 08/16] remove unnecessary resolved assertion inside dataType method --- .../sql/catalyst/expressions/complexTypeCreator.scala | 9 --------- 1 file changed, 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 8134b71b1050..d44299957d88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -92,7 +92,6 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val names = nameExprs.map(_.asInstanceOf[Literal].value.toString) override lazy val dataType: StructType = { - assert(resolved, resolveFailureMessage) val fields = names.zip(valExprs).map { case (name, valExpr) => StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) } @@ -125,12 +124,4 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { override def eval(input: InternalRow): Any = { InternalRow(valExprs.map(_.eval(input)): _*) } - - private def resolveFailureMessage(): String = { - if (!childrenResolved) { - s"CreateNamedStruct contains unresolvable children: ${children.filterNot(_.resolved)}." - } else { - checkInputDataTypes().asInstanceOf[TypeCheckFailure].message - } - } } From 60812a7719f3f4067f383617762017e7c82e02ab Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 24 Jun 2015 10:49:41 +0800 Subject: [PATCH 09/16] Fix type check --- .../sql/catalyst/expressions/complexTypeCreator.scala | 11 ++++++----- .../analysis/ExpressionTypeCheckingSuite.scala | 2 ++ ...for generic udf-0-cc120a2331158f570a073599985d3f55 | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index d44299957d88..f87ad9eb9dbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Returns an Array containing the evaluation of all children expressions. @@ -89,7 +90,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip - private lazy val names = nameExprs.map(_.asInstanceOf[Literal].value.toString) + private lazy val names = nameExprs.map(_.eval(EmptyRow).asInstanceOf[UTF8String].toString) override lazy val dataType: StructType = { val fields = names.zip(valExprs).map { case (name, valExpr) => @@ -106,11 +107,11 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure("CreateNamedStruct expects an even number of arguments.") } else { - val invalidNames = nameExprs.filterNot { case name => - name match { - case NonNullLiteral(str, StringType) => true + val invalidNames = nameExprs.filter { case name => + (name.find { + case _: Attribute | _: BoundReference => true; case _ => false - } + } != None) || (name.dataType != StringType) } if (invalidNames.size != 0) { TypeCheckResult.TypeCheckFailure( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index bacbefb3a2a4..17acd96bd35b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -166,5 +166,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( CreateNamedStruct(Seq(1, "a", "b", 2.0)), "Non String Literal fields") + assertError( + CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), "Non String Literal fields") } } diff --git a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 index 7bc77e7f2a4d..6903b102b332 100644 --- a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 +++ b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 @@ -1 +1 @@ -{"aa":"10","aaaaaa":"11","aaaaaa":"12","bb12":"13","s14s14":"14"} +{"aa":"10","aaaaaa":"11","aaaaaa":"12","Bb12":"13","s14s14":"14"} From 7fef712b0181e85745f807d21f527a97271a9629 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 24 Jun 2015 14:47:34 +0800 Subject: [PATCH 10/16] Fix checkInputTypes' implementation using foldable and nullable --- .../sql/catalyst/expressions/complexTypeCreator.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index f87ad9eb9dbf..67e649a8a62e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -107,12 +107,8 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure("CreateNamedStruct expects an even number of arguments.") } else { - val invalidNames = nameExprs.filter { case name => - (name.find { - case _: Attribute | _: BoundReference => true; - case _ => false - } != None) || (name.dataType != StringType) - } + val invalidNames = + nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable) if (invalidNames.size != 0) { TypeCheckResult.TypeCheckFailure( s"Non String Literal fields at odd position : ${invalidNames.mkString(",")}") From 9613be91b42f70688fc662884ca5bb5d905391a7 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 24 Jun 2015 19:17:11 +0800 Subject: [PATCH 11/16] review fix --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 5 +++-- .../sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67e649a8a62e..6aaafba005b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -99,7 +99,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { StructType(fields) } - override def foldable: Boolean = children.forall(_.foldable) + override def foldable: Boolean = valExprs.forall(_.foldable) override def nullable: Boolean = false @@ -111,7 +111,8 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable) if (invalidNames.size != 0) { TypeCheckResult.TypeCheckFailure( - s"Non String Literal fields at odd position : ${invalidNames.mkString(",")}") + s"Odd position only allow foldable and not-null StringType expressions, got :" + + s" ${invalidNames.mkString(",")}") } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 17acd96bd35b..8e0551b23eea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -165,8 +165,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError( CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( - CreateNamedStruct(Seq(1, "a", "b", 2.0)), "Non String Literal fields") + CreateNamedStruct(Seq(1, "a", "b", 2.0)), + "Odd position only allow foldable and not-null StringType expressions") assertError( - CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), "Non String Literal fields") + CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), + "Odd position only allow foldable and not-null StringType expressions") } } From f07e114575563e64b85ab80adaa39fc0979028f4 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 24 Jun 2015 19:23:15 +0800 Subject: [PATCH 12/16] tiny fix --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 6aaafba005b7..fa70409353e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -90,7 +90,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip - private lazy val names = nameExprs.map(_.eval(EmptyRow).asInstanceOf[UTF8String].toString) + private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) override lazy val dataType: StructType = { val fields = names.zip(valExprs).map { case (name, valExpr) => From b48735444d819624b9f5154aed3b7c75c2fe8b40 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 25 Jun 2015 13:10:41 +0800 Subject: [PATCH 13/16] replace assert using checkAnswer --- .../spark/sql/DataFrameFunctionsSuite.scala | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 116b7af09503..0d43aca877f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -81,44 +81,40 @@ class DataFrameFunctionsSuite extends QueryTest { test("struct with column expression to be automatically named") { val df = Seq((1, "str")).toDF("a", "b") - val row = df.select(struct((col("a") * 2), col("b"))).first() + val result = df.select(struct((col("a") * 2), col("b"))) val expectedType = StructType(Seq( StructField("col1", IntegerType, nullable = false), StructField("b", StringType) )) - assert(row.schema(0).dataType === expectedType) - assert(row.getAs[Row](0) === Row(2, "str")) + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Row(Row(2, "str"))) } test("struct with literal columns") { val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") - val row = df.select( - struct((col("a") * 2), lit(5.0))).take(2) + val result = df.select(struct((col("a") * 2), lit(5.0))) val expectedType = StructType(Seq( StructField("col1", IntegerType, nullable = false), StructField("col2", DoubleType, nullable = false) )) - assert(row(0).schema(0).dataType === expectedType) - assert(row(0).getAs[Row](0) === Row(2, 5.0)) - assert(row(1).getAs[Row](0) === Row(4, 5.0)) + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Seq(Row(Row(2, 5.0)), Row(Row(4, 5.0)))) } test("struct with all literal columns") { val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") - val row = df.select( - struct(lit("v"), lit(5.0))).take(2) + val result = df.select(struct(lit("v"), lit(5.0))) val expectedType = StructType(Seq( StructField("col1", StringType, nullable = false), StructField("col2", DoubleType, nullable = false) )) - assert(row(0).schema(0).dataType === expectedType) - assert(row(0).getAs[Row](0) === Row("v", 5.0)) - assert(row(1).getAs[Row](0) === Row("v", 5.0)) + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Seq(Row(Row("v", 5.0)), Row(Row("v", 5.0)))) } test("constant functions") { From 9a7039e3253f688b8c921afb16de33037f42b942 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 26 Jun 2015 17:05:19 +0800 Subject: [PATCH 14/16] fix reviews and regenerate golden answers --- ...inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad | 1 + ...inspector for generic udf-0-cc120a2331158f570a073599985d3f55 | 1 - .../org/apache/spark/sql/hive/execution/HiveQuerySuite.scala | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad delete mode 100644 sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 diff --git a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad new file mode 100644 index 000000000000..7bc77e7f2a4d --- /dev/null +++ b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad @@ -0,0 +1 @@ +{"aa":"10","aaaaaa":"11","aaaaaa":"12","bb12":"13","s14s14":"14"} diff --git a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 deleted file mode 100644 index 6903b102b332..000000000000 --- a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 +++ /dev/null @@ -1 +0,0 @@ -{"aa":"10","aaaaaa":"11","aaaaaa":"12","Bb12":"13","s14s14":"14"} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4cdba03b2702..991da2f829ae 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -132,7 +132,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { lower("AA"), "10", repeat(lower("AA"), 3), "11", lower(repeat("AA", 3)), "12", - printf("Bb%d", 12), "13", + printf("bb%d", 12), "13", repeat(printf("s%d", 14), 2), "14") FROM src LIMIT 1""") createQueryTest("NaN to Decimal", From d599d0b94c03a72583b3425f0676269e11bc3188 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 2 Jul 2015 21:39:03 +0800 Subject: [PATCH 15/16] rebase code --- .../expressions/ComplexTypeSuite.scala | 32 ------------------- 1 file changed, 32 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 57b8d6b569a4..a09014e1ffc1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -102,38 +102,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { assert(getStructField(nullStruct, "a").nullable === true) } - test("complex type") { - val row = create_row( - "^Ba*n", // 0 - null.asInstanceOf[UTF8String], // 1 - create_row("aa", "bb"), // 2 - Map("aa" -> "bb"), // 3 - Seq("aa", "bb") // 4 - ) - - val typeS = StructType( - StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil - ) - val typeMap = MapType(StringType, StringType) - val typeArray = ArrayType(StringType) - - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal("aa")), "bb", row) - checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row) - checkEvaluation( - GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal.create(null, StringType)), null, row) - - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal(1)), "bb", row) - checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row) - checkEvaluation( - GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal.create(null, IntegerType)), null, row) - } - test("GetArrayStructFields") { val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val arrayStruct = Literal.create(Seq(create_row(1)), typeAS) From 4cd3375ac6da46ae52b21a9b63340cc88b09a3f2 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 2 Jul 2015 22:14:07 +0800 Subject: [PATCH 16/16] change struct documentation --- python/pyspark/sql/functions.py | 1 - .../src/main/scala/org/apache/spark/sql/functions.scala | 7 +++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f9a15d4a6630..074e25a78a13 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -443,7 +443,6 @@ def struct(*cols): """Creates a new struct column. :param cols: list of column names (string) or list of :class:`Column` expressions - that are named or aliased. >>> df.select(struct('age', 'name').alias("struct")).collect() [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 11e960d241cc..ef0205170b83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -739,8 +739,11 @@ object functions { def sqrt(colName: String): Column = sqrt(Column(colName)) /** - * Creates a new struct column. The input column must be a column in a [[DataFrame]], or - * a derived column expression that is named (i.e. aliased). + * Creates a new struct column. + * If the input column is a column in a [[DataFrame]], or a derived column expression + * that is named (i.e. aliased), its name would be remained as the StructField's name, + * otherwise, the newly generated StructField's name would be auto generated as col${index + 1}, + * i.e. col1, col2, col3, ... * * @group normal_funcs * @since 1.4.0