Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
8935221
rebase on latest master
fqaiser94 Jan 2, 2020
c8bea43
add sql case sensitive tests
fqaiser94 May 18, 2020
1e944b3
clean up expression examples
fqaiser94 May 18, 2020
5d0cdf8
regenerate golden files
fqaiser94 May 18, 2020
20cffbc
support adding/replacing deeply nested field in a single with_field call
fqaiser94 May 22, 2020
32552c4
Merge branch 'master' into SPARK-22231-withField
fqaiser94 May 24, 2020
15f8ee4
regenerate golden files
fqaiser94 May 24, 2020
1e3300e
fix examples
fqaiser94 May 24, 2020
9b98e7d
cleanup
fqaiser94 May 24, 2020
14b9822
throw exception if intermediate struct does not exist
fqaiser94 May 24, 2020
ce1e030
clean up documentation
fqaiser94 May 24, 2020
1bfa465
remove extra brackets
fqaiser94 May 28, 2020
238f2f2
Merge branch 'master' into SPARK-22231-withField
fqaiser94 May 28, 2020
e3c7930
refactor
cloud-fan Jun 2, 2020
a59d94f
fix typo
cloud-fan Jun 3, 2020
1199550
Merge pull request #1 from cloud-fan/help
fqaiser94 Jun 6, 2020
156bf2d
fix tests and implementation
fqaiser94 Jun 7, 2020
337bdb5
added a missing test case
fqaiser94 Jun 7, 2020
40cbc59
consistent language
fqaiser94 Jun 7, 2020
ab36504
better parameter names and order
fqaiser94 Jun 7, 2020
8159e47
4 space indentation for method parameters declaration
fqaiser94 Jun 11, 2020
38747e2
parse the field name at the driver side in functions.withField so it'…
fqaiser94 Jun 11, 2020
a54d3fb
simpler WithFields implementation + 3 small optimizer rules
fqaiser94 Jun 12, 2020
3e540ac
remove unused imports
fqaiser94 Jun 16, 2020
35004cf
fix indentation
fqaiser94 Jun 16, 2020
c2f9216
more concise
fqaiser94 Jun 17, 2020
f5a9420
use `!structExpr.dataType.isInstanceOf[StructType]`
fqaiser94 Jun 17, 2020
b66fcff
change names to Seq[String] and add assertion for names.length == val…
fqaiser94 Jun 18, 2020
8ada917
better tests for SimplifyExtractValueOps
fqaiser94 Jun 19, 2020
1af96cd
failed attempt to genCode
fqaiser94 Jun 30, 2020
6adbbd7
add comment to explain `matches.last._2` code
fqaiser94 Jul 1, 2020
bc5d622
simpler implementation for simplifying GetStructField(WithFields()) a…
fqaiser94 Jul 1, 2020
2a6ec1a
revert back to Unevaluable Expression
fqaiser94 Jul 1, 2020
f428b54
remove unneeded import
fqaiser94 Jul 1, 2020
98d5843
name that is more in line with the rest of Spark codebase
fqaiser94 Jul 1, 2020
b18ac3e
replace NullPointerException with IllegalArgumentException
fqaiser94 Jul 1, 2020
b58f3c8
make replace-value-in-null-struct tests clearer by changing replace v…
fqaiser94 Jul 1, 2020
096bbf4
false.toString -> "false" and true.toString -> "false"
fqaiser94 Jul 1, 2020
6edf5a8
remove unneeded import
fqaiser94 Jul 1, 2020
445e2d8
use KnownNotNull
fqaiser94 Jul 2, 2020
9c3b2a2
don't need a separate addOrReplace method anymore
fqaiser94 Jul 3, 2020
b0a9cc7
separate test cases to avoid confusion
fqaiser94 Jul 3, 2020
9fb0a5d
add example to demonstrate intermediate struct-type fields must have …
fqaiser94 Jul 3, 2020
b53930d
put the 2 new rules in WithFields.scala
fqaiser94 Jul 3, 2020
88244f0
better
fqaiser94 Jul 3, 2020
4315e92
spaces before and after @
fqaiser94 Jul 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,61 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E

override def prettyName: String = "str_to_map"
}

