diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 000000000000..bf29a65367b0 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,41 @@ +import pyspark +import timeit +import random +from pyspark.sql import SparkSession + +numPartition = 20 + +def time(df, repeat, number): + print("toPandas with arrow") + print(timeit.repeat('df.toPandas(True)', repeat=repeat, number=number, globals={'df': df})) + + print("toPandas without arrow") + print(timeit.repeat('df.toPandas(False)', repeat=repeat, number=number, globals={'df': df})) + +def long(): + return random.randint(0, 10000) + +def double(): + return random.random() + +def genDataLocal(spark, size, columns): + data = [list([fn() for fn in columns]) for x in range(0, size)] + df = spark.createDataFrame(data) + return df + +def genData(spark, size, columns): + rdd = (spark.sparkContext + .parallelize(range(0,numPartition), numPartition) + .flatMap(lambda index: [list([fn() for fn in columns]) for x in range(0, int(size / numPartition))])) + df = spark.createDataFrame(rdd) + return df + +if __name__ == "__main__": + spark = SparkSession.builder.appName("ArrowBenchmark").getOrCreate() + df = genData(spark, 1000 * 1000, [long, double]) + df.cache() + df.count() + + time(df, 10, 1) + + df.unpersist() diff --git a/pom.xml b/pom.xml index c6b8e2fe028f..5432e11655a1 100644 --- a/pom.xml +++ b/pom.xml @@ -1883,26 +1883,6 @@ - - org.apache.arrow - arrow-tools - ${arrow.version} - test - - - com.fasterxml.jackson.core - jackson-annotations - - - com.fasterxml.jackson.core - jackson-databind - - - org.slf4j - log4j-over-slf4j - - - diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 24a2f36bb9e4..cdc012aa78c3 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -95,11 +95,6 @@ org.apache.arrow arrow-vector - - org.apache.arrow - arrow-tools - test - org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 96f779127f36..9e90b009afe0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -27,6 +27,7 @@ import scala.util.control.NonFatal import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.BitVector import org.apache.arrow.vector.file.ArrowWriter import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} import org.apache.arrow.vector.types.FloatingPointPrecision @@ -60,7 +61,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils - private[sql] object Dataset { def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) @@ -2426,6 +2426,29 @@ class Dataset[T] private[sql]( Math.ceil(numOfRows / 64.0).toInt * 8 } + private def fillArrow(buf: ArrowBuf, dataType: DataType): Unit = { + dataType match { + case NullType => + case BooleanType => + buf.writeBoolean(false) + case ShortType => + buf.writeShort(0) + case IntegerType => + buf.writeInt(0) + case LongType => + buf.writeLong(0L) + case FloatType => + buf.writeFloat(0f) + case DoubleType => + buf.writeDouble(0d) + case ByteType => + buf.writeByte(0) + case _ => + throw new UnsupportedOperationException( + s"Unsupported data type ${dataType.simpleString}") + } + } + /** * Get an entry from the InternalRow, and then set to ArrowBuf. * Note: No Null check for the entry. @@ -2466,20 +2489,26 @@ class Dataset[T] private[sql]( field.dataType match { case IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType => - val validity = allocator.buffer(numBytesOfBitmap(numOfRows)) + val validityVector = new BitVector("validity", allocator) + val validityMutator = validityVector.getMutator() + validityVector.allocateNew(numOfRows) + validityMutator.setValueCount(numOfRows) val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) var nullCount = 0 - rows.foreach { row => + rows.zipWithIndex.foreach { case (row, index) => if (row.isNullAt(ordinal)) { nullCount += 1 + validityMutator.set(index, 0) + fillArrow(buf, field.dataType) } else { + validityMutator.set(index, 1) getAndSetToArrow(row, buf, field.dataType, ordinal) } } val fieldNode = new ArrowFieldNode(numOfRows, nullCount) - (Array(validity, buf), Array(fieldNode)) + (Array(validityVector.getBuffer, buf), Array(fieldNode)) case StringType => val validityOffset = allocator.buffer(numBytesOfBitmap(numOfRows)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala index 123b0f56fb47..9b1786c83f16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql import java.io.File import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.tools.Integration -import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} +import org.apache.arrow.vector.{BitVector, VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.file.json.JsonFileReader +import org.apache.arrow.vector.util.Validator import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils { - import testImplicits._ private val nullIntsFile = "test-data/arrowNullInts.json" private def testFile(fileName: String): String = { @@ -43,8 +42,7 @@ class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val arrowSchema = df.schemaToArrowSchema(df.schema) val jsonSchema = jsonReader.start() - // TODO - requires changing to public API in arrow, will be addressed in ARROW-411 - //Integration.compareSchemas(arrowSchema, jsonSchema) + Validator.compareSchemas(arrowSchema, jsonSchema) val arrowRecordBatch = df.collectAsArrow(allocator) val arrowRoot = new VectorSchemaRoot(arrowSchema, allocator) @@ -52,7 +50,6 @@ class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils { vectorLoader.load(arrowRecordBatch) val jsonRoot = jsonReader.read() - // TODO - requires changing to public API in arrow, will be addressed in ARROW-411 - //Integration.compare(arrowRoot, jsonRoot) + Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) } }