Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,11 @@ class Analyzer(

case i @ InsertIntoStatement(u: UnresolvedRelation, _, _, _, _) if i.query.resolved =>
lookupV2Relation(u.multipartIdentifier)
.map(v2Relation => i.copy(table = v2Relation))
.map {
EliminateSubqueryAliases(_) match {
case r: DataSourceV2Relation => i.copy(table = r)
}
}
.getOrElse(i)

case alter @ AlterTable(_, _, u: UnresolvedV2Relation, _) =>
Expand All @@ -827,14 +831,10 @@ class Analyzer(
/**
* Performs the lookup of DataSourceV2 Tables from v2 catalog.
*/
private def lookupV2Relation(identifier: Seq[String]): Option[DataSourceV2Relation] =
private def lookupV2Relation(identifier: Seq[String]): Option[LogicalPlan] =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type can still be DataSourceV2Relation?

expandRelationName(identifier) match {
case NonSessionCatalogAndIdentifier(catalog, ident) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(table) =>
Some(DataSourceV2Relation.create(table, Some(catalog), Some(ident)))
case None => None
}
CatalogV2Util.loadRelation(catalog, ident)
case _ => None
}
}
Expand Down Expand Up @@ -922,7 +922,7 @@ class Analyzer(
case v1Table: V1Table =>
v1SessionCatalog.getRelation(v1Table.v1Table)
case table =>
DataSourceV2Relation.create(table, Some(catalog), Some(ident))
CatalogV2Util.getRelation(catalog, ident, table)
}
val key = catalog.name +: ident.namespace :+ ident.name
Option(AnalysisContext.get.relationCache.getOrElseUpdate(key, loaded.orNull))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,8 @@ trait CheckAnalysis extends PredicateHelper {
case _ =>
}

case alter: AlterTable if alter.childrenResolved =>
val table = alter.table
case alter @ AlterTable(_, _, SubqueryAlias(_, table: NamedRelation), _)
if alter.childrenResolved =>
def findField(operation: String, fieldName: Array[String]): StructField = {
// include collections because structs nested in maps and arrays may be altered
val field = table.schema.findNestedField(fieldName, includeCollections = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,6 @@ case class AttributeReference(
val qualifier: Seq[String] = Seq.empty[String])
extends Attribute with Unevaluable {

// currently can only handle qualifier of length 2
require(qualifier.length <= 2)
/**
* Returns true iff the expression id is the same for both attributes.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import com.google.common.collect.Maps

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{StructField, StructType}

/**
Expand Down Expand Up @@ -128,90 +127,65 @@ package object expressions {
m.mapValues(_.distinct).map(identity)
}

/** Map to use for direct case insensitive attribute lookups. */
@transient private lazy val direct: Map[String, Seq[Attribute]] = {
/** Attribute name to attributes */
@transient private val attrsMap: Map[String, Seq[Attribute]] = {
unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT)))
}

/** Map to use for qualified case insensitive attribute lookups with 2 part key */
@transient private lazy val qualified: Map[(String, String), Seq[Attribute]] = {
// key is 2 part: table/alias and name
val grouped = attrs.filter(_.qualifier.nonEmpty).groupBy {
a => (a.qualifier.last.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT))
}
unique(grouped)
}

/** Map to use for qualified case insensitive attribute lookups with 3 part key */
@transient private val qualified3Part: Map[(String, String, String), Seq[Attribute]] = {
// key is 3 part: database name, table name and name
val grouped = attrs.filter(_.qualifier.length == 2).groupBy { a =>
(a.qualifier.head.toLowerCase(Locale.ROOT),
a.qualifier.last.toLowerCase(Locale.ROOT),
a.name.toLowerCase(Locale.ROOT))
}
unique(grouped)
}

