From 96b5d50f3efb97c734f8c370e263a82d34f78d1b Mon Sep 17 00:00:00 2001 From: Alex Balikov <91913242+alex-balikov@users.noreply.github.com> Date: Mon, 24 Oct 2022 08:12:42 +0900 Subject: [PATCH 01/22] [SPARK-40821][SQL][CORE][PYTHON][SS] Introduce window_time function to extract event time from the window column ### What changes were proposed in this pull request? This PR introduces a window_time function to extract streaming event time from a window column produced by the window aggregating operators. This is one step in sequence of fixes required to add support for multiple stateful operators in Spark Structured Streaming as described in https://issues.apache.org/jira/browse/SPARK-40821 ### Why are the changes needed? The window_time function is a convenience function to compute correct event time for a window aggregate records. Such records produced by window aggregating operators have no explicit event time but rather a window column of type StructType { start: TimestampType, end: TimestampType } where start is inclusive and end is exclusive. The correct event time for such record is window.end - 1. The event time is necessary when chaining other stateful operators after the window aggregating operators. ### Does this PR introduce _any_ user-facing change? Yes: The PR introduces a new window_time SQL function for both Scala and Python APIs. ### How was this patch tested? Added new unit tests. Closes #38288 from alex-balikov/SPARK-40821-time-window. Authored-by: Alex Balikov <91913242+alex-balikov@users.noreply.github.com> Signed-off-by: Jungtaek Lim --- .../reference/pyspark.sql/functions.rst | 1 + python/pyspark/sql/functions.py | 46 +++ python/pyspark/sql/tests/test_functions.py | 16 + .../sql/catalyst/analysis/Analyzer.scala | 238 +----------- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../analysis/ResolveTimeWindows.scala | 346 ++++++++++++++++++ .../sql/catalyst/expressions/TimeWindow.scala | 2 + .../sql/catalyst/expressions/WindowTime.scala | 62 ++++ .../org/apache/spark/sql/functions.scala | 17 + .../sql-functions/sql-expression-schema.md | 1 + .../sql/DataFrameTimeWindowingSuite.scala | 62 ++++ 11 files changed, 555 insertions(+), 237 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WindowTime.scala diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index 5a64845598ea5..37ddbaf1673d7 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -142,6 +142,7 @@ Datetime Functions window session_window timestamp_seconds + window_time Collection Functions diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f01379afd6ef9..ad1bc488e876d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -4884,6 +4884,52 @@ def check_string_field(field, fieldName): # type: ignore[no-untyped-def] return _invoke_function("window", time_col, windowDuration) +def window_time( + windowColumn: "ColumnOrName", +) -> Column: + """Computes the event time from a window column. The column window values are produced + by window aggregating operators and are of type `STRUCT` + where start is inclusive and end is exclusive. The event time of records produced by window + aggregating operators can be computed as ``window_time(window)`` and are + ``window.end - lit(1).alias("microsecond")`` (as microsecond is the minimal supported event + time precision). The window column must be one produced by a window aggregating operator. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + windowColumn : :class:`~pyspark.sql.Column` + The window column of a window aggregate records. + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for computed results. + + Examples + -------- + >>> import datetime + >>> df = spark.createDataFrame( + ... [(datetime.datetime(2016, 3, 11, 9, 0, 7), 1)], + ... ).toDF("date", "val") + + Group the data into 5 second time windows and aggregate as sum. + + >>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum")) + + Extract the window event time using the window_time function. + + >>> w.select( + ... w.window.end.cast("string").alias("end"), + ... window_time(w.window).cast("string").alias("window_time"), + ... "sum" + ... ).collect() + [Row(end='2016-03-11 09:00:10', window_time='2016-03-11 09:00:09.999999', sum=1)] + """ + window_col = _to_java_column(windowColumn) + return _invoke_function("window_time", window_col) + + def session_window(timeColumn: "ColumnOrName", gapDuration: Union[Column, str]) -> Column: """ Generates session window given a timestamp specifying column. diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 32cc77e11155b..55ef012b6d021 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -894,6 +894,22 @@ def test_window_functions_cumulative_sum(self): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[: len(r)]) + def test_window_time(self): + df = self.spark.createDataFrame( + [(datetime.datetime(2016, 3, 11, 9, 0, 7), 1)], ["date", "val"] + ) + from pyspark.sql import functions as F + + w = df.groupBy(F.window("date", "5 seconds")).agg(F.sum("val").alias("sum")) + r = w.select( + w.window.end.cast("string").alias("end"), + F.window_time(w.window).cast("string").alias("window_time"), + "sum", + ).collect() + self.assertEqual( + r[0], Row(end="2016-03-11 09:00:10", window_time="2016-03-11 09:00:09.999999", sum=1) + ) + def test_collect_functions(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql import functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b185b38797bb9..fc12b6522b41b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -56,7 +56,6 @@ import org.apache.spark.sql.internal.connector.V1Function import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType.DAY import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils} -import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils import org.apache.spark.util.collection.{Utils => CUtils} @@ -313,6 +312,7 @@ class Analyzer(override val catalogManager: CatalogManager) ResolveAggregateFunctions :: TimeWindowing :: SessionWindowing :: + ResolveWindowTime :: ResolveDefaultColumns(v1SessionCatalog) :: ResolveInlineTables :: ResolveLambdaVariables :: @@ -3965,242 +3965,6 @@ object EliminateEventTimeWatermark extends Rule[LogicalPlan] { } } -/** - * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to - * figure out how many windows a time column can map to, we over-estimate the number of windows and - * filter out the rows where the time column is not inside the time window. - */ -object TimeWindowing extends Rule[LogicalPlan] { - import org.apache.spark.sql.catalyst.dsl.expressions._ - - private final val WINDOW_COL_NAME = "window" - private final val WINDOW_START = "start" - private final val WINDOW_END = "end" - - /** - * Generates the logical plan for generating window ranges on a timestamp column. Without - * knowing what the timestamp value is, it's non-trivial to figure out deterministically how many - * window ranges a timestamp will map to given all possible combinations of a window duration, - * slide duration and start time (offset). Therefore, we express and over-estimate the number of - * windows there may be, and filter the valid windows. We use last Project operator to group - * the window columns into a struct so they can be accessed as `window.start` and `window.end`. - * - * The windows are calculated as below: - * maxNumOverlapping <- ceil(windowDuration / slideDuration) - * for (i <- 0 until maxNumOverlapping) - * lastStart <- timestamp - (timestamp - startTime + slideDuration) % slideDuration - * windowStart <- lastStart - i * slideDuration - * windowEnd <- windowStart + windowDuration - * return windowStart, windowEnd - * - * This behaves as follows for the given parameters for the time: 12:05. The valid windows are - * marked with a +, and invalid ones are marked with a x. The invalid ones are filtered using the - * Filter operator. - * window: 12m, slide: 5m, start: 0m :: window: 12m, slide: 5m, start: 2m - * 11:55 - 12:07 + 11:52 - 12:04 x - * 12:00 - 12:12 + 11:57 - 12:09 + - * 12:05 - 12:17 + 12:02 - 12:14 + - * - * @param plan The logical plan - * @return the logical plan that will generate the time windows using the Expand operator, with - * the Filter operator for correctness and Project for usability. - */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - _.containsPattern(TIME_WINDOW), ruleId) { - case p: LogicalPlan if p.children.size == 1 => - val child = p.children.head - val windowExpressions = - p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet - - val numWindowExpr = p.expressions.flatMap(_.collect { - case s: SessionWindow => s - case t: TimeWindow => t - }).toSet.size - - // Only support a single window expression for now - if (numWindowExpr == 1 && windowExpressions.nonEmpty && - windowExpressions.head.timeColumn.resolved && - windowExpressions.head.checkInputDataTypes().isSuccess) { - - val window = windowExpressions.head - - val metadata = window.timeColumn match { - case a: Attribute => a.metadata - case _ => Metadata.empty - } - - def getWindow(i: Int, dataType: DataType): Expression = { - val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType) - val lastStart = timestamp - (timestamp - window.startTime - + window.slideDuration) % window.slideDuration - val windowStart = lastStart - i * window.slideDuration - val windowEnd = windowStart + window.windowDuration - - // We make sure value fields are nullable since the dataType of TimeWindow defines them - // as nullable. - CreateNamedStruct( - Literal(WINDOW_START) :: - PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() :: - Literal(WINDOW_END) :: - PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() :: - Nil) - } - - val windowAttr = AttributeReference( - WINDOW_COL_NAME, window.dataType, metadata = metadata)() - - if (window.windowDuration == window.slideDuration) { - val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)( - exprId = windowAttr.exprId, explicitMetadata = Some(metadata)) - - val replacedPlan = p transformExpressions { - case t: TimeWindow => windowAttr - } - - // For backwards compatibility we add a filter to filter out nulls - val filterExpr = IsNotNull(window.timeColumn) - - replacedPlan.withNewChildren( - Project(windowStruct +: child.output, - Filter(filterExpr, child)) :: Nil) - } else { - val overlappingWindows = - math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt - val windows = - Seq.tabulate(overlappingWindows)(i => - getWindow(i, window.timeColumn.dataType)) - - val projections = windows.map(_ +: child.output) - - // When the condition windowDuration % slideDuration = 0 is fulfilled, - // the estimation of the number of windows becomes exact one, - // which means all produced windows are valid. - val filterExpr = - if (window.windowDuration % window.slideDuration == 0) { - IsNotNull(window.timeColumn) - } else { - window.timeColumn >= windowAttr.getField(WINDOW_START) && - window.timeColumn < windowAttr.getField(WINDOW_END) - } - - val substitutedPlan = Filter(filterExpr, - Expand(projections, windowAttr +: child.output, child)) - - val renamedPlan = p transformExpressions { - case t: TimeWindow => windowAttr - } - - renamedPlan.withNewChildren(substitutedPlan :: Nil) - } - } else if (numWindowExpr > 1) { - throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) - } else { - p // Return unchanged. Analyzer will throw exception later - } - } -} - -/** Maps a time column to a session window. */ -object SessionWindowing extends Rule[LogicalPlan] { - import org.apache.spark.sql.catalyst.dsl.expressions._ - - private final val SESSION_COL_NAME = "session_window" - private final val SESSION_START = "start" - private final val SESSION_END = "end" - - /** - * Generates the logical plan for generating session window on a timestamp column. - * Each session window is initially defined as [timestamp, timestamp + gap). - * - * This also adds a marker to the session column so that downstream can easily find the column - * on session window. - */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case p: LogicalPlan if p.children.size == 1 => - val child = p.children.head - val sessionExpressions = - p.expressions.flatMap(_.collect { case s: SessionWindow => s }).toSet - - val numWindowExpr = p.expressions.flatMap(_.collect { - case s: SessionWindow => s - case t: TimeWindow => t - }).toSet.size - - // Only support a single session expression for now - if (numWindowExpr == 1 && sessionExpressions.nonEmpty && - sessionExpressions.head.timeColumn.resolved && - sessionExpressions.head.checkInputDataTypes().isSuccess) { - - val session = sessionExpressions.head - - val metadata = session.timeColumn match { - case a: Attribute => a.metadata - case _ => Metadata.empty - } - - val newMetadata = new MetadataBuilder() - .withMetadata(metadata) - .putBoolean(SessionWindow.marker, true) - .build() - - val sessionAttr = AttributeReference( - SESSION_COL_NAME, session.dataType, metadata = newMetadata)() - - val sessionStart = - PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType) - val gapDuration = session.gapDuration match { - case expr if Cast.canCast(expr.dataType, CalendarIntervalType) => - Cast(expr, CalendarIntervalType) - case other => - throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType) - } - val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration, - session.timeColumn.dataType, LongType) - - // We make sure value fields are nullable since the dataType of SessionWindow defines them - // as nullable. - val literalSessionStruct = CreateNamedStruct( - Literal(SESSION_START) :: - PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) - .castNullable() :: - Literal(SESSION_END) :: - PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) - .castNullable() :: - Nil) - - val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( - exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata)) - - val replacedPlan = p transformExpressions { - case s: SessionWindow => sessionAttr - } - - val filterByTimeRange = session.gapDuration match { - case Literal(interval: CalendarInterval, CalendarIntervalType) => - interval == null || interval.months + interval.days + interval.microseconds <= 0 - case _ => true - } - - // As same as tumbling window, we add a filter to filter out nulls. - // And we also filter out events with negative or zero or invalid gap duration. - val filterExpr = if (filterByTimeRange) { - IsNotNull(session.timeColumn) && - (sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START)) - } else { - IsNotNull(session.timeColumn) - } - - replacedPlan.withNewChildren( - Filter(filterExpr, - Project(sessionStruct +: child.output, child)) :: Nil) - } else if (numWindowExpr > 1) { - throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) - } else { - p // Return unchanged. Analyzer will throw exception later - } - } -} - /** * Resolve expressions if they contains [[NamePlaceholder]]s. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ef8ce3f48d5a5..f5e494e909671 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -639,6 +639,7 @@ object FunctionRegistry { expression[Year]("year"), expression[TimeWindow]("window"), expression[SessionWindow]("session_window"), + expression[WindowTime]("window_time"), expression[MakeDate]("make_date"), expression[MakeTimestamp]("make_timestamp"), // We keep the 2 expression builders below to have different function docs. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala new file mode 100644 index 0000000000000..fd5da3ff13d88 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Cast, CreateNamedStruct, Expression, GetStructField, IsNotNull, Literal, PreciseTimestampConversion, SessionWindow, Subtract, TimeWindow, WindowTime} +import org.apache.spark.sql.catalyst.plans.logical.{Expand, Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.TIME_WINDOW +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.{CalendarIntervalType, DataType, LongType, Metadata, MetadataBuilder, StructType} +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to + * figure out how many windows a time column can map to, we over-estimate the number of windows and + * filter out the rows where the time column is not inside the time window. + */ +object TimeWindowing extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private final val WINDOW_COL_NAME = "window" + private final val WINDOW_START = "start" + private final val WINDOW_END = "end" + + /** + * Generates the logical plan for generating window ranges on a timestamp column. Without + * knowing what the timestamp value is, it's non-trivial to figure out deterministically how many + * window ranges a timestamp will map to given all possible combinations of a window duration, + * slide duration and start time (offset). Therefore, we express and over-estimate the number of + * windows there may be, and filter the valid windows. We use last Project operator to group + * the window columns into a struct so they can be accessed as `window.start` and `window.end`. + * + * The windows are calculated as below: + * maxNumOverlapping <- ceil(windowDuration / slideDuration) + * for (i <- 0 until maxNumOverlapping) + * lastStart <- timestamp - (timestamp - startTime + slideDuration) % slideDuration + * windowStart <- lastStart - i * slideDuration + * windowEnd <- windowStart + windowDuration + * return windowStart, windowEnd + * + * This behaves as follows for the given parameters for the time: 12:05. The valid windows are + * marked with a +, and invalid ones are marked with a x. The invalid ones are filtered using the + * Filter operator. + * window: 12m, slide: 5m, start: 0m :: window: 12m, slide: 5m, start: 2m + * 11:55 - 12:07 + 11:52 - 12:04 x + * 12:00 - 12:12 + 11:57 - 12:09 + + * 12:05 - 12:17 + 12:02 - 12:14 + + * + * @param plan The logical plan + * @return the logical plan that will generate the time windows using the Expand operator, with + * the Filter operator for correctness and Project for usability. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( + _.containsPattern(TIME_WINDOW), ruleId) { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val windowExpressions = + p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet + + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + + // Only support a single window expression for now + if (numWindowExpr == 1 && windowExpressions.nonEmpty && + windowExpressions.head.timeColumn.resolved && + windowExpressions.head.checkInputDataTypes().isSuccess) { + + val window = windowExpressions.head + + if (StructType.acceptsType(window.timeColumn.dataType)) { + return p.transformExpressions { + case t: TimeWindow => t.copy(timeColumn = WindowTime(window.timeColumn)) + } + } + + val metadata = window.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putBoolean(TimeWindow.marker, true) + .build() + + def getWindow(i: Int, dataType: DataType): Expression = { + val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType) + val lastStart = timestamp - (timestamp - window.startTime + + window.slideDuration) % window.slideDuration + val windowStart = lastStart - i * window.slideDuration + val windowEnd = windowStart + window.windowDuration + + // We make sure value fields are nullable since the dataType of TimeWindow defines them + // as nullable. + CreateNamedStruct( + Literal(WINDOW_START) :: + PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() :: + Literal(WINDOW_END) :: + PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() :: + Nil) + } + + val windowAttr = AttributeReference( + WINDOW_COL_NAME, window.dataType, metadata = newMetadata)() + + if (window.windowDuration == window.slideDuration) { + val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)( + exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata)) + + val replacedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } + + // For backwards compatibility we add a filter to filter out nulls + val filterExpr = IsNotNull(window.timeColumn) + + replacedPlan.withNewChildren( + Project(windowStruct +: child.output, + Filter(filterExpr, child)) :: Nil) + } else { + val overlappingWindows = + math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt + val windows = + Seq.tabulate(overlappingWindows)(i => + getWindow(i, window.timeColumn.dataType)) + + val projections = windows.map(_ +: child.output) + + // When the condition windowDuration % slideDuration = 0 is fulfilled, + // the estimation of the number of windows becomes exact one, + // which means all produced windows are valid. + val filterExpr = + if (window.windowDuration % window.slideDuration == 0) { + IsNotNull(window.timeColumn) + } else { + window.timeColumn >= windowAttr.getField(WINDOW_START) && + window.timeColumn < windowAttr.getField(WINDOW_END) + } + + val substitutedPlan = Filter(filterExpr, + Expand(projections, windowAttr +: child.output, child)) + + val renamedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } + + renamedPlan.withNewChildren(substitutedPlan :: Nil) + } + } else if (numWindowExpr > 1) { + throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) + } else { + p // Return unchanged. Analyzer will throw exception later + } + } +} + +/** Maps a time column to a session window. */ +object SessionWindowing extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private final val SESSION_COL_NAME = "session_window" + private final val SESSION_START = "start" + private final val SESSION_END = "end" + + /** + * Generates the logical plan for generating session window on a timestamp column. + * Each session window is initially defined as [timestamp, timestamp + gap). + * + * This also adds a marker to the session column so that downstream can easily find the column + * on session window. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val sessionExpressions = + p.expressions.flatMap(_.collect { case s: SessionWindow => s }).toSet + + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + + // Only support a single session expression for now + if (numWindowExpr == 1 && sessionExpressions.nonEmpty && + sessionExpressions.head.timeColumn.resolved && + sessionExpressions.head.checkInputDataTypes().isSuccess) { + + val session = sessionExpressions.head + + if (StructType.acceptsType(session.timeColumn.dataType)) { + return p transformExpressions { + case t: SessionWindow => t.copy(timeColumn = WindowTime(session.timeColumn)) + } + } + + val metadata = session.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putBoolean(SessionWindow.marker, true) + .build() + + val sessionAttr = AttributeReference( + SESSION_COL_NAME, session.dataType, metadata = newMetadata)() + + val sessionStart = + PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType) + val gapDuration = session.gapDuration match { + case expr if Cast.canCast(expr.dataType, CalendarIntervalType) => + Cast(expr, CalendarIntervalType) + case other => + throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType) + } + val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration, + session.timeColumn.dataType, LongType) + + // We make sure value fields are nullable since the dataType of SessionWindow defines them + // as nullable. + val literalSessionStruct = CreateNamedStruct( + Literal(SESSION_START) :: + PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) + .castNullable() :: + Literal(SESSION_END) :: + PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) + .castNullable() :: + Nil) + + val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( + exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata)) + + val replacedPlan = p transformExpressions { + case s: SessionWindow => sessionAttr + } + + val filterByTimeRange = session.gapDuration match { + case Literal(interval: CalendarInterval, CalendarIntervalType) => + interval == null || interval.months + interval.days + interval.microseconds <= 0 + case _ => true + } + + // As same as tumbling window, we add a filter to filter out nulls. + // And we also filter out events with negative or zero or invalid gap duration. + val filterExpr = if (filterByTimeRange) { + IsNotNull(session.timeColumn) && + (sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START)) + } else { + IsNotNull(session.timeColumn) + } + + replacedPlan.withNewChildren( + Filter(filterExpr, + Project(sessionStruct +: child.output, child)) :: Nil) + } else if (numWindowExpr > 1) { + throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) + } else { + p // Return unchanged. Analyzer will throw exception later + } + } +} + +/** + * Resolves the window_time expression which extracts the correct window time from the + * window column generated as the output of the window aggregating operators. The + * window column is of type struct { start: TimestampType, end: TimestampType }. + * The correct representative event time of a window is ``window.end - 1``. + * */ +object ResolveWindowTime extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val windowTimeExpressions = + p.expressions.flatMap(_.collect { case w: WindowTime => w }).toSet + + if (windowTimeExpressions.size == 1 && + windowTimeExpressions.head.windowColumn.resolved && + windowTimeExpressions.head.checkInputDataTypes().isSuccess) { + + val windowTime = windowTimeExpressions.head + + val metadata = windowTime.windowColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + + if (!metadata.contains(TimeWindow.marker) && + !metadata.contains(SessionWindow.marker)) { + // FIXME: error framework? + throw new AnalysisException( + "The input is not a correct window column: $windowTime", plan = Some(p)) + } + + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .remove(TimeWindow.marker) + .remove(SessionWindow.marker) + .build() + + val attr = AttributeReference( + "window_time", windowTime.dataType, metadata = newMetadata)() + + // NOTE: "window.end" is "exclusive" upper bound of window, so if we use this value as + // it is, it is going to be bound to the different window even if we apply the same window + // spec. Decrease 1 microsecond from window.end to let the window_time be bound to the + // correct window range. + val subtractExpr = + PreciseTimestampConversion( + Subtract(PreciseTimestampConversion( + GetStructField(windowTime.windowColumn, 1), + windowTime.dataType, LongType), Literal(1L)), + LongType, + windowTime.dataType) + + val newColumn = Alias(subtractExpr, "window_time")( + exprId = attr.exprId, explicitMetadata = Some(newMetadata)) + + val replacedPlan = p transformExpressions { + case w: WindowTime => attr + } + + replacedPlan.withNewChildren(Project(newColumn +: child.output, child) :: Nil) + } else { + p // Return unchanged. Analyzer will throw exception later + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index d7deca2f7b765..53c79d1fd54bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -136,6 +136,8 @@ case class TimeWindow( } object TimeWindow { + val marker = "spark.timeWindow" + /** * Parses the interval string for a valid time duration. CalendarInterval expects interval * strings to start with the string `interval`. For usability, we prepend `interval` to the string diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WindowTime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WindowTime.scala new file mode 100644 index 0000000000000..effc1506d741a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WindowTime.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types._ + +// scalastyle:off line.size.limit line.contains.tab +@ExpressionDescription( + usage = """ + _FUNC_(window_column) - Extract the time value from time/session window column which can be used for event time value of window. + The extracted time is (window.end - 1) which reflects the fact that the the aggregating + windows have exclusive upper bound - [start, end) + See 'Window Operations on Event Time' in Structured Streaming guide doc for detailed explanation and examples. + """, + arguments = """ + Arguments: + * window_column - The column representing time/session window. + """, + examples = """ + Examples: + > SELECT a, window.start as start, window.end as end, _FUNC_(window), cnt FROM (SELECT a, window, count(*) as cnt FROM VALUES ('A1', '2021-01-01 00:00:00'), ('A1', '2021-01-01 00:04:30'), ('A1', '2021-01-01 00:06:00'), ('A2', '2021-01-01 00:01:00') AS tab(a, b) GROUP by a, window(b, '5 minutes') ORDER BY a, window.start); + A1 2021-01-01 00:00:00 2021-01-01 00:05:00 2021-01-01 00:04:59.999999 2 + A1 2021-01-01 00:05:00 2021-01-01 00:10:00 2021-01-01 00:09:59.999999 1 + A2 2021-01-01 00:00:00 2021-01-01 00:05:00 2021-01-01 00:04:59.999999 1 + """, + group = "datetime_funcs", + since = "3.3.0") +// scalastyle:on line.size.limit line.contains.tab +case class WindowTime(windowColumn: Expression) + extends UnaryExpression + with ImplicitCastInputTypes + with Unevaluable + with NonSQLExpression { + + override def child: Expression = windowColumn + override def inputTypes: Seq[AbstractDataType] = Seq(StructType) + + override def dataType: DataType = child.dataType.asInstanceOf[StructType].head.dataType + + override def prettyName: String = "window_time" + + // This expression is replaced in the analyzer. + override lazy val resolved = false + + override protected def withNewChildInternal(newChild: Expression): WindowTime = + copy(windowColumn = newChild) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 620e1c6072172..780bf925ad7e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3777,6 +3777,23 @@ object functions { window(timeColumn, windowDuration, windowDuration, "0 second") } + /** + * Extracts the event time from the window column. + * + * The window column is of StructType { start: Timestamp, end: Timestamp } where start is + * inclusive and end is exclusive. Since event time can support microsecond precision, + * window_time(window) = window.end - 1 microsecond. + * + * @param windowColumn The window column (typically produced by window aggregation) of type + * StructType { start: Timestamp, end: Timestamp } + * + * @group datetime_funcs + * @since 3.3.0 + */ + def window_time(windowColumn: Column): Column = withExpr { + WindowTime(windowColumn.expr) + } + /** * Generates session window given a timestamp specifying column. * diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 4ce4f1225ce64..6f111b777a6d0 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -345,6 +345,7 @@ | org.apache.spark.sql.catalyst.expressions.WeekDay | weekday | SELECT weekday('2009-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.WeekOfYear | weekofyear | SELECT weekofyear('2008-02-20') | struct | | org.apache.spark.sql.catalyst.expressions.WidthBucket | width_bucket | SELECT width_bucket(5.3, 0.2, 10.6, 5) | struct | +| org.apache.spark.sql.catalyst.expressions.WindowTime | window_time | SELECT a, window.start as start, window.end as end, window_time(window), cnt FROM (SELECT a, window, count(*) as cnt FROM VALUES ('A1', '2021-01-01 00:00:00'), ('A1', '2021-01-01 00:04:30'), ('A1', '2021-01-01 00:06:00'), ('A2', '2021-01-01 00:01:00') AS tab(a, b) GROUP by a, window(b, '5 minutes') ORDER BY a, window.start) | struct | | org.apache.spark.sql.catalyst.expressions.XxHash64 | xxhash64 | SELECT xxhash64('Spark', array(123), 2) | struct | | org.apache.spark.sql.catalyst.expressions.Year | year | SELECT year('2016-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.ZipWith | zip_with | SELECT zip_with(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x)) | struct>> | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index bd39453f5120e..f775eb9ecfc0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -575,4 +575,66 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession { validateWindowColumnInSchema(schema2, "window") } } + + test("window_time function on raw window column") { + val df = Seq( + ("2016-03-27 19:38:18"), ("2016-03-27 19:39:25") + ).toDF("time") + + checkAnswer( + df.select(window($"time", "10 seconds").as("window")) + .select( + $"window.end".cast("string"), + window_time($"window").cast("string") + ), + Seq( + Row("2016-03-27 19:38:20", "2016-03-27 19:38:19.999999"), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:29.999999") + ) + ) + } + + test("2 window_time functions on raw window column") { + val df = Seq( + ("2016-03-27 19:38:18"), ("2016-03-27 19:39:25") + ).toDF("time") + + val e = intercept[AnalysisException] { + df + .withColumn("time2", expr("time - INTERVAL 5 minutes")) + .select( + window($"time", "10 seconds").as("window1"), + window($"time2", "10 seconds").as("window2") + ) + .select( + $"window1.end".cast("string"), + window_time($"window1").cast("string"), + $"window2.end".cast("string"), + window_time($"window2").cast("string") + ) + } + assert(e.getMessage.contains( + "Multiple time/session window expressions would result in a cartesian product of rows, " + + "therefore they are currently not supported")) + } + + test("window_time function on agg output") { + val df = Seq( + ("2016-03-27 19:38:19", 1), ("2016-03-27 19:39:25", 2) + ).toDF("time", "value") + checkAnswer( + df.groupBy(window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select( + $"window.start".cast("string"), + $"window.end".cast("string"), + window_time($"window").cast("string"), + $"counts"), + Seq( + Row("2016-03-27 19:38:10", "2016-03-27 19:38:20", "2016-03-27 19:38:19.999999", 1), + Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", "2016-03-27 19:39:29.999999", 1) + ) + ) + } } From 02a2242a45062755bf7e20805958d5bdf1f5ed74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn?= Date: Mon, 24 Oct 2022 10:32:18 +0900 Subject: [PATCH 02/22] [SPARK-40884][BUILD] Upgrade fabric8io - `kubernetes-client` to 6.2.0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Upgrade fabric8io - kubernetes-client from 6.1.1 to 6.2.0 ### Why are the changes needed? [Release notes](https://github.com/fabric8io/kubernetes-client/releases/tag/v6.2.0) [Snakeyaml version should be updated to mitigate CVE-2022-28857](https://github.com/fabric8io/kubernetes-client/issues/4383) ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA Closes #38348 from bjornjorgensen/kubernetes-client6.2.0. Authored-by: Bjørn Signed-off-by: Hyukjin Kwon --- dev/deps/spark-deps-hadoop-2-hive-2.3 | 48 +++++++++++++-------------- dev/deps/spark-deps-hadoop-3-hive-2.3 | 48 +++++++++++++-------------- pom.xml | 2 +- 3 files changed, 49 insertions(+), 49 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2-hive-2.3 b/dev/deps/spark-deps-hadoop-2-hive-2.3 index 6756dd5831277..2c1eab56f3331 100644 --- a/dev/deps/spark-deps-hadoop-2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2-hive-2.3 @@ -160,30 +160,30 @@ jsr305/3.0.0//jsr305-3.0.0.jar jta/1.1//jta-1.1.jar jul-to-slf4j/2.0.3//jul-to-slf4j-2.0.3.jar kryo-shaded/4.0.2//kryo-shaded-4.0.2.jar -kubernetes-client-api/6.1.1//kubernetes-client-api-6.1.1.jar -kubernetes-client/6.1.1//kubernetes-client-6.1.1.jar -kubernetes-httpclient-okhttp/6.1.1//kubernetes-httpclient-okhttp-6.1.1.jar -kubernetes-model-admissionregistration/6.1.1//kubernetes-model-admissionregistration-6.1.1.jar -kubernetes-model-apiextensions/6.1.1//kubernetes-model-apiextensions-6.1.1.jar -kubernetes-model-apps/6.1.1//kubernetes-model-apps-6.1.1.jar -kubernetes-model-autoscaling/6.1.1//kubernetes-model-autoscaling-6.1.1.jar -kubernetes-model-batch/6.1.1//kubernetes-model-batch-6.1.1.jar -kubernetes-model-certificates/6.1.1//kubernetes-model-certificates-6.1.1.jar -kubernetes-model-common/6.1.1//kubernetes-model-common-6.1.1.jar -kubernetes-model-coordination/6.1.1//kubernetes-model-coordination-6.1.1.jar -kubernetes-model-core/6.1.1//kubernetes-model-core-6.1.1.jar -kubernetes-model-discovery/6.1.1//kubernetes-model-discovery-6.1.1.jar -kubernetes-model-events/6.1.1//kubernetes-model-events-6.1.1.jar -kubernetes-model-extensions/6.1.1//kubernetes-model-extensions-6.1.1.jar -kubernetes-model-flowcontrol/6.1.1//kubernetes-model-flowcontrol-6.1.1.jar -kubernetes-model-gatewayapi/6.1.1//kubernetes-model-gatewayapi-6.1.1.jar -kubernetes-model-metrics/6.1.1//kubernetes-model-metrics-6.1.1.jar -kubernetes-model-networking/6.1.1//kubernetes-model-networking-6.1.1.jar -kubernetes-model-node/6.1.1//kubernetes-model-node-6.1.1.jar -kubernetes-model-policy/6.1.1//kubernetes-model-policy-6.1.1.jar -kubernetes-model-rbac/6.1.1//kubernetes-model-rbac-6.1.1.jar -kubernetes-model-scheduling/6.1.1//kubernetes-model-scheduling-6.1.1.jar -kubernetes-model-storageclass/6.1.1//kubernetes-model-storageclass-6.1.1.jar +kubernetes-client-api/6.2.0//kubernetes-client-api-6.2.0.jar +kubernetes-client/6.2.0//kubernetes-client-6.2.0.jar +kubernetes-httpclient-okhttp/6.2.0//kubernetes-httpclient-okhttp-6.2.0.jar +kubernetes-model-admissionregistration/6.2.0//kubernetes-model-admissionregistration-6.2.0.jar +kubernetes-model-apiextensions/6.2.0//kubernetes-model-apiextensions-6.2.0.jar +kubernetes-model-apps/6.2.0//kubernetes-model-apps-6.2.0.jar +kubernetes-model-autoscaling/6.2.0//kubernetes-model-autoscaling-6.2.0.jar +kubernetes-model-batch/6.2.0//kubernetes-model-batch-6.2.0.jar +kubernetes-model-certificates/6.2.0//kubernetes-model-certificates-6.2.0.jar +kubernetes-model-common/6.2.0//kubernetes-model-common-6.2.0.jar +kubernetes-model-coordination/6.2.0//kubernetes-model-coordination-6.2.0.jar +kubernetes-model-core/6.2.0//kubernetes-model-core-6.2.0.jar +kubernetes-model-discovery/6.2.0//kubernetes-model-discovery-6.2.0.jar +kubernetes-model-events/6.2.0//kubernetes-model-events-6.2.0.jar +kubernetes-model-extensions/6.2.0//kubernetes-model-extensions-6.2.0.jar +kubernetes-model-flowcontrol/6.2.0//kubernetes-model-flowcontrol-6.2.0.jar +kubernetes-model-gatewayapi/6.2.0//kubernetes-model-gatewayapi-6.2.0.jar +kubernetes-model-metrics/6.2.0//kubernetes-model-metrics-6.2.0.jar +kubernetes-model-networking/6.2.0//kubernetes-model-networking-6.2.0.jar +kubernetes-model-node/6.2.0//kubernetes-model-node-6.2.0.jar +kubernetes-model-policy/6.2.0//kubernetes-model-policy-6.2.0.jar +kubernetes-model-rbac/6.2.0//kubernetes-model-rbac-6.2.0.jar +kubernetes-model-scheduling/6.2.0//kubernetes-model-scheduling-6.2.0.jar +kubernetes-model-storageclass/6.2.0//kubernetes-model-storageclass-6.2.0.jar lapack/3.0.2//lapack-3.0.2.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index d29a10c1230b1..c7f4e02d4dcb1 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -144,30 +144,30 @@ jsr305/3.0.0//jsr305-3.0.0.jar jta/1.1//jta-1.1.jar jul-to-slf4j/2.0.3//jul-to-slf4j-2.0.3.jar kryo-shaded/4.0.2//kryo-shaded-4.0.2.jar -kubernetes-client-api/6.1.1//kubernetes-client-api-6.1.1.jar -kubernetes-client/6.1.1//kubernetes-client-6.1.1.jar -kubernetes-httpclient-okhttp/6.1.1//kubernetes-httpclient-okhttp-6.1.1.jar -kubernetes-model-admissionregistration/6.1.1//kubernetes-model-admissionregistration-6.1.1.jar -kubernetes-model-apiextensions/6.1.1//kubernetes-model-apiextensions-6.1.1.jar -kubernetes-model-apps/6.1.1//kubernetes-model-apps-6.1.1.jar -kubernetes-model-autoscaling/6.1.1//kubernetes-model-autoscaling-6.1.1.jar -kubernetes-model-batch/6.1.1//kubernetes-model-batch-6.1.1.jar -kubernetes-model-certificates/6.1.1//kubernetes-model-certificates-6.1.1.jar -kubernetes-model-common/6.1.1//kubernetes-model-common-6.1.1.jar -kubernetes-model-coordination/6.1.1//kubernetes-model-coordination-6.1.1.jar -kubernetes-model-core/6.1.1//kubernetes-model-core-6.1.1.jar -kubernetes-model-discovery/6.1.1//kubernetes-model-discovery-6.1.1.jar -kubernetes-model-events/6.1.1//kubernetes-model-events-6.1.1.jar -kubernetes-model-extensions/6.1.1//kubernetes-model-extensions-6.1.1.jar -kubernetes-model-flowcontrol/6.1.1//kubernetes-model-flowcontrol-6.1.1.jar -kubernetes-model-gatewayapi/6.1.1//kubernetes-model-gatewayapi-6.1.1.jar -kubernetes-model-metrics/6.1.1//kubernetes-model-metrics-6.1.1.jar -kubernetes-model-networking/6.1.1//kubernetes-model-networking-6.1.1.jar -kubernetes-model-node/6.1.1//kubernetes-model-node-6.1.1.jar -kubernetes-model-policy/6.1.1//kubernetes-model-policy-6.1.1.jar -kubernetes-model-rbac/6.1.1//kubernetes-model-rbac-6.1.1.jar -kubernetes-model-scheduling/6.1.1//kubernetes-model-scheduling-6.1.1.jar -kubernetes-model-storageclass/6.1.1//kubernetes-model-storageclass-6.1.1.jar +kubernetes-client-api/6.2.0//kubernetes-client-api-6.2.0.jar +kubernetes-client/6.2.0//kubernetes-client-6.2.0.jar +kubernetes-httpclient-okhttp/6.2.0//kubernetes-httpclient-okhttp-6.2.0.jar +kubernetes-model-admissionregistration/6.2.0//kubernetes-model-admissionregistration-6.2.0.jar +kubernetes-model-apiextensions/6.2.0//kubernetes-model-apiextensions-6.2.0.jar +kubernetes-model-apps/6.2.0//kubernetes-model-apps-6.2.0.jar +kubernetes-model-autoscaling/6.2.0//kubernetes-model-autoscaling-6.2.0.jar +kubernetes-model-batch/6.2.0//kubernetes-model-batch-6.2.0.jar +kubernetes-model-certificates/6.2.0//kubernetes-model-certificates-6.2.0.jar +kubernetes-model-common/6.2.0//kubernetes-model-common-6.2.0.jar +kubernetes-model-coordination/6.2.0//kubernetes-model-coordination-6.2.0.jar +kubernetes-model-core/6.2.0//kubernetes-model-core-6.2.0.jar +kubernetes-model-discovery/6.2.0//kubernetes-model-discovery-6.2.0.jar +kubernetes-model-events/6.2.0//kubernetes-model-events-6.2.0.jar +kubernetes-model-extensions/6.2.0//kubernetes-model-extensions-6.2.0.jar +kubernetes-model-flowcontrol/6.2.0//kubernetes-model-flowcontrol-6.2.0.jar +kubernetes-model-gatewayapi/6.2.0//kubernetes-model-gatewayapi-6.2.0.jar +kubernetes-model-metrics/6.2.0//kubernetes-model-metrics-6.2.0.jar +kubernetes-model-networking/6.2.0//kubernetes-model-networking-6.2.0.jar +kubernetes-model-node/6.2.0//kubernetes-model-node-6.2.0.jar +kubernetes-model-policy/6.2.0//kubernetes-model-policy-6.2.0.jar +kubernetes-model-rbac/6.2.0//kubernetes-model-rbac-6.2.0.jar +kubernetes-model-scheduling/6.2.0//kubernetes-model-scheduling-6.2.0.jar +kubernetes-model-storageclass/6.2.0//kubernetes-model-storageclass-6.2.0.jar lapack/3.0.2//lapack-3.0.2.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar diff --git a/pom.xml b/pom.xml index 78936392b85d9..c7efd8f9f61bf 100644 --- a/pom.xml +++ b/pom.xml @@ -219,7 +219,7 @@ 9.0.0 org.fusesource.leveldbjni - 6.1.1 + 6.2.0 ${java.home} From 5d3b1e6ed549ce9b6aa1f13ff60bce290c2b5160 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 24 Oct 2022 10:33:04 +0900 Subject: [PATCH 03/22] [SPARK-40877][SQL] Reimplement `crosstab` with dataframe operations ### What changes were proposed in this pull request? Reimplement `crosstab` with dataframe operations ### Why are the changes needed? 1, do not truncate the sql plan; 2, much more scalable; 3, existing implementation (added in v1.5.0) collect distinct `col1, col2` pairs to driver, while `pivot` (added in v2.4.0) only collect distinct `col2` which is much smaller; ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing UTs and manually check Closes #38340 from zhengruifeng/sql_stat_crosstab. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .../sql/execution/stat/StatFunctions.scala | 50 +++---------------- 1 file changed, 8 insertions(+), 42 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index c9d3b99990830..484be76b99156 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.collection.Utils object StatFunctions extends Logging { @@ -188,47 +187,14 @@ object StatFunctions extends Logging { /** Generate a table of frequencies for the elements of two columns. */ def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { - val tableName = s"${col1}_$col2" - val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt) - if (counts.length == 1e6.toInt) { - logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " + - "the pairs. Please try reducing the amount of distinct items in your columns.") - } - def cleanElement(element: Any): String = { - if (element == null) "null" else element.toString - } - // get the distinct sorted values of column 2, so that we can make them the column names - val distinctCol2: Map[Any, Int] = - Utils.toMapWithIndex(counts.map(e => cleanElement(e.get(1))).distinct.sorted) - val columnSize = distinctCol2.size - require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + - s"exceed 1e4. Currently $columnSize") - val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) => - val countsRow = new GenericInternalRow(columnSize + 1) - rows.foreach { (row: Row) => - // row.get(0) is column 1 - // row.get(1) is column 2 - // row.get(2) is the frequency - val columnIndex = distinctCol2(cleanElement(row.get(1))) - countsRow.setLong(columnIndex + 1, row.getLong(2)) - } - // the value of col1 is the first value, the rest are the counts - countsRow.update(0, UTF8String.fromString(cleanElement(col1Item))) - countsRow - }.toSeq - // Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept - // special keywords and `.`, wrap the column names in ``. - def cleanColumnName(name: String): String = { - name.replace("`", "") - } - // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in - // SPARK-8681. We need to explicitly sort by the column index and assign the column names. - val headerNames = distinctCol2.toSeq.sortBy(_._2).map { r => - StructField(cleanColumnName(r._1.toString), LongType) - } - val schema = StructType(StructField(tableName, StringType) +: headerNames) - - Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0) + df.groupBy( + when(isnull(col(col1)), "null") + .otherwise(col(col1).cast("string")) + .as(s"${col1}_$col2") + ).pivot( + when(isnull(col(col2)), "null") + .otherwise(regexp_replace(col(col2).cast("string"), "`", "")) + ).count().na.fill(0L) } /** Calculate selected summary statistics for a dataset */ From 6a0713a141fa98d83029d8388508cbbc40fd554e Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 24 Oct 2022 10:58:13 +0900 Subject: [PATCH 04/22] [SPARK-40880][SQL] Reimplement `summary` with dataframe operations ### What changes were proposed in this pull request? Reimplement `summary` with dataframe operations ### Why are the changes needed? 1, do not truncate the sql plan any more; 2, enable sql optimization like column pruning: ``` scala> val df = spark.range(0, 3, 1, 10).withColumn("value", lit("str")) df: org.apache.spark.sql.DataFrame = [id: bigint, value: string] scala> df.summary("max", "50%").show +-------+---+-----+ |summary| id|value| +-------+---+-----+ | max| 2| str| | 50%| 1| null| +-------+---+-----+ scala> df.summary("max", "50%").select("id").show +---+ | id| +---+ | 2| | 1| +---+ scala> df.summary("max", "50%").select("id").queryExecution.optimizedPlan res4: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = Project [element_at(id#367, summary#376, None, false) AS id#371] +- Generate explode([max,50%]), false, [summary#376] +- Aggregate [map(max, cast(max(id#153L) as string), 50%, cast(percentile_approx(id#153L, [0.5], 10000, 0, 0)[0] as string)) AS id#367] +- Range (0, 3, step=1, splits=Some(10)) ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing UTs and manually check Closes #38346 from zhengruifeng/sql_stat_summary. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .../sql/execution/stat/StatFunctions.scala | 122 +++++++++--------- 1 file changed, 59 insertions(+), 63 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 484be76b99156..508d2c64d0923 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -21,11 +21,10 @@ import java.util.Locale import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, EvalMode, Expression, GenericInternalRow, GetArrayItem, Literal} +import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt, EvalMode, GenericInternalRow} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.catalyst.util.{GenericArrayData, QuantileSummaries} +import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -199,9 +198,11 @@ object StatFunctions extends Logging { /** Calculate selected summary statistics for a dataset */ def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = { - - val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") - val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics + val selectedStatistics = if (statistics.nonEmpty) { + statistics.toArray + } else { + Array("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") + } val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p => try { @@ -213,71 +214,66 @@ object StatFunctions extends Logging { } require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") - def castAsDoubleIfNecessary(e: Expression): Expression = if (e.dataType == StringType) { - Cast(e, DoubleType, evalMode = EvalMode.TRY) - } else { - e - } - var percentileIndex = 0 - val statisticFns = selectedStatistics.map { stats => - if (stats.endsWith("%")) { - val index = percentileIndex - percentileIndex += 1 - (child: Expression) => - GetArrayItem( - new ApproximatePercentile(castAsDoubleIfNecessary(child), - Literal(new GenericArrayData(percentiles), ArrayType(DoubleType, false))) - .toAggregateExpression(), - Literal(index)) - } else { - stats.toLowerCase(Locale.ROOT) match { - case "count" => (child: Expression) => Count(child).toAggregateExpression() - case "count_distinct" => (child: Expression) => - Count(child).toAggregateExpression(isDistinct = true) - case "approx_count_distinct" => (child: Expression) => - HyperLogLogPlusPlus(child).toAggregateExpression() - case "mean" => (child: Expression) => - Average(castAsDoubleIfNecessary(child)).toAggregateExpression() - case "stddev" => (child: Expression) => - StddevSamp(castAsDoubleIfNecessary(child)).toAggregateExpression() - case "min" => (child: Expression) => Min(child).toAggregateExpression() - case "max" => (child: Expression) => Max(child).toAggregateExpression() - case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats) + var mapColumns = Seq.empty[Column] + var columnNames = Seq.empty[String] + + ds.schema.fields.foreach { field => + if (field.dataType.isInstanceOf[NumericType] || field.dataType.isInstanceOf[StringType]) { + val column = col(field.name) + var casted = column + if (field.dataType.isInstanceOf[StringType]) { + casted = new Column(Cast(column.expr, DoubleType, evalMode = EvalMode.TRY)) } - } - } - val selectedCols = ds.logicalPlan.output - .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType]) + val percentilesCol = if (percentiles.nonEmpty) { + percentile_approx(casted, lit(percentiles), + lit(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)) + } else null - val aggExprs = statisticFns.flatMap { func => - selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name)) - } + var aggColumns = Seq.empty[Column] + var percentileIndex = 0 + selectedStatistics.foreach { stats => + aggColumns :+= lit(stats) - // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val. - lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.collect().head + stats.toLowerCase(Locale.ROOT) match { + case "count" => aggColumns :+= count(column) - // We will have one row for each selected statistic in the result. - val result = Array.fill[InternalRow](selectedStatistics.length) { - // each row has the statistic name, and statistic values of each selected column. - new GenericInternalRow(selectedCols.length + 1) - } + case "count_distinct" => aggColumns :+= count_distinct(column) + + case "approx_count_distinct" => aggColumns :+= approx_count_distinct(column) - var rowIndex = 0 - while (rowIndex < result.length) { - val statsName = selectedStatistics(rowIndex) - result(rowIndex).update(0, UTF8String.fromString(statsName)) - for (colIndex <- selectedCols.indices) { - val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex) - result(rowIndex).update(colIndex + 1, statsValue) + case "mean" => aggColumns :+= avg(casted) + + case "stddev" => aggColumns :+= stddev(casted) + + case "min" => aggColumns :+= min(column) + + case "max" => aggColumns :+= max(column) + + case percentile if percentile.endsWith("%") => + aggColumns :+= get(percentilesCol, lit(percentileIndex)) + percentileIndex += 1 + + case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats) + } + } + + // map { "count" -> "1024", "min" -> "1.0", ... } + mapColumns :+= map(aggColumns.map(_.cast(StringType)): _*).as(field.name) + columnNames :+= field.name } - rowIndex += 1 } - // All columns are string type - val output = AttributeReference("summary", StringType)() +: - selectedCols.map(c => AttributeReference(c.name, StringType)()) - - Dataset.ofRows(ds.sparkSession, LocalRelation(output, result)) + if (mapColumns.isEmpty) { + ds.sparkSession.createDataFrame(selectedStatistics.map(Tuple1.apply)) + .withColumnRenamed("_1", "summary") + } else { + val valueColumns = columnNames.map { columnName => + new Column(ElementAt(col(columnName).expr, col("summary").expr)).as(columnName) + } + ds.select(mapColumns: _*) + .withColumn("summary", explode(lit(selectedStatistics))) + .select(Array(col("summary")) ++ valueColumns: _*) + } } } From 79aae64380ff83570549cb8c4ed85ffb022fc8eb Mon Sep 17 00:00:00 2001 From: Jerry Peng Date: Mon, 24 Oct 2022 11:09:40 +0900 Subject: [PATCH 05/22] [SPARK-40849][SS] Async log purge MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Purging old entries in both the offset log and commit log will be done asynchronously. For every micro-batch, older entries in both offset log and commit log are deleted. This is done so that the offset log and commit log do not continually grow. Please reference logic here https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala#L539 The time spent performing these log purges is grouped with the “walCommit” execution time in the StreamingProgressListener metrics. Around two thirds of the “walCommit” execution time is performing these purge operations thus making these operations asynchronous will also reduce latency. Also, we do not necessarily need to perform the purges every micro-batch. When these purges are executed asynchronously, they do not need to block micro-batch execution and we don’t need to start another purge until the current one is finished. The purges can happen essentially in the background. We will just have to synchronize the purges with the offset WAL commits and completion commits so that we don’t have concurrent modifications of the offset log and commit log. ### Why are the changes needed? Decrease microbatch processing latency ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests Closes #38313 from jerrypeng/SPARK-40849. Authored-by: Jerry Peng Signed-off-by: Jungtaek Lim --- .../org/apache/spark/util/ThreadUtils.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 9 ++ .../execution/streaming/AsyncLogPurge.scala | 82 ++++++++++++++++++ .../execution/streaming/ErrorNotifier.scala | 46 ++++++++++ .../streaming/MicroBatchExecution.scala | 22 ++++- .../execution/streaming/StreamExecution.scala | 7 ++ .../streaming/MicroBatchExecutionSuite.scala | 85 ++++++++++++++++++- 7 files changed, 249 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ErrorNotifier.scala diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index d45dc937910d9..99b4e894bf0a6 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -162,9 +162,9 @@ private[spark] object ThreadUtils { /** * Wrapper over newSingleThreadExecutor. */ - def newDaemonSingleThreadExecutor(threadName: String): ExecutorService = { + def newDaemonSingleThreadExecutor(threadName: String): ThreadPoolExecutor = { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() - Executors.newSingleThreadExecutor(threadFactory) + Executors.newFixedThreadPool(1, threadFactory).asInstanceOf[ThreadPoolExecutor] } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 72eb420de3749..ebff9ce546d00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1982,6 +1982,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ASYNC_LOG_PURGE = + buildConf("spark.sql.streaming.asyncLogPurge.enabled") + .internal() + .doc("When true, purging the offset log and " + + "commit log of old entries will be done asynchronously.") + .version("3.4.0") + .booleanConf + .createWithDefault(true) + val VARIABLE_SUBSTITUTE_ENABLED = buildConf("spark.sql.variable.substitute") .doc("This enables substitution using syntax like `${var}`, `${system:var}`, " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala new file mode 100644 index 0000000000000..b3729dbc7b459 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ThreadUtils + +/** + * Used to enable the capability to allow log purges to be done asynchronously + */ +trait AsyncLogPurge extends Logging { + + protected var currentBatchId: Long + + protected val minLogEntriesToMaintain: Int + + + protected[sql] val errorNotifier: ErrorNotifier + + protected val sparkSession: SparkSession + + private val asyncPurgeExecutorService + = ThreadUtils.newDaemonSingleThreadExecutor("async-log-purge") + + private val purgeRunning = new AtomicBoolean(false) + + protected def purge(threshold: Long): Unit + + protected lazy val useAsyncPurge: Boolean = sparkSession.conf.get(SQLConf.ASYNC_LOG_PURGE) + + protected def purgeAsync(): Unit = { + if (purgeRunning.compareAndSet(false, true)) { + // save local copy because currentBatchId may get updated. There are not really + // any concurrency issues here in regards to calculating the purge threshold + // but for the sake of defensive coding lets make a copy + val currentBatchIdCopy: Long = currentBatchId + asyncPurgeExecutorService.execute(() => { + try { + purge(currentBatchIdCopy - minLogEntriesToMaintain) + } catch { + case throwable: Throwable => + logError("Encountered error while performing async log purge", throwable) + errorNotifier.markError(throwable) + } finally { + purgeRunning.set(false) + } + }) + } else { + log.debug("Skipped log purging since there is already one in progress.") + } + } + + protected def asyncLogPurgeShutdown(): Unit = { + ThreadUtils.shutdown(asyncPurgeExecutorService) + } + + // used for testing + private[sql] def arePendingAsyncPurge: Boolean = { + purgeRunning.get() || + asyncPurgeExecutorService.getQueue.size() > 0 || + asyncPurgeExecutorService.getActiveCount > 0 + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ErrorNotifier.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ErrorNotifier.scala new file mode 100644 index 0000000000000..0f25d0667a0ef --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ErrorNotifier.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.atomic.AtomicReference + +import org.apache.spark.internal.Logging + +/** + * Class to notify of any errors that might have occurred out of band + */ +class ErrorNotifier extends Logging { + + private val error = new AtomicReference[Throwable] + + /** To indicate any errors that have occurred */ + def markError(th: Throwable): Unit = { + logError("A fatal error has occurred.", th) + error.set(th) + } + + /** Get any errors that have occurred */ + def getError(): Option[Throwable] = { + Option(error.get()) + } + + /** Throw errors that have occurred */ + def throwErrorIfExists(): Unit = { + getError().foreach({th => throw th}) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 153bc82f89286..5f8fb93827b32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -46,7 +46,9 @@ class MicroBatchExecution( plan: WriteToStream) extends StreamExecution( sparkSession, plan.name, plan.resolvedCheckpointLocation, plan.inputQuery, plan.sink, trigger, - triggerClock, plan.outputMode, plan.deleteCheckpointOnStop) { + triggerClock, plan.outputMode, plan.deleteCheckpointOnStop) with AsyncLogPurge { + + protected[sql] val errorNotifier = new ErrorNotifier() @volatile protected var sources: Seq[SparkDataStream] = Seq.empty @@ -210,6 +212,14 @@ class MicroBatchExecution( logInfo(s"Query $prettyIdString was stopped") } + override def cleanup(): Unit = { + super.cleanup() + + // shutdown and cleanup required for async log purge mechanism + asyncLogPurgeShutdown() + logInfo(s"Async log purge executor pool for query ${prettyIdString} has been shutdown") + } + /** Begins recording statistics about query progress for a given trigger. */ override protected def startTrigger(): Unit = { super.startTrigger() @@ -226,6 +236,10 @@ class MicroBatchExecution( triggerExecutor.execute(() => { if (isActive) { + + // check if there are any previous errors and bubble up any existing async operations + errorNotifier.throwErrorIfExists + var currentBatchHasNewData = false // Whether the current batch had new data startTrigger() @@ -536,7 +550,11 @@ class MicroBatchExecution( // It is now safe to discard the metadata beyond the minimum number to retain. // Note that purge is exclusive, i.e. it purges everything before the target ID. if (minLogEntriesToMaintain < currentBatchId) { - purge(currentBatchId - minLogEntriesToMaintain) + if (useAsyncPurge) { + purgeAsync() + } else { + purge(currentBatchId - minLogEntriesToMaintain) + } } } noNewData = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index eeaa37aa7ffb5..5afd744f5e9be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -347,6 +347,7 @@ abstract class StreamExecution( try { stopSources() + cleanup() state.set(TERMINATED) currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false) @@ -410,6 +411,12 @@ abstract class StreamExecution( } } + + /** + * Any clean up that needs to happen when the query is stopped or exits + */ + protected def cleanup(): Unit = {} + /** * Interrupts the query execution thread and awaits its termination until until it exceeds the * timeout. The timeout can be set on "spark.sql.streaming.stopTimeout". diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala index 749ca9d06eaf9..0ddd48420ef3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala @@ -21,17 +21,20 @@ import java.io.File import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfter +import org.scalatest.matchers.should._ +import org.scalatest.time.{Seconds, Span} import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.connector.read.streaming import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.functions.{count, timestamp_seconds, window} -import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest, Trigger} import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.util.Utils -class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { +class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter with Matchers { import testImplicits._ @@ -39,6 +42,84 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { sqlContext.streams.active.foreach(_.stop()) } + def getListOfFiles(dir: String): List[File] = { + val d = new File(dir) + if (d.exists && d.isDirectory) { + d.listFiles.filter(_.isFile).toList + } else { + List[File]() + } + } + + test("async log purging") { + withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") { + withTempDir { checkpointLocation => + val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val ds = inputData.toDS() + testStream(ds)( + StartStream(checkpointLocation = checkpointLocation.getCanonicalPath), + AddData(inputData, 0), + CheckNewAnswer(0), + AddData(inputData, 1), + CheckNewAnswer(1), + Execute { q => + getListOfFiles(checkpointLocation + "/offsets") + .filter(file => !file.isHidden) + .map(file => file.getName.toInt) + .sorted should equal(Array(0, 1)) + getListOfFiles(checkpointLocation + "/commits") + .filter(file => !file.isHidden) + .map(file => file.getName.toInt) + .sorted should equal(Array(0, 1)) + }, + AddData(inputData, 2), + CheckNewAnswer(2), + AddData(inputData, 3), + CheckNewAnswer(3), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + + getListOfFiles(checkpointLocation + "/offsets") + .filter(file => !file.isHidden) + .map(file => file.getName.toInt) + .sorted should equal(Array(1, 2, 3)) + getListOfFiles(checkpointLocation + "/commits") + .filter(file => !file.isHidden) + .map(file => file.getName.toInt) + .sorted should equal(Array(1, 2, 3)) + }, + StopStream + ) + } + } + } + + test("error notifier test") { + withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") { + withTempDir { checkpointLocation => + val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val ds = inputData.toDS() + val e = intercept[StreamingQueryException] { + + testStream(ds)( + StartStream(checkpointLocation = checkpointLocation.getCanonicalPath), + AddData(inputData, 0), + CheckNewAnswer(0), + AddData(inputData, 1), + CheckNewAnswer(1), + Execute { q => + q.asInstanceOf[MicroBatchExecution].errorNotifier.markError(new Exception("test")) + }, + AddData(inputData, 2), + CheckNewAnswer(2)) + } + e.getCause.getMessage should include("test") + } + } + } + test("SPARK-24156: do not plan a no-data batch again after it has already been planned") { val inputData = MemoryStream[Int] val df = inputData.toDF() From f7eee0950493ede83f5f00be2030cb8111ae6aa1 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 24 Oct 2022 11:22:29 +0900 Subject: [PATCH 06/22] [SPARK-40880][SQL][FOLLOW-UP] Remove unused imports ### What changes were proposed in this pull request? remove unused imports ### Why are the changes needed? ``` [error] /home/runner/work/spark/spark/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala:24:78: Unused import [error] import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt, EvalMode, GenericInternalRow} [error] ^ [error] /home/runner/work/spark/spark/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala:26:52: Unused import [error] import org.apache.spark.sql.catalyst.plans.logical.LocalRelation [error] ^ [error] /home/runner/work/spark/spark/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala:31:38: Unused import [error] import org.apache.spark.unsafe.types.UTF8String [error] ^ [error] three errors found ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? maunally build Closes #38362 from zhengruifeng/sql_clean_unused_imports. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/sql/execution/stat/StatFunctions.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 508d2c64d0923..80e8f6d734207 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -21,14 +21,12 @@ import java.util.Locale import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt, EvalMode, GenericInternalRow} +import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt, EvalMode} import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String object StatFunctions extends Logging { From 74c826429416493a6d1d0efdf83b0e561dc33591 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Mon, 24 Oct 2022 10:50:55 +0800 Subject: [PATCH 07/22] [SPARK-40812][CONNECT][PYTHON][FOLLOW-UP] Improve Deduplicate in Python client ### What changes were proposed in this pull request? Following up on https://github.com/apache/spark/pull/38276, this PR improve both `distinct()` and `dropDuplicates` DataFrame API in Python client, which both depends on `Deduplicate` plan in the Connect proto. ### Why are the changes needed? Improve API coverage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38327 from amaliujia/python_deduplicate. Authored-by: Rui Wang Signed-off-by: Wenchen Fan --- python/pyspark/sql/connect/dataframe.py | 41 +++++++++++++++++-- python/pyspark/sql/connect/plan.py | 39 ++++++++++++++++++ .../tests/connect/test_connect_plan_only.py | 19 +++++++++ 3 files changed, 95 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index eabcf433ae9bc..2b7e3d520391d 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -157,11 +157,44 @@ def coalesce(self, num_partitions: int) -> "DataFrame": def describe(self, cols: List[ColumnRef]) -> Any: ... + def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": + """Return a new :class:`DataFrame` with duplicate rows removed, + optionally only deduplicating based on certain columns. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + subset : List of column names, optional + List of columns to use for duplicate comparison (default All columns). + + Returns + ------- + :class:`DataFrame` + DataFrame without duplicated rows. + """ + if subset is None: + return DataFrame.withPlan( + plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session + ) + else: + return DataFrame.withPlan( + plan.Deduplicate(child=self._plan, column_names=subset), session=self._session + ) + def distinct(self) -> "DataFrame": - """Returns all distinct rows.""" - all_cols = self.columns - gf = self.groupBy(*all_cols) - return gf.agg() + """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`DataFrame` + DataFrame with distinct rows. + """ + return DataFrame.withPlan( + plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session + ) def drop(self, *cols: "ColumnOrString") -> "DataFrame": all_cols = self.columns diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 297b15994d3bc..d6b6f9e3b67dd 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -327,6 +327,45 @@ def _repr_html_(self) -> str: """ +class Deduplicate(LogicalPlan): + def __init__( + self, + child: Optional["LogicalPlan"], + all_columns_as_keys: bool = False, + column_names: Optional[List[str]] = None, + ) -> None: + super().__init__(child) + self.all_columns_as_keys = all_columns_as_keys + self.column_names = column_names + + def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + assert self._child is not None + plan = proto.Relation() + plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys + if self.column_names is not None: + plan.deduplicate.column_names.extend(self.column_names) + return plan + + def print(self, indent: int = 0) -> str: + c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" + return ( + f"{' ' * indent}\n{c_buf}" + ) + + def _repr_html_(self) -> str: + return f""" +
    +
  • + Deduplicate
    + all_columns_as_keys: {self.all_columns_as_keys}
    + column_names: {self.column_names}
    + {self._child_repr_()} +
  • +
