diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 5212ef3930bc9..33eefd5d9fbc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -539,3 +539,61 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E override def prettyName: String = "str_to_map" } + +/** + * Adds/replaces field in struct by name. + */ +case class WithFields( + structExpr: Expression, + names: Seq[String], + valExprs: Seq[Expression]) extends Unevaluable { + + assert(names.length == valExprs.length) + + override def checkInputDataTypes(): TypeCheckResult = { + if (!structExpr.dataType.isInstanceOf[StructType]) { + TypeCheckResult.TypeCheckFailure( + "struct argument should be struct type, got: " + structExpr.dataType.catalogString) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def children: Seq[Expression] = structExpr +: valExprs + + override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType] + + override def foldable: Boolean = structExpr.foldable && valExprs.forall(_.foldable) + + override def nullable: Boolean = structExpr.nullable + + override def prettyName: String = "with_fields" + + lazy val evalExpr: Expression = { + val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { + case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression]) + } + + val addOrReplaceExprs = names.zip(valExprs) + + val resolver = SQLConf.get.resolver + val newExprs = addOrReplaceExprs.foldLeft(existingExprs) { + case (resultExprs, newExpr @ (newExprName, _)) => + if (resultExprs.exists(x => resolver(x._1, newExprName))) { + resultExprs.map { + case (name, _) if resolver(name, newExprName) => newExpr + case x => x + } + } else { + resultExprs :+ newExpr + } + }.flatMap { case (name, expr) => Seq(Literal(name), expr) } + + val expr = CreateNamedStruct(newExprs) + if (structExpr.nullable) { + If(IsNull(structExpr), Literal(null, expr.dataType), expr) + } else { + expr + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index f79dabf758c14..1c33a2c7c3136 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -39,7 +39,18 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // Remove redundant field extraction. case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => createNamedStruct.valExprs(ordinal) - + case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) => + val name = w.dataType(ordinal).name + val matches = names.zip(valExprs).filter(_._1 == name) + if (matches.nonEmpty) { + // return last matching element as that is the final value for the field being extracted. + // For example, if a user submits a query like this: + // `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")` + // we want to return `lit(2)` (and not `lit(1)`). + matches.last._2 + } else { + GetStructField(struct, ordinal, maybeName) + } // Remove redundant array indexing. case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) => // Instead of selecting the field on the entire array, select it from each member diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f1a307b1c2cc1..36b49f5ee7866 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -106,6 +106,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateSerialization, RemoveRedundantAliases, RemoveNoopOperators, + CombineWithFields, SimplifyExtractValueOps, CombineConcats) ++ extendedOperatorOptimizationRules @@ -202,7 +203,8 @@ abstract class Optimizer(catalogManager: CatalogManager) CollapseProject, RemoveNoopOperators) :+ // This batch must be executed after the `RewriteSubquery` batch, which creates joins. - Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) + Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ + Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression) // remove any batches with no rules. this may happen when subclasses do not add optional rules. batches.filter(_.rules.nonEmpty) @@ -235,7 +237,8 @@ abstract class Optimizer(catalogManager: CatalogManager) PullupCorrelatedPredicates.ruleName :: RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: - NormalizeFloatingNumbers.ruleName :: Nil + NormalizeFloatingNumbers.ruleName :: + ReplaceWithFieldsExpression.ruleName :: Nil /** * Optimize all the subqueries inside expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala new file mode 100644 index 0000000000000..05c90864e4bb0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.WithFields +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + + +/** + * Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression. + */ +object CombineWithFields extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => + WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) + } +} + +/** + * Replaces [[WithFields]] expression with an evaluable expression. + */ +object ReplaceWithFieldsExpression extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case w: WithFields => w.evalExpr + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala new file mode 100644 index 0000000000000..a3e0bbc57e639 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + + +class CombineWithFieldsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil + } + + private val testRelation = LocalRelation('a.struct('a1.int)) + + test("combines two WithFields") { + val originalQuery = testRelation + .select(Alias( + WithFields( + WithFields( + 'a, + Seq("b1"), + Seq(Literal(4))), + Seq("c1"), + Seq(Literal(5))), "out")()) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")()) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("combines three WithFields") { + val originalQuery = testRelation + .select(Alias( + WithFields( + WithFields( + WithFields( + 'a, + Seq("b1"), + Seq(Literal(4))), + Seq("c1"), + Seq(Literal(5))), + Seq("d1"), + Seq(Literal(6))), "out")()) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")()) + .analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index d55746002783a..c71e7dbe7d6f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -452,4 +452,61 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) } + + private val structAttr = 'struct1.struct('a.int) + private val testStructRelation = LocalRelation(structAttr) + + test("simplify GetStructField on WithFields that is not changing the attribute being extracted") { + val query = testStructRelation.select( + GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 0, Some("a")) as "outerAtt") + val expected = testStructRelation.select(GetStructField('struct1, 0, Some("a")) as "outerAtt") + checkRule(query, expected) + } + + test("simplify GetStructField on WithFields that is changing the attribute being extracted") { + val query = testStructRelation.select( + GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 1, Some("b")) as "outerAtt") + val expected = testStructRelation.select(Literal(1) as "outerAtt") + checkRule(query, expected) + } + + test( + "simplify GetStructField on WithFields that is changing the attribute being extracted twice") { + val query = testStructRelation + .select(GetStructField(WithFields('struct1, Seq("b", "b"), Seq(Literal(1), Literal(2))), 1, + Some("b")) as "outerAtt") + val expected = testStructRelation.select(Literal(2) as "outerAtt") + checkRule(query, expected) + } + + test("collapse multiple GetStructField on the same WithFields") { + val query = testStructRelation + .select(WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2") + .select( + GetStructField('struct2, 0, Some("a")) as "struct1A", + GetStructField('struct2, 1, Some("b")) as "struct1B") + val expected = testStructRelation.select( + GetStructField('struct1, 0, Some("a")) as "struct1A", + Literal(2) as "struct1B") + checkRule(query, expected) + } + + test("collapse multiple GetStructField on different WithFields") { + val query = testStructRelation + .select( + WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2", + WithFields('struct1, Seq("b"), Seq(Literal(3))) as "struct3") + .select( + GetStructField('struct2, 0, Some("a")) as "struct2A", + GetStructField('struct2, 1, Some("b")) as "struct2B", + GetStructField('struct3, 0, Some("a")) as "struct3A", + GetStructField('struct3, 1, Some("b")) as "struct3B") + val expected = testStructRelation + .select( + GetStructField('struct1, 0, Some("a")) as "struct2A", + Literal(2) as "struct2B", + GetStructField('struct1, 0, Some("a")) as "struct3A", + Literal(3) as "struct3B") + checkRule(query, expected) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 2144472937f9b..516ef40e2baa8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -871,6 +871,72 @@ class Column(val expr: Expression) extends Logging { */ def getItem(key: Any): Column = withExpr { UnresolvedExtractValue(expr, Literal(key)) } + // scalastyle:off line.size.limit + /** + * An expression that adds/replaces field in `StructType` by name. + * + * {{{ + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".withField("c", lit(3))) + * // result: {"a":1,"b":2,"c":3} + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".withField("b", lit(3))) + * // result: {"a":1,"b":3} + * + * val df = sql("SELECT CAST(NULL AS struct) struct_col") + * df.select($"struct_col".withField("c", lit(3))) + * // result: null of type struct + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col") + * df.select($"struct_col".withField("b", lit(100))) + * // result: {"a":1,"b":100,"b":100} + * + * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + * df.select($"struct_col".withField("a.c", lit(3))) + * // result: {"a":{"a":1,"b":2,"c":3}} + * + * val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col") + * df.select($"struct_col".withField("a.c", lit(3))) + * // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields + * }}} + * + * @group expr_ops + * @since 3.1.0 + */ + // scalastyle:on line.size.limit + def withField(fieldName: String, col: Column): Column = withExpr { + require(fieldName != null, "fieldName cannot be null") + require(col != null, "col cannot be null") + + val nameParts = if (fieldName.isEmpty) { + fieldName :: Nil + } else { + CatalystSqlParser.parseMultipartIdentifier(fieldName) + } + withFieldHelper(expr, nameParts, Nil, col.expr) + } + + private def withFieldHelper( + struct: Expression, + namePartsRemaining: Seq[String], + namePartsDone: Seq[String], + value: Expression) : WithFields = { + val name = namePartsRemaining.head + if (namePartsRemaining.length == 1) { + WithFields(struct, name :: Nil, value :: Nil) + } else { + val newNamesRemaining = namePartsRemaining.tail + val newNamesDone = namePartsDone :+ name + val newValue = withFieldHelper( + struct = UnresolvedExtractValue(struct, Literal(name)), + namePartsRemaining = newNamesRemaining, + namePartsDone = newNamesDone, + value = value) + WithFields(struct, name :: Nil, newValue :: Nil) + } + } + /** * An expression that gets a field by name in a `StructType`. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index fa06484a73d95..131ab1b94f59e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -923,4 +923,503 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString)) assert(inSet.sql === "('a' IN ('a', 'b'))") } + + def checkAnswerAndSchema( + df: => DataFrame, + expectedAnswer: Seq[Row], + expectedSchema: StructType): Unit = { + + checkAnswer(df, expectedAnswer) + assert(df.schema == expectedSchema) + } + + private lazy val structType = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false))) + + private lazy val structLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, null, 3)) :: Nil), + StructType(Seq(StructField("a", structType, nullable = false)))) + + private lazy val nullStructLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Nil), + StructType(Seq(StructField("a", structType, nullable = true)))) + + private lazy val structLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false))), + nullable = false)))) + + private lazy val nullStructLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(null)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = true))), + nullable = false)))) + + private lazy val structLevel3: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(Row(1, null, 3)))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false))), + nullable = false))), + nullable = false)))) + + test("withField should throw an exception if called on a non-StructType column") { + intercept[AnalysisException] { + testData.withColumn("key", $"key".withField("a", lit(2))) + }.getMessage should include("struct argument should be struct type, got: int") + } + + test("withField should throw an exception if either fieldName or col argument are null") { + intercept[IllegalArgumentException] { + structLevel1.withColumn("a", $"a".withField(null, lit(2))) + }.getMessage should include("fieldName cannot be null") + + intercept[IllegalArgumentException] { + structLevel1.withColumn("a", $"a".withField("b", null)) + }.getMessage should include("col cannot be null") + + intercept[IllegalArgumentException] { + structLevel1.withColumn("a", $"a".withField(null, null)) + }.getMessage should include("fieldName cannot be null") + } + + test("withField should throw an exception if any intermediate structs don't exist") { + intercept[AnalysisException] { + structLevel2.withColumn("a", 'a.withField("x.b", lit(2))) + }.getMessage should include("No such struct field x in a") + + intercept[AnalysisException] { + structLevel3.withColumn("a", 'a.withField("a.x.b", lit(2))) + }.getMessage should include("No such struct field x in a") + } + + test("withField should throw an exception if intermediate field is not a struct") { + intercept[AnalysisException] { + structLevel1.withColumn("a", 'a.withField("b.a", lit(2))) + }.getMessage should include("struct argument should be struct type, got: int") + } + + test("withField should throw an exception if intermediate field reference is ambiguous") { + intercept[AnalysisException] { + val structLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false), + StructField("a", structType, nullable = false))), + nullable = false)))) + + structLevel2.withColumn("a", 'a.withField("a.b", lit(2))) + }.getMessage should include("Ambiguous reference to fields") + } + + test("withField should add field with no name") { + checkAnswerAndSchema( + structLevel1.withColumn("a", $"a".withField("", lit(4))), + Row(Row(1, null, 3, 4)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should add field to struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("d", lit(4))), + Row(Row(1, null, 3, 4)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should add field to null struct") { + checkAnswerAndSchema( + nullStructLevel1.withColumn("a", $"a".withField("d", lit(4))), + Row(null) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = true)))) + } + + test("withField should add field to nested null struct") { + checkAnswerAndSchema( + nullStructLevel2.withColumn("a", $"a".withField("a.d", lit(4))), + Row(Row(null)) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = true))), + nullable = false)))) + } + + test("withField should add null field to struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("d", lit(null).cast(IntegerType))), + Row(Row(1, null, 3, null)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = true))), + nullable = false)))) + } + + test("withField should add multiple fields to struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))), + Row(Row(1, null, 3, 4, 5)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false), + StructField("e", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should add field to nested struct") { + Seq( + structLevel2.withColumn("a", 'a.withField("a.d", lit(4))), + structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("d", lit(4)))) + ).foreach { df => + checkAnswerAndSchema( + df, + Row(Row(Row(1, null, 3, 4))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("withField should add field to deeply nested struct") { + checkAnswerAndSchema( + structLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))), + Row(Row(Row(Row(1, null, 3, 4)))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("withField should replace field in struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("b", lit(2))), + Row(Row(1, 2, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should replace field in null struct") { + checkAnswerAndSchema( + nullStructLevel1.withColumn("a", 'a.withField("b", lit("foo"))), + Row(null) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true)))) + } + + test("withField should replace field in nested null struct") { + checkAnswerAndSchema( + nullStructLevel2.withColumn("a", $"a".withField("a.b", lit("foo"))), + Row(Row(null)) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true))), + nullable = false)))) + } + + test("withField should replace field with null value in struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("c", lit(null).cast(IntegerType))), + Row(Row(1, null, null)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true))), + nullable = false)))) + } + + test("withField should replace multiple fields in struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))), + Row(Row(10, 20, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should replace field in nested struct") { + Seq( + structLevel2.withColumn("a", $"a".withField("a.b", lit(2))), + structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("b", lit(2)))) + ).foreach { df => + checkAnswerAndSchema( + df, + Row(Row(Row(1, 2, 3))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("withField should replace field in deeply nested struct") { + checkAnswerAndSchema( + structLevel3.withColumn("a", $"a".withField("a.a.b", lit(2))), + Row(Row(Row(Row(1, 2, 3)))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("withField should replace all fields with given name in struct") { + val structLevel1 = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 2, 3)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("b", lit(100))), + Row(Row(1, 100, 100)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should replace fields in struct in given order") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("b", lit(2)).withField("b", lit(20))), + Row(Row(1, 20, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should add field and then replace same field in struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("d", lit(5))), + Row(Row(1, null, 3, 5)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should handle fields with dots in their name if correctly quoted") { + val df: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a.b", StructType(Seq( + StructField("c.d", IntegerType, nullable = false), + StructField("e.f", IntegerType, nullable = true), + StructField("g.h", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + df.withColumn("a", 'a.withField("`a.b`.`e.f`", lit(2))), + Row(Row(Row(1, 2, 3))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a.b", StructType(Seq( + StructField("c.d", IntegerType, nullable = false), + StructField("e.f", IntegerType, nullable = false), + StructField("g.h", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + intercept[AnalysisException] { + df.withColumn("a", 'a.withField("a.b.e.f", lit(2))) + }.getMessage should include("No such struct field a in a.b") + } + + private lazy val mixedCaseStructLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 1)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + + test("withField should replace field in struct even if casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), + Row(Row(2, 1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("A", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), + Row(Row(1, 2)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + } + } + + test("withField should add field to struct because casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), + Row(Row(1, 1, 2)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false), + StructField("A", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), + Row(Row(1, 1, 2)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + } + } + + private lazy val mixedCaseStructLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, 1), Row(1, 1))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("B", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + test("withField should replace nested field in struct even if casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswerAndSchema( + mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))), + Row(Row(Row(2, 1), Row(1, 1))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("A", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("B", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))), + Row(Row(Row(1, 1), Row(2, 1))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("b", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("withField should throw an exception because casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + intercept[AnalysisException] { + mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))) + }.getMessage should include("No such struct field A in a, B") + + intercept[AnalysisException] { + mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))) + }.getMessage should include("No such struct field b in a, B") + } + } }