From 0220a8fee53ee245c1af91a40baa64e56aab32eb Mon Sep 17 00:00:00 2001 From: Eduard Tudenhoefner Date: Thu, 19 Dec 2024 18:42:08 +0100 Subject: [PATCH] [SPARK-50624][SQL] Add TimestampNTZType to ColumnarRow/MutableColumnarRow --- .../spark/sql/vectorized/ColumnarRow.java | 2 ++ .../vectorized/MutableColumnarRow.java | 4 ++++ .../vectorized/ColumnVectorSuite.scala | 19 +++++++++++++++ .../vectorized/ArrowColumnVectorSuite.scala | 24 +++++++++++++++++++ 4 files changed, 49 insertions(+) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index c4fbc2ff64229..fdea42c6ea14a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -183,6 +183,8 @@ public Object get(int ordinal, DataType dataType) { return getInt(ordinal); } else if (dataType instanceof TimestampType) { return getLong(ordinal); + } else if (dataType instanceof TimestampNTZType) { + return getLong(ordinal); } else if (dataType instanceof ArrayType) { return getArray(ordinal); } else if (dataType instanceof StructType) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 64568f18f6858..7be68366481b5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -80,6 +80,8 @@ public InternalRow copy() { row.setInt(i, getInt(i)); } else if (dt instanceof TimestampType) { row.setLong(i, getLong(i)); + } else if (dt instanceof TimestampNTZType) { + row.setLong(i, getLong(i)); } else if (dt instanceof StructType) { row.update(i, getStruct(i, ((StructType) dt).fields().length).copy()); } else if (dt instanceof ArrayType) { @@ -185,6 +187,8 @@ public Object get(int ordinal, DataType dataType) { return getInt(ordinal); } else if (dataType instanceof TimestampType) { return getLong(ordinal); + } else if (dataType instanceof TimestampNTZType) { + return getLong(ordinal); } else if (dataType instanceof ArrayType) { return getArray(ordinal); } else if (dataType instanceof StructType) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index a40a416bbb5a1..9591201383e59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -271,6 +271,19 @@ class ColumnVectorSuite extends SparkFunSuite { } } + testVectors("mutable ColumnarRow with TimestampNTZType", 10, TimestampNTZType) { testVector => + val mutableRow = new MutableColumnarRow(Array(testVector)) + (0 until 10).foreach { i => + mutableRow.rowId = i + mutableRow.setLong(0, 10 - i) + } + (0 until 10).foreach { i => + mutableRow.rowId = i + assert(mutableRow.get(0, TimestampNTZType) === (10 - i)) + assert(mutableRow.copy().get(0, TimestampNTZType) === (10 - i)) + } + } + val arrayType: ArrayType = ArrayType(IntegerType, containsNull = true) testVectors("array", 10, arrayType) { testVector => @@ -381,18 +394,24 @@ class ColumnVectorSuite extends SparkFunSuite { } val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) + .add("ts", TimestampNTZType) testVectors("struct", 10, structType) { testVector => val c1 = testVector.getChild(0) val c2 = testVector.getChild(1) + val c3 = testVector.getChild(2) c1.putInt(0, 123) c2.putDouble(0, 3.45) + c3.putLong(0, 1000L) c1.putInt(1, 456) c2.putDouble(1, 5.67) + c3.putLong(1, 2000L) assert(testVector.getStruct(0).get(0, IntegerType) === 123) assert(testVector.getStruct(0).get(1, DoubleType) === 3.45) + assert(testVector.getStruct(0).get(2, TimestampNTZType) === 1000L) assert(testVector.getStruct(1).get(0, IntegerType) === 456) assert(testVector.getStruct(1).get(1, DoubleType) === 5.67) + assert(testVector.getStruct(1).get(2, TimestampNTZType) === 2000L) } testVectors("SPARK-44805: getInts with dictionary", 3, IntegerType) { testVector => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala index 436cea50ad972..9180ce1aee198 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala @@ -515,4 +515,28 @@ class ArrowColumnVectorSuite extends SparkFunSuite { columnVector.close() allocator.close() } + + test("struct with TimestampNTZType") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) + val schema = new StructType().add("ts", TimestampNTZType) + val vector = ArrowUtils.toArrowField("struct", schema, nullable = true, null) + .createVector(allocator).asInstanceOf[StructVector] + vector.allocateNew() + val timestampVector = vector.getChildByOrdinal(0).asInstanceOf[TimeStampMicroVector] + + vector.setIndexDefined(0) + timestampVector.setSafe(0, 1000L) + + timestampVector.setValueCount(1) + vector.setValueCount(1) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === schema) + + val row0 = columnVector.getStruct(0) + assert(row0.get(0, TimestampNTZType) === 1000L) + + columnVector.close() + allocator.close() + } }