diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 617eb173f4f49..259d30a259196 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -458,7 +458,10 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.FetchFailed.this"), // [SPARK-28957][SQL] Copy any "spark.hive.foo=bar" spark properties into hadoop conf as "hive.foo=bar" - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.appendS3AndSparkHadoopConfigurations") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.appendS3AndSparkHadoopConfigurations"), + + // [SPARK-29348] Add observable metrics. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryProgress.this") ) // Exclude rules for 2.4.x diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 625ef2153c711..83fa405e521ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2432,6 +2432,10 @@ class Analyzer( nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) }.copy(child = newChild) + // Don't touch collect metrics. Top-level metrics are not supported (check analysis will fail) + // and we want to retain them inside the aggregate functions. + case m: CollectMetrics => m + // todo: It's hard to write a general rule to pull out nondeterministic expressions // from LogicalPlan, currently we only do it for UnaryNode which has same output // schema with its child. @@ -2932,6 +2936,12 @@ object CleanupAliases extends Rule[LogicalPlan] { Window(cleanedWindowExprs, partitionSpec.map(trimAliases), orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) + case CollectMetrics(name, metrics, child) => + val cleanedMetrics = metrics.map { + e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression] + } + CollectMetrics(name, cleanedMetrics, child) + // Operators that operate on objects should only have expressions from encoders, which should // never have extra aliases. case o: ObjectConsumer => o diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 72612d1dc76c9..cfb16233b3940 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -14,9 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.sql.catalyst.analysis +import scala.collection.mutable + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ @@ -280,6 +281,41 @@ trait CheckAnalysis extends PredicateHelper { groupingExprs.foreach(checkValidGroupingExprs) aggregateExprs.foreach(checkValidAggregateExpression) + case CollectMetrics(name, metrics, _) => + if (name == null || name.isEmpty) { + operator.failAnalysis(s"observed metrics should be named: $operator") + } + // Check if an expression is a valid metric. A metric must meet the following criteria: + // - Is not a window function; + // - Is not nested aggregate function; + // - Is not a distinct aggregate function; + // - Has only non-deterministic functions that are nested inside an aggregate function; + // - Has only attributes that are nested inside an aggregate function. + def checkMetric(s: Expression, e: Expression, seenAggregate: Boolean = false): Unit = { + e match { + case _: WindowExpression => + e.failAnalysis( + "window expressions are not allowed in observed metrics, but found: " + s.sql) + case _ if !e.deterministic && !seenAggregate => + e.failAnalysis(s"non-deterministic expression ${s.sql} can only be used " + + "as an argument to an aggregate function.") + case a: AggregateExpression if seenAggregate => + e.failAnalysis( + "nested aggregates are not allowed in observed metrics, but found: " + s.sql) + case a: AggregateExpression if a.isDistinct => + e.failAnalysis( + "distinct aggregates are not allowed in observed metrics, but found: " + s.sql) + case _: Attribute if !seenAggregate => + e.failAnalysis (s"attribute ${s.sql} can only be used as an argument to an " + + "aggregate function.") + case _: AggregateExpression => + e.children.foreach(checkMetric (s, _, seenAggregate = true)) + case _ => + e.children.foreach(checkMetric (s, _, seenAggregate)) + } + } + metrics.foreach(m => checkMetric(m, m)) + case Sort(orders, _, _) => orders.foreach { order => if (!RowOrdering.isOrderable(order.dataType)) { @@ -534,6 +570,7 @@ trait CheckAnalysis extends PredicateHelper { case _ => // Analysis successful! } } + checkCollectedMetrics(plan) extendedCheckRules.foreach(_(plan)) plan.foreachUp { case o if !o.resolved => @@ -627,6 +664,38 @@ trait CheckAnalysis extends PredicateHelper { checkCorrelationsInSubquery(expr.plan) } + /** + * Validate that collected metrics names are unique. The same name cannot be used for metrics + * with different results. However multiple instances of metrics with with same result and name + * are allowed (e.g. self-joins). + */ + private def checkCollectedMetrics(plan: LogicalPlan): Unit = { + val metricsMap = mutable.Map.empty[String, LogicalPlan] + def check(plan: LogicalPlan): Unit = plan.foreach { node => + node match { + case metrics @ CollectMetrics(name, _, _) => + metricsMap.get(name) match { + case Some(other) => + // Exact duplicates are allowed. They can be the result + // of a CTE that is used multiple times or a self join. + if (!metrics.sameResult(other)) { + failAnalysis( + s"Multiple definitions of observed metrics named '$name': $plan") + } + case None => + metricsMap.put(name, metrics) + } + case _ => + } + node.expressions.foreach(_.foreach { + case subquery: SubqueryExpression => + check(subquery.plan) + case _ => + }) + } + check(plan) + } + /** * Validates to make sure the outer references appearing inside the subquery * are allowed. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala index 4a28d879d1145..63348f766a5b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala @@ -43,7 +43,9 @@ object PlanHelper { case e: WindowExpression if !plan.isInstanceOf[Window] => e case e: AggregateExpression - if !(plan.isInstanceOf[Aggregate] || plan.isInstanceOf[Window]) => e + if !(plan.isInstanceOf[Aggregate] || + plan.isInstanceOf[Window] || + plan.isInstanceOf[CollectMetrics]) => e case e: Generator if !plan.isInstanceOf[Generate] => e } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 93b314d4e54a5..67438a47e8daa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -971,3 +971,25 @@ case class Deduplicate( * This is used to whitelist such commands in the subquery-related checks. */ trait SupportsSubquery extends LogicalPlan + +/** + * Collect arbitrary (named) metrics from a dataset. As soon as the query reaches a completion + * point (batch query completes or streaming query epoch completes) an event is emitted on the + * driver which can be observed by attaching a listener to the spark session. The metrics are named + * so we can collect metrics at multiple places in a single dataset. + * + * This node behaves like a global aggregate. All the metrics collected must be aggregate functions + * or be literals. + */ +case class CollectMetrics( + name: String, + metrics: Seq[NamedExpression], + child: LogicalPlan) + extends UnaryNode { + + override lazy val resolved: Boolean = { + name.nonEmpty && metrics.nonEmpty && metrics.forall(_.resolved) && childrenResolved + } + + override def output: Seq[Attribute] = child.output +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e1b8192fba213..426ec6efb6ad6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -160,7 +160,12 @@ object SQLConf { confGetter.get()() } } else { - confGetter.get()() + val conf = existingConf.get() + if (conf != null) { + conf + } else { + confGetter.get()() + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index ad1d6b62ef3a1..de062f7efd8d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -21,9 +21,14 @@ import java.util.Locale import scala.util.control.NonFatal +import com.fasterxml.jackson.core.{JsonGenerator, JsonParser} +import com.fasterxml.jackson.databind.{DeserializationContext, JsonDeserializer, JsonSerializer, SerializerProvider} +import com.fasterxml.jackson.databind.`type`.TypeFactory +import com.fasterxml.jackson.databind.annotation.{JsonDeserialize, JsonSerialize} import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ +import org.json4s.jackson.{JValueDeserializer, JValueSerializer} import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.Stable @@ -40,7 +45,10 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ + @Stable +@JsonSerialize(using = classOf[DataTypeJsonSerializer]) +@JsonDeserialize(using = classOf[DataTypeJsonDeserializer]) abstract class DataType extends AbstractDataType { /** * Enables matching against DataType for expressions: @@ -475,3 +483,30 @@ object DataType { } } } + +/** + * Jackson serializer for [[DataType]]. Internally this delegates to json4s based serialization. + */ +class DataTypeJsonSerializer extends JsonSerializer[DataType] { + private val delegate = new JValueSerializer + override def serialize( + value: DataType, + gen: JsonGenerator, + provider: SerializerProvider): Unit = { + delegate.serialize(value.jsonValue, gen, provider) + } +} + +/** + * Jackson deserializer for [[DataType]]. Internally this delegates to json4s based deserialization. + */ +class DataTypeJsonDeserializer extends JsonDeserializer[DataType] { + private val delegate = new JValueDeserializer(classOf[Any]) + + override def deserialize( + jsonParser: JsonParser, + deserializationContext: DeserializationContext): DataType = { + val json = delegate.deserialize(jsonParser, deserializationContext) + DataType.parseDataType(json.asInstanceOf[JValue]) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 74445a111e4d7..ae474cac9f748 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -29,11 +29,11 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Sum} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, - RangePartitioning, RoundRobinPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf @@ -650,4 +650,87 @@ class AnalysisSuite extends AnalysisTest with Matchers { assertAnalysisError(parsePlan("INSERT INTO test VALUES (1)"), Seq("Table not found: test")) } + + test("check CollectMetrics resolved") { + val a = testRelation.output.head + val sum = Sum(a).toAggregateExpression().as("sum") + val random_sum = Sum(Rand(1L)).toAggregateExpression().as("rand_sum") + val literal = Literal(1).as("lit") + + // Ok + assert(CollectMetrics("event", literal :: sum :: random_sum :: Nil, testRelation).resolved) + + // Bad name + assert(!CollectMetrics("", sum :: Nil, testRelation).resolved) + assertAnalysisError(CollectMetrics("", sum :: Nil, testRelation), + "observed metrics should be named" :: Nil) + + // No columns + assert(!CollectMetrics("evt", Nil, testRelation).resolved) + + def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit = { + assertAnalysisError(CollectMetrics("event", exprs, testRelation), errors) + } + + // Unwrapped attribute + checkAnalysisError( + a :: Nil, + "Attribute", "can only be used as an argument to an aggregate function") + + // Unwrapped non-deterministic expression + checkAnalysisError( + Rand(10).as("rnd") :: Nil, + "non-deterministic expression", "can only be used as an argument to an aggregate function") + + // Distinct aggregate + checkAnalysisError( + Sum(a).toAggregateExpression(isDistinct = true).as("sum") :: Nil, + "distinct aggregates are not allowed in observed metrics, but found") + + // Nested aggregate + checkAnalysisError( + Sum(Sum(a).toAggregateExpression()).toAggregateExpression().as("sum") :: Nil, + "nested aggregates are not allowed in observed metrics, but found") + + // Windowed aggregate + val windowExpr = WindowExpression( + RowNumber(), + WindowSpecDefinition(Nil, a.asc :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))) + checkAnalysisError( + windowExpr.as("rn") :: Nil, + "window expressions are not allowed in observed metrics, but found") + } + + test("check CollectMetrics duplicates") { + val a = testRelation.output.head + val sum = Sum(a).toAggregateExpression().as("sum") + val count = Count(Literal(1)).toAggregateExpression().as("cnt") + + // Same result - duplicate names are allowed + assertAnalysisSuccess(Union( + CollectMetrics("evt1", count :: Nil, testRelation) :: + CollectMetrics("evt1", count :: Nil, testRelation) :: Nil)) + + // Same children, structurally different metrics - fail + assertAnalysisError(Union( + CollectMetrics("evt1", count :: Nil, testRelation) :: + CollectMetrics("evt1", sum :: Nil, testRelation) :: Nil), + "Multiple definitions of observed metrics" :: "evt1" :: Nil) + + // Different children, same metrics - fail + val b = 'b.string + val tblB = LocalRelation(b) + assertAnalysisError(Union( + CollectMetrics("evt1", count :: Nil, testRelation) :: + CollectMetrics("evt1", count :: Nil, tblB) :: Nil), + "Multiple definitions of observed metrics" :: "evt1" :: Nil) + + // Subquery different tree - fail + val subquery = Aggregate(Nil, sum :: Nil, CollectMetrics("evt1", count :: Nil, testRelation)) + val query = Project( + b :: ScalarSubquery(subquery, Nil).as("sum") :: Nil, + CollectMetrics("evt1", count :: Nil, tblB)) + assertAnalysisError(query, "Multiple definitions of observed metrics" :: "evt1" :: Nil) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1b75fccbdb7b2..635f5f87b5a2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1848,6 +1848,54 @@ class Dataset[T] private[sql]( @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*) + /** + * Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset + * that returns the same result as the input, with the following guarantees: + * - It will compute the defined aggregates (metrics) on all the data that is flowing through the + * Dataset at that point. + * - It will report the value of the defined aggregate columns as soon as we reach a completion + * point. A completion point is either the end of a query (batch mode) or the end of a streaming + * epoch. The value of the aggregates only reflects the data processed since the previous + * completion point. + * Please note that continuous execution is currently not supported. + * + * The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or + * more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that + * contain references to the input Dataset's columns must always be wrapped in an aggregate + * function. + * + * A user can observe these metrics by either adding + * [[org.apache.spark.sql.streaming.StreamingQueryListener]] or a + * [[org.apache.spark.sql.util.QueryExecutionListener]] to the spark session. + * + * {{{ + * // Observe row count (rc) and error row count (erc) in the streaming Dataset + * val observed_ds = ds.observe("my_event", count(lit(1)).as("rc"), count($"error").as("erc")) + * observed_ds.writeStream.format("...").start() + * + * // Monitor the metrics using a listener. + * spark.streams.addListener(new StreamingQueryListener() { + * override def onQueryProgress(event: QueryProgressEvent): Unit = { + * event.progress.observedMetrics.get("my_event").foreach { row => + * // Trigger if the number of errors exceeds 5 percent + * val num_rows = row.getAs[Long]("rc") + * val num_error_rows = row.getAs[Long]("erc") + * val ratio = num_error_rows.toDouble / num_rows + * if (ratio > 0.05) { + * // Trigger alert + * } + * } + * } + * }) + * }}} + * + * @group typedrel + * @since 3.0.0 + */ + def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withTypedPlan { + CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan) + } + /** * Returns a new Dataset by taking the first `n` rows. The difference between this function * and `head` is that `head` is an action and returns an array (by triggering query execution) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala index fa5ba1a691cd6..9aab5b390fe13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala @@ -43,7 +43,8 @@ class AggregatingAccumulator private( assert(bufferSchema.size == updateExpressions.size) assert(mergeExpressions == null || bufferSchema.size == mergeExpressions.size) - private[this] var joinedRow: JoinedRow = _ + @transient + private var joinedRow: JoinedRow = _ private var buffer: SpecificInternalRow = _ @@ -184,7 +185,6 @@ class AggregatingAccumulator private( resultProjection(input) } - /** * Get the output schema of the aggregating accumulator. */ @@ -194,6 +194,17 @@ class AggregatingAccumulator private( case (e, i) => StructField(s"c_$i", e.dataType, e.nullable) }) } + + /** + * Set the state of the accumulator to the state of another accumulator. This is used in cases + * where we only want to publish the state of the accumulator when the task completes, see + * [[CollectMetricsExec]] for an example. + */ + private[execution] def setState(other: AggregatingAccumulator): Unit = { + assert(buffer == null || (buffer eq other.buffer)) + buffer = other.buffer + joinedRow = other.joinedRow + } } object AggregatingAccumulator { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala new file mode 100644 index 0000000000000..e482bc9941ea9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import scala.collection.mutable + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.types.StructType + +/** + * Collect arbitrary (named) metrics from a [[SparkPlan]]. + */ +case class CollectMetricsExec( + name: String, + metricExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryExecNode { + + private lazy val accumulator: AggregatingAccumulator = { + val acc = AggregatingAccumulator(metricExpressions, child.output) + acc.register(sparkContext, Option("Collected metrics")) + acc + } + + val metricsSchema: StructType = { + StructType.fromAttributes(metricExpressions.map(_.toAttribute)) + } + + // This is not used very frequently (once a query); it is not useful to use code generation here. + private lazy val toRowConverter: InternalRow => Row = { + CatalystTypeConverters.createToScalaConverter(metricsSchema) + .asInstanceOf[InternalRow => Row] + } + + def collectedMetrics: Row = toRowConverter(accumulator.value) + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override protected def doExecute(): RDD[InternalRow] = { + val collector = accumulator + collector.reset() + child.execute().mapPartitions { rows => + // Only publish the value of the accumulator when the task has completed. This is done by + // updating a task local accumulator ('updater') which will be merged with the actual + // accumulator as soon as the task completes. This avoids the following problems during the + // heartbeat: + // - Correctness issues due to partially completed/visible updates. + // - Performance issues due to excessive serialization. + val updater = collector.copyAndReset() + TaskContext.get().addTaskCompletionListener[Unit] { _ => + collector.setState(updater) + } + + rows.map { r => + updater.add(r) + r + } + } + } +} + +object CollectMetricsExec { + /** + * Recursively collect all collected metrics from a query tree. + */ + def collect(plan: SparkPlan): Map[String, Row] = { + val metrics = plan.collectInPlanAndSubqueries { + case collector: CollectMetricsExec => collector.name -> collector.collectedMetrics + } + metrics.toMap + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 190c74297e9f2..28bbe4fb4993d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -22,7 +22,7 @@ import java.io.{BufferedWriter, OutputStreamWriter} import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats @@ -106,6 +106,9 @@ class QueryExecution( lazy val toRdd: RDD[InternalRow] = new SQLExecutionRDD( executedPlan.execute(), sparkSession.sessionState.conf) + /** Get the metrics observed during the execution of the query plan. */ + def observedMetrics: Map[String, Row] = CollectMetricsExec.collect(executedPlan) + protected def preparations: Seq[Rule[SparkPlan]] = { QueryExecution.preparations(sparkSession) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 76af81abdb0d4..8eb0e2262e670 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -750,6 +750,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new UnsupportedOperationException(s"UPDATE TABLE is not supported temporarily.") case _: MergeIntoTable => throw new UnsupportedOperationException(s"MERGE INTO TABLE is not supported temporarily.") + case logical.CollectMetrics(name, metrics, child) => + execution.CollectMetricsExec(name, metrics, planLater(child)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 4f42992126c49..71bcd53435850 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -173,6 +173,7 @@ trait ProgressReporter extends Logging { val sinkProgress = SinkProgress( sink.toString, sinkCommitProgress.map(_.numOutputRows)) + val observedMetrics = extractObservedMetrics(hasNewData, lastExecution) val newProgress = new StreamingQueryProgress( id = id, @@ -184,7 +185,8 @@ trait ProgressReporter extends Logging { eventTime = new java.util.HashMap(executionStats.eventTimeStats.asJava), stateOperators = executionStats.stateOperators.toArray, sources = sourceProgress.toArray, - sink = sinkProgress) + sink = sinkProgress, + observedMetrics = new java.util.HashMap(observedMetrics.asJava)) if (hasNewData) { // Reset noDataEventTimestamp if we processed any data @@ -323,6 +325,16 @@ trait ProgressReporter extends Logging { } } + /** Extracts observed metrics from the most recent query execution. */ + private def extractObservedMetrics( + hasNewData: Boolean, + lastExecution: QueryExecution): Map[String, Row] = { + if (!hasNewData || lastExecution == null) { + return Map.empty + } + lastExecution.observedMetrics + } + /** Records the duration of running `body` for the next query progress update. */ protected def reportTimeTaken[T](triggerDetailKey: String)(body: => T): T = { val startTime = triggerClock.getTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index e2fea8c9dd6ab..a9681dbd0c676 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -24,12 +24,15 @@ import java.util.UUID import scala.collection.JavaConverters._ import scala.util.control.NonFatal +import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.Evolving +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.streaming.SinkProgress.DEFAULT_NUM_OUTPUT_ROWS /** @@ -106,7 +109,9 @@ class StreamingQueryProgress private[sql]( val eventTime: ju.Map[String, String], val stateOperators: Array[StateOperatorProgress], val sources: Array[SourceProgress], - val sink: SinkProgress) extends Serializable { + val sink: SinkProgress, + @JsonDeserialize(contentAs = classOf[GenericRowWithSchema]) + val observedMetrics: ju.Map[String, Row]) extends Serializable { /** The aggregate (across all sources) number of records processed in a trigger. */ def numInputRows: Long = sources.map(_.numInputRows).sum @@ -149,7 +154,8 @@ class StreamingQueryProgress private[sql]( ("eventTime" -> safeMapToJValue[String](eventTime, s => JString(s))) ~ ("stateOperators" -> JArray(stateOperators.map(_.jsonValue).toList)) ~ ("sources" -> JArray(sources.map(_.jsonValue).toList)) ~ - ("sink" -> sink.jsonValue) + ("sink" -> sink.jsonValue) ~ + ("observedMetrics" -> safeMapToJValue[Row](observedMetrics, row => row.jsonValue)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala index aaec6a9761d63..a33b9fad7ff4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala @@ -73,9 +73,12 @@ class AggregatingAccumulatorSuite // Idempotency of result checkResult(acc1.value, InternalRow(73L, str("baz"), 3L), expectedSchema, false) - // A few updates to the copied accumulator - acc2.add(InternalRow(-2L, str("qwerty"), -6773.9d)) - acc2.add(InternalRow(-35L, str("zzz-top"), -323.9d)) + // A few updates to the copied accumulator using an updater + val updater = acc2.copyAndReset() + updater.add(InternalRow(-2L, str("qwerty"), -6773.9d)) + updater.add(InternalRow(-35L, str("zzz-top"), -323.9d)) + assert(acc2.isZero) + acc2.setState(updater) checkResult(acc2.value, InternalRow(-36L, str("zzz-top"), 2L), expectedSchema, false) // Merge accumulators diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 4c58cb85c4d36..2f66dd3255b11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.concurrent.Waiters.Waiter import org.apache.spark.SparkException import org.apache.spark.scheduler._ -import org.apache.spark.sql.{Encoder, SparkSession} +import org.apache.spark.sql.{Encoder, Row, SparkSession} import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf @@ -404,6 +404,63 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { testReplayListenerBusWithBorkenEventJsons("query-event-logs-version-2.0.2.txt") } + test("listener propagates observable metrics") { + import org.apache.spark.sql.functions._ + val clock = new StreamManualClock + val inputData = new MemoryStream[Int](0, sqlContext) + val df = inputData.toDF() + .observe( + name = "my_event", + min($"value").as("min_val"), + max($"value").as("max_val"), + sum($"value").as("sum_val"), + count(when($"value" % 2 === 0, 1)).as("num_even")) + .observe( + name = "other_event", + avg($"value").cast("int").as("avg_val")) + val listener = new EventCollector + def checkMetrics(f: java.util.Map[String, Row] => Unit): StreamAction = { + AssertOnQuery { _ => + eventually(Timeout(streamingTimeout)) { + assert(listener.allProgressEvents.nonEmpty) + f(listener.allProgressEvents.last.observedMetrics) + true + } + } + } + + try { + spark.streams.addListener(listener) + testStream(df, OutputMode.Append)( + StartStream(Trigger.ProcessingTime(100), triggerClock = clock), + // Batch 1 + AddData(inputData, 1, 2), + AdvanceManualClock(100), + checkMetrics { metrics => + assert(metrics.get("my_event") === Row(1, 2, 3L, 1L)) + assert(metrics.get("other_event") === Row(1)) + }, + + // Batch 2 + AddData(inputData, 10, 30, -10, 5), + AdvanceManualClock(100), + checkMetrics { metrics => + assert(metrics.get("my_event") === Row(-10, 30, 35L, 3L)) + assert(metrics.get("other_event") === Row(8)) + }, + + // Batch 3 - no data + AdvanceManualClock(100), + checkMetrics { metrics => + assert(metrics.isEmpty) + }, + StopStream + ) + } finally { + spark.streams.removeListener(listener) + } + } + private def testReplayListenerBusWithBorkenEventJsons(fileName: String): Unit = { val input = getClass.getResourceAsStream(s"/structured-streaming/$fileName") val events = mutable.ArrayBuffer[SparkListenerEvent]() @@ -454,6 +511,10 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { _progressEvents.filter(_.numInputRows > 0) } + def allProgressEvents: Seq[StreamingQueryProgress] = _progressEvents.synchronized { + _progressEvents.clone() + } + def reset(): Unit = { startEvent = null terminationEvent = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index da291f490b76c..b6a6be2bb0312 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -27,12 +27,15 @@ import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite._ import org.apache.spark.sql.streaming.StreamingQuerySuite.clock import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types.StructType class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { test("StreamingQueryProgress - prettyJson") { @@ -77,6 +80,17 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | "sink" : { | "description" : "sink", | "numOutputRows" : -1 + | }, + | "observedMetrics" : { + | "event1" : { + | "c1" : 1, + | "c2" : 3.0 + | }, + | "event2" : { + | "rc" : 1, + | "min_q" : "hello", + | "max_q" : "world" + | } | } |} """.stripMargin.trim) @@ -110,6 +124,22 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | "sink" : { | "description" : "sink", | "numOutputRows" : -1 + | }, + | "observedMetrics" : { + | "event_a" : { + | "c1" : null, + | "c2" : -20.7 + | }, + | "event_b1" : { + | "rc" : 33, + | "min_q" : "foo", + | "max_q" : "bar" + | }, + | "event_b2" : { + | "rc" : 200, + | "min_q" : "fzo", + | "max_q" : "baz" + | } | } |} """.stripMargin.trim) @@ -265,6 +295,17 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { } object StreamingQueryStatusAndProgressSuite { + private val schema1 = new StructType() + .add("c1", "long") + .add("c2", "double") + private val schema2 = new StructType() + .add("rc", "long") + .add("min_q", "string") + .add("max_q", "string") + private def row(schema: StructType, elements: Any*): Row = { + new GenericRowWithSchema(elements.toArray, schema) + } + val testProgress1 = new StreamingQueryProgress( id = UUID.randomUUID, runId = UUID.randomUUID, @@ -293,7 +334,10 @@ object StreamingQueryStatusAndProgressSuite { processedRowsPerSecond = Double.PositiveInfinity // should not be present in the json ) ), - sink = SinkProgress("sink", None) + sink = SinkProgress("sink", None), + observedMetrics = new java.util.HashMap(Map( + "event1" -> row(schema1, 1L, 3.0d), + "event2" -> row(schema2, 1L, "hello", "world")).asJava) ) val testProgress2 = new StreamingQueryProgress( @@ -317,7 +361,11 @@ object StreamingQueryStatusAndProgressSuite { processedRowsPerSecond = Double.NegativeInfinity // should not be present in the json ) ), - sink = SinkProgress("sink", None) + sink = SinkProgress("sink", None), + observedMetrics = new java.util.HashMap(Map( + "event_a" -> row(schema1, null, -20.7d), + "event_b1" -> row(schema2, 33L, "foo", "bar"), + "event_b2" -> row(schema2, 200L, "fzo", "baz")).asJava) ) val testStatus = new StreamingQueryStatus("active", true, false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 083b40d0680aa..f4ab232af28b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.util import scala.collection.mutable.ArrayBuffer import org.apache.spark._ -import org.apache.spark.sql.{functions, AnalysisException, QueryTest} +import org.apache.spark.sql.{functions, AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec} @@ -226,4 +226,52 @@ class DataFrameCallbackSuite extends QueryTest with SharedSparkSession { assert(errors.head._2 == e) } } + + test("get observable metrics by callback") { + val metricMaps = ArrayBuffer.empty[Map[String, Row]] + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + metricMaps += qe.observedMetrics + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Throwable): Unit = { + // No-op + } + } + spark.listenerManager.register(listener) + try { + val df = spark.range(100) + .observe( + name = "my_event", + min($"id").as("min_val"), + max($"id").as("max_val"), + sum($"id").as("sum_val"), + count(when($"id" % 2 === 0, 1)).as("num_even")) + .observe( + name = "other_event", + avg($"id").cast("int").as("avg_val")) + + def checkMetrics(metrics: Map[String, Row]): Unit = { + assert(metrics.size === 2) + assert(metrics("my_event") === Row(0L, 99L, 4950L, 50L)) + assert(metrics("other_event") === Row(49)) + } + + // First run + df.collect() + sparkContext.listenerBus.waitUntilEmpty() + assert(metricMaps.size === 1) + checkMetrics(metricMaps.head) + metricMaps.clear() + + // Second run should produce the same result as the first run. + df.collect() + sparkContext.listenerBus.waitUntilEmpty() + assert(metricMaps.size === 1) + checkMetrics(metricMaps.head) + + } finally { + spark.listenerManager.unregister(listener) + } + } }