Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,12 @@ class PythonPartitionReaderFactory(
.map(_ -> new SQLMetric("sum", -1)).toMap
}

private val outputIter = {
val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
pickledReadFunc,
outputSchema,
metrics,
jobArtifactUUID)

val part = partition.asInstanceOf[PythonInputPartition]
evaluatorFactory.createEvaluator().eval(
part.index, Iterator.single(InternalRow(part.pickedPartition)))
}
private val outputIter = source.createPartitionReadIteratorInPython(
partition.asInstanceOf[PythonInputPartition],
pickledReadFunc,
outputSchema,
metrics,
jobArtifactUUID)

override def next(): Boolean = outputIter.hasNext

Expand All @@ -164,41 +159,20 @@ class PythonPartitionReaderFactory(
}
}

class PythonCustomMetric extends CustomMetric {
private var initName: String = _
private var initDescription: String = _
def initialize(n: String, d: String): Unit = {
initName = n
initDescription = d
}
override def name(): String = {
assert(initName != null)
initName
}
override def description(): String = {
assert(initDescription != null)
initDescription
}
class PythonCustomMetric(
override val name: String,
override val description: String) extends CustomMetric {
// To allow the aggregation can be called. See `SQLAppStatusListener.aggregateMetrics`
def this() = this(null, null)

override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long])
}
}

class PythonCustomTaskMetric extends CustomTaskMetric {
private var initName: String = _
private var initValue: Long = -1L
def initialize(n: String, v: Long): Unit = {
initName = n
initValue = v
}
override def name(): String = {
assert(initName != null)
initName
}
override def value(): Long = {
initValue
}
}
class PythonCustomTaskMetric(
override val name: String,
override val value: Long) extends CustomTaskMetric

/**
* A user-defined Python data source. This is used by the Python API.
Expand Down Expand Up @@ -240,11 +214,12 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
/**
* (Executor-side) Create an iterator that reads the input partitions.
*/
def createMapInBatchEvaluatorFactory(
def createPartitionReadIteratorInPython(
partition: PythonInputPartition,
pickledReadFunc: Array[Byte],
outputSchema: StructType,
metrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String]): MapInBatchEvaluatorFactory = {
jobArtifactUUID: Option[String]): Iterator[InternalRow] = {
val readerFunc = createPythonFunction(pickledReadFunc)

val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF
Expand All @@ -260,7 +235,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
val conf = SQLConf.get

val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
new MapInBatchEvaluatorFactory(
val evaluatorFactory = new MapInBatchEvaluatorFactory(
toAttributes(outputSchema),
Seq(ChainedPythonFunctions(Seq(pythonUDF.func))),
inputSchema,
Expand All @@ -271,24 +246,21 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
pythonRunnerConf,
metrics,
jobArtifactUUID)

val part = partition
evaluatorFactory.createEvaluator().eval(
part.index, Iterator.single(InternalRow(part.pickedPartition)))
}

def createPythonMetrics(): Array[CustomMetric] = {
// Do not add other metrics such as number of rows,
// that is already included via DSv2.
PythonSQLMetrics.pythonSizeMetricsDesc.map { case (k, v) =>
val m = new PythonCustomMetric()
m.initialize(k, v)
m
}.toArray
PythonSQLMetrics.pythonSizeMetricsDesc
.map { case (k, v) => new PythonCustomMetric(k, v)}.toArray
}

def createPythonTaskMetrics(taskMetrics: Map[String, Long]): Array[CustomTaskMetric] = {
taskMetrics.map { case (k, v) =>
val m = new PythonCustomTaskMetric()
m.initialize(k, v)
m
}.toArray
taskMetrics.map { case (k, v) => new PythonCustomTaskMetric(k, v)}.toArray
}

private def createPythonFunction(pickledFunc: Array[Byte]): PythonFunction = {
Expand Down