diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c6429077b07fb..e55d504653fa2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -4017,7 +4017,8 @@ object SessionWindowing extends Rule[LogicalPlan] { * This also adds a marker to the session column so that downstream can easily find the column * on session window. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( + _.containsPattern(SESSION_WINDOW), ruleId) { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val sessionExpressions = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala index 77e8dfde87bbb..e1d9588c46328 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.trees.TreePattern.{SESSION_WINDOW, TreePattern} import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -73,6 +74,7 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Expression) extend override def dataType: DataType = new StructType() .add(StructField("start", timeColumn.dataType)) .add(StructField("end", timeColumn.dataType)) + final override val nodePatterns: Seq[TreePattern] = Seq(SESSION_WINDOW) // This expression is replaced in the analyzer. override lazy val resolved = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 5b710e6e137b2..12296926410db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -88,6 +88,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.ResolveLambdaVariables" :: "org.apache.spark.sql.catalyst.analysis.ResolveTimeZone" :: "org.apache.spark.sql.catalyst.analysis.ResolveUnion" :: + "org.apache.spark.sql.catalyst.analysis.SessionWindowing" :: "org.apache.spark.sql.catalyst.analysis.SubstituteUnresolvedOrdinals" :: "org.apache.spark.sql.catalyst.analysis.TimeWindowing" :: "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$CombinedTypeCoercionRule" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 93273b5a2c7a7..25f6a66a219a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -79,6 +79,7 @@ object TreePattern extends Enumeration { val SCALAR_SUBQUERY: Value = Value val SCALAR_SUBQUERY_REFERENCE: Value = Value val SCALA_UDF: Value = Value + val SESSION_WINDOW: Value = Value val SORT: Value = Value val SUBQUERY_ALIAS: Value = Value val SUM: Value = Value