/**
* Adds/replaces field in struct by name.
*/
case class WithFields(
structExpr: Expression,
names: Seq[String],
valExprs: Seq[Expression]) extends Unevaluable {

assert(names.length == valExprs.length)

override def checkInputDataTypes(): TypeCheckResult = {
if (!structExpr.dataType.isInstanceOf[StructType]) {
TypeCheckResult.TypeCheckFailure(
"struct argument should be struct type, got: " + structExpr.dataType.catalogString)
} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def children: Seq[Expression] = structExpr +: valExprs

override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType]

override def foldable: Boolean = structExpr.foldable && valExprs.forall(_.foldable)

override def nullable: Boolean = structExpr.nullable

override def prettyName: String = "with_fields"

lazy val evalExpr: Expression = {
val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression])
}

val addOrReplaceExprs = names.zip(valExprs)

val resolver = SQLConf.get.resolver
val newExprs = addOrReplaceExprs.foldLeft(existingExprs) {
case (resultExprs, newExpr @ (newExprName, _)) =>
if (resultExprs.exists(x => resolver(x._1, newExprName))) {
resultExprs.map {
case (name, _) if resolver(name, newExprName) => newExpr
case x => x
}
} else {
resultExprs :+ newExpr
}
}.flatMap { case (name, expr) => Seq(Literal(name), expr) }

val expr = CreateNamedStruct(newExprs)
if (structExpr.nullable) {
If(IsNull(structExpr), Literal(null, expr.dataType), expr)
} else {
expr
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,18 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
createNamedStruct.valExprs(ordinal)

case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) =>
val name = w.dataType(ordinal).name
val matches = names.zip(valExprs).filter(_._1 == name)
if (matches.nonEmpty) {
// return last matching element as that is the final value for the field being extracted.
// For example, if a user submits a query like this:
// `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")`
// we want to return `lit(2)` (and not `lit(1)`).
matches.last._2
} else {
GetStructField(struct, ordinal, maybeName)
}
// Remove redundant array indexing.
case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) =>
// Instead of selecting the field on the entire array, select it from each member
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateSerialization,
RemoveRedundantAliases,
RemoveNoopOperators,
CombineWithFields,
SimplifyExtractValueOps,
CombineConcats) ++
extendedOperatorOptimizationRules
Expand Down Expand Up @@ -202,7 +203,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
CollapseProject,
RemoveNoopOperators) :+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers)
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression)

// remove any batches with no rules. this may happen when subclasses do not add optional rules.
batches.filter(_.rules.nonEmpty)
Expand Down Expand Up @@ -235,7 +237,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
PullupCorrelatedPredicates.ruleName ::
RewriteCorrelatedScalarSubquery.ruleName ::
RewritePredicateSubquery.ruleName ::
NormalizeFloatingNumbers.ruleName :: Nil
NormalizeFloatingNumbers.ruleName ::
ReplaceWithFieldsExpression.ruleName :: Nil

/**
* Optimize all the subqueries inside expression.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.WithFields
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule


/**
* Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression.
*/
object CombineWithFields extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) =>
WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2)
}
}

/**
* Replaces [[WithFields]] expression with an evaluable expression.
*/
object ReplaceWithFieldsExpression extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case w: WithFields => w.evalExpr
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class CombineWithFieldsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil
}

private val testRelation = LocalRelation('a.struct('a1.int))

test("combines two WithFields") {
val originalQuery = testRelation
.select(Alias(
WithFields(
WithFields(
'a,
Seq("b1"),
Seq(Literal(4))),
Seq("c1"),
Seq(Literal(5))), "out")())

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")())
.analyze

comparePlans(optimized, correctAnswer)
}

test("combines three WithFields") {
val originalQuery = testRelation
.select(Alias(
WithFields(
WithFields(
WithFields(
'a,
Seq("b1"),
Seq(Literal(4))),
Seq("c1"),
Seq(Literal(5))),
Seq("d1"),
Seq(Literal(6))), "out")())

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")())
.analyze

comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -452,4 +452,61 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null)
}

private val structAttr = 'struct1.struct('a.int)
private val testStructRelation = LocalRelation(structAttr)

test("simplify GetStructField on WithFields that is not changing the attribute being extracted") {
val query = testStructRelation.select(
GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 0, Some("a")) as "outerAtt")
val expected = testStructRelation.select(GetStructField('struct1, 0, Some("a")) as "outerAtt")
checkRule(query, expected)
}

test("simplify GetStructField on WithFields that is changing the attribute being extracted") {
val query = testStructRelation.select(
GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 1, Some("b")) as "outerAtt")
val expected = testStructRelation.select(Literal(1) as "outerAtt")
checkRule(query, expected)
}

