Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@

package org.apache.spark.sql.catalyst

import java.util.Locale

import com.google.common.collect.Maps

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{StructField, StructType}

Expand Down Expand Up @@ -138,6 +142,88 @@ package object expressions {
def indexOf(exprId: ExprId): Int = {
Option(exprIdToOrdinal.get(exprId)).getOrElse(-1)
}

private def unique[T](m: Map[T, Seq[Attribute]]): Map[T, Seq[Attribute]] = {
m.mapValues(_.distinct).map(identity)
}

/** Map to use for direct case insensitive attribute lookups. */
@transient private lazy val direct: Map[String, Seq[Attribute]] = {
unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT)))
}

/** Map to use for qualified case insensitive attribute lookups. */
@transient private val qualified: Map[(String, String), Seq[Attribute]] = {
val grouped = attrs.filter(_.qualifier.isDefined).groupBy { a =>
(a.qualifier.get.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT))
}
unique(grouped)
}

/** Perform attribute resolution given a name and a resolver. */
def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = {
// Collect matching attributes given a name and a lookup.
def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = {
candidates.toSeq.flatMap(_.collect {
case a if resolver(a.name, name) => a.withName(name)
})
}

// Find matches for the given name assuming that the 1st part is a qualifier (i.e. table name,
// alias, or subquery alias) and the 2nd part is the actual name. This returns a tuple of
// matched attributes and a list of parts that are to be resolved.
//
// For example, consider an example where "a" is the table name, "b" is the column name,
// and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b",
// and the second element will be List("c").
val matches = nameParts match {
case qualifier +: name +: nestedFields =>
val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT))
val attributes = collectMatches(name, qualified.get(key)).filter { a =>
resolver(qualifier, a.qualifier.get)
}
(attributes, nestedFields)
case all =>
(Nil, all)
}

// If none of attributes match `table.column` pattern, we try to resolve it as a column.
val (candidates, nestedFields) = matches match {
case (Seq(), _) =>
val name = nameParts.head
val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT)))
(attributes, nameParts.tail)
case _ => matches
}

def name = UnresolvedAttribute(nameParts).name
candidates match {
case Seq(a) if nestedFields.nonEmpty =>
// One match, but we also need to extract the requested nested field.
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
// and aliased it with the last part of the name.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
// Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final
// expression as "c".
val fieldExprs = nestedFields.foldLeft(a: Expression) { (e, name) =>
ExtractValue(e, Literal(name), resolver)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExtractValue has the same perf problem, but this can be fixed in a follow up

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an issue for the follow up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heuermh I have not filed the issue for this. Do you want to work on this?

}
Some(Alias(fieldExprs, nestedFields.last)())

case Seq(a) =>
// One match, no nested fields, use it.
Some(a)

case Seq() =>
// No matches.
None

case ambiguousReferences =>
// More than one match.
val referenceNames = ambiguousReferences.map(_.qualifiedName).mkString(", ")
throw new AnalysisException(s"Reference '$name' is ambiguous, could be: $referenceNames.")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ abstract class LogicalPlan
}
}

private[this] lazy val childAttributes = AttributeSeq(children.flatMap(_.output))

private[this] lazy val outputAttributes = AttributeSeq(output)

/**
* Optionally resolves the given strings to a [[NamedExpression]] using the input from all child
* nodes of this LogicalPlan. The attribute is expressed as
Expand All @@ -94,7 +98,7 @@ abstract class LogicalPlan
def resolveChildren(
nameParts: Seq[String],
resolver: Resolver): Option[NamedExpression] =
resolve(nameParts, children.flatMap(_.output), resolver)
childAttributes.resolve(nameParts, resolver)

/**
* Optionally resolves the given strings to a [[NamedExpression]] based on the output of this
Expand All @@ -104,7 +108,7 @@ abstract class LogicalPlan
def resolve(
nameParts: Seq[String],
resolver: Resolver): Option[NamedExpression] =
resolve(nameParts, output, resolver)
outputAttributes.resolve(nameParts, resolver)

/**
* Given an attribute name, split it to name parts by dot, but
Expand All @@ -114,105 +118,7 @@ abstract class LogicalPlan
def resolveQuoted(
name: String,
resolver: Resolver): Option[NamedExpression] = {
resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver)
}

/**
* Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
*
* This assumes `name` has multiple parts, where the 1st part is a qualifier
* (i.e. table name, alias, or subquery alias).
* See the comment above `candidates` variable in resolve() for semantics the returned data.
*/
private def resolveAsTableColumn(
nameParts: Seq[String],
resolver: Resolver,
attribute: Attribute): Option[(Attribute, List[String])] = {
assert(nameParts.length > 1)
if (attribute.qualifier.exists(resolver(_, nameParts.head))) {
// At least one qualifier matches. See if remaining parts match.
val remainingParts = nameParts.tail
resolveAsColumn(remainingParts, resolver, attribute)
} else {
None
}
}

/**
* Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
*
* Different from resolveAsTableColumn, this assumes `name` does NOT start with a qualifier.
* See the comment above `candidates` variable in resolve() for semantics the returned data.
*/
private def resolveAsColumn(
nameParts: Seq[String],
resolver: Resolver,
attribute: Attribute): Option[(Attribute, List[String])] = {
if (resolver(attribute.name, nameParts.head)) {
Option((attribute.withName(nameParts.head), nameParts.tail.toList))
} else {
None
}
}

/** Performs attribute resolution given a name and a sequence of possible attributes. */
protected def resolve(
nameParts: Seq[String],
input: Seq[Attribute],
resolver: Resolver): Option[NamedExpression] = {

// A sequence of possible candidate matches.
// Each candidate is a tuple. The first element is a resolved attribute, followed by a list
// of parts that are to be resolved.
// For example, consider an example where "a" is the table name, "b" is the column name,
// and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b",
// and the second element will be List("c").
var candidates: Seq[(Attribute, List[String])] = {
// If the name has 2 or more parts, try to resolve it as `table.column` first.
if (nameParts.length > 1) {
input.flatMap { option =>
resolveAsTableColumn(nameParts, resolver, option)
}
} else {
Seq.empty
}
}

// If none of attributes match `table.column` pattern, we try to resolve it as a column.
if (candidates.isEmpty) {
candidates = input.flatMap { candidate =>
resolveAsColumn(nameParts, resolver, candidate)
}
}

def name = UnresolvedAttribute(nameParts).name

candidates.distinct match {
// One match, no nested fields, use it.
case Seq((a, Nil)) => Some(a)

// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
// and aliased it with the last part of the name.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
// Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final
// expression as "c".
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), resolver))
Some(Alias(fieldExprs, nestedFields.last)())

// No matches.
case Seq() =>
logTrace(s"Could not find $name in ${input.mkString(", ")}")
None

// More than one match.
case ambiguousReferences =>
val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ")
throw new AnalysisException(
s"Reference '$name' is ambiguous, could be: $referenceNames.")
}
outputAttributes.resolve(UnresolvedAttribute.parseAttributeName(name), resolver)
}

/**
Expand Down