Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
*/
val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil

/**
* Override to provide additional rules for the "Substitution" batch.
*/
val extendedSubstitutionRules: Seq[Rule[LogicalPlan]] = Nil

/**
* Override to provide rules to do post-hoc resolution. Note that these rules will be executed
* in an individual batch. This batch is to run right after the normal resolution batch and
Expand All @@ -259,11 +264,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// However, when manipulating deeply nested schema, `UpdateFields` expression tree could be
// very complex and make analysis impossible. Thus we need to optimize `UpdateFields` early
// at the beginning of analysis.
OptimizeUpdateFields,
CTESubstitution,
WindowsSubstitution,
EliminateUnions,
SubstituteUnresolvedOrdinals),
OptimizeUpdateFields +:
CTESubstitution +:
WindowsSubstitution +:
EliminateUnions +:
SubstituteUnresolvedOrdinals +:
extendedSubstitutionRules: _*),
Batch("Disable Hints", Once,
new ResolveHints.DisableHints),
Batch("Hints", fixedPoint,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,20 @@ class SparkSessionExtensions {
preCBORules += builder
}

private[this] val substitutionRules = mutable.Buffer.empty[RuleBuilder]

private[sql] def buildSubstitutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
substitutionRules.map(_.apply(session)).toSeq
}

/**
* Inject a substitution `Rule` builder into the [[SparkSession]]. The injected rules will be
* executed during the substitution batch.
*/
def injectSubstitutionRule(builder: RuleBuilder): Unit = {
substitutionRules += builder
}

private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder]

private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ abstract class BaseSessionStateBuilder(
* Note: this depends on the `conf` and `catalog` fields.
*/
protected def analyzer: Analyzer = new Analyzer(catalogManager) {
override val extendedSubstitutionRules: Seq[Rule[LogicalPlan]] =
customSubstitutionRules

override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
new FindDataSourceTable(session) +:
new ResolveSQLOnFile(session) +:
Expand Down Expand Up @@ -223,6 +226,16 @@ abstract class BaseSessionStateBuilder(
extensions.buildResolutionRules(session)
}

/**
* Custom substitution rules to add to the Analyzer. Prefer overriding this instead of creating
* your own Analyzer.
*
* Note that this may NOT depend on the `analyzer` function.
*/
protected def customSubstitutionRules: Seq[Rule[LogicalPlan]] = {
extensions.buildSubstitutionRules(session)
}

/**
* Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of
* creating your own Analyzer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
}
}

test("SPARK-46050: inject substitution rule") {
withSession(Seq(_.injectSubstitutionRule(MyRule))) { session =>
assert(session.sessionState.analyzer.extendedSubstitutionRules.contains(MyRule(session)))
}
}

test("inject post hoc resolution analyzer rule") {
withSession(Seq(_.injectPostHocResolutionRule(MyRule))) { session =>
assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class HiveSessionStateBuilder(
* A logical query plan `Analyzer` with rules specific to Hive.
*/
override protected def analyzer: Analyzer = new Analyzer(catalogManager) {
override val extendedSubstitutionRules: Seq[Rule[LogicalPlan]] =
customSubstitutionRules

override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
new ResolveHiveSerdeTable(session) +:
new FindDataSourceTable(session) +:
Expand Down