Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pyspark
import timeit
import random
from pyspark.sql import SparkSession

numPartition = 20

def time(df, repeat, number):
print("toPandas with arrow")
print(timeit.repeat('df.toPandas(True)', repeat=repeat, number=number, globals={'df': df}))

print("toPandas without arrow")
print(timeit.repeat('df.toPandas(False)', repeat=repeat, number=number, globals={'df': df}))

def long():
return random.randint(0, 10000)

def double():
return random.random()

def genDataLocal(spark, size, columns):
data = [list([fn() for fn in columns]) for x in range(0, size)]
df = spark.createDataFrame(data)
return df

def genData(spark, size, columns):
rdd = (spark.sparkContext
.parallelize(range(0,numPartition), numPartition)
.flatMap(lambda index: [list([fn() for fn in columns]) for x in range(0, int(size / numPartition))]))
df = spark.createDataFrame(rdd)
return df

if __name__ == "__main__":
spark = SparkSession.builder.appName("ArrowBenchmark").getOrCreate()
df = genData(spark, 1000 * 1000, [long, double])
df.cache()
df.count()

time(df, 10, 1)

df.unpersist()
20 changes: 0 additions & 20 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1883,26 +1883,6 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-tools</artifactId>
<version>${arrow.version}</version>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
</exclusion>
<exclusion>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</exclusion>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>log4j-over-slf4j</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
</dependencyManagement>

Expand Down
5 changes: 0 additions & 5 deletions sql/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-tools</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.binary.version}</artifactId>
Expand Down
37 changes: 33 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.util.control.NonFatal

import io.netty.buffer.ArrowBuf
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.BitVector
import org.apache.arrow.vector.file.ArrowWriter
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
import org.apache.arrow.vector.types.FloatingPointPrecision
Expand Down Expand Up @@ -60,7 +61,6 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.Utils


private[sql] object Dataset {
def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = {
new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]])
Expand Down Expand Up @@ -2426,6 +2426,29 @@ class Dataset[T] private[sql](
Math.ceil(numOfRows / 64.0).toInt * 8
}

private def fillArrow(buf: ArrowBuf, dataType: DataType): Unit = {
dataType match {
case NullType =>
case BooleanType =>
buf.writeBoolean(false)
case ShortType =>
buf.writeShort(0)
case IntegerType =>
buf.writeInt(0)
case LongType =>
buf.writeLong(0L)
case FloatType =>
buf.writeFloat(0f)
case DoubleType =>
buf.writeDouble(0d)
case ByteType =>
buf.writeByte(0)
case _ =>
throw new UnsupportedOperationException(
s"Unsupported data type ${dataType.simpleString}")
}
}

/**
* Get an entry from the InternalRow, and then set to ArrowBuf.
* Note: No Null check for the entry.
Expand Down Expand Up @@ -2466,20 +2489,26 @@ class Dataset[T] private[sql](

field.dataType match {
case IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType =>
val validity = allocator.buffer(numBytesOfBitmap(numOfRows))
val validityVector = new BitVector("validity", allocator)
val validityMutator = validityVector.getMutator()
validityVector.allocateNew(numOfRows)
validityMutator.setValueCount(numOfRows)
val buf = allocator.buffer(numOfRows * field.dataType.defaultSize)
var nullCount = 0
rows.foreach { row =>
rows.zipWithIndex.foreach { case (row, index) =>
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better just use a while loop here since zipWithIndex will iterate and copy the items in an array

var index = 0
while (index < rows.length) {
  ..
  index += 1
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I am going to refactor this bit of code later to be more efficient. Do you want to wait until that is done or do you want to merge this first?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave this open for a day or so, so Xusen can have a look since he wrote the conversion code. If you want to update go ahead or I can before I merge, no biggie. I also would like to have fillWithArrow and getAndSetToArrow use the same data type cases to avoid duplication, but I can do that later.

if (row.isNullAt(ordinal)) {
nullCount += 1
validityMutator.set(index, 0)
fillArrow(buf, field.dataType)
Copy link
Owner

@BryanCutler BryanCutler Jan 10, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, so the buffer must contain values at each "null" position? Is the case for StringType below done correctly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

} else {
validityMutator.set(index, 1)
getAndSetToArrow(row, buf, field.dataType, ordinal)
}
}

val fieldNode = new ArrowFieldNode(numOfRows, nullCount)

(Array(validity, buf), Array(fieldNode))
(Array(validityVector.getBuffer, buf), Array(fieldNode))

case StringType =>
val validityOffset = allocator.buffer(numBytesOfBitmap(numOfRows))
Expand Down
11 changes: 4 additions & 7 deletions sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ package org.apache.spark.sql
import java.io.File

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.tools.Integration
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot}
import org.apache.arrow.vector.{BitVector, VectorLoader, VectorSchemaRoot}
import org.apache.arrow.vector.file.json.JsonFileReader
import org.apache.arrow.vector.util.Validator

import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}

class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
import testImplicits._
private val nullIntsFile = "test-data/arrowNullInts.json"

private def testFile(fileName: String): String = {
Expand All @@ -43,16 +42,14 @@ class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils {

val arrowSchema = df.schemaToArrowSchema(df.schema)
val jsonSchema = jsonReader.start()
// TODO - requires changing to public API in arrow, will be addressed in ARROW-411
//Integration.compareSchemas(arrowSchema, jsonSchema)
Validator.compareSchemas(arrowSchema, jsonSchema)

val arrowRecordBatch = df.collectAsArrow(allocator)
val arrowRoot = new VectorSchemaRoot(arrowSchema, allocator)
val vectorLoader = new VectorLoader(arrowRoot)
vectorLoader.load(arrowRecordBatch)
val jsonRoot = jsonReader.read()

// TODO - requires changing to public API in arrow, will be addressed in ARROW-411
//Integration.compare(arrowRoot, jsonRoot)
Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot)
}
}