diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6f201ba3a842f..acc9637abaed8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -240,6 +240,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor */ val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil + /** + * Override to provide additional rules for the "Substitution" batch. + */ + val extendedSubstitutionRules: Seq[Rule[LogicalPlan]] = Nil + /** * Override to provide rules to do post-hoc resolution. Note that these rules will be executed * in an individual batch. This batch is to run right after the normal resolution batch and @@ -259,11 +264,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // However, when manipulating deeply nested schema, `UpdateFields` expression tree could be // very complex and make analysis impossible. Thus we need to optimize `UpdateFields` early // at the beginning of analysis. - OptimizeUpdateFields, - CTESubstitution, - WindowsSubstitution, - EliminateUnions, - SubstituteUnresolvedOrdinals), + OptimizeUpdateFields +: + CTESubstitution +: + WindowsSubstitution +: + EliminateUnions +: + SubstituteUnresolvedOrdinals +: + extendedSubstitutionRules: _*), Batch("Disable Hints", Once, new ResolveHints.DisableHints), Batch("Hints", fixedPoint, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index b7c86ab7de6b4..6392db5d5bc46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -284,6 +284,20 @@ class SparkSessionExtensions { preCBORules += builder } + private[this] val substitutionRules = mutable.Buffer.empty[RuleBuilder] + + private[sql] def buildSubstitutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + substitutionRules.map(_.apply(session)).toSeq + } + + /** + * Inject a substitution `Rule` builder into the [[SparkSession]]. The injected rules will be + * executed during the substitution batch. + */ + def injectSubstitutionRule(builder: RuleBuilder): Unit = { + substitutionRules += builder + } + private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 630e1202f6d36..cc78dd10f62b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -184,6 +184,9 @@ abstract class BaseSessionStateBuilder( * Note: this depends on the `conf` and `catalog` fields. */ protected def analyzer: Analyzer = new Analyzer(catalogManager) { + override val extendedSubstitutionRules: Seq[Rule[LogicalPlan]] = + customSubstitutionRules + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: @@ -223,6 +226,16 @@ abstract class BaseSessionStateBuilder( extensions.buildResolutionRules(session) } + /** + * Custom substitution rules to add to the Analyzer. Prefer overriding this instead of creating + * your own Analyzer. + * + * Note that this may NOT depend on the `analyzer` function. + */ + protected def customSubstitutionRules: Seq[Rule[LogicalPlan]] = { + extensions.buildSubstitutionRules(session) + } + /** * Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of * creating your own Analyzer. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index c1b5d2761f7b4..dc55680eccb2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -77,6 +77,12 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt } } + test("SPARK-46050: inject substitution rule") { + withSession(Seq(_.injectSubstitutionRule(MyRule))) { session => + assert(session.sessionState.analyzer.extendedSubstitutionRules.contains(MyRule(session))) + } + } + test("inject post hoc resolution analyzer rule") { withSession(Seq(_.injectPostHocResolutionRule(MyRule))) { session => assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index e991665e2887c..ec17fbc0a3278 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -84,6 +84,9 @@ class HiveSessionStateBuilder( * A logical query plan `Analyzer` with rules specific to Hive. */ override protected def analyzer: Analyzer = new Analyzer(catalogManager) { + override val extendedSubstitutionRules: Seq[Rule[LogicalPlan]] = + customSubstitutionRules + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = new ResolveHiveSerdeTable(session) +: new FindDataSourceTable(session) +: