Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,35 @@ public final class ColumnarBatch {
final Row row;

public static ColumnarBatch allocate(StructType schema, MemoryMode memMode) {
return new ColumnarBatch(schema, DEFAULT_BATCH_SIZE, memMode);
return allocate(schema, memMode, DEFAULT_BATCH_SIZE);
}

public static ColumnarBatch allocate(StructType type) {
return new ColumnarBatch(type, DEFAULT_BATCH_SIZE, DEFAULT_MEMORY_MODE);
return allocate(type, DEFAULT_MEMORY_MODE, DEFAULT_BATCH_SIZE);
}

public static ColumnarBatch allocate(StructType schema, MemoryMode memMode, int maxRows) {
return new ColumnarBatch(schema, maxRows, memMode);
ColumnVector[] columns = allocateCols(schema, maxRows, memMode);
return new ColumnarBatch(schema, columns, maxRows);
}

private static ColumnVector[] allocateCols(StructType schema, int maxRows, MemoryMode memMode) {
ColumnVector[] columns = new ColumnVector[schema.size()];
for (int i = 0; i < schema.fields().length; ++i) {
StructField field = schema.fields()[i];
columns[i] = ColumnVector.allocate(maxRows, field.dataType(), memMode);
}
return columns;
}

public static ColumnarBatch createReadOnly(
StructType schema,
ReadOnlyColumnVector[] columns,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to restrict this to only ReadOnlyColumnVector?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary? What impact will it cause?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't need to be restricted, but if they are ReadOnlyColumnVectors then it means they are already populated and it is safe to call setNumRows(numRows) here. If it took in any ColumnVector then it might cause issues by someone passing in unallocated vectors.

int numRows) {
assert(schema.length() == columns.length);
ColumnarBatch batch = new ColumnarBatch(schema, columns, numRows);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the capacity is set to numRows inside the ctor but need to call batch.setNumRows() manually?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The max capacity only has meaning when allocating ColumnVectors so it doesn't really do
anything for read-only vectors. You need to callsetNumRows to tell the batch how many rows there for the given columns, it doesn't look at the capacity in the individual vectors.

batch.setNumRows(numRows);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check each ReadOnlyColumnVector has numRows?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ArrowColumnVector.valueCount here would need to be moved to ReadOnlyColumnVector which could go in place of the capacity. If @ueshin thinks that's ok to do so here, I can add that.

return batch;
}

/**
Expand Down Expand Up @@ -505,18 +525,12 @@ public void filterNullsInColumn(int ordinal) {
nullFilteredColumns.add(ordinal);
}

private ColumnarBatch(StructType schema, int maxRows, MemoryMode memMode) {
private ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) {
this.schema = schema;
this.capacity = maxRows;
this.columns = new ColumnVector[schema.size()];
this.columns = columns;
this.capacity = capacity;
this.nullFilteredColumns = new HashSet<>();
this.filteredRows = new boolean[maxRows];

for (int i = 0; i < schema.fields().length; ++i) {
StructField field = schema.fields()[i];
columns[i] = ColumnVector.allocate(maxRows, field.dataType(), memMode);
}

this.filteredRows = new boolean[this.capacity];
this.row = new Row(this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.arrow
import java.io.ByteArrayOutputStream
import java.nio.channels.Channels

import scala.collection.JavaConverters._

import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.file._
Expand All @@ -28,14 +30,15 @@ import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ReadOnlyColumnVector}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils


/**
* Store Arrow data in a form that can be serialized by Spark and served to a Python process.
*/
private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable {
private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable {

/**
* Convert the ArrowPayload to an ArrowRecordBatch.
Expand Down Expand Up @@ -110,6 +113,65 @@ private[sql] object ArrowConverters {
}
}

private[sql] def fromPayloadIterator(
payloadIter: Iterator[ArrowPayload],
context: TaskContext): (Iterator[InternalRow], StructType) = {
val allocator =
ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue)
var reader: ArrowFileReader = null

def nextBatch(): (Iterator[InternalRow], StructType) = {
val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable)
reader = new ArrowFileReader(in, allocator)
reader.loadNextBatch() // throws IOException
val root = reader.getVectorSchemaRoot // throws IOException
val schemaRead = ArrowUtils.fromArrowSchema(root.getSchema)

val columns = root.getFieldVectors.asScala.map { vector =>
new ArrowColumnVector(vector).asInstanceOf[ReadOnlyColumnVector]
}.toArray

(ColumnarBatch.createReadOnly(schemaRead, columns, root.getRowCount).rowIterator().asScala,
schemaRead)
}

var (rowIter, schemaRead) = if (payloadIter.hasNext) {
nextBatch()
} else {
(Iterator.empty, StructType(Seq.empty))
}

val outputIterator = new Iterator[InternalRow] {

context.addTaskCompletionListener { _ =>
closeReader()
allocator.close()
}

override def hasNext: Boolean = rowIter.hasNext || {
closeReader()
if (payloadIter.hasNext) {
rowIter = nextBatch()._1
true
} else {
allocator.close()
false
}
}

override def next(): InternalRow = rowIter.next()

private def closeReader(): Unit = {
if (reader != null) {
reader.close()
reader = null
}
}
}

(outputIterator, schemaRead)
}

/**
* Convert a byte array to an ArrowRecordBatch.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,18 @@ import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Locale

import scala.collection.JavaConverters._

import com.google.common.io.Files
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot}
import org.apache.arrow.vector.{NullableIntVector, VectorLoader, VectorSchemaRoot}
import org.apache.arrow.vector.file.json.JsonFileReader
import org.apache.arrow.vector.util.Validator
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.SparkException
import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ReadOnlyColumnVector}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -1629,6 +1632,42 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
}
}

test("roundtrip payloads") {
val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue)
val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true)
.createVector(allocator).asInstanceOf[NullableIntVector]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the allocator and the vector be closed at the end of this test?

Copy link
Member Author

@BryanCutler BryanCutler Aug 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, thanks for catching that. I close them now.

vector.allocateNew()
val mutator = vector.getMutator()

(0 until 10).foreach { i =>
mutator.setSafe(i, i)
}
mutator.setNull(10)
mutator.setValueCount(11)

val schema = StructType(Seq(StructField("int", IntegerType)))

val columnarBatch = ColumnarBatch.createReadOnly(
schema, Array[ReadOnlyColumnVector](new ArrowColumnVector(vector)), 11)

val context = TaskContext.empty()

val payloadIter = ArrowConverters.toPayloadIterator(
columnarBatch.rowIterator().asScala, schema, 0, context)

val (rowIter, schemaRead) = ArrowConverters.fromPayloadIterator(payloadIter, context)

assert(schema.equals(schemaRead))

rowIter.zipWithIndex.foreach { case (row, i) =>
if (i == 10) {
assert(row.isNullAt(0))
} else {
assert(row.getInt(0) == i)
}
}
}

/** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */
private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = {
// NOTE: coalesce to single partition because can only load 1 batch in validator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random

import org.apache.arrow.vector.NullableIntVector

import org.apache.spark.SparkFunSuite
import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.CalendarInterval
Expand Down Expand Up @@ -1248,4 +1251,55 @@ class ColumnarBatchSuite extends SparkFunSuite {
s"vectorized reader"))
}
}

test("create read-only batch") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create a columnar batch from Arrow column vectors or something?

val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue)
val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true)
.createVector(allocator).asInstanceOf[NullableIntVector]
vector1.allocateNew()
val mutator1 = vector1.getMutator()
val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true)
.createVector(allocator).asInstanceOf[NullableIntVector]
vector2.allocateNew()
val mutator2 = vector2.getMutator()

