Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
Expand Down Expand Up @@ -541,57 +543,114 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
}

/**
* Adds/replaces field in struct by name.
* Represents an operation to be applied to the fields of a struct.
*/
case class WithFields(
structExpr: Expression,
names: Seq[String],
valExprs: Seq[Expression]) extends Unevaluable {
trait StructFieldsOperation {

val resolver: Resolver = SQLConf.get.resolver

assert(names.length == valExprs.length)
/**
* Returns an updated list of StructFields and Expressions that will ultimately be used
* as the fields argument for [[StructType]] and as the children argument for
* [[CreateNamedStruct]] respectively inside of [[UpdateFields]].
*/
def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)]
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya I'm not quite done with this PR yet but I wanted to share it with you early because some of the changes I'm making in here may be helpful for #29587 (assuming this PR is accepted). Specifically, it would be possible to implement sorting of fields in a struct simply by:

case class OrderStructFieldsByName() extends StructFieldsOperation {
  override def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)] =
    values.sortBy { case (field, _) => field.name }
}

UpdateFields(structExpr, OrderStructFieldsByName() :: Nil)


/**
* Add or replace a field by name.
*
* We extend [[Unevaluable]] here to ensure that [[UpdateFields]] can include it as part of its
* children, and thereby enable the analyzer to resolve and transform valExpr as necessary.
*/
case class WithField(name: String, valExpr: Expression)
extends Unevaluable with StructFieldsOperation {

override def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)] = {
val newFieldExpr = (StructField(name, valExpr.dataType, valExpr.nullable), valExpr)
val result = ArrayBuffer.empty[(StructField, Expression)]
var hasMatch = false
for (existingFieldExpr @ (existingField, _) <- values) {
if (resolver(existingField.name, name)) {
hasMatch = true
result += newFieldExpr
} else {
result += existingFieldExpr
}
}
if (!hasMatch) result += newFieldExpr
result
}

override def children: Seq[Expression] = valExpr :: Nil

override def dataType: DataType = throw new IllegalStateException(
"WithField.dataType should not be called.")

override def nullable: Boolean = throw new IllegalStateException(
"WithField.nullable should not be called.")

override def prettyName: String = "WithField"
}

/**
* Drop a field by name.
*/
case class DropField(name: String) extends StructFieldsOperation {
override def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)] =
values.filterNot { case (field, _) => resolver(field.name, name) }
}

/**
* Updates fields in a struct.
*/
case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperation])
extends Unevaluable {

override def checkInputDataTypes(): TypeCheckResult = {
if (!structExpr.dataType.isInstanceOf[StructType]) {
TypeCheckResult.TypeCheckFailure(
"struct argument should be struct type, got: " + structExpr.dataType.catalogString)
val dataType = structExpr.dataType
if (!dataType.isInstanceOf[StructType]) {
TypeCheckResult.TypeCheckFailure("struct argument should be struct type, got: " +
dataType.catalogString)
} else if (newExprs.isEmpty) {
TypeCheckResult.TypeCheckFailure("cannot drop all fields in struct")
} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def children: Seq[Expression] = structExpr +: valExprs
override def children: Seq[Expression] = structExpr +: fieldOps.collect {
case e: Expression => e
}

override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType]
override def dataType: StructType = StructType(newFields)

override def nullable: Boolean = structExpr.nullable

override def prettyName: String = "with_fields"
override def prettyName: String = "update_fields"

lazy val evalExpr: Expression = {
val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression])
}
private lazy val newFieldExprs: Seq[(StructField, Expression)] = {
val existingFieldExprs: Seq[(StructField, Expression)] =
structExpr.dataType.asInstanceOf[StructType].fields.zipWithIndex.map {
case (field, i) => (field, GetStructField(structExpr, i))
}

val addOrReplaceExprs = names.zip(valExprs)

val resolver = SQLConf.get.resolver
val newExprs = addOrReplaceExprs.foldLeft(existingExprs) {
case (resultExprs, newExpr @ (newExprName, _)) =>
if (resultExprs.exists(x => resolver(x._1, newExprName))) {
resultExprs.map {
case (name, _) if resolver(name, newExprName) => newExpr
case x => x
}
} else {
resultExprs :+ newExpr
}
}.flatMap { case (name, expr) => Seq(Literal(name), expr) }
fieldOps.foldLeft(existingFieldExprs)((exprs, op) => op(exprs))
}

private lazy val newFields: Seq[StructField] = newFieldExprs.map(_._1)

lazy val newExprs: Seq[Expression] = newFieldExprs.map(_._2)

