Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,6 @@ def struct(*cols):
"""Creates a new struct column.

:param cols: list of column names (string) or list of :class:`Column` expressions
that are named or aliased.

>>> df.select(struct('age', 'name').alias("struct")).collect()
[Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ object FunctionRegistry {
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[CreateStruct]("struct"),
expression[CreateNamedStruct]("named_struct"),
expression[Sqrt]("sqrt"),

// math functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Returns an Array containing the evaluation of all children expressions.
Expand Down Expand Up @@ -54,6 +57,8 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {

override def foldable: Boolean = children.forall(_.foldable)

override lazy val resolved: Boolean = childrenResolved

override lazy val dataType: StructType = {
val fields = children.zipWithIndex.map { case (child, idx) =>
child match {
Expand All @@ -74,3 +79,47 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {

override def prettyName: String = "struct"
}

/**
* Creates a struct with the given field names and values
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {

private lazy val (nameExprs, valExprs) =
children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip

private lazy val names = nameExprs.map(_.eval(EmptyRow).toString)

override lazy val dataType: StructType = {
val fields = names.zip(valExprs).map { case (name, valExpr) =>
StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
}
StructType(fields)
}

override def foldable: Boolean = valExprs.forall(_.foldable)

override def nullable: Boolean = false

override def checkInputDataTypes(): TypeCheckResult = {
if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure("CreateNamedStruct expects an even number of arguments.")
} else {
val invalidNames =
nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan, is this what you mean?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes.

if (invalidNames.size != 0) {
TypeCheckResult.TypeCheckFailure(
s"Odd position only allow foldable and not-null StringType expressions, got :" +
s" ${invalidNames.mkString(",")}")
} else {
TypeCheckResult.TypeCheckSuccess
}
}
}

override def eval(input: InternalRow): Any = {
InternalRow(valExprs.map(_.eval(input)): _*)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,15 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(Explode('intField),
"input to function explode should be array or map type")
}

test("check types for CreateNamedStruct") {
assertError(
CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments")
assertError(
CreateNamedStruct(Seq(1, "a", "b", 2.0)),
"Odd position only allow foldable and not-null StringType expressions")
assertError(
CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
"Odd position only allow foldable and not-null StringType expressions")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import org.scalatest.exceptions.TestFailedException

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.dsl.expressions._
Expand Down Expand Up @@ -119,11 +121,29 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {

test("CreateStruct") {
val row = create_row(1, 2, 3)
val c1 = 'a.int.at(0).as("a")
val c3 = 'c.int.at(2).as("c")
val c1 = 'a.int.at(0)
val c3 = 'c.int.at(2)
checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row)
}

test("CreateNamedStruct") {
val row = InternalRow(1, 2, 3)
val c1 = 'a.int.at(0)
val c3 = 'c.int.at(2)
checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row)
}

test("CreateNamedStruct with literal field") {
val row = InternalRow(1, 2, 3)
val c1 = 'a.int.at(0)
checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), InternalRow(1, "y"), row)
}

test("CreateNamedStruct from all literal fields") {
checkEvaluation(
CreateNamedStruct(Seq("a", "x", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty)
}

test("test dsl for complex type") {
def quickResolve(u: UnresolvedExtractValue): Expression = {
ExtractValue(u.child, u.extraction, _ == _)
Expand Down
11 changes: 6 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -739,17 +739,18 @@ object functions {
def sqrt(colName: String): Column = sqrt(Column(colName))

/**
* Creates a new struct column. The input column must be a column in a [[DataFrame]], or
* a derived column expression that is named (i.e. aliased).
* Creates a new struct column.
* If the input column is a column in a [[DataFrame]], or a derived column expression
* that is named (i.e. aliased), its name would be remained as the StructField's name,
* otherwise, the newly generated StructField's name would be auto generated as col${index + 1},
* i.e. col1, col2, col3, ...
*
* @group normal_funcs
* @since 1.4.0
*/
@scala.annotation.varargs
def struct(cols: Column*): Column = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation above needs to be updated and should specify what happens when the columns are unnamed.

require(cols.forall(_.expr.isInstanceOf[NamedExpression]),
s"struct input columns must all be named or aliased ($cols)")
CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression]))
CreateStruct(cols.map(_.expr))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,42 @@ class DataFrameFunctionsSuite extends QueryTest {
assert(row.getAs[Row](0) === Row(2, "str"))
}

test("struct: must use named column expression") {
intercept[IllegalArgumentException] {
struct(col("a") * 2)
}
test("struct with column expression to be automatically named") {
val df = Seq((1, "str")).toDF("a", "b")
val result = df.select(struct((col("a") * 2), col("b")))

val expectedType = StructType(Seq(
StructField("col1", IntegerType, nullable = false),
StructField("b", StringType)
))
assert(result.first.schema(0).dataType === expectedType)
checkAnswer(result, Row(Row(2, "str")))
}

test("struct with literal columns") {
val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b")
val result = df.select(struct((col("a") * 2), lit(5.0)))

val expectedType = StructType(Seq(
StructField("col1", IntegerType, nullable = false),
StructField("col2", DoubleType, nullable = false)
))

assert(result.first.schema(0).dataType === expectedType)
checkAnswer(result, Seq(Row(Row(2, 5.0)), Row(Row(4, 5.0))))
}

test("struct with all literal columns") {
val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b")
val result = df.select(struct(lit("v"), lit(5.0)))

val expectedType = StructType(Seq(
StructField("col1", StringType, nullable = false),
StructField("col2", DoubleType, nullable = false)
))

assert(result.first.schema(0).dataType === expectedType)
checkAnswer(result, Seq(Row(Row("v", 5.0)), Row(Row("v", 5.0))))
}

test("constant functions") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
lower("AA"), "10",
repeat(lower("AA"), 3), "11",
lower(repeat("AA", 3)), "12",
printf("Bb%d", 12), "13",
printf("bb%d", 12), "13",
repeat(printf("s%d", 14), 2), "14") FROM src LIMIT 1""")

createQueryTest("NaN to Decimal",
Expand Down