Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object Arrow {
case LongType => new ArrowType.Int(8 * LongType.defaultSize, true)
case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
case ByteType => new ArrowType.Int(8, false)
case ByteType => new ArrowType.Int(8, true)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@icexelloss, I think this should be signed right? Or is it better to use Arrow binary type for this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like scala treats bytes as signed int:

scala> scala.Byte.MinValue
res2: Byte = -128

scala> scala.Byte.MaxValue
res3: Byte = 127

So yes I think this is right

case StringType => ArrowType.Utf8.INSTANCE
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}")
}
Expand Down
32 changes: 32 additions & 0 deletions sql/core/src/test/resources/test-data/arrow/boolData.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"schema": {
"fields": [
{
"name": "a_bool",
"type": {"name": "bool"},
"nullable": false,
"children": [],
"typeLayout": {
"vectors": [
{"type": "VALIDITY", "typeBitWidth": 1},
{"type": "DATA", "typeBitWidth": 8}
]
}
}
]
},

"batches": [
{
"count": 4,
"columns": [
{
"name": "a_bool",
"count": 4,
"VALIDITY": [1, 1, 1, 1],
"DATA": [true, true, false, true]
}
]
}
]
}
32 changes: 32 additions & 0 deletions sql/core/src/test/resources/test-data/arrow/byteData.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"schema": {
"fields": [
{
"name": "a_byte",
"type": {"name": "int", "isSigned": true, "bitWidth": 8},
"nullable": false,
"children": [],
"typeLayout": {
"vectors": [
{"type": "VALIDITY", "typeBitWidth": 1},
{"type": "DATA", "typeBitWidth": 8}
]
}
}
]
},

"batches": [
{
"count": 4,
"columns": [
{
"name": "a_byte",
"count": 4,
"VALIDITY": [1, 1, 1, 1],
"DATA": [1, -1, 64, 127]
}
]
}
]
}
41 changes: 32 additions & 9 deletions sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
package org.apache.spark.sql

import java.io.File
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Locale

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot}
Expand Down Expand Up @@ -62,9 +65,19 @@ class ArrowSuite extends SharedSQLContext {
collectAndValidate(doubleData, "test-data/arrow/doubleData-double_precision-nullable.json")
}

test("boolean type conversion") {
val boolData = Seq(true, true, false, true).toDF("a_bool")
collectAndValidate(boolData, "test-data/arrow/boolData.json")
}

test("byte type conversion") {
val byteData = Seq(1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue).toDF("a_byte")
collectAndValidate(byteData, "test-data/arrow/byteData.json")
}

test("mixed standard type nullable conversion") {
val mixedData = shortData.join(intData, "i").join(longData, "i").join(floatData, "i")
.join(doubleData, "i").sort("i")
val mixedData = Seq(shortData, intData, longData, floatData, doubleData)
.reduce((a, b) => a.join(b, "i")).sort("i")
collectAndValidate(mixedData, "test-data/arrow/mixedData-standard-nullable.json")
}

Expand All @@ -77,7 +90,16 @@ class ArrowSuite extends SharedSQLContext {
collectAndValidate(lowerCaseData, "test-data/arrow/lowercase-strings.json")
}

test("time and date conversion") { }
ignore("time and date conversion") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I am adding support for time and date. I will rebase on this and finish this test.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
val d1 = new Date(sdf.parse("2015-04-08 13:10:15").getTime)
val d2 = new Date(sdf.parse("2015-04-08 13:10:15").getTime)
val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15").getTime)
val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10").getTime)
val dateTimeData = Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2))
.toDF("a_date", "b_string", "c_timestamp")
collectAndValidate(dateTimeData, "test-data/arrow/datetimeData-strings.json")
}

test("nested type conversion") { }

Expand All @@ -93,11 +115,6 @@ class ArrowSuite extends SharedSQLContext {

test("floating-point NaN") { }

// Arrow currently supports single or double precision
ignore("arbitrary precision floating point") {
collectAndValidate(decimalData, "test-data/arrow/decimalData-BigDecimal.json")
}

test("other null conversion") { }

test("convert int column with null to arrow") {
Expand All @@ -115,7 +132,13 @@ class ArrowSuite extends SharedSQLContext {
assert(emptyBatch.getLength == 0)
}

test("negative tests") {
test("unsupported types") {
intercept[UnsupportedOperationException] {
collectAndValidate(decimalData, "test-data/arrow/decimalData-BigDecimal.json")
}
}

test("test Arrow Validator") {

// Missing test file
intercept[NullPointerException] {
Expand Down