diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c560062d5b09a..5cb5f21e9f710 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3893,11 +3893,13 @@ object TimeWindowing extends Rule[LogicalPlan] { val windowStart = lastStart - i * window.slideDuration val windowEnd = windowStart + window.windowDuration + // We make sure value fields are nullable since the dataType of TimeWindow defines them + // as nullable. CreateNamedStruct( Literal(WINDOW_START) :: - PreciseTimestampConversion(windowStart, LongType, dataType) :: + PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() :: Literal(WINDOW_END) :: - PreciseTimestampConversion(windowEnd, LongType, dataType) :: + PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() :: Nil) } @@ -4012,11 +4014,15 @@ object SessionWindowing extends Rule[LogicalPlan] { val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration, session.timeColumn.dataType, LongType) + // We make sure value fields are nullable since the dataType of SessionWindow defines them + // as nullable. val literalSessionStruct = CreateNamedStruct( Literal(SESSION_START) :: - PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) :: + PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) + .castNullable() :: Literal(SESSION_END) :: - PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) :: + PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) + .castNullable() :: Nil) val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index dda0d193e7483..0988bef30290c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -138,6 +138,14 @@ package object dsl { } } + def castNullable(): Expression = { + if (expr.resolved && expr.nullable) { + expr + } else { + KnownNullable(expr) + } + } + def asc: SortOrder = SortOrder(expr, Ascending) def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Seq.empty) def desc: SortOrder = SortOrder(expr, Descending) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala index 8feaf52ecb134..75d912633a0fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala @@ -30,6 +30,17 @@ trait TaggingExpression extends UnaryExpression { override def eval(input: InternalRow): Any = child.eval(input) } +case class KnownNullable(child: Expression) extends TaggingExpression { + override def nullable: Boolean = true + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.genCode(ctx) + } + + override protected def withNewChildInternal(newChild: Expression): KnownNullable = + copy(child = newChild) +} + case class KnownNotNull(child: Expression) extends TaggingExpression { override def nullable: Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index b3d212716dd9a..076b64cde8c66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -21,6 +21,7 @@ import java.time.LocalDateTime import org.scalatest.BeforeAndAfterEach +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand} import org.apache.spark.sql.functions._ @@ -406,4 +407,64 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession checkAnswer(aggDF, Seq(Row("2016-03-27 19:39:25", "2016-03-27 19:39:40", 2))) } + + test("SPARK-38227: 'start' and 'end' fields should be nullable") { + // We expect the fields in window struct as nullable since the dataType of SessionWindow + // defines them as nullable. The rule 'SessionWindowing' should respect the dataType. + val df1 = Seq( + ("hello", "2016-03-27 09:00:05", 1), + ("structured", "2016-03-27 09:00:32", 2)).toDF("id", "time", "value") + val df2 = Seq( + ("world", LocalDateTime.parse("2016-03-27T09:00:05"), 1), + ("spark", LocalDateTime.parse("2016-03-27T09:00:32"), 2)).toDF("id", "time", "value") + + val udf = spark.udf.register("gapDuration", (s: String) => { + if (s == "hello") { + "1 second" + } else if (s == "structured") { + // zero gap duration will be filtered out from aggregation + "0 second" + } else if (s == "world") { + // negative gap duration will be filtered out from aggregation + "-10 seconds" + } else { + "10 seconds" + } + }) + + def validateWindowColumnInSchema(schema: StructType, colName: String): Unit = { + schema.find(_.name == colName) match { + case Some(StructField(_, st: StructType, _, _)) => + assertFieldInWindowStruct(st, "start") + assertFieldInWindowStruct(st, "end") + + case _ => fail("Failed to find suitable window column from DataFrame!") + } + } + + def assertFieldInWindowStruct(windowType: StructType, fieldName: String): Unit = { + val field = windowType.fields.find(_.name == fieldName) + assert(field.isDefined, s"'$fieldName' field should exist in window struct") + assert(field.get.nullable, s"'$fieldName' field should be nullable") + } + + for { + df <- Seq(df1, df2) + nullable <- Seq(true, false) + } { + val dfWithDesiredNullability = new DataFrame(df.queryExecution, RowEncoder( + StructType(df.schema.fields.map(_.copy(nullable = nullable))))) + // session window without dynamic gap + val windowedProject = dfWithDesiredNullability + .select(session_window($"time", "10 seconds").as("session"), $"value") + val schema = windowedProject.queryExecution.optimizedPlan.schema + validateWindowColumnInSchema(schema, "session") + + // session window with dynamic gap + val windowedProject2 = dfWithDesiredNullability + .select(session_window($"time", udf($"id")).as("session"), $"value") + val schema2 = windowedProject2.queryExecution.optimizedPlan.schema + validateWindowColumnInSchema(schema2, "session") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index e9a145cec01c2..bd39453f5120e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.time.LocalDateTime +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, Filter} import org.apache.spark.sql.functions._ @@ -527,4 +528,51 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession { "when windowDuration is multiple of slideDuration") } } + + test("SPARK-38227: 'start' and 'end' fields should be nullable") { + // We expect the fields in window struct as nullable since the dataType of TimeWindow defines + // them as nullable. The rule 'TimeWindowing' should respect the dataType. + val df1 = Seq( + ("2016-03-27 09:00:05", 1), + ("2016-03-27 09:00:32", 2)).toDF("time", "value") + val df2 = Seq( + (LocalDateTime.parse("2016-03-27T09:00:05"), 1), + (LocalDateTime.parse("2016-03-27T09:00:32"), 2)).toDF("time", "value") + + def validateWindowColumnInSchema(schema: StructType, colName: String): Unit = { + schema.find(_.name == colName) match { + case Some(StructField(_, st: StructType, _, _)) => + assertFieldInWindowStruct(st, "start") + assertFieldInWindowStruct(st, "end") + + case _ => fail("Failed to find suitable window column from DataFrame!") + } + } + + def assertFieldInWindowStruct(windowType: StructType, fieldName: String): Unit = { + val field = windowType.fields.find(_.name == fieldName) + assert(field.isDefined, s"'$fieldName' field should exist in window struct") + assert(field.get.nullable, s"'$fieldName' field should be nullable") + } + + for { + df <- Seq(df1, df2) + nullable <- Seq(true, false) + } { + val dfWithDesiredNullability = new DataFrame(df.queryExecution, RowEncoder( + StructType(df.schema.fields.map(_.copy(nullable = nullable))))) + // tumbling windows + val windowedProject = dfWithDesiredNullability + .select(window($"time", "10 seconds").as("window"), $"value") + val schema = windowedProject.queryExecution.optimizedPlan.schema + validateWindowColumnInSchema(schema, "window") + + // sliding windows + val windowedProject2 = dfWithDesiredNullability + .select(window($"time", "10 seconds", "3 seconds").as("window"), + $"value") + val schema2 = windowedProject2.queryExecution.optimizedPlan.schema + validateWindowColumnInSchema(schema2, "window") + } + } }