+ """ + + class Sort(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], *columns: Union[SortOrder, ColumnRef, str] diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 3b609db7a028d..450f5c70fabad 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -72,6 +72,25 @@ def test_sample(self): self.assertEqual(plan.root.sample.with_replacement, True) self.assertEqual(plan.root.sample.seed.seed, -1) + def test_deduplicate(self): + df = self.connect.readTable(table_name=self.tbl_name) + + distinct_plan = df.distinct()._plan.to_proto(self.connect) + self.assertEqual(distinct_plan.root.deduplicate.all_columns_as_keys, True) + self.assertEqual(len(distinct_plan.root.deduplicate.column_names), 0) + + deduplicate_on_all_columns_plan = df.dropDuplicates()._plan.to_proto(self.connect) + self.assertEqual(deduplicate_on_all_columns_plan.root.deduplicate.all_columns_as_keys, True) + self.assertEqual(len(deduplicate_on_all_columns_plan.root.deduplicate.column_names), 0) + + deduplicate_on_subset_columns_plan = df.dropDuplicates(["name", "height"])._plan.to_proto( + self.connect + ) + self.assertEqual( + deduplicate_on_subset_columns_plan.root.deduplicate.all_columns_as_keys, False + ) + self.assertEqual(len(deduplicate_on_subset_columns_plan.root.deduplicate.column_names), 2) + def test_relation_alias(self): df = self.connect.readTable(table_name=self.tbl_name) plan = df.alias("table_alias")._plan.to_proto(self.connect) From 4d33ee072272f9e8876ea05f0c069b2e9977835c Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Mon, 24 Oct 2022 11:10:43 +0800 Subject: [PATCH 08/22] [SPARK-36114][SQL] Support subqueries with correlated non-equality predicates ### What changes were proposed in this pull request? This PR supports correlated non-equality predicates in subqueries. It leverages the DecorrelateInnerQuery framework to decorrelate subqueries with non-equality predicates. DecorrelateInnerQuery inserts domain joins in the query plan and the rule RewriteCorrelatedScalarSubquery rewrites the domain joins into actual joins with the outer query. Note, correlated non-equality predicates can lead to query plans with non-equality join conditions, which may be planned as a broadcast NL join or cartesian product. ### Why are the changes needed? To improve subquery support in Spark. ### Does this PR introduce _any_ user-facing change? Yes. Before this PR, Spark does not allow correlated non-equality predicates in subqueries. For example: ```sql SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 > t2.c1) FROM t1 ``` This will throw an exception: `Correlated column is not allowed in a non-equality predicate` After this PR, this query can run successfully. ### How was this patch tested? Unit tests and SQL query tests. Closes #38135 from allisonwang-db/spark-36114-non-equality-pred. Authored-by: allisonwang-db Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/CheckAnalysis.scala | 7 +- .../analysis/AnalysisErrorSuite.scala | 2 +- .../sql-tests/inputs/join-lateral.sql | 3 + .../scalar-subquery-select.sql | 45 ++++++++ .../sql-tests/results/join-lateral.sql.out | 9 ++ .../scalar-subquery-select.sql.out | 107 ++++++++++++++++++ .../sql-tests/results/udf/udf-except.sql.out | 17 +-- .../org/apache/spark/sql/SubquerySuite.scala | 59 ++++------ 8 files changed, 195 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4346f51b613a2..cad036a34e97c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1066,7 +1066,12 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { // 1 | 2 | 4 // and the plan after rewrite will give the original query incorrect results. def failOnUnsupportedCorrelatedPredicate(predicates: Seq[Expression], p: LogicalPlan): Unit = { - if (predicates.nonEmpty) { + // Correlated non-equality predicates are only supported with the decorrelate + // inner query framework. Currently we only use this new framework for scalar + // and lateral subqueries. + val allowNonEqualityPredicates = + SQLConf.get.decorrelateInnerQueryEnabled && (isScalar || isLateral) + if (!allowNonEqualityPredicates && predicates.nonEmpty) { // Report a non-supported case as an exception p.failAnalysis( errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 8b71bb05550a6..c44a0852b85c3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -917,7 +917,7 @@ class AnalysisErrorSuite extends AnalysisTest { (And($"a" === $"c", Cast($"d", IntegerType) === $"c"), "CAST(d#x AS INT) = outer(c#x)")) conditions.foreach { case (cond, msg) => val plan = Project( - ScalarSubquery( + Exists( Aggregate(Nil, count(Literal(1)).as("cnt") :: Nil, Filter(cond, t1)) ).as("sub") :: Nil, diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql index fc5776c46afdd..dc1a35072728f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql @@ -44,6 +44,9 @@ SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE t1.c1 = t2.c1); -- lateral join with correlated non-equality predicates SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE t1.c2 < t2.c2); +-- SPARK-36114: lateral join with aggregation and correlated non-equality predicates +SELECT * FROM t1, LATERAL (SELECT max(c2) AS m FROM t2 WHERE t1.c2 < t2.c2); + -- lateral join can reference preceding FROM clause items SELECT * FROM t1 JOIN t2 JOIN LATERAL (SELECT t1.c2 + t2.c2); -- expect error: cannot resolve `t2.c1` diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql index b999d1723c911..6d673f149cc95 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql @@ -190,3 +190,48 @@ SELECT c1, ( -- Multi-value subquery error SELECT (SELECT a FROM (SELECT 1 AS a UNION ALL SELECT 2 AS a) t) AS b; + +-- SPARK-36114: Support correlated non-equality predicates +CREATE OR REPLACE TEMP VIEW t1(c1, c2) AS (VALUES (0, 1), (1, 2)); +CREATE OR REPLACE TEMP VIEW t2(c1, c2) AS (VALUES (0, 2), (0, 3)); + +-- Neumann example Q2 +CREATE OR REPLACE TEMP VIEW students(id, name, major, year) AS (VALUES + (0, 'A', 'CS', 2022), + (1, 'B', 'CS', 2022), + (2, 'C', 'Math', 2022)); +CREATE OR REPLACE TEMP VIEW exams(sid, course, curriculum, grade, date) AS (VALUES + (0, 'C1', 'CS', 4, 2020), + (0, 'C2', 'CS', 3, 2021), + (1, 'C1', 'CS', 2, 2020), + (1, 'C2', 'CS', 1, 2021)); + +SELECT students.name, exams.course +FROM students, exams +WHERE students.id = exams.sid + AND (students.major = 'CS' OR students.major = 'Games Eng') + AND exams.grade >= ( + SELECT avg(exams.grade) + 1 + FROM exams + WHERE students.id = exams.sid + OR (exams.curriculum = students.major AND students.year > exams.date)); + +-- Correlated non-equality predicates +SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 > t2.c1) FROM t1; +SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 >= t2.c1 AND t1.c2 < t2.c2) FROM t1; + +-- Correlated non-equality predicates with the COUNT bug. +SELECT (SELECT count(*) FROM t2 WHERE t1.c1 > t2.c1) FROM t1; + +-- Correlated equality predicates that are not supported after SPARK-35080 +SELECT c, ( + SELECT count(*) + FROM (VALUES ('ab'), ('abc'), ('bc')) t2(c) + WHERE t1.c = substring(t2.c, 1, 1) +) FROM (VALUES ('a'), ('b')) t1(c); + +SELECT c, ( + SELECT count(*) + FROM (VALUES (0, 6), (1, 5), (2, 4), (3, 3)) t1(a, b) + WHERE a + b = c +) FROM (VALUES (6)) t2(c); diff --git a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out index be07ba7bd9a1e..34c0543dfdda8 100644 --- a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out @@ -272,6 +272,15 @@ struct 1 2 3 +-- !query +SELECT * FROM t1, LATERAL (SELECT max(c2) AS m FROM t2 WHERE t1.c2 < t2.c2) +-- !query schema +struct +-- !query output +0 1 3 +1 2 3 + + -- !query SELECT * FROM t1 JOIN t2 JOIN LATERAL (SELECT t1.c2 + t2.c2) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out index d1e56786207ed..38ab365ef6941 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out @@ -433,3 +433,110 @@ org.apache.spark.SparkException "fragment" : "(SELECT a FROM (SELECT 1 AS a UNION ALL SELECT 2 AS a) t)" } ] } + + +-- !query +CREATE OR REPLACE TEMP VIEW t1(c1, c2) AS (VALUES (0, 1), (1, 2)) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE OR REPLACE TEMP VIEW t2(c1, c2) AS (VALUES (0, 2), (0, 3)) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE OR REPLACE TEMP VIEW students(id, name, major, year) AS (VALUES + (0, 'A', 'CS', 2022), + (1, 'B', 'CS', 2022), + (2, 'C', 'Math', 2022)) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE OR REPLACE TEMP VIEW exams(sid, course, curriculum, grade, date) AS (VALUES + (0, 'C1', 'CS', 4, 2020), + (0, 'C2', 'CS', 3, 2021), + (1, 'C1', 'CS', 2, 2020), + (1, 'C2', 'CS', 1, 2021)) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT students.name, exams.course +FROM students, exams +WHERE students.id = exams.sid + AND (students.major = 'CS' OR students.major = 'Games Eng') + AND exams.grade >= ( + SELECT avg(exams.grade) + 1 + FROM exams + WHERE students.id = exams.sid + OR (exams.curriculum = students.major AND students.year > exams.date)) +-- !query schema +struct +-- !query output +A C1 + + +-- !query +SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 > t2.c1) FROM t1 +-- !query schema +struct +-- !query output +2 +NULL + + +-- !query +SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 >= t2.c1 AND t1.c2 < t2.c2) FROM t1 +-- !query schema +struct +-- !query output +2 +3 + + +-- !query +SELECT (SELECT count(*) FROM t2 WHERE t1.c1 > t2.c1) FROM t1 +-- !query schema +struct +-- !query output +0 +2 + + +-- !query +SELECT c, ( + SELECT count(*) + FROM (VALUES ('ab'), ('abc'), ('bc')) t2(c) + WHERE t1.c = substring(t2.c, 1, 1) +) FROM (VALUES ('a'), ('b')) t1(c) +-- !query schema +struct +-- !query output +a 2 +b 1 + + +-- !query +SELECT c, ( + SELECT count(*) + FROM (VALUES (0, 6), (1, 5), (2, 4), (3, 3)) t1(a, b) + WHERE a + b = c +) FROM (VALUES (6)) t2(c) +-- !query schema +struct +-- !query output +6 4 diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out index f532b0d41e344..14ecf98c7a831 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out @@ -97,19 +97,6 @@ WHERE udf(t1.v) >= (SELECT min(udf(t2.v)) FROM t2 WHERE t2.k = t1.k) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", - "messageParameters" : { - "treeNode" : "(cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))\nFilter (cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))\n+- SubqueryAlias t2\n +- View (`t2`, [k#x,v#x])\n +- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]\n +- Project [k#x, v#x]\n +- SubqueryAlias t2\n +- LocalRelation [k#x, v#x]\n" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 39, - "stopIndex" : 141, - "fragment" : "SELECT udf(max(udf(t2.v)))\n FROM t2\n WHERE udf(t2.k) = udf(t1.k)" - } ] -} +two diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index ecb4bfd0ec41b..9d326b92b939f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -66,6 +66,11 @@ class SubquerySuite extends QueryTest t.createOrReplaceTempView("t") } + private def checkNumJoins(plan: LogicalPlan, numJoins: Int): Unit = { + val joins = plan.collect { case j: Join => j } + assert(joins.size == numJoins) + } + test("SPARK-18854 numberedTreeString for subquery") { val df = sql("select * from range(10) where id not in " + "(select id from range(2) union all select id from range(2))") @@ -562,17 +567,10 @@ class SubquerySuite extends QueryTest } test("non-equal correlated scalar subquery") { - val exception = intercept[AnalysisException] { - sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1") - } - checkErrorMatchPVals( - exception, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", - parameters = Map("treeNode" -> "(?s).*"), - sqlState = None, - context = ExpectedContext( - fragment = "select sum(b) from l l2 where l2.a < l1.a", start = 11, stop = 51)) + checkAnswer( + sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1"), + Seq(Row(1, null), Row(1, null), Row(2, 4), Row(2, 4), Row(3, 6), Row(null, null), + Row(null, null), Row(6, 9))) } test("disjunctive correlated scalar subquery") { @@ -2105,25 +2103,17 @@ class SubquerySuite extends QueryTest } } - test("SPARK-38155: disallow distinct aggregate in lateral subqueries") { + test("SPARK-36114: distinct aggregate in lateral subqueries") { withTempView("t1", "t2") { Seq((0, 1)).toDF("c1", "c2").createOrReplaceTempView("t1") Seq((1, 2), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") - val exception = intercept[AnalysisException] { - sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)") - } - checkErrorMatchPVals( - exception, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", - parameters = Map("treeNode" -> "(?s).*"), - sqlState = None, - context = ExpectedContext( - fragment = "SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1", start = 31, stop = 73)) + checkAnswer( + sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)"), + Row(0, 1, 2) :: Nil) } } - test("SPARK-38180: allow safe cast expressions in correlated equality conditions") { + test("SPARK-38180, SPARK-36114: allow safe cast expressions in correlated equality conditions") { withTempView("t1", "t2") { Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") Seq((0, 2), (0, 3)).toDF("c1", "c2").createOrReplaceTempView("t2") @@ -2139,19 +2129,14 @@ class SubquerySuite extends QueryTest |FROM (SELECT CAST(c1 AS STRING) a FROM t1) |""".stripMargin), Row(5) :: Row(null) :: Nil) - val exception1 = intercept[AnalysisException] { - sql( - """SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a) - |FROM (SELECT CAST(c1 AS SHORT) a FROM t1)""".stripMargin) - } - checkErrorMatchPVals( - exception1, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", - parameters = Map("treeNode" -> "(?s).*"), - sqlState = None, - context = ExpectedContext( - fragment = "SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a", start = 8, stop = 57)) + // SPARK-36114: we now allow non-safe cast expressions in correlated predicates. + val df = sql( + """SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a) + |FROM (SELECT CAST(c1 AS SHORT) a FROM t1) + |""".stripMargin) + checkAnswer(df, Row(5) :: Row(null) :: Nil) + // The optimized plan should have one left outer join and one domain (inner) join. + checkNumJoins(df.queryExecution.optimizedPlan, 2) } } From 58490da6d2ef83cfc47438cf078c82d4d008fdae Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Mon, 24 Oct 2022 11:12:45 +0800 Subject: [PATCH 09/22] [SPARK-40800][SQL] Always inline expressions in OptimizeOneRowRelationSubquery ### What changes were proposed in this pull request? This PR modifies the optimizer rule `OptimizeOneRowRelationSubquery` to always collapse projects and inline non-volatile expressions. ### Why are the changes needed? SPARK-39699 made `CollpaseProjects` more conservative. This has impacted correlated subqueries that Spark used to be able to support. For example, Spark used to be able to execute this correlated subquery: ```sql SELECT ( SELECT array_sort(a, (i, j) -> rank[i] - rank[j]) AS sorted FROM (SELECT MAP('a', 1, 'b', 2) rank) ) FROM t1 ``` But after SPARK-39699, it will throw an exception `Unexpected operator Join Inner` because the projects inside the subquery can no longer be collapsed. We should always inline expressions if possible to support a broader range of correlated subqueries and avoid adding expensive domain joins. ### Does this PR introduce _any_ user-facing change? Yes. It will allow Spark to execute more types of correlated subqueries. ### How was this patch tested? Unit test. Closes #38260 from allisonwang-db/spark-40800-inline-expr-subquery. Authored-by: allisonwang-db Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/Optimizer.scala | 5 ++- .../sql/catalyst/optimizer/subquery.scala | 4 +- .../org/apache/spark/sql/SubquerySuite.scala | 39 +++++++++++++++++-- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2664fd638062d..afbf73027277e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -989,7 +989,10 @@ object ColumnPruning extends Rule[LogicalPlan] { object CollapseProject extends Rule[LogicalPlan] with AliasHelper { def apply(plan: LogicalPlan): LogicalPlan = { - val alwaysInline = conf.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE) + apply(plan, conf.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE)) + } + + def apply(plan: LogicalPlan, alwaysInline: Boolean): LogicalPlan = { plan.transformUpWithPruning(_.containsPattern(PROJECT), ruleId) { case p1 @ Project(_, p2: Project) if canCollapseExpressions(p1.projectList, p2.projectList, alwaysInline) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 9a1d20ed9b21d..6665d885554fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -730,7 +730,9 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { object OneRowSubquery { def unapply(plan: LogicalPlan): Option[Seq[NamedExpression]] = { - CollapseProject(EliminateSubqueryAliases(plan)) match { + // SPARK-40800: always inline expressions to support a broader range of correlated + // subqueries and avoid expensive domain joins. + CollapseProject(EliminateSubqueryAliases(plan), alwaysInline = true) match { case Project(projectList, _: OneRowRelation) => Some(stripOuterReferences(projectList)) case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 9d326b92b939f..4b58635636771 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2454,10 +2454,41 @@ class SubquerySuite extends QueryTest Row(2)) // Cannot use non-orderable data type in one row subquery that cannot be collapsed. - val error = intercept[AnalysisException] { - sql("select (select concat(a, a) from (select upper(x['a']) as a)) from v1").collect() - } - assert(error.getMessage.contains("Correlated column reference 'v1.x' cannot be map type")) + val error = intercept[AnalysisException] { + sql( + """ + |select ( + | select concat(a, a) from + | (select upper(x['a'] + rand()) as a) + |) from v1 + |""".stripMargin).collect() + } + assert(error.getMessage.contains("Correlated column reference 'v1.x' cannot be map type")) + } + } + + test("SPARK-40800: always inline expressions in OptimizeOneRowRelationSubquery") { + withTempView("t1") { + sql("CREATE TEMP VIEW t1 AS SELECT ARRAY('a', 'b') a") + // Scalar subquery. + checkAnswer(sql( + """ + |SELECT ( + | SELECT array_sort(a, (i, j) -> rank[i] - rank[j])[0] AS sorted + | FROM (SELECT MAP('a', 1, 'b', 2) rank) + |) FROM t1 + |""".stripMargin), + Row("a")) + // Lateral subquery. + checkAnswer( + sql(""" + |SELECT sorted[0] FROM t1 + |JOIN LATERAL ( + | SELECT array_sort(a, (i, j) -> rank[i] - rank[j]) AS sorted + | FROM (SELECT MAP('a', 1, 'b', 2) rank) + |) + |""".stripMargin), + Row("a")) } } } From c721c7299d8821d1f15dbac2d156d8936c71522d Mon Sep 17 00:00:00 2001 From: Yikun Jiang Date: Mon, 24 Oct 2022 14:41:12 +0800 Subject: [PATCH 10/22] [SPARK-40881][INFRA] Upgrade actions/cache to v3 and actions/upload-artifact to v3 ### What changes were proposed in this pull request? Upgrade actions/cache to v3 and actions/upload-artifact to v3 ### Why are the changes needed? - Since actions/cachev3: support from node 12 -> node 16 and cleanup `set-output` warning - Since actions/upload-artifactv3: support from node 12 -> node 16 and cleanup `set-output` warning ### Does this PR introduce _any_ user-facing change? No, dev only ### How was this patch tested? CI passed Closes #38353 from Yikun/SPARK-40881. Authored-by: Yikun Jiang Signed-off-by: Yikun Jiang --- .github/workflows/benchmark.yml | 14 +++--- .github/workflows/build_and_test.yml | 60 +++++++++++++------------- .github/workflows/publish_snapshot.yml | 2 +- 3 files changed, 38 insertions(+), 38 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index f73267a95fa32..227c444a7a419 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -70,7 +70,7 @@ jobs: with: fetch-depth: 0 - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -81,7 +81,7 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: benchmark-coursier-${{ github.event.inputs.jdk }}-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} @@ -89,7 +89,7 @@ jobs: benchmark-coursier-${{ github.event.inputs.jdk }} - name: Cache TPC-DS generated data id: cache-tpcds-sf-1 - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ./tpcds-sf-1 key: tpcds-${{ hashFiles('.github/workflows/benchmark.yml', 'sql/core/src/test/scala/org/apache/spark/sql/TPCDSSchema.scala') }} @@ -138,7 +138,7 @@ jobs: with: fetch-depth: 0 - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -149,7 +149,7 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: benchmark-coursier-${{ github.event.inputs.jdk }}-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} @@ -162,7 +162,7 @@ jobs: - name: Cache TPC-DS generated data if: contains(github.event.inputs.class, 'TPCDSQueryBenchmark') || contains(github.event.inputs.class, '*') id: cache-tpcds-sf-1 - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ./tpcds-sf-1 key: tpcds-${{ hashFiles('.github/workflows/benchmark.yml', 'sql/core/src/test/scala/org/apache/spark/sql/TPCDSSchema.scala') }} @@ -186,7 +186,7 @@ jobs: echo "Preparing the benchmark results:" tar -cvf benchmark-results-${{ github.event.inputs.jdk }}-${{ github.event.inputs.scala }}.tar `git diff --name-only` `git ls-files --others --exclude=tpcds-sf-1 --exclude-standard` - name: Upload benchmark results - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: benchmark-results-${{ github.event.inputs.jdk }}-${{ github.event.inputs.scala }}-${{ matrix.split }} path: benchmark-results-${{ github.event.inputs.jdk }}-${{ github.event.inputs.scala }}.tar diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 8411170e2d5cd..0e0314e29506f 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -209,7 +209,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty # Cache local repositories. Note that GitHub Actions cache has a 2G limit. - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -220,7 +220,7 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: ${{ matrix.java }}-${{ matrix.hadoop }}-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} @@ -254,13 +254,13 @@ jobs: ./dev/run-tests --parallelism 1 --modules "$MODULES_TO_TEST" --included-tags "$INCLUDED_TAGS" --excluded-tags "$EXCLUDED_TAGS" - name: Upload test results to report if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: test-results-${{ matrix.modules }}-${{ matrix.comment }}-${{ matrix.java }}-${{ matrix.hadoop }}-${{ matrix.hive }} path: "**/target/test-reports/*.xml" - name: Upload unit tests log files if: failure() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: unit-tests-log-${{ matrix.modules }}-${{ matrix.comment }}-${{ matrix.java }}-${{ matrix.hadoop }}-${{ matrix.hive }} path: "**/target/unit-tests.log" @@ -366,7 +366,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty # Cache local repositories. Note that GitHub Actions cache has a 2G limit. - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -377,7 +377,7 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: pyspark-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} @@ -410,13 +410,13 @@ jobs: name: PySpark - name: Upload test results to report if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: test-results-${{ matrix.modules }}--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" - name: Upload unit tests log files if: failure() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: unit-tests-log-${{ matrix.modules }}--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/unit-tests.log" @@ -455,7 +455,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty # Cache local repositories. Note that GitHub Actions cache has a 2G limit. - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -466,7 +466,7 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: sparkr-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} @@ -486,7 +486,7 @@ jobs: ./dev/run-tests --parallelism 1 --modules sparkr - name: Upload test results to report if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: test-results-sparkr--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" @@ -523,7 +523,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty # Cache local repositories. Note that GitHub Actions cache has a 2G limit. - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -534,14 +534,14 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: docs-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | docs-coursier- - name: Cache Maven local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.m2/repository key: docs-maven-${{ hashFiles('**/pom.xml') }} @@ -646,7 +646,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -657,7 +657,7 @@ jobs: restore-keys: | build- - name: Cache Maven local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.m2/repository key: java${{ matrix.java }}-maven-${{ hashFiles('**/pom.xml') }} @@ -695,7 +695,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -706,7 +706,7 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: scala-213-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} @@ -743,7 +743,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -754,7 +754,7 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: tpcds-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} @@ -766,7 +766,7 @@ jobs: java-version: 8 - name: Cache TPC-DS generated data id: cache-tpcds-sf-1 - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ./tpcds-sf-1 key: tpcds-${{ hashFiles('.github/workflows/build_and_test.yml', 'sql/core/src/test/scala/org/apache/spark/sql/TPCDSSchema.scala') }} @@ -808,13 +808,13 @@ jobs: spark.sql.join.forceApplyShuffledHashJoin=true - name: Upload test results to report if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: test-results-tpcds--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" - name: Upload unit tests log files if: failure() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: unit-tests-log-tpcds--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/unit-tests.log" @@ -846,7 +846,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -857,7 +857,7 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: docker-integration-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} @@ -872,13 +872,13 @@ jobs: ./dev/run-tests --parallelism 1 --modules docker-integration-tests --included-tags org.apache.spark.tags.DockerTest - name: Upload test results to report if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: test-results-docker-integration--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" - name: Upload unit tests log files if: failure() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: unit-tests-log-docker-integration--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/unit-tests.log" @@ -903,7 +903,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -914,7 +914,7 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: k8s-integration-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} @@ -948,7 +948,7 @@ jobs: build/sbt -Psparkr -Pkubernetes -Pkubernetes-integration-tests -Dspark.kubernetes.test.driverRequestCores=0.5 -Dspark.kubernetes.test.executorRequestCores=0.2 "kubernetes-integration-tests/test" - name: Upload Spark on K8S integration tests log files if: failure() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: spark-on-kubernetes-it-log path: "**/target/integration-tests.log" diff --git a/.github/workflows/publish_snapshot.yml b/.github/workflows/publish_snapshot.yml index a8251aa5b67c7..93bda0c2a772e 100644 --- a/.github/workflows/publish_snapshot.yml +++ b/.github/workflows/publish_snapshot.yml @@ -41,7 +41,7 @@ jobs: with: ref: ${{ matrix.branch }} - name: Cache Maven local repository - uses: actions/cache@c64c572235d810460d0d6876e9c705ad5002b353 # pin@v2 + uses: actions/cache@v3 with: path: ~/.m2/repository key: snapshot-maven-${{ hashFiles('**/pom.xml') }} From 825f2190bd826a8a877739454393e79ef163fdf1 Mon Sep 17 00:00:00 2001 From: Yikun Jiang Date: Mon, 24 Oct 2022 14:51:26 +0800 Subject: [PATCH 11/22] [SPARK-40882][INFRA] Upgrade actions/setup-java to v3 with distribution specified ### What changes were proposed in this pull request? Upgrade actions/setup-java to v3 with distribution specified ### Why are the changes needed? - The `distribution` is required after v2, now just keep `zulu` (same distribution with v1): https://github.com/actions/setup-java/releases/tag/v2.0.0 - https://github.com/actions/setup-java/releases/tag/v3.0.0: Upgrade node - https://github.com/actions/setup-java/releases/tag/v3.6.0: Cleanup set-output warning ### Does this PR introduce _any_ user-facing change? No,dev only ### How was this patch tested? CI passed Closes #38354 from Yikun/SPARK-40882. Authored-by: Yikun Jiang Signed-off-by: Yikun Jiang --- .github/workflows/benchmark.yml | 6 ++++-- .github/workflows/build_and_test.yml | 27 +++++++++++++++++--------- .github/workflows/publish_snapshot.yml | 3 ++- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 227c444a7a419..8671cff054bb8 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -105,8 +105,9 @@ jobs: run: cd tpcds-kit/tools && make OS=LINUX - name: Install Java ${{ github.event.inputs.jdk }} if: steps.cache-tpcds-sf-1.outputs.cache-hit != 'true' - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: ${{ github.event.inputs.jdk }} - name: Generate TPC-DS (SF=1) table data if: steps.cache-tpcds-sf-1.outputs.cache-hit != 'true' @@ -156,8 +157,9 @@ jobs: restore-keys: | benchmark-coursier-${{ github.event.inputs.jdk }} - name: Install Java ${{ github.event.inputs.jdk }} - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: ${{ github.event.inputs.jdk }} - name: Cache TPC-DS generated data if: contains(github.event.inputs.class, 'TPCDSQueryBenchmark') || contains(github.event.inputs.class, '*') diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 0e0314e29506f..688c40cc3b63e 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -227,8 +227,9 @@ jobs: restore-keys: | ${{ matrix.java }}-${{ matrix.hadoop }}-coursier- - name: Install Java ${{ matrix.java }} - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: ${{ matrix.java }} - name: Install Python 3.8 uses: actions/setup-python@v2 @@ -384,8 +385,9 @@ jobs: restore-keys: | pyspark-coursier- - name: Install Java ${{ matrix.java }} - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: ${{ matrix.java }} - name: List Python packages (Python 3.9, PyPy3) run: | @@ -473,8 +475,9 @@ jobs: restore-keys: | sparkr-coursier- - name: Install Java ${{ inputs.java }} - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: ${{ inputs.java }} - name: Run tests env: ${{ fromJSON(inputs.envs) }} @@ -597,8 +600,9 @@ jobs: cd docs bundle install - name: Install Java 8 - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: 8 - name: Scala linter run: ./dev/lint-scala @@ -664,8 +668,9 @@ jobs: restore-keys: | java${{ matrix.java }}-maven- - name: Install Java ${{ matrix.java }} - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: ${{ matrix.java }} - name: Build with Maven run: | @@ -713,8 +718,9 @@ jobs: restore-keys: | scala-213-coursier- - name: Install Java 8 - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: 8 - name: Build with SBT run: | @@ -761,8 +767,9 @@ jobs: restore-keys: | tpcds-coursier- - name: Install Java 8 - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: 8 - name: Cache TPC-DS generated data id: cache-tpcds-sf-1 @@ -864,8 +871,9 @@ jobs: restore-keys: | docker-integration-coursier- - name: Install Java 8 - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: 8 - name: Run tests run: | @@ -921,8 +929,9 @@ jobs: restore-keys: | k8s-integration-coursier- - name: Install Java ${{ inputs.java }} - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: ${{ inputs.java }} - name: start minikube run: | diff --git a/.github/workflows/publish_snapshot.yml b/.github/workflows/publish_snapshot.yml index 93bda0c2a772e..33de10b5b2e66 100644 --- a/.github/workflows/publish_snapshot.yml +++ b/.github/workflows/publish_snapshot.yml @@ -48,8 +48,9 @@ jobs: restore-keys: | snapshot-maven- - name: Install Java 8 - uses: actions/setup-java@d202f5dbf7256730fb690ec59f6381650114feb2 # pin@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: 8 - name: Publish snapshot env: From 91407958b49bc04b4d1aa8609a9064056a25d2df Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Mon, 24 Oct 2022 15:20:54 +0800 Subject: [PATCH 12/22] [SPARK-40798][SQL] Alter partition should verify value follow storeAssignmentPolicy ### What changes were proposed in this pull request? extract the check insertion field cast methold so that we can do validate patition value at PartitioningUtils.normalizePartitionSpec ### Why are the changes needed? Insertion follow the behavior of config `spark.sql.storeAssignmentPolicy`, which will fail if the value can not cast to target data type by default. Alter partition should also follow it. For example: ```SQL CREATE TABLE t (c int) USING PARQUET PARTITIONED BY(p int); -- This DDL should fail but worked: ALTER TABLE t ADD PARTITION(p='aaa'); -- FAILED which follows spark.sql.storeAssignmentPolicy INSERT INTO t PARTITION(p='aaa') SELECT 1 ``` ### Does this PR introduce _any_ user-facing change? yes, the added partition value will follow the behavior of `storeAssignmentPolicy`. To restore the previous behavior, set spark.sql.legacy.skipPartitionSpecTypeValidation = true; ### How was this patch tested? add test Closes #38257 from ulysses-you/verify-partition. Authored-by: ulysses-you Signed-off-by: Wenchen Fan --- docs/sql-migration-guide.md | 1 + .../apache/spark/sql/internal/SQLConf.scala | 10 +++++ .../spark/sql/util/PartitioningUtils.scala | 43 +++++++++++++++++- .../datasources/DataSourceStrategy.scala | 20 ++------- .../AlterTableAddPartitionSuiteBase.scala | 44 +++++++++++++++++++ .../v1/AlterTableAddPartitionSuite.scala | 15 +++++++ .../v2/AlterTableAddPartitionSuite.scala | 15 +++++++ .../sql/hive/execution/HiveDDLSuite.scala | 2 +- 8 files changed, 131 insertions(+), 19 deletions(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 18cc579e4f9ea..aaad2a3280919 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -34,6 +34,7 @@ license: | - Valid hexadecimal strings should include only allowed symbols (0-9A-Fa-f). - Valid values for `fmt` are case-insensitive `hex`, `base64`, `utf-8`, `utf8`. - Since Spark 3.4, Spark throws only `PartitionsAlreadyExistException` when it creates partitions but some of them exist already. In Spark 3.3 or earlier, Spark can throw either `PartitionsAlreadyExistException` or `PartitionAlreadyExistsException`. + - Since Spark 3.4, Spark will do validation for partition spec in ALTER PARTITION to follow the behavior of `spark.sql.storeAssignmentPolicy` which may cause an exception if type conversion fails, e.g. `ALTER TABLE .. ADD PARTITION(p='a')` if column `p` is int type. To restore the legacy behavior, set `spark.sql.legacy.skipPartitionSpecTypeValidation` to `true`. ## Upgrading from Spark SQL 3.2 to 3.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ebff9ce546d00..0a60c6b0265af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3004,6 +3004,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val SKIP_TYPE_VALIDATION_ON_ALTER_PARTITION = + buildConf("spark.sql.legacy.skipTypeValidationOnAlterPartition") + .internal() + .doc("When true, skip validation for partition spec in ALTER PARTITION. E.g., " + + "`ALTER TABLE .. ADD PARTITION(p='a')` would work even the partition type is int. " + + s"When false, the behavior follows ${STORE_ASSIGNMENT_POLICY.key}") + .version("3.4.0") + .booleanConf + .createWithDefault(false) + val SORT_BEFORE_REPARTITION = buildConf("spark.sql.execution.sortBeforeRepartition") .internal() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala index 1f5e225324efc..87f140cb3c4a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala @@ -20,14 +20,47 @@ package org.apache.spark.sql.util import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{CharType, StructType, VarcharType} +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy +import org.apache.spark.sql.types.{CharType, DataType, StringType, StructField, StructType, VarcharType} import org.apache.spark.unsafe.types.UTF8String private[sql] object PartitioningUtils { + + def castPartitionSpec(value: String, dt: DataType, conf: SQLConf): Expression = { + conf.storeAssignmentPolicy match { + // SPARK-30844: try our best to follow StoreAssignmentPolicy for static partition + // values but not completely follow because we can't do static type checking due to + // the reason that the parser has erased the type info of static partition values + // and converted them to string. + case StoreAssignmentPolicy.ANSI | StoreAssignmentPolicy.STRICT => + val cast = Cast(Literal(value), dt, Option(conf.sessionLocalTimeZone), + ansiEnabled = true) + cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) + cast + case _ => + Cast(Literal(value), dt, Option(conf.sessionLocalTimeZone), + ansiEnabled = false) + } + } + + private def normalizePartitionStringValue(value: String, field: StructField): String = { + val casted = Cast( + castPartitionSpec(value, field.dataType, SQLConf.get), + StringType, + Option(SQLConf.get.sessionLocalTimeZone) + ).eval() + if (casted != null) { + casted.asInstanceOf[UTF8String].toString + } else { + null + } + } + /** * Normalize the column names in partition specification, w.r.t. the real partition column names * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a @@ -61,6 +94,14 @@ private[sql] object PartitioningUtils { case other => other } v.asInstanceOf[T] + case _ if !SQLConf.get.getConf(SQLConf.SKIP_TYPE_VALIDATION_ON_ALTER_PARTITION) && + value != null && value != DEFAULT_PARTITION_NAME => + val v = value match { + case Some(str: String) => Some(normalizePartitionStringValue(str, normalizedFiled)) + case str: String => normalizePartitionStringValue(str, normalizedFiled) + case other => other + } + v.asInstanceOf[T] case _ => value } normalizedFiled.name -> normalizedVal diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 0216503fba0f4..8b985e82963e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -48,9 +48,9 @@ import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.streaming.StreamingRelation -import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.{PartitioningUtils => CatalystPartitioningUtils} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String @@ -106,22 +106,8 @@ case class DataSourceAnalysis(analyzer: Analyzer) extends Rule[LogicalPlan] { None } else if (potentialSpecs.size == 1) { val partValue = potentialSpecs.head._2 - conf.storeAssignmentPolicy match { - // SPARK-30844: try our best to follow StoreAssignmentPolicy for static partition - // values but not completely follow because we can't do static type checking due to - // the reason that the parser has erased the type info of static partition values - // and converted them to string. - case StoreAssignmentPolicy.ANSI | StoreAssignmentPolicy.STRICT => - val cast = Cast(Literal(partValue), field.dataType, Option(conf.sessionLocalTimeZone), - ansiEnabled = true) - cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) - Some(Alias(cast, field.name)()) - case _ => - val castExpression = - Cast(Literal(partValue), field.dataType, Option(conf.sessionLocalTimeZone), - ansiEnabled = false) - Some(Alias(castExpression, field.name)()) - } + Some(Alias(CatalystPartitioningUtils.castPartitionSpec( + partValue, field.dataType, conf), field.name)()) } else { throw QueryCompilationErrors.multiplePartitionColumnValuesSpecifiedError( field, potentialSpecs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala index e113499ec685e..f414de1b87c48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.command import java.time.{Duration, Period} +import org.apache.spark.SparkNumberFormatException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.quoteIdentifier @@ -40,6 +41,7 @@ import org.apache.spark.sql.internal.SQLConf */ trait AlterTableAddPartitionSuiteBase extends QueryTest with DDLCommandTestUtils { override val command = "ALTER TABLE .. ADD PARTITION" + def defaultPartitionName: String test("one partition") { withNamespaceAndTable("ns", "tbl") { t => @@ -213,4 +215,46 @@ trait AlterTableAddPartitionSuiteBase extends QueryTest with DDLCommandTestUtils Row(Period.ofYears(1), Duration.ofDays(-1), "bbb"))) } } + + test("SPARK-40798: Alter partition should verify partition value") { + def shouldThrowException(policy: SQLConf.StoreAssignmentPolicy.Value): Boolean = policy match { + case SQLConf.StoreAssignmentPolicy.ANSI | SQLConf.StoreAssignmentPolicy.STRICT => + true + case SQLConf.StoreAssignmentPolicy.LEGACY => + false + } + + SQLConf.StoreAssignmentPolicy.values.foreach { policy => + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t (c int) $defaultUsing PARTITIONED BY (p int)") + + withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> policy.toString) { + if (shouldThrowException(policy)) { + checkError( + exception = intercept[SparkNumberFormatException] { + sql(s"ALTER TABLE $t ADD PARTITION (p='aaa')") + }, + errorClass = "CAST_INVALID_INPUT", + parameters = Map( + "ansiConfig" -> "\"spark.sql.ansi.enabled\"", + "expression" -> "'aaa'", + "sourceType" -> "\"STRING\"", + "targetType" -> "\"INT\""), + context = ExpectedContext( + fragment = s"ALTER TABLE $t ADD PARTITION (p='aaa')", + start = 0, + stop = 35 + t.length)) + } else { + sql(s"ALTER TABLE $t ADD PARTITION (p='aaa')") + checkPartitions(t, Map("p" -> defaultPartitionName)) + sql(s"ALTER TABLE $t DROP PARTITION (p=null)") + } + + sql(s"ALTER TABLE $t ADD PARTITION (p=null)") + checkPartitions(t, Map("p" -> defaultPartitionName)) + sql(s"ALTER TABLE $t DROP PARTITION (p=null)") + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala index 11df5ede8bbf4..d41fd6b00f8aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.command.v1 import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.analysis.PartitionsAlreadyExistException +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME import org.apache.spark.sql.execution.command import org.apache.spark.sql.internal.SQLConf @@ -33,6 +34,8 @@ import org.apache.spark.sql.internal.SQLConf * `org.apache.spark.sql.hive.execution.command.AlterTableAddPartitionSuite` */ trait AlterTableAddPartitionSuiteBase extends command.AlterTableAddPartitionSuiteBase { + override def defaultPartitionName: String = DEFAULT_PARTITION_NAME + test("empty string as partition value") { withNamespaceAndTable("ns", "tbl") { t => sql(s"CREATE TABLE $t (col1 INT, p1 STRING) $defaultUsing PARTITIONED BY (p1)") @@ -157,6 +160,18 @@ trait AlterTableAddPartitionSuiteBase extends command.AlterTableAddPartitionSuit checkPartitions(t, Map("id" -> "1"), Map("id" -> "2")) } } + + test("SPARK-40798: Alter partition should verify partition value - legacy") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t (c int) $defaultUsing PARTITIONED BY (p int)") + + withSQLConf(SQLConf.SKIP_TYPE_VALIDATION_ON_ALTER_PARTITION.key -> "true") { + sql(s"ALTER TABLE $t ADD PARTITION (p='aaa')") + checkPartitions(t, Map("p" -> "aaa")) + sql(s"ALTER TABLE $t DROP PARTITION (p='aaa')") + } + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala index 835be8573fdc4..a9ab11e483fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.analysis.PartitionsAlreadyExistException import org.apache.spark.sql.execution.command +import org.apache.spark.sql.internal.SQLConf /** * The class contains tests for the `ALTER TABLE .. ADD PARTITION` command @@ -28,6 +29,8 @@ import org.apache.spark.sql.execution.command class AlterTableAddPartitionSuite extends command.AlterTableAddPartitionSuiteBase with CommandSuiteBase { + override def defaultPartitionName: String = "null" + test("SPARK-33650: add partition into a table which doesn't support partition management") { withNamespaceAndTable("ns", "tbl", s"non_part_$catalog") { t => sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") @@ -121,4 +124,16 @@ class AlterTableAddPartitionSuite checkPartitions(t, Map("id" -> "1"), Map("id" -> "2")) } } + + test("SPARK-40798: Alter partition should verify partition value - legacy") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t (c int) $defaultUsing PARTITIONED BY (p int)") + + withSQLConf(SQLConf.SKIP_TYPE_VALIDATION_ON_ALTER_PARTITION.key -> "true") { + sql(s"ALTER TABLE $t ADD PARTITION (p='aaa')") + checkPartitions(t, Map("p" -> defaultPartitionName)) + sql(s"ALTER TABLE $t DROP PARTITION (p=null)") + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 99aef0e47de9b..ef99b06a46a68 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -608,7 +608,7 @@ class HiveDDLSuite } test("SPARK-19129: drop partition with a empty string will drop the whole table") { - val df = spark.createDataFrame(Seq((0, "a"), (1, "b"))).toDF("partCol1", "name") + val df = spark.createDataFrame(Seq(("0", "a"), ("1", "b"))).toDF("partCol1", "name") df.write.mode("overwrite").partitionBy("partCol1").saveAsTable("partitionedTable") assertAnalysisError( "alter table partitionedTable drop partition(partCol1='')", From b7a88cd7ccba287f88f2aae1b9353868655658ff Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 24 Oct 2022 16:45:11 +0900 Subject: [PATCH 13/22] [SPARK-40821][SQL][SS][FOLLOWUP] Fix available version for new function window_time ### What changes were proposed in this pull request? This PR fixes the incorrect available version for new function `window_time` to 3.4.0 which is upcoming release for master branch. ### Why are the changes needed? The version information is incorrect. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? N/A Closes #38368 from HeartSaVioR/SPARK-40821-follow-up-minor-version-fix. Authored-by: Jungtaek Lim Signed-off-by: Jungtaek Lim --- .../org/apache/spark/sql/catalyst/expressions/WindowTime.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WindowTime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WindowTime.scala index effc1506d741a..1bb934cb2023c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WindowTime.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WindowTime.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.types._ A2 2021-01-01 00:00:00 2021-01-01 00:05:00 2021-01-01 00:04:59.999999 1 """, group = "datetime_funcs", - since = "3.3.0") + since = "3.4.0") // scalastyle:on line.size.limit line.contains.tab case class WindowTime(windowColumn: Expression) extends UnaryExpression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 780bf925ad7e5..f38f24920faf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3788,7 +3788,7 @@ object functions { * StructType { start: Timestamp, end: Timestamp } * * @group datetime_funcs - * @since 3.3.0 + * @since 3.4.0 */ def window_time(windowColumn: Column): Column = withExpr { WindowTime(windowColumn.expr) From e966c387f624da3ece6e507f95f00cf20b42f45e Mon Sep 17 00:00:00 2001 From: Luca Canali Date: Mon, 24 Oct 2022 17:04:37 +0800 Subject: [PATCH 14/22] [SPARK-34265][PYTHON][SQL] Instrument Python UDFs using SQL metrics ### What changes are proposed in this pull request? This proposes to add SQLMetrics instrumentation for Python UDF execution, including Pandas UDF, and related operations such as MapInPandas and MapInArrow. The proposed metrics are: - data sent to Python workers - data returned from Python workers - number of output rows ### Why are the changes needed? This aims at improving monitoring and performance troubleshooting of Python UDFs. In particular it is intended as an aid to answer performance-related questions such as: why is the UDF slow?, how much work has been done so far?, etc. ### Does this PR introduce _any_ user-facing change? SQL metrics are made available in the WEB UI. See the following examples: ![image1](https://issues.apache.org/jira/secure/attachment/13038693/PandasUDF_ArrowEvalPython_Metrics.png) ### How was this patch tested? Manually tested + a Python unit test and a Scala unit test have been added. Example code used for testing: ``` from pyspark.sql.functions import col, pandas_udf import time pandas_udf("long") def test_pandas(col1): time.sleep(0.02) return col1 * col1 spark.udf.register("test_pandas", test_pandas) spark.sql("select rand(42)*rand(51)*rand(12) col1 from range(10000000)").createOrReplaceTempView("t1") spark.sql("select max(test_pandas(col1)) from t1").collect() ``` This is used to test with more data pushed to the Python workers: ``` from pyspark.sql.functions import col, pandas_udf import time pandas_udf("long") def test_pandas(col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13,col14,col15,col16,col17): time.sleep(0.02) return col1 spark.udf.register("test_pandas", test_pandas) spark.sql("select rand(42)*rand(51)*rand(12) col1 from range(10000000)").createOrReplaceTempView("t1") spark.sql("select max(test_pandas(col1,col1+1,col1+2,col1+3,col1+4,col1+5,col1+6,col1+7,col1+8,col1+9,col1+10,col1+11,col1+12,col1+13,col1+14,col1+15,col1+16)) from t1").collect() ``` This (from the Spark doc) has been used to test with MapInPandas, where the number of output rows is different from the number of input rows: ``` import pandas as pd from pyspark.sql.functions import pandas_udf, PandasUDFType df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) def filter_func(iterator): for pdf in iterator: yield pdf[pdf.id == 1] df.mapInPandas(filter_func, schema=df.schema).show() ``` This for testing BatchEvalPython and metrics related to data transfer (bytes sent and received): ``` from pyspark.sql.functions import udf udf def test_udf(col1, col2): return col1 * col1 spark.sql("select id, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa' col2 from range(10)").select(test_udf("id", "col2")).collect() ``` Closes #33559 from LucaCanali/pythonUDFKeySQLMetrics. Authored-by: Luca Canali Signed-off-by: Wenchen Fan --- dev/sparktestsupport/modules.py | 1 + docs/web-ui.md | 2 + .../sql/tests/test_pandas_sqlmetrics.py | 68 +++++++++++++++++++ .../python/AggregateInPandasExec.scala | 5 +- .../ApplyInPandasWithStatePythonRunner.scala | 7 +- .../python/ArrowEvalPythonExec.scala | 5 +- .../execution/python/ArrowPythonRunner.scala | 4 +- .../python/BatchEvalPythonExec.scala | 6 +- .../python/CoGroupedArrowPythonRunner.scala | 8 ++- .../python/FlatMapCoGroupsInPandasExec.scala | 6 +- .../python/FlatMapGroupsInPandasExec.scala | 5 +- .../FlatMapGroupsInPandasWithStateExec.scala | 6 +- .../sql/execution/python/MapInBatchExec.scala | 5 +- .../execution/python/PythonArrowInput.scala | 6 ++ .../execution/python/PythonArrowOutput.scala | 8 +++ .../execution/python/PythonSQLMetrics.scala | 35 ++++++++++ .../execution/python/PythonUDFRunner.scala | 10 ++- .../execution/python/WindowInPandasExec.scala | 5 +- .../streaming/statefulOperators.scala | 5 +- .../sql/execution/python/PythonUDFSuite.scala | 19 ++++++ 20 files changed, 193 insertions(+), 23 deletions(-) create mode 100644 python/pyspark/sql/tests/test_pandas_sqlmetrics.py create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 2a427139148ad..a439b4cbbed05 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -484,6 +484,7 @@ def __hash__(self): "pyspark.sql.tests.pandas.test_pandas_udf_typehints", "pyspark.sql.tests.pandas.test_pandas_udf_typehints_with_future_annotations", "pyspark.sql.tests.pandas.test_pandas_udf_window", + "pyspark.sql.tests.test_pandas_sqlmetrics", "pyspark.sql.tests.test_readwriter", "pyspark.sql.tests.test_serde", "pyspark.sql.tests.test_session", diff --git a/docs/web-ui.md b/docs/web-ui.md index d3356ec5a43fe..e228d7fe2a987 100644 --- a/docs/web-ui.md +++ b/docs/web-ui.md @@ -406,6 +406,8 @@ Here is the list of SQL metrics: time to build hash map the time spent on building hash map ShuffledHashJoin task commit time the time spent on committing the output of a task after the writes succeed any write operation on a file-based table job commit time the time spent on committing the output of a job after the writes succeed any write operation on a file-based table + data sent to Python workers the number of bytes of serialized data sent to the Python workers ArrowEvalPython, AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas, FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas, PythonMapInArrow, WindowsInPandas + data returned from Python workers the number of bytes of serialized data received back from the Python workers ArrowEvalPython, AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas, FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas, PythonMapInArrow, WindowsInPandas ## Structured Streaming Tab diff --git a/python/pyspark/sql/tests/test_pandas_sqlmetrics.py b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py new file mode 100644 index 0000000000000..d182bafd8b543 --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from typing import cast + +from pyspark.sql.functions import pandas_udf +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + cast(str, pandas_requirement_message or pyarrow_requirement_message), +) +class PandasSQLMetrics(ReusedSQLTestCase): + def test_pandas_sql_metrics_basic(self): + # SPARK-34265: Instrument Python UDFs using SQL metrics + + python_sql_metrics = [ + "data sent to Python workers", + "data returned from Python workers", + "number of output rows", + ] + + @pandas_udf("long") + def test_pandas(col1): + return col1 * col1 + + self.spark.range(10).select(test_pandas("id")).collect() + + statusStore = self.spark._jsparkSession.sharedState().statusStore() + lastExecId = statusStore.executionsList().last().executionId() + executionMetrics = statusStore.execution(lastExecId).get().metrics().mkString() + + for metric in python_sql_metrics: + self.assertIn(metric, executionMetrics) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_sqlmetrics import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 2f85149ee8e13..6a8b197742d1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -46,7 +46,7 @@ case class AggregateInPandasExec( udfExpressions: Seq[PythonUDF], resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with PythonSQLMetrics { override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -163,7 +163,8 @@ case class AggregateInPandasExec( argOffsets, aggInputSchema, sessionLocalTimeZone, - pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context) + pythonRunnerConf, + pythonMetrics).compute(projectedRowIter, context.partitionId(), context) val joinedAttributes = groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index bd8c72029dcbe..f3531668c8e65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.api.python.PythonSQLUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER} import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA import org.apache.spark.sql.execution.streaming.GroupStateImpl @@ -58,7 +59,8 @@ class ApplyInPandasWithStatePythonRunner( stateEncoder: ExpressionEncoder[Row], keySchema: StructType, outputSchema: StructType, - stateValueSchema: StructType) + stateValueSchema: StructType, + val pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets) with PythonArrowInput[InType] with PythonArrowOutput[OutType] { @@ -116,6 +118,7 @@ class ApplyInPandasWithStatePythonRunner( val w = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch) while (inputIterator.hasNext) { + val startData = dataOut.size() val (keyRow, groupState, dataIter) = inputIterator.next() assert(dataIter.hasNext, "should have at least one data row!") w.startNewGroup(keyRow, groupState) @@ -126,6 +129,8 @@ class ApplyInPandasWithStatePythonRunner( } w.finalizeGroup() + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData } w.finalizeData() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 096712cf93529..b11dd4947af6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -61,7 +61,7 @@ private[spark] class BatchIterator[T](iter: Iterator[T], batchSize: Int) */ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan, evalType: Int) - extends EvalPythonExec { + extends EvalPythonExec with PythonSQLMetrics { private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone @@ -85,7 +85,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] argOffsets, schema, sessionLocalTimeZone, - pythonRunnerConf).compute(batchIter, context.partitionId(), context) + pythonRunnerConf, + pythonMetrics).compute(batchIter, context.partitionId(), context) columnarBatchIter.flatMap { batch => val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 8467feb91d144..dbafc444281e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -32,7 +33,8 @@ class ArrowPythonRunner( argOffsets: Array[Array[Int]], protected override val schema: StructType, protected override val timeZoneId: String, - protected override val workerConf: Map[String, String]) + protected override val workerConf: Map[String, String], + val pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets) with BasicPythonArrowInput with BasicPythonArrowOutput { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 10f7966b93d1a..ca7ca2e2f80a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{StructField, StructType} * A physical plan that evaluates a [[PythonUDF]] */ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan) - extends EvalPythonExec { + extends EvalPythonExec with PythonSQLMetrics { protected override def evaluate( funcs: Seq[ChainedPythonFunctions], @@ -77,7 +77,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] }.grouped(100).map(x => pickle.dumps(x.toArray)) // Output iterator for results from Python. - val outputIterator = new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets) + val outputIterator = + new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets, pythonMetrics) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler @@ -94,6 +95,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => + pythonMetrics("pythonNumRowsReceived") += 1 if (udfs.length == 1) { // fast path for single UDF mutableRow(0) = fromJava(result) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index 2661896ececc9..1df9f37188a7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -27,6 +27,7 @@ import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils @@ -45,7 +46,8 @@ class CoGroupedArrowPythonRunner( leftSchema: StructType, rightSchema: StructType, timeZoneId: String, - conf: Map[String, String]) + conf: Map[String, String], + val pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[ (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, evalType, argOffsets) with BasicPythonArrowOutput { @@ -77,10 +79,14 @@ class CoGroupedArrowPythonRunner( // For each we first send the number of dataframes in each group then send // first df, then send second df. End of data is marked by sending 0. while (inputIterator.hasNext) { + val startData = dataOut.size() dataOut.writeInt(2) val (nextLeft, nextRight) = inputIterator.next() writeGroup(nextLeft, leftSchema, dataOut, "left") writeGroup(nextRight, rightSchema, dataOut, "right") + + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData } dataOut.writeInt(0) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index b39787b12a484..629df51e18ae3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -54,7 +54,7 @@ case class FlatMapCoGroupsInPandasExec( output: Seq[Attribute], left: SparkPlan, right: SparkPlan) - extends SparkPlan with BinaryExecNode { + extends SparkPlan with BinaryExecNode with PythonSQLMetrics { private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) @@ -77,7 +77,6 @@ case class FlatMapCoGroupsInPandasExec( } override protected def doExecute(): RDD[InternalRow] = { - val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup) val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup) @@ -97,7 +96,8 @@ case class FlatMapCoGroupsInPandasExec( StructType.fromAttributes(leftDedup), StructType.fromAttributes(rightDedup), sessionLocalTimeZone, - pythonRunnerConf) + pythonRunnerConf, + pythonMetrics) executePython(data, output, runner) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index f0e815e966e79..271ccdb6b271f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -50,7 +50,7 @@ case class FlatMapGroupsInPandasExec( func: Expression, output: Seq[Attribute], child: SparkPlan) - extends SparkPlan with UnaryExecNode { + extends SparkPlan with UnaryExecNode with PythonSQLMetrics { private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) @@ -89,7 +89,8 @@ case class FlatMapGroupsInPandasExec( Array(argOffsets), StructType.fromAttributes(dedupAttributes), sessionLocalTimeZone, - pythonRunnerConf) + pythonRunnerConf, + pythonMetrics) executePython(data, output, runner) }} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 09123344c2e2c..3b096f07241fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -62,7 +62,8 @@ case class FlatMapGroupsInPandasWithStateExec( timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], eventTimeWatermark: Option[Long], - child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { + child: SparkPlan) + extends UnaryExecNode with PythonSQLMetrics with FlatMapGroupsWithStateExecBase { // TODO(SPARK-40444): Add the support of initial state. override protected val initialStateDeserializer: Expression = null @@ -166,7 +167,8 @@ case class FlatMapGroupsInPandasWithStateExec( stateEncoder.asInstanceOf[ExpressionEncoder[Row]], groupingAttributes.toStructType, outAttributes.toStructType, - stateType) + stateType, + pythonMetrics) val context = TaskContext.get() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index d25c138354077..450891c69483a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} * This is somewhat similar with [[FlatMapGroupsInPandasExec]] and * `org.apache.spark.sql.catalyst.plans.logical.MapPartitionsInRWithArrow` */ -trait MapInBatchExec extends UnaryExecNode { +trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { protected val func: Expression protected val pythonEvalType: Int @@ -75,7 +75,8 @@ trait MapInBatchExec extends UnaryExecNode { argOffsets, StructType(StructField("struct", outputTypes) :: Nil), sessionLocalTimeZone, - pythonRunnerConf).compute(batchIter, context.partitionId(), context) + pythonRunnerConf, + pythonMetrics).compute(batchIter, context.partitionId(), context) val unsafeProj = UnsafeProjection.create(output, output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index bf66791183ece..5a0541d11cbe6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, PythonRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils @@ -41,6 +42,8 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected val timeZoneId: String + protected def pythonMetrics: Map[String, SQLMetric] + protected def writeIteratorToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, @@ -115,6 +118,7 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In val arrowWriter = ArrowWriter.create(root) while (inputIterator.hasNext) { + val startData = dataOut.size() val nextBatch = inputIterator.next() while (nextBatch.hasNext) { @@ -124,6 +128,8 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In arrowWriter.finish() writer.writeBatch() arrowWriter.reset() + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index 339f114539c28..c12c690f776a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -27,6 +27,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, SpecialLengths} +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} @@ -37,6 +38,8 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column */ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] => + protected def pythonMetrics: Map[String, SQLMetric] + protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { } protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT @@ -82,10 +85,15 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ } try { if (reader != null && batchLoaded) { + val bytesReadStart = reader.bytesRead() batchLoaded = reader.loadNextBatch() if (batchLoaded) { val batch = new ColumnarBatch(vectors) + val rowCount = root.getRowCount batch.setNumRows(root.getRowCount) + val bytesReadEnd = reader.bytesRead() + pythonMetrics("pythonNumRowsReceived") += rowCount + pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart deserializeColumnarBatch(batch, schema) } else { reader.close(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala new file mode 100644 index 0000000000000..a748c1bc10082 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetrics + +private[sql] trait PythonSQLMetrics { self: SparkPlan => + + val pythonMetrics = Map( + "pythonDataSent" -> SQLMetrics.createSizeMetric(sparkContext, + "data sent to Python workers"), + "pythonDataReceived" -> SQLMetrics.createSizeMetric(sparkContext, + "data returned from Python workers"), + "pythonNumRowsReceived" -> SQLMetrics.createMetric(sparkContext, + "number of output rows") + ) + + override lazy val metrics = pythonMetrics +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index d1109d251c284..09e06b55df3e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark._ import org.apache.spark.api.python._ +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf /** @@ -31,7 +32,8 @@ import org.apache.spark.sql.internal.SQLConf class PythonUDFRunner( funcs: Seq[ChainedPythonFunctions], evalType: Int, - argOffsets: Array[Array[Int]]) + argOffsets: Array[Array[Int]], + pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[Array[Byte], Array[Byte]]( funcs, evalType, argOffsets) { @@ -50,8 +52,13 @@ class PythonUDFRunner( } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val startData = dataOut.size() + PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData } } } @@ -77,6 +84,7 @@ class PythonUDFRunner( case length if length > 0 => val obj = new Array[Byte](length) stream.readFully(obj) + pythonMetrics("pythonDataReceived") += length obj case 0 => Array.emptyByteArray case SpecialLengths.TIMING_DATA => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index ccb1ed92525d1..dcaffed89cca9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -84,7 +84,7 @@ case class WindowInPandasExec( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], child: SparkPlan) - extends WindowExecBase { + extends WindowExecBase with PythonSQLMetrics { /** * Helper functions and data structures for window bounds @@ -375,7 +375,8 @@ case class WindowInPandasExec( argOffsets, pythonInputSchema, sessionLocalTimeZone, - pythonRunnerConf).compute(pythonInput, context.partitionId(), context) + pythonRunnerConf, + pythonMetrics).compute(pythonInput, context.partitionId(), context) val joined = new JoinedRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 2b8fc6515618d..b540f9f00939a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.python.PythonSQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ @@ -93,7 +94,7 @@ trait StateStoreReader extends StatefulOperator { } /** An operator that writes to a StateStore. */ -trait StateStoreWriter extends StatefulOperator { self: SparkPlan => +trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: SparkPlan => override lazy val metrics = statefulOperatorCustomMetrics ++ Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), @@ -109,7 +110,7 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => "numShufflePartitions" -> SQLMetrics.createMetric(sparkContext, "number of shuffle partitions"), "numStateStoreInstances" -> SQLMetrics.createMetric(sparkContext, "number of state store instances") - ) ++ stateStoreCustomMetrics + ) ++ stateStoreCustomMetrics ++ pythonMetrics /** * Get the progress made by this stateful operator after execution. This should be called in diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 70784c20a8eb3..7850b2d79b045 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -84,4 +84,23 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { checkAnswer(actual, expected) } + + test("SPARK-34265: Instrument Python UDF execution using SQL Metrics") { + + val pythonSQLMetrics = List( + "data sent to Python workers", + "data returned from Python workers", + "number of output rows") + + val df = base.groupBy(pythonTestUDF(base("a") + 1)) + .agg(pythonTestUDF(pythonTestUDF(base("a") + 1))) + df.count() + + val statusStore = spark.sharedState.statusStore + val lastExecId = statusStore.executionsList.last.executionId + val executionMetrics = statusStore.execution(lastExecId).get.metrics.mkString + for (metric <- pythonSQLMetrics) { + assert(executionMetrics.contains(metric)) + } + } } From 6edcafc8aa114f53cb7c05666aacffd21b21dcaa Mon Sep 17 00:00:00 2001 From: panbingkun Date: Mon, 24 Oct 2022 15:13:58 +0500 Subject: [PATCH 15/22] [SPARK-40891][SQL][TESTS] Check error classes in TableIdentifierParserSuite ### What changes were proposed in this pull request? his PR aims to replace 'intercept' with 'Check error classes' in TableIdentifierParserSuite. ### Why are the changes needed? The changes improve the error framework. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By running the modified test suite: ``` $ build/sbt "test:testOnly *TableIdentifierParserSuite" ``` Closes #38364 from panbingkun/SPARK-40891. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../parser/TableIdentifierParserSuite.scala | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index c2b240b3c496e..62557ead1d2ee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -290,8 +290,17 @@ class TableIdentifierParserSuite extends SQLKeywordUtils { assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q")) // Illegal names. - Seq("", "d.q.g", "t:", "${some.var.x}", "tab:1").foreach { identifier => - intercept[ParseException](parseTableIdentifier(identifier)) + Seq( + "" -> ("PARSE_EMPTY_STATEMENT", Map.empty[String, String]), + "d.q.g" -> ("PARSE_SYNTAX_ERROR", Map("error" -> "'.'", "hint" -> "")), + "t:" -> ("PARSE_SYNTAX_ERROR", Map("error" -> "':'", "hint" -> ": extra input ':'")), + "${some.var.x}" -> ("PARSE_SYNTAX_ERROR", Map("error" -> "'$'", "hint" -> "")), + "tab:1" -> ("PARSE_SYNTAX_ERROR", Map("error" -> "':'", "hint" -> "")) + ).foreach { case (identifier, (errorClass, parameters)) => + checkError( + exception = intercept[ParseException](parseTableIdentifier(identifier)), + errorClass = errorClass, + parameters = parameters) } } @@ -307,10 +316,10 @@ class TableIdentifierParserSuite extends SQLKeywordUtils { withSQLConf(SQLConf.ANSI_ENABLED.key -> "true", SQLConf.ENFORCE_RESERVED_KEYWORDS.key -> "true") { reservedKeywordsInAnsiMode.foreach { keyword => - val errMsg = intercept[ParseException] { - parseTableIdentifier(keyword) - }.getMessage - assert(errMsg.contains("Syntax error at or near")) + checkError( + exception = intercept[ParseException](parseTableIdentifier(keyword)), + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> s"'$keyword'", "hint" -> "")) assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) } @@ -363,7 +372,10 @@ class TableIdentifierParserSuite extends SQLKeywordUtils { val complexName = TableIdentifier("`weird`table`name", Some("`d`b`1")) assert(complexName === parseTableIdentifier("```d``b``1`.```weird``table``name`")) assert(complexName === parseTableIdentifier(complexName.quotedString)) - intercept[ParseException](parseTableIdentifier(complexName.unquotedString)) + checkError( + exception = intercept[ParseException](parseTableIdentifier(complexName.unquotedString)), + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'b'", "hint" -> "")) // Table identifier contains continuous backticks should be treated correctly. val complexName2 = TableIdentifier("x``y", Some("d``b")) assert(complexName2 === parseTableIdentifier(complexName2.quotedString)) From e2e449e83cdc743e7701ed93eae782ef9042f2d1 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 24 Oct 2022 19:45:21 +0900 Subject: [PATCH 16/22] [SPARK-40897][DOCS] Add some PySpark APIs to References ### What changes were proposed in this pull request? add following missing APIs to references: - StorageLevel.MEMORY_AND_DISK_DESER - TaskContext.cpus - BarrierTaskContext.cpus ### Why are the changes needed? they were missing in Reference ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? manually check, for `BarrierTaskContext.cpus` ``` In [10]: from pyspark import BarrierTaskContext In [11]: rdd = spark.sparkContext.parallelize([1]) In [12]: rdd.barrier().mapPartitions(lambda _: [BarrierTaskContext.get().cpus()]).collect() Out[12]: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] ``` Closes #38373 from zhengruifeng/py_doc_missing. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/docs/source/reference/pyspark.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/docs/source/reference/pyspark.rst b/python/docs/source/reference/pyspark.rst index c3afae10ddb61..ec3df07163921 100644 --- a/python/docs/source/reference/pyspark.rst +++ b/python/docs/source/reference/pyspark.rst @@ -262,10 +262,12 @@ Management StorageLevel.DISK_ONLY_3 StorageLevel.MEMORY_AND_DISK StorageLevel.MEMORY_AND_DISK_2 + StorageLevel.MEMORY_AND_DISK_DESER StorageLevel.MEMORY_ONLY StorageLevel.MEMORY_ONLY_2 StorageLevel.OFF_HEAP TaskContext.attemptNumber + TaskContext.cpus TaskContext.get TaskContext.getLocalProperty TaskContext.partitionId @@ -277,6 +279,7 @@ Management BarrierTaskContext.allGather BarrierTaskContext.attemptNumber BarrierTaskContext.barrier + BarrierTaskContext.cpus BarrierTaskContext.get BarrierTaskContext.getLocalProperty BarrierTaskContext.getTaskInfos From 363b8539059183e422f288acf58d2b043c2fc603 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Mon, 24 Oct 2022 08:27:44 -0500 Subject: [PATCH 17/22] [SPARK-39977][BUILD] Remove unnecessary guava exclusion from jackson-module-scala ### What changes were proposed in this pull request? Remove unnecessary guava exclusion from jackson-module-scala ### Why are the changes needed? The exclusion added in SPARK-6149, the recent versions of jackson-module-scala does not depend on gauva any more, so we can remove this exclusion. https://mvnrepository.com/artifact/com.fasterxml.jackson.module/jackson-module-scala_2.12/2.13.3 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Exsiting UT. Closes #37405 from pan3793/SPARK-39977. Authored-by: Cheng Pan Signed-off-by: Sean Owen --- pom.xml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pom.xml b/pom.xml index c7efd8f9f61bf..77b89ddb76ac3 100644 --- a/pom.xml +++ b/pom.xml @@ -990,18 +990,10 @@ jackson-datatype-jsr310 ${fasterxml.jackson.version} - com.fasterxml.jackson.module jackson-module-scala_${scala.binary.version} ${fasterxml.jackson.version} - - - com.google.guava - guava - - com.fasterxml.jackson.module From 880d9bb3fcb69001512886496f2988ed17cc4c50 Mon Sep 17 00:00:00 2001 From: Phil Date: Mon, 24 Oct 2022 08:28:54 -0500 Subject: [PATCH 18/22] [SPARK-40739][SPARK-40738] Fixes for cygwin/msys2/mingw sbt build and bash scripts This fixes two problems that affect development in a Windows shell environment, such as `cygwin` or `msys2`. ### The fixed build error Running `./build/sbt packageBin` from A Windows cygwin `bash` session fails. This occurs if `WSL` is installed, because `project\SparkBuild.scala` creates a `bash` process, but `WSL bash` is called, even though `cygwin bash` appears earlier in the `PATH`. In addition, file path arguments to bash contain backslashes. The fix is to insure that the correct `bash` is called, and that arguments passed to `bash` are passed with slashes rather than **slashes.** ### The build error message: ```bash ./build.sbt packageBin ```
[info] compiling 9 Java sources to C:\Users\philwalk\workspace\spark\common\sketch\target\scala-2.12\classes ...
/bin/bash: C:Usersphilwalkworkspacesparkcore/../build/spark-build-info: No such file or directory
[info] compiling 1 Scala source to C:\Users\philwalk\workspace\spark\tools\target\scala-2.12\classes ...
[info] compiling 5 Scala sources to C:\Users\philwalk\workspace\spark\mllib-local\target\scala-2.12\classes ...
[info] Compiling 5 protobuf files to C:\Users\philwalk\workspace\spark\connector\connect\target\scala-2.12\src_managed\main
[error] stack trace is suppressed; run last core / Compile / managedResources for the full output
[error] (core / Compile / managedResources) Nonzero exit value: 127
[error] Total time: 42 s, completed Oct 8, 2022, 4:49:12 PM
sbt:spark-parent>
sbt:spark-parent> last core /Compile /managedResources
last core /Compile /managedResources
[error] java.lang.RuntimeException: Nonzero exit value: 127
[error]         at scala.sys.package$.error(package.scala:30)
[error]         at scala.sys.process.ProcessBuilderImpl$AbstractBuilder.slurp(ProcessBuilderImpl.scala:138)
[error]         at scala.sys.process.ProcessBuilderImpl$AbstractBuilder.$bang$bang(ProcessBuilderImpl.scala:108)
[error]         at Core$.$anonfun$settings$4(SparkBuild.scala:604)
[error]         at scala.Function1.$anonfun$compose$1(Function1.scala:49)
[error]         at sbt.internal.util.$tilde$greater.$anonfun$$u2219$1(TypeFunctions.scala:62)
[error]         at sbt.std.Transform$$anon$4.work(Transform.scala:68)
[error]         at sbt.Execute.$anonfun$submit$2(Execute.scala:282)
[error]         at sbt.internal.util.ErrorHandling$.wideConvert(ErrorHandling.scala:23)
[error]         at sbt.Execute.work(Execute.scala:291)
[error]         at sbt.Execute.$anonfun$submit$1(Execute.scala:282)
[error]         at sbt.ConcurrentRestrictions$$anon$4.$anonfun$submitValid$1(ConcurrentRestrictions.scala:265)
[error]         at sbt.CompletionService$$anon$2.call(CompletionService.scala:64)
[error]         at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
[error]         at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:515)
[error]         at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
[error]         at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
[error]         at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
[error]         at java.base/java.lang.Thread.run(Thread.java:834)
[error] (core / Compile / managedResources) Nonzero exit value: 127
### bash scripts fail when run from `cygwin` or `msys2` The other problem fixed by the PR is to address problems preventing the `bash` scripts (`spark-shell`, `spark-submit`, etc.) from being used in Windows `SHELL` environments. The problem is that the bash version of `spark-class` fails in a Windows shell environment, the result of `launcher/src/main/java/org/apache/spark/launcher/Main.java` not following the convention expected by `spark-class`, and also appending CR to line endings. The resulting error message not helpful. There are two parts to this fix: 1. modify `Main.java` to treat a `SHELL` session on Windows as a `bash` session 2. remove the appended CR character when parsing the output produced by `Main.java` ### Does this PR introduce _any_ user-facing change? These changes should NOT affect anyone who is not trying build or run bash scripts from a Windows SHELL environment. ### How was this patch tested? Manual tests were performed to verify both changes. ### related JIRA issues The following 2 JIRA issue were created. Both are fixed by this PR. They are both linked to this PR. - Bug SPARK-40739 "sbt packageBin" fails in cygwin or other windows bash session - Bug SPARK-40738 spark-shell fails with "bad array" Closes #38228 from philwalk/windows-shell-env-fixes. Authored-by: Phil Signed-off-by: Sean Owen --- bin/spark-class | 3 ++- bin/spark-class2.cmd | 2 ++ build/spark-build-info | 2 +- .../src/main/java/org/apache/spark/launcher/Main.java | 6 ++++-- project/SparkBuild.scala | 9 ++++++++- 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/bin/spark-class b/bin/spark-class index c1461a7712289..fc343ca29fddd 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -77,7 +77,8 @@ set +o posix CMD=() DELIM=$'\n' CMD_START_FLAG="false" -while IFS= read -d "$DELIM" -r ARG; do +while IFS= read -d "$DELIM" -r _ARG; do + ARG=${_ARG//$'\r'} if [ "$CMD_START_FLAG" == "true" ]; then CMD+=("$ARG") else diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 68b271d1d05d9..800ec0c02c22f 100755 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -69,6 +69,8 @@ rem SPARK-28302: %RANDOM% would return the same number if we call it instantly a rem so we should make it sure to generate unique file to avoid process collision of writing into rem the same file concurrently. if exist %LAUNCHER_OUTPUT% goto :gen +rem unset SHELL to indicate non-bash environment to launcher/Main +set SHELL= "%RUNNER%" -Xmx128m -cp "%LAUNCH_CLASSPATH%" org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% for /f "tokens=*" %%i in (%LAUNCHER_OUTPUT%) do ( set SPARK_CMD=%%i diff --git a/build/spark-build-info b/build/spark-build-info index eb0e3d730e23e..26157e8cf8cb1 100755 --- a/build/spark-build-info +++ b/build/spark-build-info @@ -24,7 +24,7 @@ RESOURCE_DIR="$1" mkdir -p "$RESOURCE_DIR" -SPARK_BUILD_INFO="${RESOURCE_DIR}"/spark-version-info.properties +SPARK_BUILD_INFO="${RESOURCE_DIR%/}"/spark-version-info.properties echo_build_properties() { echo version=$1 diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index e1054c7060f12..6501fc1764c25 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -87,7 +87,9 @@ public static void main(String[] argsArray) throws Exception { cmd = buildCommand(builder, env, printLaunchCommand); } - if (isWindows()) { + // test for shell environments, to enable non-Windows treatment of command line prep + boolean shellflag = !isEmpty(System.getenv("SHELL")); + if (isWindows() && !shellflag) { System.out.println(prepareWindowsCommand(cmd, env)); } else { // A sequence of NULL character and newline separates command-strings and others. @@ -96,7 +98,7 @@ public static void main(String[] argsArray) throws Exception { // In bash, use NULL as the arg separator since it cannot be used in an argument. List bashCmd = prepareBashCommand(cmd, env); for (String c : bashCmd) { - System.out.print(c); + System.out.print(c.replaceFirst("\r$","")); System.out.print('\0'); } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index cc103e4ab00ac..33883a2efaa51 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -599,11 +599,18 @@ object SparkParallelTestGrouping { object Core { import scala.sys.process.Process + def buildenv = Process(Seq("uname")).!!.trim.replaceFirst("[^A-Za-z0-9].*", "").toLowerCase + def bashpath = Process(Seq("where", "bash")).!!.split("[\r\n]+").head.replace('\\', '/') lazy val settings = Seq( (Compile / resourceGenerators) += Def.task { val buildScript = baseDirectory.value + "/../build/spark-build-info" val targetDir = baseDirectory.value + "/target/extra-resources/" - val command = Seq("bash", buildScript, targetDir, version.value) + // support Windows build under cygwin/mingw64, etc + val bash = buildenv match { + case "cygwin" | "msys2" | "mingw64" | "clang64" => bashpath + case _ => "bash" + } + val command = Seq(bash, buildScript, targetDir, version.value) Process(command).!! val propsFile = baseDirectory.value / "target" / "extra-resources" / "spark-version-info.properties" Seq(propsFile) From 05ad1027a897b63a9f82f7131f6a024732a7e64d Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Mon, 24 Oct 2022 08:30:34 -0500 Subject: [PATCH 19/22] [SPARK-40391][SQL][TESTS][FOLLOWUP] Change to use `mockito-inline` instead of manually write MockMaker ### What changes were proposed in this pull request? This pr aims use `mockito-inline` instead of manually write `MockMaker` ### Why are the changes needed? `mockito-inline` is a more recommended [way](https://javadoc.io/doc/org.mockito/mockito-core/latest/org/mockito/Mockito.html#39) to use mockito to mocking final types, enums and final methods and `mllib` and `mllib-local` module is already using `mockito-inline`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions - Manual test:run `build/sbt clean "sql/testOnly *QueryExecutionErrorsSuite"` with Java 8u352, 11.0.17 and 17.0.5, all 3 Java versions passed Closes #38372 from LuciferYang/SPARK-40391. Authored-by: yangjie01 Signed-off-by: Sean Owen --- sql/core/pom.xml | 5 +++++ .../org.mockito.plugins.MockMaker | 18 ------------------ 2 files changed, 5 insertions(+), 18 deletions(-) delete mode 100644 sql/core/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 7203fc591081a..cfcf7455ad030 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -193,6 +193,11 @@ mockito-core test
+ + org.mockito + mockito-inline + test + org.seleniumhq.selenium selenium-java diff --git a/sql/core/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker b/sql/core/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker deleted file mode 100644 index eb074c6ae3fca..0000000000000 --- a/sql/core/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker +++ /dev/null @@ -1,18 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -mock-maker-inline From 4ba7ce2136ab4dcdb5e4b28e52e4c3864f5f8e3f Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 24 Oct 2022 23:09:30 +0800 Subject: [PATCH 20/22] [SPARK-40857][CONNECT] Enable configurable GPRC Interceptors ### What changes were proposed in this pull request? To be able to modify the incoming requests for the Spark Connect GRPC service, for example to be able to translate metadata from the HTTP/2 request to values in the proto message the GRPC service needs to be configured using an interceptor. This patch adds two ways to configure interceptors for the GRPC service. First, we can now configure interceptors in the `SparkConnectInterceptorRegistry` by adding a value to the `interceptorChain` like in the example below: ``` object SparkConnectInterceptorRegistry { // Contains the list of configured interceptors. private lazy val interceptorChain: Seq[InterceptorBuilder] = Seq( interceptor[LoggingInterceptor](classOf[LoggingInterceptor]) ) // ... } ``` The second way to configure interceptors is by configuring them using Spark configuration values at startup. Therefore a new config key has been added called: `spark.connect.grpc.interceptor.classes`. This config value contains a comma-separated list of classes that are added as interceptors to the system. ``` ./bin/pyspark --conf spark.connect.grpc.interceptor.classes=com.my.important.LoggingInterceptor ``` During startup all of the interceptors are added in order to the `NettyServerBuilder`. ``` // Add all registered interceptors to the server builder. SparkConnectInterceptorRegistry.chainInterceptors(sb) ``` ### Why are the changes needed? Provide a configurable and extensible way to configure interceptors. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit Tests Closes #38320 from grundprinzip/SPARK-40857. Lead-authored-by: Martin Grund Co-authored-by: Martin Grund Signed-off-by: Wenchen Fan --- .../spark/sql/connect/config/Connect.scala | 8 + .../SparkConnectInterceptorRegistry.scala | 109 ++++++++++++ .../connect/service/SparkConnectService.scala | 3 + .../service/InterceptorRegistrySuite.scala | 167 ++++++++++++++++++ .../main/resources/error/error-classes.json | 19 +- 5 files changed, 305 insertions(+), 1 deletion(-) create mode 100644 connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala create mode 100644 connector/connect/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 81c5328c9b29b..76d159cfd159a 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -26,4 +26,12 @@ private[spark] object Connect { .intConf .createWithDefault(15002) + val CONNECT_GRPC_INTERCEPTOR_CLASSES = + ConfigBuilder("spark.connect.grpc.interceptor.classes") + .doc( + "Comma separated list of class names that must " + + "implement the io.grpc.ServerInterceptor interface.") + .version("3.4.0") + .stringConf + .createOptional } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala new file mode 100644 index 0000000000000..cddd4b976638d --- /dev/null +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import java.lang.reflect.InvocationTargetException + +import io.grpc.ServerInterceptor +import io.grpc.netty.NettyServerBuilder + +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.util.Utils + +/** + * This object provides a global list of configured interceptors for GRPC. The interceptors are + * added to the GRPC server in order of their position in the list. Once the statically compiled + * interceptors are added, dynamically configured interceptors are added. + */ +object SparkConnectInterceptorRegistry { + + // Contains the list of configured interceptors. + private lazy val interceptorChain: Seq[InterceptorBuilder] = Seq( + // Adding a new interceptor at compile time works like the eaxmple below with the dummy + // interceptor: + // interceptor[DummyInterceptor](classOf[DummyInterceptor]) + ) + + /** + * Given a NettyServerBuilder instance, will chain all interceptors to it in reverse order. + * @param sb + */ + def chainInterceptors(sb: NettyServerBuilder): Unit = { + interceptorChain.foreach(i => sb.intercept(i())) + createConfiguredInterceptors().foreach(sb.intercept(_)) + } + + // Type used to identify the closure responsible to instantiate a ServerInterceptor. + type InterceptorBuilder = () => ServerInterceptor + + /** + * Exposed for testing only. + */ + def createConfiguredInterceptors(): Seq[ServerInterceptor] = { + // Check all values from the Spark conf. + val classes = SparkEnv.get.conf.get(Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES) + if (classes.nonEmpty) { + classes.get + .split(",") + .map(_.trim) + .filter(_.nonEmpty) + .map(Utils.classForName[ServerInterceptor](_)) + .map(createInstance(_)) + } else { + Seq.empty + } + } + + /** + * Creates a new instance of T using the default constructor. + * @param cls + * @tparam T + * @return + */ + private def createInstance[T <: ServerInterceptor](cls: Class[T]): ServerInterceptor = { + val ctorOpt = cls.getConstructors.find(_.getParameterCount == 0) + if (ctorOpt.isEmpty) { + throw new SparkException( + errorClass = "CONNECT.INTERCEPTOR_CTOR_MISSING", + messageParameters = Map("cls" -> cls.getName), + cause = null) + } + try { + ctorOpt.get.newInstance().asInstanceOf[T] + } catch { + case e: InvocationTargetException => + throw new SparkException( + errorClass = "CONNECT.INTERCEPTOR_RUNTIME_ERROR", + messageParameters = Map("msg" -> e.getTargetException.getMessage), + cause = e) + case e: Exception => + throw new SparkException( + errorClass = "CONNECT.INTERCEPTOR_RUNTIME_ERROR", + messageParameters = Map("msg" -> e.getMessage), + cause = e) + } + } + + /** + * Creates a callable expression that instantiates the configured GPRC interceptor + * implementation. + */ + private def interceptor[T <: ServerInterceptor](cls: Class[T]): InterceptorBuilder = + () => createInstance(cls) +} diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 7c494e39a69a0..20776a29edab4 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -192,6 +192,9 @@ object SparkConnectService { .forPort(port) .addService(new SparkConnectService(debugMode)) + // Add all registered interceptors to the server builder. + SparkConnectInterceptorRegistry.chainInterceptors(sb) + // If debug mode is configured, load the ProtoReflection service so that tools like // grpcurl can introspect the API for debugging. if (debugMode) { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala new file mode 100644 index 0000000000000..bac02ec7af695 --- /dev/null +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor} +import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener +import io.grpc.netty.NettyServerBuilder + +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Used for testing only, does not do anything. + */ +class DummyInterceptor extends ServerInterceptor { + override def interceptCall[ReqT, RespT]( + call: ServerCall[ReqT, RespT], + headers: Metadata, + next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = { + val listener = next.startCall(call, headers) + new SimpleForwardingServerCallListener[ReqT](listener) { + override def onMessage(message: ReqT): Unit = { + delegate().onMessage(message) + } + } + } +} + +/** + * Used for testing only. + */ +class TestingInterceptorNoTrivialCtor(id: Int) extends ServerInterceptor { + override def interceptCall[ReqT, RespT]( + call: ServerCall[ReqT, RespT], + headers: Metadata, + next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = { + val listener = next.startCall(call, headers) + new SimpleForwardingServerCallListener[ReqT](listener) { + override def onMessage(message: ReqT): Unit = { + delegate().onMessage(message) + } + } + } +} + +/** + * Used for testing only. + */ +class TestingInterceptorInstantiationError extends ServerInterceptor { + throw new ArrayIndexOutOfBoundsException("Bad Error") + + override def interceptCall[ReqT, RespT]( + call: ServerCall[ReqT, RespT], + headers: Metadata, + next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = { + val listener = next.startCall(call, headers) + new SimpleForwardingServerCallListener[ReqT](listener) { + override def onMessage(message: ReqT): Unit = { + delegate().onMessage(message) + } + } + } +} + +class InterceptorRegistrySuite extends SharedSparkSession { + + override def beforeEach(): Unit = { + if (SparkEnv.get.conf.contains(Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES)) { + SparkEnv.get.conf.remove(Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES) + } + } + + def withSparkConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SparkEnv.get.conf + pairs.foreach { kv => conf.set(kv._1, kv._2) } + try f + finally { + pairs.foreach { kv => conf.remove(kv._1) } + } + } + + test("Check that the empty registry works") { + val sb = NettyServerBuilder.forPort(9999) + SparkConnectInterceptorRegistry.chainInterceptors(sb) + } + + test("Test server builder and configured interceptor") { + withSparkConf( + Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> + "org.apache.spark.sql.connect.service.DummyInterceptor") { + val sb = NettyServerBuilder.forPort(9999) + SparkConnectInterceptorRegistry.chainInterceptors(sb) + } + } + + test("Test server build throws when using bad configured interceptor") { + withSparkConf( + Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> + "org.apache.spark.sql.connect.service.TestingInterceptorNoTrivialCtor") { + val sb = NettyServerBuilder.forPort(9999) + assertThrows[SparkException] { + SparkConnectInterceptorRegistry.chainInterceptors(sb) + } + } + } + + test("Exception handling for interceptor classes") { + withSparkConf( + Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> + "org.apache.spark.sql.connect.service.TestingInterceptorNoTrivialCtor") { + assertThrows[SparkException] { + SparkConnectInterceptorRegistry.createConfiguredInterceptors + } + } + + withSparkConf( + Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> + "org.apache.spark.sql.connect.service.TestingInterceptorInstantiationError") { + assertThrows[SparkException] { + SparkConnectInterceptorRegistry.createConfiguredInterceptors + } + } + } + + test("No configured interceptors returns empty list") { + // Not set. + assert(SparkConnectInterceptorRegistry.createConfiguredInterceptors.isEmpty) + // Set to empty string + withSparkConf(Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> "") { + assert(SparkConnectInterceptorRegistry.createConfiguredInterceptors.isEmpty) + } + } + + test("Configured classes can have multiple entries") { + withSparkConf( + Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> + (" org.apache.spark.sql.connect.service.DummyInterceptor," + + " org.apache.spark.sql.connect.service.DummyInterceptor ")) { + assert(SparkConnectInterceptorRegistry.createConfiguredInterceptors.size == 2) + } + } + + test("Configured class not found is properly thrown") { + withSparkConf(Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> "this.class.does.not.exist") { + assertThrows[ClassNotFoundException] { + SparkConnectInterceptorRegistry.createConfiguredInterceptors + } + } + } + +} diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 0f9b665718ca6..804b95c65502d 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -76,6 +76,23 @@ "Another instance of this query was just started by a concurrent session." ] }, + "CONNECT" : { + "message" : [ + "Generic Spark Connect error." + ], + "subClass" : { + "INTERCEPTOR_CTOR_MISSING" : { + "message" : [ + "Cannot instantiate GRPC interceptor because is missing a default constructor without arguments." + ] + }, + "INTERCEPTOR_RUNTIME_ERROR" : { + "message" : [ + "Error instantiating GRPC interceptor: " + ] + } + } + }, "CONVERSION_INVALID_INPUT" : { "message" : [ "The value () cannot be converted to because it is malformed. Correct the value as per the syntax, or change its format. Use to tolerate malformed input and return NULL instead." @@ -4291,4 +4308,4 @@ "Not enough memory to build and broadcast the table to all worker nodes. As a workaround, you can either disable broadcast by setting to -1 or increase the spark driver memory by setting to a higher value" ] } -} +} \ No newline at end of file From 9d2757c6c4e511d8230ea345b155eedb2328b9c8 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Mon, 24 Oct 2022 19:15:16 +0300 Subject: [PATCH 21/22] [SPARK-40750][SQL] Migrate type check failures of math expressions onto error classes ### What changes were proposed in this pull request? This pr replaces TypeCheckFailure by DataTypeMismatch in type checks in the math expressions, includes: - hash.scala (HashExpression) - mathExpressions.scala (RoundBase) ### Why are the changes needed? Migration onto error classes unifies Spark SQL error messages. ### Does this PR introduce _any_ user-facing change? Yes. The PR changes user-facing error messages. ### How was this patch tested? - Add new UT - Update existed UT - Pass GA. Closes #38332 from panbingkun/SPARK-40750. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../main/resources/error/error-classes.json | 5 ++ .../spark/sql/catalyst/expressions/hash.scala | 18 +++-- .../expressions/mathExpressions.scala | 10 ++- .../ExpressionTypeCheckingSuite.scala | 65 +++++++++++++++++-- .../spark/sql/DataFrameFunctionsSuite.scala | 62 ++++++++++++++++-- 5 files changed, 139 insertions(+), 21 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 804b95c65502d..6f5b3b5a1347b 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -160,6 +160,11 @@ "Offset expression must be a literal." ] }, + "HASH_MAP_TYPE" : { + "message" : [ + "Input to the function cannot contain elements of the \"MAP\" type. In Spark, same maps may have different hashcode, thus hash expressions are prohibited on \"MAP\" elements. To restore previous behavior set \"spark.sql.legacy.allowHashOnMapType\" to \"true\"." + ] + }, "INVALID_JSON_MAP_KEY_TYPE" : { "message" : [ "Input schema can only contain STRING as a key type for a MAP." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 7ac486f05af1b..4f8ed1953f409 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -28,6 +28,8 @@ import org.apache.commons.codec.digest.MessageDigestAlgorithms import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} @@ -268,15 +270,17 @@ abstract class HashExpression[E] extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.length < 1) { - TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName requires at least one argument") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_PARAMS", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "expectedNum" -> "> 0", + "actualNum" -> children.length.toString)) } else if (children.exists(child => hasMapType(child.dataType)) && !SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_HASH_ON_MAPTYPE)) { - TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName cannot contain elements of MapType. In Spark, same maps " + - "may have different hashcode, thus hash expressions are prohibited on MapType elements." + - s" To restore previous behavior set ${SQLConf.LEGACY_ALLOW_HASH_ON_MAPTYPE.key} " + - "to true.") + DataTypeMismatch( + errorSubClass = "HASH_MAP_TYPE", + messageParameters = Map("functionName" -> toSQLId(prettyName))) } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 5643598b4bd56..28739fb47a2b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -22,7 +22,8 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils} @@ -1481,7 +1482,12 @@ abstract class RoundBase(child: Expression, scale: Expression, if (scale.foldable) { TypeCheckSuccess } else { - TypeCheckFailure("Only foldable Expression is allowed for scale arguments") + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "scala", + "inputType" -> toSQLType(scale.dataType), + "inputExpr" -> toSQLExpr(scale))) } case f => f } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index b41f627bac94e..0d66ad4b06848 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -440,8 +441,31 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer ) assertError(Coalesce(Nil), "function coalesce requires at least one argument") - assertError(new Murmur3Hash(Nil), "function hash requires at least one argument") - assertError(new XxHash64(Nil), "function xxhash64 requires at least one argument") + + val murmur3Hash = new Murmur3Hash(Nil) + checkError( + exception = intercept[AnalysisException] { + assertSuccess(murmur3Hash) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + parameters = Map( + "sqlExpr" -> "\"hash()\"", + "functionName" -> toSQLId(murmur3Hash.prettyName), + "expectedNum" -> "> 0", + "actualNum" -> "0")) + + val xxHash64 = new XxHash64(Nil) + checkError( + exception = intercept[AnalysisException] { + assertSuccess(xxHash64) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + parameters = Map( + "sqlExpr" -> "\"xxhash64()\"", + "functionName" -> toSQLId(xxHash64.prettyName), + "expectedNum" -> "> 0", + "actualNum" -> "0")) + assertError(Explode($"intField"), "input to function explode should be array or map type") assertError(PosExplode($"intField"), @@ -478,8 +502,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer assertSuccess(Round(Literal(null), Literal(null))) assertSuccess(Round($"intField", Literal(1))) - assertError(Round($"intField", $"intField"), - "Only foldable Expression is allowed") + checkError( + exception = intercept[AnalysisException] { + assertSuccess(Round($"intField", $"intField")) + }, + errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "sqlExpr" -> "\"round(intField, intField)\"", + "inputName" -> "scala", + "inputType" -> "\"INT\"", + "inputExpr" -> "\"intField\"")) + checkError( exception = intercept[AnalysisException] { assertSuccess(Round($"intField", $"booleanField")) @@ -516,9 +549,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer assertSuccess(BRound(Literal(null), Literal(null))) assertSuccess(BRound($"intField", Literal(1))) - - assertError(BRound($"intField", $"intField"), - "Only foldable Expression is allowed") + checkError( + exception = intercept[AnalysisException] { + assertSuccess(BRound($"intField", $"intField")) + }, + errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "sqlExpr" -> "\"bround(intField, intField)\"", + "inputName" -> "scala", + "inputType" -> "\"INT\"", + "inputExpr" -> "\"intField\"")) checkError( exception = intercept[AnalysisException] { assertSuccess(BRound($"intField", $"booleanField")) @@ -602,4 +642,15 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer assert(Literal.create(Map(42L -> null), MapType(LongType, NullType)).sql == "MAP(42L, NULL)") } + + test("hash expressions are prohibited on MapType elements") { + val argument = Literal.create(Map(42L -> true), MapType(LongType, BooleanType)) + val murmur3Hash = new Murmur3Hash(Seq(argument)) + assert(murmur3Hash.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "HASH_MAP_TYPE", + messageParameters = Map("functionName" -> toSQLId(murmur3Hash.prettyName)) + ) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7dea7799b666d..c52cb85e119d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -4219,16 +4219,68 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val funcsMustHaveAtLeastOneArg = ("coalesce", (df: DataFrame) => df.select(coalesce())) :: - ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: - ("hash", (df: DataFrame) => df.select(hash())) :: - ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: - ("xxhash64", (df: DataFrame) => df.select(xxhash64())) :: - ("xxhash64", (df: DataFrame) => df.selectExpr("xxhash64()")) :: Nil + ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: Nil funcsMustHaveAtLeastOneArg.foreach { case (name, func) => val errMsg = intercept[AnalysisException] { func(df) }.getMessage assert(errMsg.contains(s"input to function $name requires at least one argument")) } + checkError( + exception = intercept[AnalysisException] { + df.select(hash()) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"hash()\"", + "functionName" -> "`hash`", + "expectedNum" -> "> 0", + "actualNum" -> "0")) + + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("hash()") + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"hash()\"", + "functionName" -> "`hash`", + "expectedNum" -> "> 0", + "actualNum" -> "0"), + context = ExpectedContext( + fragment = "hash()", + start = 0, + stop = 5)) + + checkError( + exception = intercept[AnalysisException] { + df.select(xxhash64()) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"xxhash64()\"", + "functionName" -> "`xxhash64`", + "expectedNum" -> "> 0", + "actualNum" -> "0")) + + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("xxhash64()") + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"xxhash64()\"", + "functionName" -> "`xxhash64`", + "expectedNum" -> "> 0", + "actualNum" -> "0"), + context = ExpectedContext( + fragment = "xxhash64()", + start = 0, + stop = 9)) + checkError( exception = intercept[AnalysisException] { df.select(greatest()) From 60b1056307b3ee9d880a936f3a97c5fb16a2b698 Mon Sep 17 00:00:00 2001 From: Mridul Date: Mon, 24 Oct 2022 10:51:45 -0700 Subject: [PATCH 22/22] [SPARK-40902][MESOS][TESTS] Fix issue with mesos tests failing due to quick submission of drivers ### What changes were proposed in this pull request? ##### Quick submission of drivers in tests to mesos scheduler results in dropping drivers Queued drivers in `MesosClusterScheduler` are ordered based on `MesosDriverDescription` - and the ordering used checks for priority (if different), followed by comparison of submission time. For two driver submissions with same priority, if made in quick succession (such that submission time is same due to millisecond granularity of Date), this results in dropping the second `MesosDriverDescription` from `queuedDrivers` (since `driverOrdering` returns `0` when comparing the descriptions). This PR fixes the more immediate issue with tests. ### Why are the changes needed? Flakey tests, [see here](https://lists.apache.org/thread/jof098qxp0s6qqmt9qwv52f9665b1pjg) for an example. ### Does this PR introduce _any_ user-facing change? No. Fixing only tests for now - as mesos support is deprecated, not changing scheduler itself to address this. ### How was this patch tested? Fixes unit tests Closes #38378 from mridulm/fix_MesosClusterSchedulerSuite. Authored-by: Mridul Signed-off-by: Dongjoon Hyun --- .../mesos/MesosClusterSchedulerSuite.scala | 66 +++++++++++-------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index 9a1862d32dc13..102dd4b76d237 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster.mesos import java.util.{Collection, Collections, Date} +import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConverters._ @@ -40,6 +41,19 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi private var driver: SchedulerDriver = _ private var scheduler: MesosClusterScheduler = _ + private val submissionTime = new AtomicLong(System.currentTimeMillis()) + + // Queued drivers in MesosClusterScheduler are ordered based on MesosDriverDescription + // The default ordering checks for priority, followed by submission time. For two driver + // submissions with same priority and if made in quick succession (such that submission + // time is same due to millisecond granularity), this results in dropping the + // second MesosDriverDescription from the queuedDrivers - as driverOrdering + // returns 0 when comparing the descriptions. Ensure two seperate submissions + // have differnt dates + private def getDate: Date = { + new Date(submissionTime.incrementAndGet()) + } + private def setScheduler(sparkConfVars: Map[String, String] = null): Unit = { val conf = new SparkConf() conf.setMaster("mesos://localhost:5050") @@ -68,7 +82,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi command, Map[String, String](), submissionId, - new Date()) + getDate) } test("can queue drivers") { @@ -108,7 +122,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi Map((config.EXECUTOR_HOME.key, "test"), ("spark.app.name", "test"), (config.DRIVER_MEMORY_OVERHEAD.key, "0")), "s1", - new Date())) + getDate)) assert(response.success) val offer = Offer.newBuilder() .addResources( @@ -213,7 +227,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi Map("spark.mesos.executor.home" -> "test", "spark.app.name" -> "test"), "s1", - new Date())) + getDate)) assert(response.success) val offer = Utils.createOffer("o1", "s1", mem*2, cpu) @@ -240,7 +254,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi Map("spark.mesos.executor.home" -> "test", "spark.app.name" -> "test"), "s1", - new Date())) + getDate)) assert(response.success) val offer = Utils.createOffer("o1", "s1", mem*2, cpu) @@ -270,7 +284,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi config.DRIVER_MEMORY_OVERHEAD.key -> "0" ), "s1", - new Date())) + getDate)) assert(response.success) val offer = Utils.createOffer("o1", "s1", mem, cpu) @@ -296,7 +310,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi config.NETWORK_LABELS.key -> "key1:val1,key2:val2", config.DRIVER_MEMORY_OVERHEAD.key -> "0"), "s1", - new Date())) + getDate)) assert(response.success) @@ -327,7 +341,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi "spark.app.name" -> "test", config.DRIVER_MEMORY_OVERHEAD.key -> "0"), "s1", - new Date())) + getDate)) assert(response.success) @@ -352,7 +366,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi "spark.app.name" -> "test", config.DRIVER_MEMORY_OVERHEAD.key -> "0"), "s1", - new Date())) + getDate)) assert(response.success) @@ -378,7 +392,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi "spark.app.name" -> "test", config.DRIVER_MEMORY_OVERHEAD.key -> "0"), "s1", - new Date())) + getDate)) assert(response.success) @@ -413,7 +427,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi config.DRIVER_CONSTRAINTS.key -> driverConstraints, config.DRIVER_MEMORY_OVERHEAD.key -> "0"), "s1", - new Date())) + getDate)) assert(response.success) } @@ -452,7 +466,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi config.DRIVER_LABELS.key -> "key:value", config.DRIVER_MEMORY_OVERHEAD.key -> "0"), "s1", - new Date())) + getDate)) assert(response.success) @@ -474,7 +488,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi val response = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map((config.EXECUTOR_HOME.key, "test"), ("spark.app.name", "test")), "s1", new Date())) + Map((config.EXECUTOR_HOME.key, "test"), ("spark.app.name", "test")), "s1", getDate)) assert(response.success) val agentId = SlaveID.newBuilder().setValue("s1").build() val offer = Offer.newBuilder() @@ -533,7 +547,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi val response = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map(("spark.mesos.executor.home", "test"), ("spark.app.name", "test")), "sub1", new Date())) + Map(("spark.mesos.executor.home", "test"), ("spark.app.name", "test")), "sub1", getDate)) assert(response.success) // Offer a resource to launch the submitted driver @@ -651,7 +665,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi config.EXECUTOR_URI.key -> "s3a://bucket/spark-version.tgz", "another.conf" -> "\\value"), "s1", - new Date()) + getDate) val expectedCmd = "cd spark-version*; " + "bin/spark-submit --name \"app name\" --master mesos://mesos://localhost:5050 " + @@ -691,7 +705,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi command, Map("spark.mesos.dispatcher.queue" -> "dummy"), "s1", - new Date()) + getDate) assertThrows[NoSuchElementException] { scheduler.getDriverPriority(desc) @@ -702,7 +716,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi command, Map[String, String](), "s2", - new Date()) + getDate) assert(scheduler.getDriverPriority(desc) == 0.0f) @@ -711,7 +725,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi command, Map("spark.mesos.dispatcher.queue" -> "default"), "s3", - new Date()) + getDate) assert(scheduler.getDriverPriority(desc) == 0.0f) @@ -720,7 +734,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi command, Map("spark.mesos.dispatcher.queue" -> "ROUTINE"), "s4", - new Date()) + getDate) assert(scheduler.getDriverPriority(desc) == 1.0f) @@ -729,7 +743,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi command, Map("spark.mesos.dispatcher.queue" -> "URGENT"), "s5", - new Date()) + getDate) assert(scheduler.getDriverPriority(desc) == 2.0f) } @@ -746,22 +760,22 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi val response0 = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map("spark.mesos.dispatcher.queue" -> "ROUTINE"), "s0", new Date())) + Map("spark.mesos.dispatcher.queue" -> "ROUTINE"), "s0", getDate)) assert(response0.success) val response1 = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map[String, String](), "s1", new Date())) + Map[String, String](), "s1", getDate)) assert(response1.success) val response2 = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map("spark.mesos.dispatcher.queue" -> "EXCEPTIONAL"), "s2", new Date())) + Map("spark.mesos.dispatcher.queue" -> "EXCEPTIONAL"), "s2", getDate)) assert(response2.success) val response3 = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map("spark.mesos.dispatcher.queue" -> "URGENT"), "s3", new Date())) + Map("spark.mesos.dispatcher.queue" -> "URGENT"), "s3", getDate)) assert(response3.success) val state = scheduler.getSchedulerState() @@ -782,12 +796,12 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi val response0 = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map("spark.mesos.dispatcher.queue" -> "LOWER"), "s0", new Date())) + Map("spark.mesos.dispatcher.queue" -> "LOWER"), "s0", getDate)) assert(response0.success) val response1 = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map[String, String](), "s1", new Date())) + Map[String, String](), "s1", getDate)) assert(response1.success) val state = scheduler.getSchedulerState() @@ -812,7 +826,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi config.DRIVER_MEMORY_OVERHEAD.key -> "0") ++ addlSparkConfVars, "s1", - new Date()) + getDate) val response = scheduler.submitDriver(driverDesc) assert(response.success) val offer = Utils.createOffer("o1", "s1", mem, cpu)