/** Perform attribute resolution given a name and a resolver. */
def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = {
// Collect matching attributes given a name and a lookup.
def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = {
candidates.toSeq.flatMap(_.collect {
case a if resolver(a.name, name) => a.withName(name)
})
// Returns true if the `short` qualifier is a subset of the last elements of
// `long` qualifier. For example, Seq("a", "b") is a subset of Seq("a", "a", "b"),
// but not a subset of Seq("a", "b", "b").
def matchQualifier(short: Seq[String], long: Seq[String]): Boolean = {
(long.length >= short.length) &&
long.takeRight(short.length)
.zip(short)
.forall(x => resolver(x._1, x._2))
}

// Find matches for the given name assuming that the 1st two parts are qualifier
// (i.e. database name and table name) and the 3rd part is the actual column name.
//
// For example, consider an example where "db1" is the database name, "a" is the table name
// and "b" is the column name and "c" is the struct field name.
// If the name parts is db1.a.b.c, then Attribute will match
// Attribute(b, qualifier("db1,"a")) and List("c") will be the second element
var matches: (Seq[Attribute], Seq[String]) = nameParts match {
case dbPart +: tblPart +: name +: nestedFields =>
val key = (dbPart.toLowerCase(Locale.ROOT),
tblPart.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT))
val attributes = collectMatches(name, qualified3Part.get(key)).filter {
a => (resolver(dbPart, a.qualifier.head) && resolver(tblPart, a.qualifier.last))
}
(attributes, nestedFields)
case _ =>
(Seq.empty, Seq.empty)
// Collect attributes that match the given name and qualifier.
// A match occurs if
// 1) the given name matches the attribute's name according to the resolver.
// 2) the given qualifier is a subset of the attribute's qualifier.
def collectMatches(
name: String,
qualifier: Seq[String],
candidates: Option[Seq[Attribute]]): Seq[Attribute] = {
candidates.toSeq.flatMap(_.collect {
case a if resolver(name, a.name) && matchQualifier(qualifier, a.qualifier) =>
a.withName(name)
})
}

// If there are no matches, then find matches for the given name assuming that
// the 1st part is a qualifier (i.e. table name, alias, or subquery alias) and the
// 2nd part is the actual name. This returns a tuple of
// matched attributes and a list of parts that are to be resolved.
//
// For example, consider an example where "a" is the table name, "b" is the column name,
// and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b",
// and the second element will be List("c").
if (matches._1.isEmpty) {
matches = nameParts match {
case qualifier +: name +: nestedFields =>
val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT))
val attributes = collectMatches(name, qualified.get(key)).filter { a =>
resolver(qualifier, a.qualifier.last)
}
(attributes, nestedFields)
case _ =>
(Seq.empty[Attribute], Seq.empty[String])
// Iterate each string in `nameParts` in a reverse order and try to match the attributes
// considering the current string as the attribute name. For example, if `nameParts` is
// Seq("a", "b", "c"), the match will be performed in the following order:
// 1) name = "c", qualifier = Seq("a", "b")
// 2) name = "b", qualifier = Seq("a")
// 3) name = "a", qualifier = Seq()
// Note that the match is performed in the reverse order in order to match the longest
// qualifier as possible. If a match is found, the remaining portion of `nameParts`
// is also returned as nested fields.
val matches = nameParts.zipWithIndex.reverseIterator.flatMap { case (name, index) =>
val matched = collectMatches(
name,
nameParts.take(index),
attrsMap.get(name.toLowerCase(Locale.ROOT)))
if (matched.nonEmpty) {
(matched, nameParts.takeRight(nameParts.length - index - 1)) :: Nil
} else {
Nil
}
}

// If none of attributes match database.table.column pattern or
// `table.column` pattern, we try to resolve it as a column.
val (candidates, nestedFields) = matches match {
case (Seq(), _) =>
val name = nameParts.head
val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT)))
(attributes, nameParts.tail)
case _ => matches
if (matches.isEmpty) {
return None
}

// Note that `matches` is an iterator, and only the first match will be used.
val (candidates, nestedFields) = matches.next

