diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5bb1e562d281..bc0888e98df5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2375,8 +2375,8 @@ class Dataset[T] private[sql]( * @since 2.2.0 */ @DeveloperApi - def collectAsArrow( - allocator: RootAllocator = new RootAllocator(Long.MaxValue)): ArrowRecordBatch = { + def collectAsArrow(rootAllocator: Option[RootAllocator] = None): ArrowRecordBatch = { + val allocator = rootAllocator.getOrElse(new RootAllocator(Long.MaxValue)) withNewExecutionId { try { val collectedRows = queryExecution.executedPlan.executeCollect() diff --git a/sql/core/src/test/resources/test-data/arrow/allNulls-ints.json b/sql/core/src/test/resources/test-data/arrow/allNulls-ints.json new file mode 100644 index 000000000000..e12f546e461c --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/allNulls-ints.json @@ -0,0 +1,32 @@ +{ + "schema": { + "fields": [ + { + "name": "a", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 4, + "columns": [ + { + "name": "a", + "count": 4, + "VALIDITY": [0, 0, 0, 0], + "DATA": [0, 0, 0, 0] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/nanData-floating_point.json b/sql/core/src/test/resources/test-data/arrow/nanData-floating_point.json new file mode 100644 index 000000000000..4a8407d45f37 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/nanData-floating_point.json @@ -0,0 +1,68 @@ +{ + "schema": { + "fields": [ + { + "name": "i", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + }, + { + "name": "NaN_f", + "type": {"name": "floatingpoint", "precision": "SINGLE"}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "NaN_d", + "type": {"name": "floatingpoint", "precision": "DOUBLE"}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 2, + "columns": [ + { + "name": "i", + "count": 2, + "VALIDITY": [1, 1], + "DATA": [1, 2] + }, + { + "name": "NaN_f", + "count": 2, + "VALIDITY": [1, 1], + "DATA": [1.2, "NaN"] + }, + { + "name": "NaN_d", + "count": 2, + "VALIDITY": [1, 1], + "DATA": ["NaN", 1.23] + } + ] + } + ] +} diff --git a/sql/core/src/test/resources/test-data/arrow/timestampData.json b/sql/core/src/test/resources/test-data/arrow/timestampData.json index 174c62e4a12d..6fe59975954d 100644 --- a/sql/core/src/test/resources/test-data/arrow/timestampData.json +++ b/sql/core/src/test/resources/test-data/arrow/timestampData.json @@ -2,7 +2,7 @@ "schema": { "fields": [ { - "name": "a_timestamp", + "name": "c_timestamp", "type": {"name": "timestamp", "unit": "MILLISECOND"}, "nullable": true, "children": [], @@ -21,7 +21,7 @@ "count": 2, "columns": [ { - "name": "a_timestamp", + "name": "c_timestamp", "count": 2, "VALIDITY": [1, 1], "DATA": [1365383415567, 1365426610789] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala index 13b38c8c8568..c784b3eefb74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala @@ -27,6 +27,7 @@ import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.types.CalendarInterval // NOTE - nullable type can be declared as Option[*] or java.lang.* @@ -88,25 +89,16 @@ class ArrowSuite extends SharedSQLContext { test("string type conversion") { collectAndValidate(upperCaseData, "test-data/arrow/uppercase-strings.json") collectAndValidate(lowerCaseData, "test-data/arrow/lowercase-strings.json") + val nullStringsColOnly = nullStrings.select(nullStrings.columns(1)) + collectAndValidate(nullStringsColOnly, "test-data/arrow/null-strings.json") } ignore("date conversion") { - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS", Locale.US) - val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000").getTime) - val d2 = new Date(sdf.parse("2015-04-08 13:10:15.000").getTime) - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789").getTime) - val dateTimeData = Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2)) - .toDF("a_date", "b_string", "c_timestamp") collectAndValidate(dateTimeData, "test-data/arrow/datetimeData-strings.json") } test("timestamp conversion") { - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) - val dateTimeData = Seq((ts1), (ts2)).toDF("a_timestamp") - collectAndValidate(dateTimeData, "test-data/arrow/timestampData.json") + collectAndValidate(dateTimeData.select($"c_timestamp"), "test-data/arrow/timestampData.json") } // Arrow json reader doesn't support binary data @@ -120,24 +112,15 @@ class ArrowSuite extends SharedSQLContext { test("mapped type conversion") { } - test("other type conversion") { - // half-precision - // byte type, or binary - // allNulls + test("floating-point NaN") { + val nanData = Seq((1, 1.2F, Double.NaN), (2, Float.NaN, 1.23)).toDF("i", "NaN_f", "NaN_d") + collectAndValidate(nanData, "test-data/arrow/nanData-floating_point.json") } - test("floating-point NaN") { } - - test("other null conversion") { } - test("convert int column with null to arrow") { collectAndValidate(nullInts, "test-data/arrow/null-ints.json") collectAndValidate(testData3, "test-data/arrow/null-ints-mixed.json") - } - - test("convert string column with null to arrow") { - val nullStringsColOnly = nullStrings.select(nullStrings.columns(1)) - collectAndValidate(nullStringsColOnly, "test-data/arrow/null-strings.json") + collectAndValidate(allNulls, "test-data/arrow/allNulls-ints.json") } test("empty frame collect") { @@ -146,7 +129,14 @@ class ArrowSuite extends SharedSQLContext { } test("unsupported types") { - intercept[UnsupportedOperationException] { + def runUnsupported(block: => Unit): Unit = { + val msg = intercept[UnsupportedOperationException] { + block + } + assert(msg.getMessage.contains("Unsupported data type")) + } + + runUnsupported { collectAndValidate(decimalData, "test-data/arrow/decimalData-BigDecimal.json") } } @@ -180,7 +170,7 @@ class ArrowSuite extends SharedSQLContext { val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) - val arrowRecordBatch = df.collectAsArrow(allocator) + val arrowRecordBatch = df.collectAsArrow(Some(allocator)) val arrowRoot = new VectorSchemaRoot(arrowSchema, allocator) val vectorLoader = new VectorLoader(arrowRoot) vectorLoader.load(arrowRecordBatch) @@ -240,4 +230,14 @@ class ArrowSuite extends SharedSQLContext { DoubleData(5, 0.0001, None) :: DoubleData(6, 20000.0, Some(3.3)) :: Nil).toDF() } + + protected lazy val dateTimeData: DataFrame = { + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) + val d2 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) + Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2)) + .toDF("a_date", "b_string", "c_timestamp") + } }