diff --git a/benchmark.py b/benchmark.py
new file mode 100644
index 000000000000..bf29a65367b0
--- /dev/null
+++ b/benchmark.py
@@ -0,0 +1,41 @@
+import pyspark
+import timeit
+import random
+from pyspark.sql import SparkSession
+
+numPartition = 20
+
+def time(df, repeat, number):
+ print("toPandas with arrow")
+ print(timeit.repeat('df.toPandas(True)', repeat=repeat, number=number, globals={'df': df}))
+
+ print("toPandas without arrow")
+ print(timeit.repeat('df.toPandas(False)', repeat=repeat, number=number, globals={'df': df}))
+
+def long():
+ return random.randint(0, 10000)
+
+def double():
+ return random.random()
+
+def genDataLocal(spark, size, columns):
+ data = [list([fn() for fn in columns]) for x in range(0, size)]
+ df = spark.createDataFrame(data)
+ return df
+
+def genData(spark, size, columns):
+ rdd = (spark.sparkContext
+ .parallelize(range(0,numPartition), numPartition)
+ .flatMap(lambda index: [list([fn() for fn in columns]) for x in range(0, int(size / numPartition))]))
+ df = spark.createDataFrame(rdd)
+ return df
+
+if __name__ == "__main__":
+ spark = SparkSession.builder.appName("ArrowBenchmark").getOrCreate()
+ df = genData(spark, 1000 * 1000, [long, double])
+ df.cache()
+ df.count()
+
+ time(df, 10, 1)
+
+ df.unpersist()
diff --git a/pom.xml b/pom.xml
index c6b8e2fe028f..5432e11655a1 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1883,26 +1883,6 @@
-
- org.apache.arrow
- arrow-tools
- ${arrow.version}
- test
-
-
- com.fasterxml.jackson.core
- jackson-annotations
-
-
- com.fasterxml.jackson.core
- jackson-databind
-
-
- org.slf4j
- log4j-over-slf4j
-
-
-
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 24a2f36bb9e4..cdc012aa78c3 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -95,11 +95,6 @@
org.apache.arrow
arrow-vector
-
- org.apache.arrow
- arrow-tools
- test
-
org.scalacheck
scalacheck_${scala.binary.version}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 96f779127f36..9e90b009afe0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -27,6 +27,7 @@ import scala.util.control.NonFatal
import io.netty.buffer.ArrowBuf
import org.apache.arrow.memory.RootAllocator
+import org.apache.arrow.vector.BitVector
import org.apache.arrow.vector.file.ArrowWriter
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
import org.apache.arrow.vector.types.FloatingPointPrecision
@@ -60,7 +61,6 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.Utils
-
private[sql] object Dataset {
def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = {
new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]])
@@ -2426,6 +2426,29 @@ class Dataset[T] private[sql](
Math.ceil(numOfRows / 64.0).toInt * 8
}
+ private def fillArrow(buf: ArrowBuf, dataType: DataType): Unit = {
+ dataType match {
+ case NullType =>
+ case BooleanType =>
+ buf.writeBoolean(false)
+ case ShortType =>
+ buf.writeShort(0)
+ case IntegerType =>
+ buf.writeInt(0)
+ case LongType =>
+ buf.writeLong(0L)
+ case FloatType =>
+ buf.writeFloat(0f)
+ case DoubleType =>
+ buf.writeDouble(0d)
+ case ByteType =>
+ buf.writeByte(0)
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Unsupported data type ${dataType.simpleString}")
+ }
+ }
+
/**
* Get an entry from the InternalRow, and then set to ArrowBuf.
* Note: No Null check for the entry.
@@ -2466,20 +2489,26 @@ class Dataset[T] private[sql](
field.dataType match {
case IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType =>
- val validity = allocator.buffer(numBytesOfBitmap(numOfRows))
+ val validityVector = new BitVector("validity", allocator)
+ val validityMutator = validityVector.getMutator()
+ validityVector.allocateNew(numOfRows)
+ validityMutator.setValueCount(numOfRows)
val buf = allocator.buffer(numOfRows * field.dataType.defaultSize)
var nullCount = 0
- rows.foreach { row =>
+ rows.zipWithIndex.foreach { case (row, index) =>
if (row.isNullAt(ordinal)) {
nullCount += 1
+ validityMutator.set(index, 0)
+ fillArrow(buf, field.dataType)
} else {
+ validityMutator.set(index, 1)
getAndSetToArrow(row, buf, field.dataType, ordinal)
}
}
val fieldNode = new ArrowFieldNode(numOfRows, nullCount)
- (Array(validity, buf), Array(fieldNode))
+ (Array(validityVector.getBuffer, buf), Array(fieldNode))
case StringType =>
val validityOffset = allocator.buffer(numBytesOfBitmap(numOfRows))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala
index 123b0f56fb47..9b1786c83f16 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.sql
import java.io.File
import org.apache.arrow.memory.RootAllocator
-import org.apache.arrow.tools.Integration
-import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot}
+import org.apache.arrow.vector.{BitVector, VectorLoader, VectorSchemaRoot}
import org.apache.arrow.vector.file.json.JsonFileReader
+import org.apache.arrow.vector.util.Validator
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
- import testImplicits._
private val nullIntsFile = "test-data/arrowNullInts.json"
private def testFile(fileName: String): String = {
@@ -43,8 +42,7 @@ class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val arrowSchema = df.schemaToArrowSchema(df.schema)
val jsonSchema = jsonReader.start()
- // TODO - requires changing to public API in arrow, will be addressed in ARROW-411
- //Integration.compareSchemas(arrowSchema, jsonSchema)
+ Validator.compareSchemas(arrowSchema, jsonSchema)
val arrowRecordBatch = df.collectAsArrow(allocator)
val arrowRoot = new VectorSchemaRoot(arrowSchema, allocator)
@@ -52,7 +50,6 @@ class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
vectorLoader.load(arrowRecordBatch)
val jsonRoot = jsonReader.read()
- // TODO - requires changing to public API in arrow, will be addressed in ARROW-411
- //Integration.compare(arrowRoot, jsonRoot)
+ Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot)
}
}