Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/pyspark/sql/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,10 @@ def register(
# Serialize the data source class.
wrapped = _wrap_function(sc, dataSource)
assert sc._jvm is not None
ds = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonDataSource(wrapped)
jvm = sc._jvm
ds = jvm.org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource(
wrapped
)
self.sparkSession._jsparkSession.dataSource().registerPython(name, ds)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Evolving
import org.apache.spark.internal.Logging
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceManager}
import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource
import org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
import org.apache.spark.sql.execution.datasources.v2.python.PythonDataSourceV2
import org.apache.spark.sql.execution.datasources.xml.XmlFileFormat
import org.apache.spark.sql.execution.python.PythonTableProvider
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -661,7 +661,7 @@ object DataSource extends Logging {
} else if (provider1.toLowerCase(Locale.ROOT) == "kafka") {
throw QueryCompilationErrors.failedToFindKafkaDataSourceError(provider1)
} else if (isUserDefinedDataSource) {
classOf[PythonTableProvider]
classOf[PythonDataSourceV2]
} else {
throw QueryExecutionErrors.dataSourceNotFoundError(provider1, error)
}
Expand Down Expand Up @@ -734,7 +734,7 @@ object DataSource extends Logging {
case t: TableProvider
if !useV1Sources.contains(cls.getCanonicalName.toLowerCase(Locale.ROOT)) =>
t match {
case p: PythonTableProvider => p.setShortName(provider)
case p: PythonDataSourceV2 => p.setShortName(provider)
case _ =>
}
Some(t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.util.concurrent.ConcurrentHashMap
import org.apache.spark.api.python.PythonUtils
import org.apache.spark.internal.Logging
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource
import org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource
import org.apache.spark.util.Utils


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.python

import scala.jdk.CollectionConverters.IteratorHasAsScala

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType

case class PythonWriterCommitMessage(pickledMessage: Array[Byte]) extends WriterCommitMessage

case class PythonBatchWriterFactory(
source: UserDefinedPythonDataSource,
pickledWriteFunc: Array[Byte],
inputSchema: StructType,
jobArtifactUUID: Option[String]) extends DataWriterFactory {
override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
new DataWriter[InternalRow] {

private[this] val metrics: Map[String, SQLMetric] = PythonCustomMetric.pythonMetrics

private var commitMessage: PythonWriterCommitMessage = _

override def writeAll(records: java.util.Iterator[InternalRow]): Unit = {
val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
pickledWriteFunc,
"write_to_data_source",
inputSchema,
UserDefinedPythonDataSource.writeOutputSchema,
metrics,
jobArtifactUUID)
val outputIter = evaluatorFactory.createEvaluator().eval(partitionId, records.asScala)
outputIter.foreach { row =>
if (commitMessage == null) {
commitMessage = PythonWriterCommitMessage(row.getBinary(0))
} else {
throw QueryExecutionErrors.invalidWriterCommitMessageError(details = "more than one")
}
}
if (commitMessage == null) {
throw QueryExecutionErrors.invalidWriterCommitMessageError(details = "zero")
}
}

override def write(record: InternalRow): Unit =
SparkException.internalError("write method for Python data source should not be called.")

override def commit(): WriterCommitMessage = {
commitMessage.asInstanceOf[WriterCommitMessage]
}

override def abort(): Unit = {}

override def close(): Unit = {}

override def currentMetricsValues(): Array[CustomTaskMetric] = {
source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value })
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.python

import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.python.PythonSQLMetrics


class PythonCustomMetric(
override val name: String,
override val description: String) extends CustomMetric {
// To allow the aggregation can be called. See `SQLAppStatusListener.aggregateMetrics`
def this() = this(null, null)

override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long])
}
}

class PythonCustomTaskMetric(
override val name: String,
override val value: Long) extends CustomTaskMetric


object PythonCustomMetric {
val pythonMetrics: Map[String, SQLMetric] = {
// Dummy SQLMetrics. The result is manually reported via DSv2 interface
// via passing the value to `CustomTaskMetric`. Note that `pythonOtherMetricsDesc`
// is not used when it is reported. It is to reuse existing Python runner.
// See also `UserDefinedPythonDataSource.createPythonMetrics`.
PythonSQLMetrics.pythonSizeMetricsDesc.keys
.map(_ -> new SQLMetric("size", -1)).toMap ++
PythonSQLMetrics.pythonOtherMetricsDesc.keys
.map(_ -> new SQLMetric("sum", -1)).toMap
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.python

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
* Data Source V2 wrapper for Python Data Source.
*/
class PythonDataSourceV2 extends TableProvider {
private var name: String = _

def setShortName(str: String): Unit = {
assert(name == null)
name = str
}

private def shortName: String = {
assert(name != null)
name
}

private var dataSourceInPython: PythonDataSourceCreationResult = _
private[python] lazy val source: UserDefinedPythonDataSource =
SparkSession.active.sessionState.dataSourceManager.lookupDataSource(shortName)

def getOrCreateDataSourceInPython(
shortName: String,
options: CaseInsensitiveStringMap,
userSpecifiedSchema: Option[StructType]): PythonDataSourceCreationResult = {
if (dataSourceInPython == null) {
dataSourceInPython = source.createDataSourceInPython(shortName, options, userSpecifiedSchema)
}
dataSourceInPython
}

override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
getOrCreateDataSourceInPython(shortName, options, None).schema
}

override def getTable(
schema: StructType,
partitioning: Array[Transform],
properties: java.util.Map[String, String]): Table = {
new PythonTable(this, shortName, schema)
}

override def supportsExternalMetadata(): Boolean = true
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.v2.python

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType


case class PythonInputPartition(index: Int, pickedPartition: Array[Byte]) extends InputPartition

class PythonPartitionReaderFactory(
source: UserDefinedPythonDataSource,
pickledReadFunc: Array[Byte],
outputSchema: StructType,
jobArtifactUUID: Option[String])
extends PartitionReaderFactory {

override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
new PartitionReader[InternalRow] {

private[this] val metrics: Map[String, SQLMetric] = PythonCustomMetric.pythonMetrics

private val outputIter = {
val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
pickledReadFunc,
"read_from_data_source",
UserDefinedPythonDataSource.readInputSchema,
outputSchema,
metrics,
jobArtifactUUID)

val part = partition.asInstanceOf[PythonInputPartition]
evaluatorFactory.createEvaluator().eval(
part.index, Iterator.single(InternalRow(part.pickedPartition)))
}

override def next(): Boolean = outputIter.hasNext

override def get(): InternalRow = outputIter.next()

override def close(): Unit = {}

override def currentMetricsValues(): Array[CustomTaskMetric] = {
source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value})
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.python

import org.apache.spark.JobArtifactSet
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap


class PythonScan(
ds: PythonDataSourceV2,
shortName: String,
outputSchema: StructType,
options: CaseInsensitiveStringMap) extends Batch with Scan {

private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)

private lazy val infoInPython: PythonDataSourceReadInfo = {
ds.source.createReadInfoInPython(
ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)),
outputSchema)
}

override def planInputPartitions(): Array[InputPartition] =
infoInPython.partitions.zipWithIndex.map(p => PythonInputPartition(p._2, p._1)).toArray

override def createReaderFactory(): PartitionReaderFactory = {
val readerFunc = infoInPython.func
new PythonPartitionReaderFactory(
ds.source, readerFunc, outputSchema, jobArtifactUUID)
}

override def toBatch: Batch = this

override def description: String = "(Python)"

override def readSchema(): StructType = outputSchema

override def supportedCustomMetrics(): Array[CustomMetric] =
ds.source.createPythonMetrics()
}
Loading