test(
"simplify GetStructField on WithFields that is changing the attribute being extracted twice") {
val query = testStructRelation
.select(GetStructField(WithFields('struct1, Seq("b", "b"), Seq(Literal(1), Literal(2))), 1,
Some("b")) as "outerAtt")
val expected = testStructRelation.select(Literal(2) as "outerAtt")
checkRule(query, expected)
}

test("collapse multiple GetStructField on the same WithFields") {
val query = testStructRelation
.select(WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2")
.select(
GetStructField('struct2, 0, Some("a")) as "struct1A",
GetStructField('struct2, 1, Some("b")) as "struct1B")
val expected = testStructRelation.select(
GetStructField('struct1, 0, Some("a")) as "struct1A",
Literal(2) as "struct1B")
checkRule(query, expected)
}

test("collapse multiple GetStructField on different WithFields") {
val query = testStructRelation
.select(
WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2",
WithFields('struct1, Seq("b"), Seq(Literal(3))) as "struct3")
.select(
GetStructField('struct2, 0, Some("a")) as "struct2A",
GetStructField('struct2, 1, Some("b")) as "struct2B",
GetStructField('struct3, 0, Some("a")) as "struct3A",
GetStructField('struct3, 1, Some("b")) as "struct3B")
val expected = testStructRelation
.select(
GetStructField('struct1, 0, Some("a")) as "struct2A",
Literal(2) as "struct2B",
GetStructField('struct1, 0, Some("a")) as "struct3A",
Literal(3) as "struct3B")
checkRule(query, expected)
}
}
66 changes: 66 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,72 @@ class Column(val expr: Expression) extends Logging {
*/
def getItem(key: Any): Column = withExpr { UnresolvedExtractValue(expr, Literal(key)) }

// scalastyle:off line.size.limit
/**
* An expression that adds/replaces field in `StructType` by name.
*
* {{{
* val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
* df.select($"struct_col".withField("c", lit(3)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have any of you try to run these examples? The optimizer ConstantFolding rule will break these examples.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird, we have tests to cover these examples. @fqaiser94 can you take a look?

Copy link
Contributor Author

@fqaiser94 fqaiser94 Aug 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I failed to write a test case to cover this scenario, my bad.
And yea, I just tried this example again, and I can see that it fails.
The issue is that I override foldable for this Unevaluable Expression. And so, when foldable returns true, Spark tries to evaluate the expression and it fails at that point.
I kind-of realized this as well recently and in my PR for dropFields here, I've fixed the issue (basically i just don't override foldable anymore, which by default returns false).
I guess I should submit a follow-up PR to fix this immediately with associated unit tests?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for raising the issue @gatorsmile
I've created a JIRA and PR to address the issue.

* // result: {"a":1,"b":2,"c":3}
*
* val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
* df.select($"struct_col".withField("b", lit(3)))
* // result: {"a":1,"b":3}
*
* val df = sql("SELECT CAST(NULL AS struct<a:int,b:int>) struct_col")
* df.select($"struct_col".withField("c", lit(3)))
* // result: null of type struct<a:int,b:int,c:int>
*
* val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col")
* df.select($"struct_col".withField("b", lit(100)))
* // result: {"a":1,"b":100,"b":100}
*
* val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
* df.select($"struct_col".withField("a.c", lit(3)))
* // result: {"a":{"a":1,"b":2,"c":3}}
*
* val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col")
* df.select($"struct_col".withField("a.c", lit(3)))
* // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields
* }}}
*
* @group expr_ops
* @since 3.1.0
*/
// scalastyle:on line.size.limit
def withField(fieldName: String, col: Column): Column = withExpr {
require(fieldName != null, "fieldName cannot be null")
require(col != null, "col cannot be null")

val nameParts = if (fieldName.isEmpty) {
fieldName :: Nil
} else {
CatalystSqlParser.parseMultipartIdentifier(fieldName)
}
withFieldHelper(expr, nameParts, Nil, col.expr)
}

private def withFieldHelper(
struct: Expression,
namePartsRemaining: Seq[String],
namePartsDone: Seq[String],
value: Expression) : WithFields = {
val name = namePartsRemaining.head
if (namePartsRemaining.length == 1) {
WithFields(struct, name :: Nil, value :: Nil)
} else {
val newNamesRemaining = namePartsRemaining.tail
val newNamesDone = namePartsDone :+ name
val newValue = withFieldHelper(
struct = UnresolvedExtractValue(struct, Literal(name)),
namePartsRemaining = newNamesRemaining,
namePartsDone = newNamesDone,
value = value)
WithFields(struct, name :: Nil, newValue :: Nil)
}
}

/**
* An expression that gets a field by name in a `StructType`.
*
Expand Down
Loading