(0 until 10).foreach { i =>
mutator1.setSafe(i, i)
mutator2.setSafe(i + 1, i)
}
mutator1.setNull(10)
mutator1.setValueCount(11)
mutator2.setNull(0)
mutator2.setValueCount(11)

val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2))

val schema = StructType(Seq(StructField("int1", IntegerType), StructField("int2", IntegerType)))
val batch = ColumnarBatch.createReadOnly(
schema, columnVectors.toArray[ReadOnlyColumnVector], 11)

assert(batch.numCols() == 2)
assert(batch.numRows() == 11)

val rowIter = batch.rowIterator().asScala
rowIter.zipWithIndex.foreach { case (row, i) =>
if (i == 10) {
assert(row.isNullAt(0))
} else {
assert(row.getInt(0) == i)
}
if (i == 0) {
assert(row.isNullAt(1))
} else {
assert(row.getInt(1) == i - 1)
}
}

intercept[java.lang.AssertionError] {
batch.getRow(100)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that is strange. I'll take a look, thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably because the assert is being compiled out.. This should probably not be in the test then.

Copy link
Member

@dongjoon-hyun dongjoon-hyun Aug 31, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, please check the error message here. Please ignore this.

Copy link
Member Author

@BryanCutler BryanCutler Aug 31, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem is that if the Java assertion is compiled out, then no error is produced and the test fails.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just made #19098 to remove this check - it's not really testing the functionality added here anyway but maybe another test should be added for checkout index out of bounds errors.

}

columnVectors.foreach(_.close())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use batch.close() here.

allocator.close()
}
}