Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,9 @@ def toPandas(self, useArrow=False):

This is only available if Pandas is installed and available.

:param useArrow: Make use of Apache Arrow for conversion, pyarrow must be installed
on the calling Python process.

.. note:: This method should only be used if the resulting Pandas's DataFrame is expected
to be small, as all the data is loaded into the driver's memory.

Expand All @@ -1581,11 +1584,10 @@ def toPandas(self, useArrow=False):
0 2 Alice
1 5 Bob
"""
import pandas as pd

if useArrow:
return self.collectAsArrow().to_pandas()
else:
import pandas as pd
return pd.DataFrame.from_records(self.collect(), columns=self.columns)

##########################################################################################
Expand Down
156 changes: 73 additions & 83 deletions sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ private[sql] class ArrowConverters {
override def hasNext: Boolean = iter.hasNext
}

def internalRowsToPayload(rows: Array[InternalRow], schema: StructType): ArrowPayload = {
val batch = ArrowConverters.internalRowsToArrowRecordBatch(rows, schema, allocator)
def interalRowIterToPayload(rowIter: Iterator[InternalRow], schema: StructType): ArrowPayload = {
val batch = ArrowConverters.internalRowIterToArrowBatch(rowIter, schema, allocator, None)
new ArrowStaticPayload(batch)
}
}
Expand Down Expand Up @@ -83,52 +83,40 @@ private[sql] object ArrowConverters {
}

/**
* Transfer an array of InternalRow to an ArrowRecordBatch.
* Iterate over InternalRows and write to an ArrowRecordBatch.
*/
private[sql] def internalRowsToArrowRecordBatch(
rows: Array[InternalRow],
private def internalRowIterToArrowBatch(
rowIter: Iterator[InternalRow],
schema: StructType,
allocator: RootAllocator): ArrowRecordBatch = {
val fieldAndBuf = schema.fields.zipWithIndex.map { case (field, ordinal) =>
internalRowToArrowBuf(rows, ordinal, field, allocator)
allocator: RootAllocator,
initialSize: Option[Int]): ArrowRecordBatch = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can get rid of initialSize for an iterator orientated implementation


val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) =>
ColumnWriter(ordinal, allocator, field.dataType)
.init(initialSize)
}

rowIter.foreach { row =>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we even discussed this before I think.. Even here it might add a little overhead to wrap in a function object, so I'll change it to a while loop just to be sure.

columnWriters.foreach { writer =>
writer.write(row)
}
}

val fieldAndBuf = columnWriters.map { writer =>
writer.finish()
}.unzip
val fieldNodes = fieldAndBuf._1.flatten
val buffers = fieldAndBuf._2.flatten

val recordBatch = new ArrowRecordBatch(rows.length,
val rowLength = if(fieldNodes.nonEmpty) fieldNodes.head.getLength else 0
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this seem acceptable to get the row length for creating an ArrowRecordBatch or better to keep a counter myself?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need another counter here.

When is fieldNodes empty?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will probably never be empty, but might be good to keep the check just in case


val recordBatch = new ArrowRecordBatch(rowLength,
fieldNodes.toList.asJava, buffers.toList.asJava)

buffers.foreach(_.release())
recordBatch
}

/**
* Write a Field from array of InternalRow to an ArrowBuf.
*/
private def internalRowToArrowBuf(
rows: Array[InternalRow],
ordinal: Int,
field: StructField,
allocator: RootAllocator): (Array[ArrowFieldNode], Array[ArrowBuf]) = {
val numOfRows = rows.length
val columnWriter = ColumnWriter(allocator, field.dataType)
columnWriter.init(numOfRows)
var index = 0

while(index < numOfRows) {
val row = rows(index)
if (row.isNullAt(ordinal)) {
columnWriter.writeNull()
} else {
columnWriter.write(row, ordinal)
}
index += 1
}

val (arrowFieldNodes, arrowBufs) = columnWriter.finish()
(arrowFieldNodes.toArray, arrowBufs.toArray)
}

/**
* Convert a Spark Dataset schema to Arrow schema.
*/
Expand Down Expand Up @@ -160,9 +148,8 @@ private[sql] object ArrowConverters {
}

private[sql] trait ColumnWriter {
def init(initialSize: Int): Unit
def writeNull(): Unit
def write(row: InternalRow, ordinal: Int): Unit
def init(initialSize: Option[Int]): this.type
def write(row: InternalRow): Unit

/**
* Clear the column writer and return the ArrowFieldNode and ArrowBuf.
Expand All @@ -174,7 +161,9 @@ private[sql] trait ColumnWriter {
/**
* Base class for flat arrow column writer, i.e., column without children.
*/
private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseAllocator)
private[sql] abstract class PrimitiveColumnWriter(
val ordinal: Int,
protected val allocator: BaseAllocator)
extends ColumnWriter {
protected def valueVector: BaseDataValueVector
protected def valueMutator: BaseMutator
Expand All @@ -185,18 +174,19 @@ private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseA
protected var count = 0
protected var nullCount = 0

override def init(initialSize: Int): Unit = {
override def init(initialSize: Option[Int]): this.type = {
initialSize.foreach(valueVector.setInitialCapacity)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed this was never used before, would it increase performance much to set the exact capacity? Although I'm not sure its possible when using Iterator[InternalRow] without making a first pass

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For iterator, we cannot know the size. I would just set a constant initial capacity.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll remove the option to set an initial capacity now. The default capacity seems to be fine.

valueVector.allocateNew()
this
}

override def writeNull(): Unit = {
setNull()
nullCount += 1
count += 1
}

override def write(row: InternalRow, ordinal: Int): Unit = {
setValue(row, ordinal)
override def write(row: InternalRow): Unit = {
if (row.isNullAt(ordinal)) {
setNull()
nullCount += 1
} else {
setValue(row, ordinal)
}
count += 1
}

Expand All @@ -208,8 +198,8 @@ private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseA
}
}

private[sql] class BooleanColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
private[sql] class BooleanColumnWriter(ordinal: Int, allocator: BaseAllocator)
extends PrimitiveColumnWriter(ordinal, allocator) {
private def bool2int(b: Boolean): Int = if (b) 1 else 0

override protected val valueVector: NullableBitVector
Expand All @@ -221,8 +211,8 @@ private[sql] class BooleanColumnWriter(allocator: BaseAllocator)
= valueMutator.setSafe(count, bool2int(row.getBoolean(ordinal)))
}

private[sql] class ShortColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
private[sql] class ShortColumnWriter(ordinal: Int, allocator: BaseAllocator)
extends PrimitiveColumnWriter(ordinal, allocator) {
override protected val valueVector: NullableSmallIntVector
= new NullableSmallIntVector("ShortValue", allocator)
override protected val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator
Expand All @@ -232,8 +222,8 @@ private[sql] class ShortColumnWriter(allocator: BaseAllocator)
= valueMutator.setSafe(count, row.getShort(ordinal))
}

private[sql] class IntegerColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
private[sql] class IntegerColumnWriter(ordinal: Int, allocator: BaseAllocator)
extends PrimitiveColumnWriter(ordinal, allocator) {
override protected val valueVector: NullableIntVector
= new NullableIntVector("IntValue", allocator)
override protected val valueMutator: NullableIntVector#Mutator = valueVector.getMutator
Expand All @@ -243,8 +233,8 @@ private[sql] class IntegerColumnWriter(allocator: BaseAllocator)
= valueMutator.setSafe(count, row.getInt(ordinal))
}

private[sql] class LongColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
private[sql] class LongColumnWriter(ordinal: Int, allocator: BaseAllocator)
extends PrimitiveColumnWriter(ordinal, allocator) {
override protected val valueVector: NullableBigIntVector
= new NullableBigIntVector("LongValue", allocator)
override protected val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator
Expand All @@ -254,8 +244,8 @@ private[sql] class LongColumnWriter(allocator: BaseAllocator)
= valueMutator.setSafe(count, row.getLong(ordinal))
}

private[sql] class FloatColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
private[sql] class FloatColumnWriter(ordinal: Int, allocator: BaseAllocator)
extends PrimitiveColumnWriter(ordinal, allocator) {
override protected val valueVector: NullableFloat4Vector
= new NullableFloat4Vector("FloatValue", allocator)
override protected val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator
Expand All @@ -265,8 +255,8 @@ private[sql] class FloatColumnWriter(allocator: BaseAllocator)
= valueMutator.setSafe(count, row.getFloat(ordinal))
}

private[sql] class DoubleColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
private[sql] class DoubleColumnWriter(ordinal: Int, allocator: BaseAllocator)
extends PrimitiveColumnWriter(ordinal, allocator) {
override protected val valueVector: NullableFloat8Vector
= new NullableFloat8Vector("DoubleValue", allocator)
override protected val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator
Expand All @@ -276,8 +266,8 @@ private[sql] class DoubleColumnWriter(allocator: BaseAllocator)
= valueMutator.setSafe(count, row.getDouble(ordinal))
}

private[sql] class ByteColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
private[sql] class ByteColumnWriter(ordinal: Int, allocator: BaseAllocator)
extends PrimitiveColumnWriter(ordinal, allocator) {
override protected val valueVector: NullableUInt1Vector
= new NullableUInt1Vector("ByteValue", allocator)
override protected val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator
Expand All @@ -287,8 +277,8 @@ private[sql] class ByteColumnWriter(allocator: BaseAllocator)
= valueMutator.setSafe(count, row.getByte(ordinal))
}

private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
private[sql] class UTF8StringColumnWriter(ordinal: Int, allocator: BaseAllocator)
extends PrimitiveColumnWriter(ordinal, allocator) {
override protected val valueVector: NullableVarBinaryVector
= new NullableVarBinaryVector("UTF8StringValue", allocator)
override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator
Expand All @@ -300,8 +290,8 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator)
}
}

private[sql] class BinaryColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
private[sql] class BinaryColumnWriter(ordinal: Int, allocator: BaseAllocator)
extends PrimitiveColumnWriter(ordinal, allocator) {
override protected val valueVector: NullableVarBinaryVector
= new NullableVarBinaryVector("BinaryValue", allocator)
override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator
Expand All @@ -313,8 +303,8 @@ private[sql] class BinaryColumnWriter(allocator: BaseAllocator)
}
}

private[sql] class DateColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
private[sql] class DateColumnWriter(ordinal: Int, allocator: BaseAllocator)
extends PrimitiveColumnWriter(ordinal, allocator) {
override protected val valueVector: NullableDateVector
= new NullableDateVector("DateValue", allocator)
override protected val valueMutator: NullableDateVector#Mutator = valueVector.getMutator
Expand All @@ -326,8 +316,8 @@ private[sql] class DateColumnWriter(allocator: BaseAllocator)
}
}

private[sql] class TimeStampColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
private[sql] class TimeStampColumnWriter(ordinal: Int, allocator: BaseAllocator)
extends PrimitiveColumnWriter(ordinal, allocator) {
override protected val valueVector: NullableTimeStampVector
= new NullableTimeStampVector("TimeStampValue", allocator)
override protected val valueMutator: NullableTimeStampVector#Mutator = valueVector.getMutator
Expand All @@ -341,19 +331,19 @@ private[sql] class TimeStampColumnWriter(allocator: BaseAllocator)
}

private[sql] object ColumnWriter {
def apply(allocator: BaseAllocator, dataType: DataType): ColumnWriter = {
def apply(ordinal: Int, allocator: BaseAllocator, dataType: DataType): ColumnWriter = {
dataType match {
case BooleanType => new BooleanColumnWriter(allocator)
case ShortType => new ShortColumnWriter(allocator)
case IntegerType => new IntegerColumnWriter(allocator)
case LongType => new LongColumnWriter(allocator)
case FloatType => new FloatColumnWriter(allocator)
case DoubleType => new DoubleColumnWriter(allocator)
case ByteType => new ByteColumnWriter(allocator)
case StringType => new UTF8StringColumnWriter(allocator)
case BinaryType => new BinaryColumnWriter(allocator)
case DateType => new DateColumnWriter(allocator)
case TimestampType => new TimeStampColumnWriter(allocator)
case BooleanType => new BooleanColumnWriter(ordinal, allocator)
case ShortType => new ShortColumnWriter(ordinal, allocator)
case IntegerType => new IntegerColumnWriter(ordinal, allocator)
case LongType => new LongColumnWriter(ordinal, allocator)
case FloatType => new FloatColumnWriter(ordinal, allocator)
case DoubleType => new DoubleColumnWriter(ordinal, allocator)
case ByteType => new ByteColumnWriter(ordinal, allocator)
case StringType => new UTF8StringColumnWriter(ordinal, allocator)
case BinaryType => new BinaryColumnWriter(ordinal, allocator)
case DateType => new DateColumnWriter(ordinal, allocator)
case TimestampType => new TimeStampColumnWriter(ordinal, allocator)
case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType")
}
}
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2375,8 +2375,8 @@ class Dataset[T] private[sql](
val cnvtr = converter.getOrElse(new ArrowConverters)
withNewExecutionId {
try {
val collectedRows = queryExecution.executedPlan.executeCollect()
cnvtr.internalRowsToPayload(collectedRows, this.schema)
val rowIter = queryExecution.executedPlan.executeToIterator()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woops, was looking at an old commit. Please ignore this.

cnvtr.interalRowIterToPayload(rowIter, this.schema)
} catch {
case e: Exception =>
throw e
Expand Down