diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index 936ab866f5bf..8795374b2a72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -61,14 +61,12 @@ class ApplyInPandasWithStatePythonRunner( keySchema: StructType, outputSchema: StructType, stateValueSchema: StructType, - pyMetrics: Map[String, SQLMetric], + override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets, jobArtifactUUID) with PythonArrowInput[InType] with PythonArrowOutput[OutType] { - override val pythonMetrics: Option[Map[String, SQLMetric]] = Some(pyMetrics) - override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( funcs.head.funcs.head.pythonExec) @@ -151,7 +149,7 @@ class ApplyInPandasWithStatePythonRunner( pandasWriter.finalizeGroup() val deltaData = dataOut.size() - startData - pythonMetrics.foreach(_("pythonDataSent") += deltaData) + pythonMetrics("pythonDataSent") += deltaData true } else { pandasWriter.finalizeData() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala index 2503deae7d5a..9e210bf5241b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala @@ -70,7 +70,7 @@ case class ArrowEvalPythonUDTFExec( sessionLocalTimeZone, largeVarTypes, pythonRunnerConf, - Some(pythonMetrics), + pythonMetrics, jobArtifactUUID).compute(batchIter, context.partitionId(), context) columnarBatchIter.map { batch => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5dcb79cc2b91..33933b64bbaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -35,7 +35,7 @@ abstract class BaseArrowPythonRunner( _timeZoneId: String, protected override val largeVarTypes: Boolean, protected override val workerConf: Map[String, String], - override val pythonMetrics: Option[Map[String, SQLMetric]], + override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, evalType, argOffsets, jobArtifactUUID) @@ -74,7 +74,7 @@ class ArrowPythonRunner( _timeZoneId: String, largeVarTypes: Boolean, workerConf: Map[String, String], - pythonMetrics: Option[Map[String, SQLMetric]], + pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BaseArrowPythonRunner( funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf, @@ -100,7 +100,7 @@ class ArrowPythonWithNamedArgumentRunner( jobArtifactUUID: Option[String]) extends BaseArrowPythonRunner( funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes, workerConf, - Some(pythonMetrics), jobArtifactUUID) { + pythonMetrics, jobArtifactUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala index df2e89128124..f52b01b6646a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala @@ -39,7 +39,7 @@ class ArrowPythonUDTFRunner( protected override val timeZoneId: String, protected override val largeVarTypes: Boolean, protected override val workerConf: Map[String, String], - override val pythonMetrics: Option[Map[String, SQLMetric]], + override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( Seq(ChainedPythonFunctions(Seq(udtf.func))), evalType, Array(argMetas.map(_.offset)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index 70bd1ce82e2e..7e1c8c2ffc13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -46,15 +46,13 @@ class CoGroupedArrowPythonRunner( rightSchema: StructType, timeZoneId: String, conf: Map[String, String], - pyMetrics: Map[String, SQLMetric], + override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BasePythonRunner[ (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch]( funcs, evalType, argOffsets, jobArtifactUUID) with BasicPythonArrowOutput { - override val pythonMetrics: Option[Map[String, SQLMetric]] = Some(pyMetrics) - override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( funcs.head.funcs.head.pythonExec) @@ -95,7 +93,7 @@ class CoGroupedArrowPythonRunner( writeGroup(nextRight, rightSchema, dataOut, "right") val deltaData = dataOut.size() - startData - pythonMetrics.foreach(_("pythonDataSent") += deltaData) + pythonMetrics("pythonDataSent") += deltaData true } else { dataOut.writeInt(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala index 5550ddf72a14..facf7bc49c5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala @@ -88,7 +88,7 @@ trait FlatMapGroupsInBatchExec extends SparkPlan with UnaryExecNode with PythonS sessionLocalTimeZone, largeVarTypes, pythonRunnerConf, - Some(pythonMetrics), + pythonMetrics, jobArtifactUUID) executePython(data, output, runner) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala index 00990ee46ea5..29dc6e0aa541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala @@ -36,7 +36,7 @@ class MapInBatchEvaluatorFactory( sessionLocalTimeZone: String, largeVarTypes: Boolean, pythonRunnerConf: Map[String, String], - pythonMetrics: Option[Map[String, SQLMetric]], + val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends PartitionEvaluatorFactory[InternalRow, InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index 6db6c96b426a..8db389f02667 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -57,7 +57,7 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { conf.sessionLocalTimeZone, conf.arrowUseLargeVarTypes, pythonRunnerConf, - Some(pythonMetrics), + pythonMetrics, jobArtifactUUID) if (isBarrier) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 6d0f31f35ff7..1e075cab9224 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -46,7 +46,7 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected val largeVarTypes: Boolean - protected def pythonMetrics: Option[Map[String, SQLMetric]] + protected def pythonMetrics: Map[String, SQLMetric] protected def writeNextInputToArrowStream( root: VectorSchemaRoot, @@ -132,7 +132,7 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In writer.writeBatch() arrowWriter.reset() val deltaData = dataOut.size() - startData - pythonMetrics.foreach(_("pythonDataSent") += deltaData) + pythonMetrics("pythonDataSent") += deltaData true } else { super[PythonArrowInput].close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index 82e8e7aa4f64..90922d89ad10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column */ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] => - protected def pythonMetrics: Option[Map[String, SQLMetric]] + protected def pythonMetrics: Map[String, SQLMetric] protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { } @@ -91,8 +91,8 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ val rowCount = root.getRowCount batch.setNumRows(root.getRowCount) val bytesReadEnd = reader.bytesRead() - pythonMetrics.foreach(_("pythonNumRowsReceived") += rowCount) - pythonMetrics.foreach(_("pythonDataReceived") += bytesReadEnd - bytesReadStart) + pythonMetrics("pythonNumRowsReceived") += rowCount + pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart deserializeColumnarBatch(batch, schema) } else { reader.close(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala index a748c1bc1008..4df6d821c014 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala @@ -18,18 +18,29 @@ package org.apache.spark.sql.execution.python import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -private[sql] trait PythonSQLMetrics { self: SparkPlan => +trait PythonSQLMetrics { self: SparkPlan => + protected val pythonMetrics: Map[String, SQLMetric] = { + PythonSQLMetrics.pythonSizeMetricsDesc.map { case (k, v) => + k -> SQLMetrics.createSizeMetric(sparkContext, v) + } ++ PythonSQLMetrics.pythonOtherMetricsDesc.map { case (k, v) => + k -> SQLMetrics.createMetric(sparkContext, v) + } + } - val pythonMetrics = Map( - "pythonDataSent" -> SQLMetrics.createSizeMetric(sparkContext, - "data sent to Python workers"), - "pythonDataReceived" -> SQLMetrics.createSizeMetric(sparkContext, - "data returned from Python workers"), - "pythonNumRowsReceived" -> SQLMetrics.createMetric(sparkContext, - "number of output rows") - ) + override lazy val metrics: Map[String, SQLMetric] = pythonMetrics +} + +object PythonSQLMetrics { + val pythonSizeMetricsDesc: Map[String, String] = { + Map( + "pythonDataSent" -> "data sent to Python workers", + "pythonDataReceived" -> "data returned from Python workers" + ) + } - override lazy val metrics = pythonMetrics + val pythonOtherMetricsDesc: Map[String, String] = { + Map("pythonNumRowsReceived" -> "number of output rows") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index 047a133a322a..f6a473adf08d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -34,8 +34,10 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BinaryType, DataType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -101,8 +103,10 @@ class PythonTableProvider extends TableProvider { new PythonPartitionReaderFactory( source, readerFunc, outputSchema, jobArtifactUUID) } - override def description: String = "(Python)" + + override def supportedCustomMetrics(): Array[CustomMetric] = + source.createPythonMetrics() } } @@ -124,21 +128,78 @@ class PythonPartitionReaderFactory( override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { new PartitionReader[InternalRow] { - private val outputIter = source.createPartitionReadIteratorInPython( - partition.asInstanceOf[PythonInputPartition], - pickledReadFunc, - outputSchema, - jobArtifactUUID) + // Dummy SQLMetrics. The result is manually reported via DSv2 interface + // via passing the value to `CustomTaskMetric`. Note that `pythonOtherMetricsDesc` + // is not used when it is reported. It is to reuse existing Python runner. + // See also `UserDefinedPythonDataSource.createPythonMetrics`. + private[this] val metrics: Map[String, SQLMetric] = { + PythonSQLMetrics.pythonSizeMetricsDesc.keys + .map(_ -> new SQLMetric("size", -1)).toMap ++ + PythonSQLMetrics.pythonOtherMetricsDesc.keys + .map(_ -> new SQLMetric("sum", -1)).toMap + } + + private val outputIter = { + val evaluatorFactory = source.createMapInBatchEvaluatorFactory( + pickledReadFunc, + outputSchema, + metrics, + jobArtifactUUID) + + val part = partition.asInstanceOf[PythonInputPartition] + evaluatorFactory.createEvaluator().eval( + part.index, Iterator.single(InternalRow(part.pickedPartition))) + } override def next(): Boolean = outputIter.hasNext override def get(): InternalRow = outputIter.next() override def close(): Unit = {} + + override def currentMetricsValues(): Array[CustomTaskMetric] = { + source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value}) + } } } } +class PythonCustomMetric extends CustomMetric { + private var initName: String = _ + private var initDescription: String = _ + def initialize(n: String, d: String): Unit = { + initName = n + initDescription = d + } + override def name(): String = { + assert(initName != null) + initName + } + override def description(): String = { + assert(initDescription != null) + initDescription + } + override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { + SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long]) + } +} + +class PythonCustomTaskMetric extends CustomTaskMetric { + private var initName: String = _ + private var initValue: Long = -1L + def initialize(n: String, v: Long): Unit = { + initName = n + initValue = v + } + override def name(): String = { + assert(initName != null) + initName + } + override def value(): Long = { + initValue + } +} + /** * A user-defined Python data source. This is used by the Python API. * Defines the interation between Python and JVM. @@ -179,11 +240,11 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { /** * (Executor-side) Create an iterator that reads the input partitions. */ - def createPartitionReadIteratorInPython( - partition: PythonInputPartition, + def createMapInBatchEvaluatorFactory( pickledReadFunc: Array[Byte], outputSchema: StructType, - jobArtifactUUID: Option[String]): Iterator[InternalRow] = { + metrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]): MapInBatchEvaluatorFactory = { val readerFunc = createPythonFunction(pickledReadFunc) val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF @@ -199,7 +260,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { val conf = SQLConf.get val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) - val evaluatorFactory = new MapInBatchEvaluatorFactory( + new MapInBatchEvaluatorFactory( toAttributes(outputSchema), Seq(ChainedPythonFunctions(Seq(pythonUDF.func))), inputSchema, @@ -208,11 +269,26 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { conf.sessionLocalTimeZone, conf.arrowUseLargeVarTypes, pythonRunnerConf, - None, + metrics, jobArtifactUUID) + } + + def createPythonMetrics(): Array[CustomMetric] = { + // Do not add other metrics such as number of rows, + // that is already included via DSv2. + PythonSQLMetrics.pythonSizeMetricsDesc.map { case (k, v) => + val m = new PythonCustomMetric() + m.initialize(k, v) + m + }.toArray + } - evaluatorFactory.createEvaluator().eval( - partition.index, Iterator.single(InternalRow(partition.pickedPartition))) + def createPythonTaskMetrics(taskMetrics: Map[String, Long]): Array[CustomTaskMetric] = { + taskMetrics.map { case (k, v) => + val m = new PythonCustomTaskMetric() + m.initialize(k, v) + m + }.toArray } private def createPythonFunction(pickledFunc: Array[Byte]): PythonFunction = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index 53a54abf8392..e8a46449ac20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType @@ -396,4 +396,61 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { assert(err.getMessage.contains("PYTHON_DATA_SOURCE_CREATE_ERROR")) } } + + test("SPARK-46424: Support Python metrics") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource, DataSourceReader + |class SimpleDataSourceReader(DataSourceReader): + | def partitions(self): + | return [] + | + | def read(self, partition): + | if partition is None: + | yield ("success", ) + | else: + | yield ("failed", ) + | + |class $dataSourceName(DataSource): + | def schema(self) -> str: + | return "status STRING" + | + | def reader(self, schema): + | return SimpleDataSourceReader() + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + val df = spark.read.format(dataSourceName).load() + + val statusStore = spark.sharedState.statusStore + val oldCount = statusStore.executionsList().size + + df.collect() + + // Wait until the new execution is started and being tracked. + while (statusStore.executionsCount() < oldCount) { + Thread.sleep(100) + } + + // Wait for listener to finish computing the metrics for the execution. + while (statusStore.executionsList().isEmpty || + statusStore.executionsList().last.metricValues == null) { + Thread.sleep(100) + } + + val executedPlan = df.queryExecution.executedPlan.collectFirst { + case p: BatchScanExec => p + } + assert(executedPlan.isDefined) + + val execId = statusStore.executionsList().last.executionId + val metrics = statusStore.executionMetrics(execId) + val pythonDataSent = executedPlan.get.metrics("pythonDataSent") + val pythonDataReceived = executedPlan.get.metrics("pythonDataReceived") + assert(metrics.contains(pythonDataSent.id)) + assert(metrics(pythonDataSent.id).asInstanceOf[String].endsWith("B")) + assert(metrics.contains(pythonDataReceived.id)) + assert(metrics(pythonDataReceived.id).asInstanceOf[String].endsWith("B")) + } }