diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala index df6b1c400bb7..be837d72c5aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Cast, CreateNamedStruct, Expression, GetStructField, IsNotNull, Literal, PreciseTimestampConversion, SessionWindow, Subtract, TimeWindow, WindowTime} import org.apache.spark.sql.catalyst.plans.logical.{Expand, Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.TIME_WINDOW +import org.apache.spark.sql.catalyst.trees.TreePattern.{SESSION_WINDOW, TIME_WINDOW} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{CalendarIntervalType, DataType, LongType, Metadata, MetadataBuilder, StructType} import org.apache.spark.unsafe.types.CalendarInterval @@ -187,7 +187,8 @@ object SessionWindowing extends Rule[LogicalPlan] { * This also adds a marker to the session column so that downstream can easily find the column * on session window. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( + _.containsPattern(SESSION_WINDOW), ruleId) { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val sessionExpressions = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala index 02273b0c4613..021f119e0a1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.trees.TreePattern.{SESSION_WINDOW, TreePattern} import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -91,6 +92,7 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Expression) extend override def dataType: DataType = new StructType() .add(StructField("start", children.head.dataType)) .add(StructField("end", children.head.dataType)) + final override val nodePatterns: Seq[TreePattern] = Seq(SESSION_WINDOW) // This expression is replaced in the analyzer. override lazy val resolved = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 41aa68f0ec61..e824a0b533d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -94,6 +94,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.ResolveOrderByAll" :: "org.apache.spark.sql.catalyst.analysis.ResolveTimeZone" :: "org.apache.spark.sql.catalyst.analysis.ResolveUnion" :: + "org.apache.spark.sql.catalyst.analysis.SessionWindowing" :: "org.apache.spark.sql.catalyst.analysis.SubstituteUnresolvedOrdinals" :: "org.apache.spark.sql.catalyst.analysis.TimeWindowing" :: "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$CombinedTypeCoercionRule" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index fbe885bda064..e2e7fca27e05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -82,6 +82,7 @@ object TreePattern extends Enumeration { val SCALAR_SUBQUERY: Value = Value val SCALAR_SUBQUERY_REFERENCE: Value = Value val SCALA_UDF: Value = Value + val SESSION_WINDOW: Value = Value val SORT: Value = Value val SUBQUERY_ALIAS: Value = Value val SUM: Value = Value