Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.expressions.UpCast
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.{View => V2View}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.connector.catalog.CatalogPlugin
import org.apache.spark.sql.connector.catalog.Identifier
Expand All @@ -41,12 +46,12 @@ case class ResolveViews(spark: SparkSession) extends Rule[LogicalPlan] with Look

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u@UnresolvedRelation(nameParts, _, _)
if catalogManager.v1SessionCatalog.isTempView(Seq(nameParts.asIdentifier.name())) =>
if catalogManager.v1SessionCatalog.isTempView(nameParts) =>
u

case u@UnresolvedRelation(parts@CatalogAndIdentifier(catalog, ident), _, _) =>
loadView(catalog, ident)
.map(createViewRelation(parts.quoted, _))
.map(createViewRelation(parts, _))
.getOrElse(u)
}

Expand All @@ -60,23 +65,64 @@ case class ResolveViews(spark: SparkSession) extends Rule[LogicalPlan] with Look
case _ => None
}

private def createViewRelation(name: String, view: V2View): LogicalPlan = {
val child = parseViewText(name, view.query)
private def createViewRelation(nameParts: Seq[String], view: View): LogicalPlan = {
val parsed = parseViewText(nameParts.quoted, view.query)

// Apply any necessary rewrites to preserve correct resolution
val viewCatalogAndNamespace: Seq[String] = view.currentCatalog +: view.currentNamespace.toSeq
// Substitute CTEs within the view before qualifying table identifiers
SubqueryAlias(name, qualifyTableIdentifiers(CTESubstitution.apply(child), viewCatalogAndNamespace))
val rewritten = rewriteIdentifiers(parsed, viewCatalogAndNamespace);

// Apply the field aliases and column comments
// This logic differs from how Spark handles views in SessionCatalog.fromCatalogTable.
// This is more strict because it doesn't allow resolution by field name.
val aliases = view.schema.fields.zipWithIndex.map { case (expected, pos) =>
val attr = GetColumnByOrdinal(pos, expected.dataType)
Alias(UpCast(attr, expected.dataType), expected.name)(explicitMetadata = Some(expected.metadata))
}

SubqueryAlias(nameParts, Project(aliases, rewritten))
}

private def parseViewText(name: String, viewText: String): LogicalPlan = {
val origin = Origin(
objectType = Some("VIEW"),
objectName = Some(name)
)

try {
SparkSession.active.sessionState.sqlParser.parsePlan(viewText)
CurrentOrigin.withOrigin(origin) {
spark.sessionState.sqlParser.parseQuery(viewText)
}
} catch {
case _: ParseException =>
throw QueryCompilationErrors.invalidViewText(viewText, name)
}
}

private def rewriteIdentifiers(
plan: LogicalPlan,
catalogAndNamespace: Seq[String]): LogicalPlan = {
// Substitute CTEs within the view, then rewrite unresolved functions and relations
qualifyTableIdentifiers(
qualifyFunctionIdentifiers(
CTESubstitution.apply(plan),
catalogAndNamespace),
catalogAndNamespace)
}

private def qualifyFunctionIdentifiers(
plan: LogicalPlan,
catalogAndNamespace: Seq[String]): LogicalPlan = plan transformExpressions {
case u@UnresolvedFunction(Seq(name), _, _, _, _) =>
if (!isBuiltinFunction(name)) {
u.copy(nameParts = catalogAndNamespace :+ name)
} else {
u
}
case u@UnresolvedFunction(parts, _, _, _, _) if !isCatalog(parts.head) =>
u.copy(nameParts = catalogAndNamespace.head +: parts)
}

/**
* Qualify table identifiers with default catalog and namespace if necessary.
*/
Expand All @@ -86,8 +132,15 @@ case class ResolveViews(spark: SparkSession) extends Rule[LogicalPlan] with Look
child transform {
case u@UnresolvedRelation(Seq(table), _, _) =>
u.copy(multipartIdentifier = catalogAndNamespace :+ table)
case u@UnresolvedRelation(parts, _, _)
if !SparkSession.active.sessionState.catalogManager.isCatalogRegistered(parts.head) =>
case u@UnresolvedRelation(parts, _, _) if !isCatalog(parts.head) =>
u.copy(multipartIdentifier = catalogAndNamespace.head +: parts)
}

private def isCatalog(name: String): Boolean = {
spark.sessionState.catalogManager.isCatalogRegistered(name)
}

private def isBuiltinFunction(name: String): Boolean = {
spark.sessionState.catalogManager.v1SessionCatalog.isBuiltinFunction(FunctionIdentifier(name))
}
}
Loading