def name = UnresolvedAttribute(nameParts).name
candidates match {
case Seq(a) if nestedFields.nonEmpty =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,21 @@ sealed trait IdentifierWithDatabase {

/**
* Encapsulates an identifier that is either a alias name or an identifier that has table
* name and optionally a database name.
* name and a namespace.
* The SubqueryAlias node keeps track of the qualifier using the information in this structure
* @param identifier - Is an alias name or a table name
* @param database - Is a database name and is optional
* @param name - Is an alias name or a table name
* @param namespace - Is a namespace
*/
case class AliasIdentifier(identifier: String, database: Option[String])
extends IdentifierWithDatabase {
case class AliasIdentifier(name: String, namespace: Seq[String]) {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

def this(identifier: String) = this(identifier, Seq())

def this(identifier: String) = this(identifier, None)
override def toString: String = (namespace :+ name).quoted
}

object AliasIdentifier {
def apply(identifier: String): AliasIdentifier = new AliasIdentifier(identifier)
def apply(name: String): AliasIdentifier = new AliasIdentifier(name)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.types._
import org.apache.spark.util.random.RandomSampler

Expand Down Expand Up @@ -849,18 +850,18 @@ case class Tail(limitExpr: Expression, child: LogicalPlan) extends OrderPreservi
/**
* Aliased subquery.
*
* @param name the alias identifier for this subquery.
* @param identifier the alias identifier for this subquery.
* @param child the logical plan of this subquery.
*/
case class SubqueryAlias(
name: AliasIdentifier,
identifier: AliasIdentifier,
child: LogicalPlan)
extends OrderPreservingUnaryNode {

def alias: String = name.identifier
def alias: String = identifier.name

override def output: Seq[Attribute] = {
val qualifierList = name.database.map(Seq(_, alias)).getOrElse(Seq(alias))
val qualifierList = identifier.namespace :+ alias
child.output.map(_.withQualifier(qualifierList))
}
override def doCanonicalize(): LogicalPlan = child.canonicalized
Expand All @@ -877,7 +878,13 @@ object SubqueryAlias {
identifier: String,
database: String,
child: LogicalPlan): SubqueryAlias = {
SubqueryAlias(AliasIdentifier(identifier, Some(database)), child)
SubqueryAlias(AliasIdentifier(identifier, Seq(database)), child)
}

def apply(
identifier: Identifier,
child: LogicalPlan): SubqueryAlias = {
SubqueryAlias(AliasIdentifier(identifier.name, identifier.namespace), child)
}
}
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ case class DropTable(
case class AlterTable(
catalog: TableCatalog,
ident: Identifier,
table: NamedRelation,
table: LogicalPlan,
changes: Seq[TableChange]) extends Command {

override lazy val resolved: Boolean = table.resolved && {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.sql.catalyst.IdentifierWithDatabase
import org.apache.spark.sql.catalyst.{AliasIdentifier, IdentifierWithDatabase}
import org.apache.spark.sql.catalyst.ScalaReflection._
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource}
import org.apache.spark.sql.catalyst.errors._
Expand Down Expand Up @@ -780,6 +780,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case exprId: ExprId => true
case field: StructField => true
case id: IdentifierWithDatabase => true
case alias: AliasIdentifier => true
case join: JoinType => true
case spec: BucketSpec => true
case catalog: CatalogTable => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ import java.util.Collections

import scala.collection.JavaConverters._

import org.apache.spark.sql.catalyst.analysis.{NamedRelation, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, UnresolvedV2Relation}
import org.apache.spark.sql.catalyst.plans.logical.AlterTable
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, UnresolvedV2Relation}
import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType}
Expand Down Expand Up @@ -285,8 +286,14 @@ private[sql] object CatalogV2Util {
case _: NoSuchNamespaceException => None
}

def loadRelation(catalog: CatalogPlugin, ident: Identifier): Option[NamedRelation] = {
loadTable(catalog, ident).map(DataSourceV2Relation.create(_, Some(catalog), Some(ident)))
def loadRelation(catalog: CatalogPlugin, ident: Identifier): Option[LogicalPlan] = {
loadTable(catalog, ident).map(getRelation(catalog, ident, _))
}

def getRelation(catalog: CatalogPlugin, ident: Identifier, table: Table): LogicalPlan = {
SubqueryAlias(
Identifier.of(catalog.name +: ident.namespace, ident.name),
DataSourceV2Relation.create(table, Some(catalog), Some(ident)))
}

def isSessionCatalog(catalog: CatalogPlugin): Boolean = {
Expand Down
Loading