lazy val evalExpr: Expression = {
val createNamedStructExpr = CreateNamedStruct(newFieldExprs.flatMap {
case (field, expr) => Seq(Literal(field.name), expr)
})

val expr = CreateNamedStruct(newExprs)
if (structExpr.nullable) {
If(IsNull(structExpr), Literal(null, expr.dataType), expr)
If(IsNull(structExpr), Literal(null, dataType), createNamedStructExpr)
} else {
expr
createNamedStructExpr
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.StructType

/**
* Simplify redundant [[CreateNamedStruct]], [[CreateArray]] and [[CreateMap]] expressions.
Expand All @@ -39,18 +40,13 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
createNamedStruct.valExprs(ordinal)
case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) =>
val name = w.dataType(ordinal).name
val matches = names.zip(valExprs).filter(_._1 == name)
if (matches.nonEmpty) {
// return last matching element as that is the final value for the field being extracted.
// For example, if a user submits a query like this:
// `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")`
// we want to return `lit(2)` (and not `lit(1)`).
val expr = matches.last._2
If(IsNull(struct), Literal(null, expr.dataType), expr)
} else {
GetStructField(struct, ordinal, maybeName)
case GetStructField(u: UpdateFields, ordinal, _)if !u.structExpr.isInstanceOf[UpdateFields] =>
val structExpr = u.structExpr
u.newExprs(ordinal) match {
// if the struct itself is null, then any value extracted from it (expr) will be null
// so we don't need to wrap expr in If(IsNull(struct), Literal(null, expr.dataType), expr)
case expr: GetStructField if expr.child.semanticEquals(structExpr) => expr
case expr => If(IsNull(structExpr), Literal(null, expr.dataType), expr)
}
// Remove redundant array indexing.
case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RemoveRedundantAliases,
UnwrapCastInBinaryComparison,
RemoveNoopOperators,
CombineWithFields,
CombineUpdateFields,
SimplifyExtractValueOps,
CombineConcats) ++
extendedOperatorOptimizationRules
Expand Down Expand Up @@ -221,7 +221,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RemoveNoopOperators) :+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression)
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)

// remove any batches with no rules. this may happen when subclasses do not add optional rules.
batches.filter(_.rules.nonEmpty)
Expand Down Expand Up @@ -255,7 +255,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewriteCorrelatedScalarSubquery.ruleName ::
RewritePredicateSubquery.ruleName ::
NormalizeFloatingNumbers.ruleName ::
ReplaceWithFieldsExpression.ruleName :: Nil
ReplaceUpdateFieldsExpression.ruleName :: Nil

/**
* Optimize all the subqueries inside expression.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,26 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.WithFields
import org.apache.spark.sql.catalyst.expressions.UpdateFields
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule


/**
* Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression.
* Combines all adjacent [[UpdateFields]] expression into a single [[UpdateFields]] expression.
*/
object CombineWithFields extends Rule[LogicalPlan] {
object CombineUpdateFields extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) =>
WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2)
case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) =>
UpdateFields(struct, fieldOps1 ++ fieldOps2)
}
}

/**
* Replaces [[WithFields]] expression with an evaluable expression.
* Replaces [[UpdateFields]] expression with an evaluable expression.
*/
object ReplaceWithFieldsExpression extends Rule[LogicalPlan] {
object ReplaceUpdateFieldsExpression extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case w: WithFields => w.evalExpr
case u: UpdateFields => u.evalExpr
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,56 +19,53 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields}
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, UpdateFields, WithField}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class CombineWithFieldsSuite extends PlanTest {
class CombineUpdateFieldsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil
val batches = Batch("CombineUpdateFields", FixedPoint(10), CombineUpdateFields) :: Nil
}

private val testRelation = LocalRelation('a.struct('a1.int))

test("combines two WithFields") {
test("combines two adjacent UpdateFields Expressions") {
val originalQuery = testRelation
.select(Alias(
WithFields(
WithFields(
UpdateFields(
UpdateFields(
'a,
Seq("b1"),
Seq(Literal(4))),
Seq("c1"),
Seq(Literal(5))), "out")())
WithField("b1", Literal(4)) :: Nil),
WithField("c1", Literal(5)) :: Nil), "out")())

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")())
.select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) ::
Nil), "out")())
.analyze

comparePlans(optimized, correctAnswer)
}

test("combines three WithFields") {
test("combines three adjacent UpdateFields Expressions") {
val originalQuery = testRelation
.select(Alias(
WithFields(
WithFields(
WithFields(
UpdateFields(
UpdateFields(
UpdateFields(
'a,
Seq("b1"),
Seq(Literal(4))),
Seq("c1"),
Seq(Literal(5))),
Seq("d1"),
Seq(Literal(6))), "out")())
WithField("b1", Literal(4)) :: Nil),
WithField("c1", Literal(5)) :: Nil),
WithField("d1", Literal(6)) :: Nil), "out")())

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")())
.select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) ::
WithField("d1", Literal(6)) :: Nil), "out")())
.analyze

comparePlans(optimized, correctAnswer)
Expand Down
Loading