Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyspark/sql/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def register(
# Serialize the data source class.
wrapped = _wrap_function(sc, dataSource)
assert sc._jvm is not None
ds = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonDataSource(wrapped)
ds = sc._jvm.org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource(wrapped)
self.sparkSession._jsparkSession.dataSource().registerPython(name, ds)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Evolving
import org.apache.spark.internal.Logging
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceManager}
import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource
import org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
import org.apache.spark.sql.execution.datasources.v2.python.PythonDataSourceV2
import org.apache.spark.sql.execution.datasources.xml.XmlFileFormat
import org.apache.spark.sql.execution.python.PythonTableProvider
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -661,7 +661,7 @@ object DataSource extends Logging {
} else if (provider1.toLowerCase(Locale.ROOT) == "kafka") {
throw QueryCompilationErrors.failedToFindKafkaDataSourceError(provider1)
} else if (isUserDefinedDataSource) {
classOf[PythonTableProvider]
classOf[PythonDataSourceV2]
} else {
throw QueryExecutionErrors.dataSourceNotFoundError(provider1, error)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.util.concurrent.ConcurrentHashMap
import org.apache.spark.api.python.PythonUtils
import org.apache.spark.internal.Logging
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource
import org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource
import org.apache.spark.util.Utils


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.python

import scala.jdk.CollectionConverters.IteratorHasAsScala

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType

case class PythonWriterCommitMessage(pickledMessage: Array[Byte]) extends WriterCommitMessage

case class PythonBatchWriterFactory(
source: UserDefinedPythonDataSource,
pickledWriteFunc: Array[Byte],
inputSchema: StructType,
jobArtifactUUID: Option[String]) extends DataWriterFactory {
override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
new DataWriter[InternalRow] {

private[this] val metrics: Map[String, SQLMetric] = PythonCustomMetric.pythonMetrics

private var commitMessage: PythonWriterCommitMessage = _

override def writeAll(records: java.util.Iterator[InternalRow]): Unit = {
val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
pickledWriteFunc,
"write_to_data_source",
inputSchema,
UserDefinedPythonDataSource.writeOutputSchema,
metrics,
jobArtifactUUID)
val outputIter = evaluatorFactory.createEvaluator().eval(partitionId, records.asScala)
outputIter.foreach { row =>
if (commitMessage == null) {
commitMessage = PythonWriterCommitMessage(row.getBinary(0))
} else {
throw QueryExecutionErrors.invalidWriterCommitMessageError(details = "more than one")
}
}
if (commitMessage == null) {
throw QueryExecutionErrors.invalidWriterCommitMessageError(details = "zero")
}
}

override def write(record: InternalRow): Unit =
SparkException.internalError("write method for Python data source should not be called.")

override def commit(): WriterCommitMessage = {
commitMessage.asInstanceOf[WriterCommitMessage]
}

override def abort(): Unit = {}

override def close(): Unit = {}

override def currentMetricsValues(): Array[CustomTaskMetric] = {
source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value })
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.python

import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.python.PythonSQLMetrics


class PythonCustomMetric(
override val name: String,
override val description: String) extends CustomMetric {
// To allow the aggregation can be called. See `SQLAppStatusListener.aggregateMetrics`
def this() = this(null, null)

override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long])
}
}

class PythonCustomTaskMetric(
override val name: String,
override val value: Long) extends CustomTaskMetric


object PythonCustomMetric {
val pythonMetrics: Map[String, SQLMetric] = {
// Dummy SQLMetrics. The result is manually reported via DSv2 interface
// via passing the value to `CustomTaskMetric`. Note that `pythonOtherMetricsDesc`
// is not used when it is reported. It is to reuse existing Python runner.
// See also `UserDefinedPythonDataSource.createPythonMetrics`.
PythonSQLMetrics.pythonSizeMetricsDesc.keys
.map(_ -> new SQLMetric("size", -1)).toMap ++
PythonSQLMetrics.pythonOtherMetricsDesc.keys
.map(_ -> new SQLMetric("sum", -1)).toMap
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.python

import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
* Data Source V2 wrapper for Python Data Source.
*/
class PythonDataSourceV2 extends TableProvider {
private var name: String = _

def setShortName(str: String): Unit = {
assert(name == null)
name = str
}

private def shortName: String = {
assert(name != null)
name
}

private var dataSourceInPython: PythonDataSourceCreationResult = _
private[python] lazy val source: UserDefinedPythonDataSource =
SparkSession.active.sessionState.dataSourceManager.lookupDataSource(shortName)

def getOrCreateDataSourceInPython(
shortName: String,
options: CaseInsensitiveStringMap,
userSpecifiedSchema: Option[StructType]): PythonDataSourceCreationResult = {
if (dataSourceInPython == null) {
dataSourceInPython = source.createDataSourceInPython(shortName, options, userSpecifiedSchema)
}
dataSourceInPython
}

override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
getOrCreateDataSourceInPython(shortName, options, None).schema
}

override def getTable(
schema: StructType,
partitioning: Array[Transform],
properties: java.util.Map[String, String]): Table = {
new PythonTable(this, shortName, schema)
}

override def supportsExternalMetadata(): Boolean = true
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.v2.python

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType


case class PythonInputPartition(index: Int, pickedPartition: Array[Byte]) extends InputPartition

class PythonPartitionReaderFactory(
source: UserDefinedPythonDataSource,
pickledReadFunc: Array[Byte],
outputSchema: StructType,
jobArtifactUUID: Option[String])
extends PartitionReaderFactory {

override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
new PartitionReader[InternalRow] {

private[this] val metrics: Map[String, SQLMetric] = PythonCustomMetric.pythonMetrics

private val outputIter = {
val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
pickledReadFunc,
"read_from_data_source",
UserDefinedPythonDataSource.readInputSchema,
outputSchema,
metrics,
jobArtifactUUID)

val part = partition.asInstanceOf[PythonInputPartition]
evaluatorFactory.createEvaluator().eval(
part.index, Iterator.single(InternalRow(part.pickedPartition)))
}

override def next(): Boolean = outputIter.hasNext

override def get(): InternalRow = outputIter.next()

override def close(): Unit = {}

override def currentMetricsValues(): Array[CustomTaskMetric] = {
source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value})
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.python

import org.apache.spark.JobArtifactSet
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap


class PythonScan(
ds: PythonDataSourceV2,
shortName: String,
outputSchema: StructType,
options: CaseInsensitiveStringMap) extends Batch with Scan {

private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)

private lazy val infoInPython: PythonDataSourceReadInfo = {
ds.source.createReadInfoInPython(
ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)),
outputSchema)
}

override def planInputPartitions(): Array[InputPartition] =
infoInPython.partitions.zipWithIndex.map(p => PythonInputPartition(p._2, p._1)).toArray

override def createReaderFactory(): PartitionReaderFactory = {
val readerFunc = infoInPython.func
new PythonPartitionReaderFactory(
ds.source, readerFunc, outputSchema, jobArtifactUUID)
}

override def toBatch: Batch = this

override def description: String = "(Python)"

override def readSchema(): StructType = outputSchema

override def supportedCustomMetrics(): Array[CustomMetric] =
ds.source.createPythonMetrics()
}
Loading