Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3893,11 +3893,13 @@ object TimeWindowing extends Rule[LogicalPlan] {
val windowStart = lastStart - i * window.slideDuration
val windowEnd = windowStart + window.windowDuration

// We make sure value fields are nullable since the dataType of TimeWindow defines them
// as nullable.
CreateNamedStruct(
Literal(WINDOW_START) ::
PreciseTimestampConversion(windowStart, LongType, dataType) ::
PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() ::
Literal(WINDOW_END) ::
PreciseTimestampConversion(windowEnd, LongType, dataType) ::
PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() ::
Nil)
}

Expand Down Expand Up @@ -4012,11 +4014,15 @@ object SessionWindowing extends Rule[LogicalPlan] {
val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration,
session.timeColumn.dataType, LongType)

// We make sure value fields are nullable since the dataType of SessionWindow defines them
// as nullable.
val literalSessionStruct = CreateNamedStruct(
Literal(SESSION_START) ::
PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) ::
PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType)
.castNullable() ::
Literal(SESSION_END) ::
PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) ::
PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType)
.castNullable() ::
Nil)

val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ package object dsl {
}
}

def castNullable(): Expression = {
if (expr.resolved && expr.nullable) {
expr
} else {
KnownNullable(expr)
}
}

def asc: SortOrder = SortOrder(expr, Ascending)
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Seq.empty)
def desc: SortOrder = SortOrder(expr, Descending)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ trait TaggingExpression extends UnaryExpression {
override def eval(input: InternalRow): Any = child.eval(input)
}

case class KnownNullable(child: Expression) extends TaggingExpression {
override def nullable: Boolean = true

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.genCode(ctx)
}

override protected def withNewChildInternal(newChild: Expression): KnownNullable =
copy(child = newChild)
}

case class KnownNotNull(child: Expression) extends TaggingExpression {
override def nullable: Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.time.LocalDateTime

import org.scalatest.BeforeAndAfterEach

import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -406,4 +407,64 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession

checkAnswer(aggDF, Seq(Row("2016-03-27 19:39:25", "2016-03-27 19:39:40", 2)))
}

test("SPARK-38227: 'start' and 'end' fields should be nullable") {
// We expect the fields in window struct as nullable since the dataType of SessionWindow
// defines them as nullable. The rule 'SessionWindowing' should respect the dataType.
val df1 = Seq(
("hello", "2016-03-27 09:00:05", 1),
("structured", "2016-03-27 09:00:32", 2)).toDF("id", "time", "value")
val df2 = Seq(
("world", LocalDateTime.parse("2016-03-27T09:00:05"), 1),
("spark", LocalDateTime.parse("2016-03-27T09:00:32"), 2)).toDF("id", "time", "value")

val udf = spark.udf.register("gapDuration", (s: String) => {
if (s == "hello") {
"1 second"
} else if (s == "structured") {
// zero gap duration will be filtered out from aggregation
"0 second"
} else if (s == "world") {
// negative gap duration will be filtered out from aggregation
"-10 seconds"
} else {
"10 seconds"
}
})

def validateWindowColumnInSchema(schema: StructType, colName: String): Unit = {
schema.find(_.name == colName) match {
case Some(StructField(_, st: StructType, _, _)) =>
assertFieldInWindowStruct(st, "start")
assertFieldInWindowStruct(st, "end")

case _ => fail("Failed to find suitable window column from DataFrame!")
}
}

def assertFieldInWindowStruct(windowType: StructType, fieldName: String): Unit = {
val field = windowType.fields.find(_.name == fieldName)
assert(field.isDefined, s"'$fieldName' field should exist in window struct")
assert(field.get.nullable, s"'$fieldName' field should be nullable")
}

for {
df <- Seq(df1, df2)
nullable <- Seq(true, false)
} {
val dfWithDesiredNullability = new DataFrame(df.queryExecution, RowEncoder(
StructType(df.schema.fields.map(_.copy(nullable = nullable)))))
// session window without dynamic gap
val windowedProject = dfWithDesiredNullability
.select(session_window($"time", "10 seconds").as("session"), $"value")
val schema = windowedProject.queryExecution.optimizedPlan.schema
validateWindowColumnInSchema(schema, "session")

// session window with dynamic gap
val windowedProject2 = dfWithDesiredNullability
.select(session_window($"time", udf($"id")).as("session"), $"value")
val schema2 = windowedProject2.queryExecution.optimizedPlan.schema
validateWindowColumnInSchema(schema2, "session")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import java.time.LocalDateTime

import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, Filter}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -527,4 +528,51 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession {
"when windowDuration is multiple of slideDuration")
}
}

test("SPARK-38227: 'start' and 'end' fields should be nullable") {
// We expect the fields in window struct as nullable since the dataType of TimeWindow defines
// them as nullable. The rule 'TimeWindowing' should respect the dataType.
val df1 = Seq(
("2016-03-27 09:00:05", 1),
("2016-03-27 09:00:32", 2)).toDF("time", "value")
val df2 = Seq(
(LocalDateTime.parse("2016-03-27T09:00:05"), 1),
(LocalDateTime.parse("2016-03-27T09:00:32"), 2)).toDF("time", "value")

def validateWindowColumnInSchema(schema: StructType, colName: String): Unit = {
schema.find(_.name == colName) match {
case Some(StructField(_, st: StructType, _, _)) =>
assertFieldInWindowStruct(st, "start")
assertFieldInWindowStruct(st, "end")

case _ => fail("Failed to find suitable window column from DataFrame!")
}
}

def assertFieldInWindowStruct(windowType: StructType, fieldName: String): Unit = {
val field = windowType.fields.find(_.name == fieldName)
assert(field.isDefined, s"'$fieldName' field should exist in window struct")
assert(field.get.nullable, s"'$fieldName' field should be nullable")
}

for {
df <- Seq(df1, df2)
nullable <- Seq(true, false)
} {
val dfWithDesiredNullability = new DataFrame(df.queryExecution, RowEncoder(
StructType(df.schema.fields.map(_.copy(nullable = nullable)))))
// tumbling windows
val windowedProject = dfWithDesiredNullability
.select(window($"time", "10 seconds").as("window"), $"value")
val schema = windowedProject.queryExecution.optimizedPlan.schema
validateWindowColumnInSchema(schema, "window")

// sliding windows
val windowedProject2 = dfWithDesiredNullability
.select(window($"time", "10 seconds", "3 seconds").as("window"),
$"value")
val schema2 = windowedProject2.queryExecution.optimizedPlan.schema
validateWindowColumnInSchema(schema2, "window")
}
}
}