diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala index 7100a8f03515..d58a25fd05e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala @@ -41,7 +41,7 @@ object Arrow { case LongType => new ArrowType.Int(8 * LongType.defaultSize, true) case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case ByteType => new ArrowType.Int(8, false) + case ByteType => new ArrowType.Int(8, true) case StringType => ArrowType.Utf8.INSTANCE case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") } diff --git a/sql/core/src/test/resources/test-data/arrow/boolData.json b/sql/core/src/test/resources/test-data/arrow/boolData.json new file mode 100644 index 000000000000..f402e5118cef --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/boolData.json @@ -0,0 +1,32 @@ +{ + "schema": { + "fields": [ + { + "name": "a_bool", + "type": {"name": "bool"}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + } + ] + }, + + "batches": [ + { + "count": 4, + "columns": [ + { + "name": "a_bool", + "count": 4, + "VALIDITY": [1, 1, 1, 1], + "DATA": [true, true, false, true] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/byteData.json b/sql/core/src/test/resources/test-data/arrow/byteData.json new file mode 100644 index 000000000000..d0a6ceb818f7 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/byteData.json @@ -0,0 +1,32 @@ +{ + "schema": { + "fields": [ + { + "name": "a_byte", + "type": {"name": "int", "isSigned": true, "bitWidth": 8}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + } + ] + }, + + "batches": [ + { + "count": 4, + "columns": [ + { + "name": "a_byte", + "count": 4, + "VALIDITY": [1, 1, 1, 1], + "DATA": [1, -1, 64, 127] + } + ] + } + ] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala index 7b5231824b2a..f51a74084a10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql import java.io.File +import java.sql.{Date, Timestamp} +import java.text.SimpleDateFormat +import java.util.Locale import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} @@ -62,9 +65,19 @@ class ArrowSuite extends SharedSQLContext { collectAndValidate(doubleData, "test-data/arrow/doubleData-double_precision-nullable.json") } + test("boolean type conversion") { + val boolData = Seq(true, true, false, true).toDF("a_bool") + collectAndValidate(boolData, "test-data/arrow/boolData.json") + } + + test("byte type conversion") { + val byteData = Seq(1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue).toDF("a_byte") + collectAndValidate(byteData, "test-data/arrow/byteData.json") + } + test("mixed standard type nullable conversion") { - val mixedData = shortData.join(intData, "i").join(longData, "i").join(floatData, "i") - .join(doubleData, "i").sort("i") + val mixedData = Seq(shortData, intData, longData, floatData, doubleData) + .reduce((a, b) => a.join(b, "i")).sort("i") collectAndValidate(mixedData, "test-data/arrow/mixedData-standard-nullable.json") } @@ -77,7 +90,16 @@ class ArrowSuite extends SharedSQLContext { collectAndValidate(lowerCaseData, "test-data/arrow/lowercase-strings.json") } - test("time and date conversion") { } + ignore("time and date conversion") { + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val d1 = new Date(sdf.parse("2015-04-08 13:10:15").getTime) + val d2 = new Date(sdf.parse("2015-04-08 13:10:15").getTime) + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10").getTime) + val dateTimeData = Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2)) + .toDF("a_date", "b_string", "c_timestamp") + collectAndValidate(dateTimeData, "test-data/arrow/datetimeData-strings.json") + } test("nested type conversion") { } @@ -93,11 +115,6 @@ class ArrowSuite extends SharedSQLContext { test("floating-point NaN") { } - // Arrow currently supports single or double precision - ignore("arbitrary precision floating point") { - collectAndValidate(decimalData, "test-data/arrow/decimalData-BigDecimal.json") - } - test("other null conversion") { } test("convert int column with null to arrow") { @@ -115,7 +132,13 @@ class ArrowSuite extends SharedSQLContext { assert(emptyBatch.getLength == 0) } - test("negative tests") { + test("unsupported types") { + intercept[UnsupportedOperationException] { + collectAndValidate(decimalData, "test-data/arrow/decimalData-BigDecimal.json") + } + } + + test("test Arrow Validator") { // Missing test file intercept[NullPointerException] {