diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 72fdcebb4cbc..09e2f01fb24b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -62,14 +62,11 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override lazy val dataType: StructType = { assert(resolved, s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } - } + + // Like the GenericUDFStruct of Hive, we will give the default names for each of its fields. + val fields = children.zipWithIndex.map { case (child, idx) => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } StructType(fields) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index dff0932c450a..a0a38e17ec67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -714,16 +714,14 @@ object functions { /** * Creates a new struct column. The input column must be a column in a [[DataFrame]], or - * a derived column expression that is named (i.e. aliased). + * a derived column expression. * * @group normal_funcs * @since 1.4.0 */ @scala.annotation.varargs def struct(cols: Column*): Column = { - require(cols.forall(_.expr.isInstanceOf[NamedExpression]), - s"struct input columns must all be named or aliased ($cols)") - CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression])) + CreateStruct(cols.map(_.expr)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index cfd23867a9bb..aad07dfe91e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -60,8 +60,8 @@ class DataFrameFunctionsSuite extends QueryTest { val row = df.select(struct("a", "b")).first() val expectedType = StructType(Seq( - StructField("a", IntegerType, nullable = false), - StructField("b", StringType) + StructField("col1", IntegerType, nullable = false), + StructField("col2", StringType) )) assert(row.schema(0).dataType === expectedType) assert(row.getAs[Row](0) === Row(1, "str")) @@ -72,19 +72,13 @@ class DataFrameFunctionsSuite extends QueryTest { val row = df.select(struct((col("a") * 2).as("c"), col("b"))).first() val expectedType = StructType(Seq( - StructField("c", IntegerType, nullable = false), - StructField("b", StringType) + StructField("col1", IntegerType, nullable = false), + StructField("col2", StringType) )) assert(row.schema(0).dataType === expectedType) assert(row.getAs[Row](0) === Row(2, "str")) } - test("struct: must use named column expression") { - intercept[IllegalArgumentException] { - struct(col("a") * 2) - } - } - test("constant functions") { checkAnswer( testData2.select(e()).limit(1),