From 5ef0ba874dd879705f79ab1a986245df9605cad0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 31 Jan 2018 14:03:01 +0000 Subject: [PATCH 1/7] Document several get API of ColumnVector's behavior when accessing null slot. --- .../datasources/orc/OrcColumnVector.java | 3 +++ .../vectorized/MutableColumnarRow.java | 6 ------ .../vectorized/WritableColumnVector.java | 4 ++++ .../spark/sql/vectorized/ArrowColumnVector.java | 4 ++++ .../spark/sql/vectorized/ColumnVector.java | 17 +++++++++-------- .../spark/sql/vectorized/ColumnarRow.java | 6 ------ 6 files changed, 20 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 5078bc7922ee2..7ef67db367e63 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -148,12 +148,14 @@ public double getDouble(int rowId) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); return Decimal.apply(data, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; int index = getRowIndex(rowId); BytesColumnVector col = bytesData; return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); @@ -161,6 +163,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; int index = getRowIndex(rowId); byte[] binary = new byte[bytesData.length[index]]; System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 66668f3753604..56dbccfb694cf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -127,37 +127,31 @@ public boolean anyNull() { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getInterval(rowId); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getArray(rowId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index a8ec8ef2aadf8..d515e372666b5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -332,6 +332,7 @@ public final int putByteArray(int rowId, byte[] value) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; if (precision <= Decimal.MAX_INT_DIGITS()) { return Decimal.createUnsafe(getInt(rowId), precision, scale); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { @@ -358,6 +359,7 @@ public void putDecimal(int rowId, Decimal value, int precision) { @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; if (dictionary == null) { return arrayData().getBytesAsUTF8String(getArrayOffset(rowId), getArrayLength(rowId)); } else { @@ -375,6 +377,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; if (dictionary == null) { return arrayData().getBytes(getArrayOffset(rowId), getArrayLength(rowId)); } else { @@ -604,6 +607,7 @@ public final int appendStruct(boolean isNull) { // array offsets and lengths in the current column vector. @Override public final ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) return null; return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index a75d76bd0f82e..17054cd81ab3d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -96,21 +96,25 @@ public double getDouble(int rowId) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; return accessor.getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getUTF8String(rowId); } @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getBinary(rowId); } @Override public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getArray(rowId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 111f5d9b358d4..91d555a548f97 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -188,7 +188,7 @@ public double[] getDoubles(int rowId, int count) { } /** - * Returns the struct type value for rowId. + * Returns the struct type value for rowId. If the slot for rowId is null, it should return null. * * To support struct type, implementations must implement {@link #getChild(int)} and make this * vector a tree structure. The number of child vectors must be same as the number of fields of @@ -201,7 +201,7 @@ public final ColumnarRow getStruct(int rowId) { } /** - * Returns the array type value for rowId. + * Returns the array type value for rowId. If the slot for rowId is null, it should return null. * * To support array type, implementations must construct an {@link ColumnarArray} and return it in * this method. {@link ColumnarArray} requires a {@link ColumnVector} that stores the data of all @@ -221,24 +221,25 @@ public MapData getMap(int ordinal) { } /** - * Returns the decimal type value for rowId. + * Returns the decimal type value for rowId. If the slot for rowId is null, it should return null. */ public abstract Decimal getDecimal(int rowId, int precision, int scale); /** - * Returns the string type value for rowId. Note that the returned UTF8String may point to the - * data of this column vector, please copy it if you want to keep it after this column vector is - * freed. + * Returns the string type value for rowId. If the slot for rowId is null, it should return null. + * Note that the returned UTF8String may point to the data of this column vector, please copy it + * if you want to keep it after this column vector is freed. */ public abstract UTF8String getUTF8String(int rowId); /** - * Returns the binary type value for rowId. + * Returns the binary type value for rowId. If the slot for rowId is null, it should return null. */ public abstract byte[] getBinary(int rowId); /** - * Returns the calendar interval type value for rowId. + * Returns the calendar interval type value for rowId. If the slot for rowId is null, it should + * return null. * * In Spark, calendar interval type value is basically an integer value representing the number of * months in this interval, and a long value representing the number of microseconds in this diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 6ca749d7c6e85..294b4920f5de4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -120,37 +120,31 @@ public boolean anyNull() { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getInterval(rowId); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getArray(rowId); } From febdf9b2f97ae6077d3d2f874223c6a8c6c0b864 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 1 Feb 2018 05:56:51 +0000 Subject: [PATCH 2/7] Also document the behavior for getInt APIs. --- .../spark/sql/vectorized/ColumnVector.java | 42 ++++++++++++------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 91d555a548f97..bf712560609f1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -76,12 +76,14 @@ public abstract class ColumnVector implements AutoCloseable { public abstract boolean isNullAt(int rowId); /** - * Returns the boolean type value for rowId. + * Returns the boolean type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract boolean getBoolean(int rowId); /** - * Gets boolean type values from [rowId, rowId + count) + * Gets boolean type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public boolean[] getBooleans(int rowId, int count) { boolean[] res = new boolean[count]; @@ -92,12 +94,14 @@ public boolean[] getBooleans(int rowId, int count) { } /** - * Returns the byte type value for rowId. + * Returns the byte type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract byte getByte(int rowId); /** - * Gets byte type values from [rowId, rowId + count) + * Gets byte type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public byte[] getBytes(int rowId, int count) { byte[] res = new byte[count]; @@ -108,12 +112,14 @@ public byte[] getBytes(int rowId, int count) { } /** - * Returns the short type value for rowId. + * Returns the short type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract short getShort(int rowId); /** - * Gets short type values from [rowId, rowId + count) + * Gets short type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public short[] getShorts(int rowId, int count) { short[] res = new short[count]; @@ -124,12 +130,14 @@ public short[] getShorts(int rowId, int count) { } /** - * Returns the int type value for rowId. + * Returns the int type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract int getInt(int rowId); /** - * Gets int type values from [rowId, rowId + count) + * Gets int type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public int[] getInts(int rowId, int count) { int[] res = new int[count]; @@ -140,12 +148,14 @@ public int[] getInts(int rowId, int count) { } /** - * Returns the long type value for rowId. + * Returns the long type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract long getLong(int rowId); /** - * Gets long type values from [rowId, rowId + count) + * Gets long type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public long[] getLongs(int rowId, int count) { long[] res = new long[count]; @@ -156,12 +166,14 @@ public long[] getLongs(int rowId, int count) { } /** - * Returns the float type value for rowId. + * Returns the float type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract float getFloat(int rowId); /** - * Gets float type values from [rowId, rowId + count) + * Gets float type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public float[] getFloats(int rowId, int count) { float[] res = new float[count]; @@ -172,12 +184,14 @@ public float[] getFloats(int rowId, int count) { } /** - * Returns the double type value for rowId. + * Returns the double type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract double getDouble(int rowId); /** - * Gets double type values from [rowId, rowId + count) + * Gets double type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public double[] getDoubles(int rowId, int count) { double[] res = new double[count]; From 7a1fd57925a080116c288ca1793af86258019494 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 1 Feb 2018 05:57:44 +0000 Subject: [PATCH 3/7] Add null tests against ColumnVector APIs. --- .../vectorized/ColumnarBatchSuite.scala | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 925c101fe1fee..08c905b49fb71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1214,4 +1214,116 @@ class ColumnarBatchSuite extends SparkFunSuite { batch.close() allocator.close() } + + testVector("getUTF8String should return null for null slot", 4, StringType) { + column => + assert(column.numNulls() == 0) + + var idx = 0 + column.putNull(idx) + assert(column.getUTF8String(idx) == null) + idx += 1 + column.putNull(idx) + assert(column.getUTF8String(idx) == null) + assert(column.numNulls() == 2) + + idx += 1 + column.putByteArray(idx, "Hello".getBytes(StandardCharsets.UTF_8), + 0, "Hello".getBytes(StandardCharsets.UTF_8).length) + assert(column.getUTF8String(idx) != null) + } + + testVector("getInterval should return null for null slot", 4, CalendarIntervalType) { + column => + assert(column.numNulls() == 0) + + var idx = 0 + column.putNull(idx) + assert(column.getInterval(idx) == null) + idx += 1 + column.putNull(idx) + assert(column.getInterval(idx) == null) + assert(column.numNulls() == 2) + + idx += 1 + val months = column.getChild(0) + val microseconds = column.getChild(1) + months.putInt(idx, 1) + microseconds.putLong(idx, 100) + assert(column.getInterval(idx) != null) + } + + testVector("getArray should return null for null slot", 4, new ArrayType(IntegerType, true)) { + column => + assert(column.numNulls() == 0) + + var idx = 0 + column.putNull(idx) + assert(column.getArray(idx) == null) + idx += 1 + column.putNull(idx) + assert(column.getArray(idx) == null) + assert(column.numNulls() == 2) + + idx += 1 + val data = column.arrayData() + data.putInt(0, 0) + data.putInt(1, 1) + column.putArray(idx, 0, 2) + assert(column.getArray(idx) != null) + } + + testVector("getDecimal should return null for null slot", 4, DecimalType.IntDecimal) { + column => + assert(column.numNulls() == 0) + + var idx = 0 + column.putNull(idx) + assert(column.getDecimal(idx, 10, 0) == null) + idx += 1 + column.putNull(idx) + assert(column.getDecimal(idx, 10, 0) == null) + assert(column.numNulls() == 2) + + idx += 1 + column.putDecimal(idx, new Decimal().set(10), 10) + assert(column.getDecimal(idx, 10, 0) != null) + } + + testVector("getStruct should return null for null slot", 4, + new StructType().add("int", IntegerType).add("double", DoubleType)) { column => + assert(column.numNulls() == 0) + + var idx = 0 + column.putNull(idx) + assert(column.getStruct(idx) == null) + idx += 1 + column.putNull(idx) + assert(column.getStruct(idx) == null) + assert(column.numNulls() == 2) + + idx += 1 + val c1 = column.getChild(0) + val c2 = column.getChild(1) + c1.putInt(0, 123) + c2.putDouble(0, 3.45) + assert(column.getStruct(idx) != null) + } + + testVector("getBinary should return null for null slot", 4, BinaryType) { + column => + assert(column.numNulls() == 0) + + var idx = 0 + column.putNull(idx) + assert(column.getBinary(idx) == null) + idx += 1 + column.putNull(idx) + assert(column.getBinary(idx) == null) + assert(column.numNulls() == 2) + + idx += 1 + column.putByteArray(idx, "Hello".getBytes(StandardCharsets.UTF_8)) + assert(column.getBinary(idx) != null) + } } From 35548e6d30211cf155a366da2ad736d1281367bf Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 1 Feb 2018 07:55:30 +0000 Subject: [PATCH 4/7] The change for map support. --- .../vectorized/WritableColumnVector.java | 1 + .../spark/sql/vectorized/ColumnVector.java | 4 ++-- .../spark/sql/vectorized/ColumnarRow.java | 1 - .../vectorized/ColumnarBatchSuite.scala | 24 +++++++++++++++++++ 4 files changed, 27 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 54b527f74ea5e..5275e4a91eac0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -624,6 +624,7 @@ public final ColumnarArray getArray(int rowId) { // second child column vector, and puts the offsets and lengths in the current column vector. @Override public final ColumnarMap getMap(int rowId) { + if (isNullAt(rowId)) return null; return new ColumnarMap(getChild(0), getChild(1), getArrayOffset(rowId), getArrayLength(rowId)); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 834c26b705bd6..a02a600676f1c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -233,13 +233,13 @@ public final ColumnarRow getStruct(int rowId) { public abstract ColumnarArray getArray(int rowId); /** - * Returns the map type value for rowId. + * Returns the map type value for rowId. If the slot for rowId is null, it should return null. * * In Spark, map type value is basically a key data array and a value data array. A key from the * key array with a index and a value from the value array with the same index contribute to * an entry of this map type value. * - * To support map type, implementations must construct an {@link ColumnarMap} and return it in + * To support map type, implementations must construct a {@link ColumnarMap} and return it in * this method. {@link ColumnarMap} requires a {@link ColumnVector} that stores the data of all * the keys of all the maps in this vector, and another {@link ColumnVector} that stores the data * of all the values of all the maps in this vector, and a pair of offset and length which diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 32433fdb7b9ff..62631efc85b14 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -150,7 +150,6 @@ public ColumnarArray getArray(int ordinal) { @Override public ColumnarMap getMap(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getMap(rowId); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 06715a5cd01d2..ec1e51ad78e7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1373,4 +1373,28 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putByteArray(idx, "Hello".getBytes(StandardCharsets.UTF_8)) assert(column.getBinary(idx) != null) } + + testVector("getMap should return null for null slot", 4, + new MapType(IntegerType, IntegerType, false)) { column => + assert(column.numNulls() == 0) + + var idx = 0 + column.putNull(idx) + assert(column.getBinary(idx) == null) + idx += 1 + column.putNull(idx) + assert(column.getBinary(idx) == null) + assert(column.numNulls() == 2) + + idx += 1 + val keyCol = column.getChild(0) + keyCol.putInt(0, 0) + keyCol.putInt(1, 1) + val valueCol = column.getChild(1) + valueCol.putInt(0, 0) + valueCol.putInt(1, 2) + + column.putArray(idx, 0, 2) + assert(column.getMap(idx) != null) + } } From 923d0fe042befe722905791fd8dfcb42003f5e15 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 1 Feb 2018 09:40:16 +0000 Subject: [PATCH 5/7] Add null checks into individual tests if possibly. --- .../vectorized/ColumnarBatchSuite.scala | 112 ++---------------- 1 file changed, 9 insertions(+), 103 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index ec1e51ad78e7c..688b9053980e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -572,7 +572,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } } - testVector("String APIs", 6, StringType) { + testVector("String APIs", 7, StringType) { column => val reference = mutable.ArrayBuffer.empty[String] @@ -619,6 +619,10 @@ class ColumnarBatchSuite extends SparkFunSuite { idx += 1 assert(column.arrayData().elementsAppended == 17 + (s + s).length) + column.putNull(idx) + assert(column.getUTF8String(idx) == null) + idx += 1 + reference.zipWithIndex.foreach { v => val errMsg = "VectorType=" + column.getClass.getSimpleName assert(v._1.length == column.getArrayLength(v._2), errMsg) @@ -647,6 +651,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += new CalendarInterval(0, 2000) column.putNull(2) + assert(column.getInterval(2) == null) reference += null months.putInt(3, 20) @@ -683,6 +688,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getArray(0).numElements == 1) assert(column.getArray(1).numElements == 2) assert(column.isNullAt(2)) + assert(column.getArray(2) == null) assert(column.getArray(3).numElements == 0) assert(column.getArray(4).numElements == 3) @@ -785,6 +791,7 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(0, 0, 1) column.putArray(1, 1, 2) column.putNull(2) + assert(column.getMap(2) == null) column.putArray(3, 3, 0) column.putArray(4, 3, 3) @@ -821,6 +828,7 @@ class ColumnarBatchSuite extends SparkFunSuite { c2.putDouble(0, 3.45) column.putNull(1) + assert(column.getStruct(1) == null) c1.putInt(2, 456) c2.putDouble(2, 5.67) @@ -1262,64 +1270,6 @@ class ColumnarBatchSuite extends SparkFunSuite { allocator.close() } - testVector("getUTF8String should return null for null slot", 4, StringType) { - column => - assert(column.numNulls() == 0) - - var idx = 0 - column.putNull(idx) - assert(column.getUTF8String(idx) == null) - idx += 1 - column.putNull(idx) - assert(column.getUTF8String(idx) == null) - assert(column.numNulls() == 2) - - idx += 1 - column.putByteArray(idx, "Hello".getBytes(StandardCharsets.UTF_8), - 0, "Hello".getBytes(StandardCharsets.UTF_8).length) - assert(column.getUTF8String(idx) != null) - } - - testVector("getInterval should return null for null slot", 4, CalendarIntervalType) { - column => - assert(column.numNulls() == 0) - - var idx = 0 - column.putNull(idx) - assert(column.getInterval(idx) == null) - idx += 1 - column.putNull(idx) - assert(column.getInterval(idx) == null) - assert(column.numNulls() == 2) - - idx += 1 - val months = column.getChild(0) - val microseconds = column.getChild(1) - months.putInt(idx, 1) - microseconds.putLong(idx, 100) - assert(column.getInterval(idx) != null) - } - - testVector("getArray should return null for null slot", 4, new ArrayType(IntegerType, true)) { - column => - assert(column.numNulls() == 0) - - var idx = 0 - column.putNull(idx) - assert(column.getArray(idx) == null) - idx += 1 - column.putNull(idx) - assert(column.getArray(idx) == null) - assert(column.numNulls() == 2) - - idx += 1 - val data = column.arrayData() - data.putInt(0, 0) - data.putInt(1, 1) - column.putArray(idx, 0, 2) - assert(column.getArray(idx) != null) - } - testVector("getDecimal should return null for null slot", 4, DecimalType.IntDecimal) { column => assert(column.numNulls() == 0) @@ -1337,26 +1287,6 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getDecimal(idx, 10, 0) != null) } - testVector("getStruct should return null for null slot", 4, - new StructType().add("int", IntegerType).add("double", DoubleType)) { column => - assert(column.numNulls() == 0) - - var idx = 0 - column.putNull(idx) - assert(column.getStruct(idx) == null) - idx += 1 - column.putNull(idx) - assert(column.getStruct(idx) == null) - assert(column.numNulls() == 2) - - idx += 1 - val c1 = column.getChild(0) - val c2 = column.getChild(1) - c1.putInt(0, 123) - c2.putDouble(0, 3.45) - assert(column.getStruct(idx) != null) - } - testVector("getBinary should return null for null slot", 4, BinaryType) { column => assert(column.numNulls() == 0) @@ -1373,28 +1303,4 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putByteArray(idx, "Hello".getBytes(StandardCharsets.UTF_8)) assert(column.getBinary(idx) != null) } - - testVector("getMap should return null for null slot", 4, - new MapType(IntegerType, IntegerType, false)) { column => - assert(column.numNulls() == 0) - - var idx = 0 - column.putNull(idx) - assert(column.getBinary(idx) == null) - idx += 1 - column.putNull(idx) - assert(column.getBinary(idx) == null) - assert(column.numNulls() == 2) - - idx += 1 - val keyCol = column.getChild(0) - keyCol.putInt(0, 0) - keyCol.putInt(1, 1) - val valueCol = column.getChild(1) - valueCol.putInt(0, 0) - valueCol.putInt(1, 2) - - column.putArray(idx, 0, 2) - assert(column.getMap(idx) != null) - } } From 369db00ce9d821e757e7b7045a63898d805a06a6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 1 Feb 2018 14:49:10 +0000 Subject: [PATCH 6/7] Make decimal and binary as normal column vector tests. --- .../vectorized/ColumnarBatchSuite.scala | 64 ++++++++++++++----- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 688b9053980e6..772f687526008 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1270,37 +1270,67 @@ class ColumnarBatchSuite extends SparkFunSuite { allocator.close() } - testVector("getDecimal should return null for null slot", 4, DecimalType.IntDecimal) { + testVector("Decimal API", 4, DecimalType.IntDecimal) { column => - assert(column.numNulls() == 0) + + val reference = mutable.ArrayBuffer.empty[Decimal] var idx = 0 - column.putNull(idx) - assert(column.getDecimal(idx, 10, 0) == null) + column.putDecimal(idx, new Decimal().set(10), 10) + reference += new Decimal().set(10) idx += 1 + + column.putDecimal(idx, new Decimal().set(20), 10) + reference += new Decimal().set(20) + idx += 1 + column.putNull(idx) assert(column.getDecimal(idx, 10, 0) == null) - assert(column.numNulls() == 2) - + reference += null idx += 1 - column.putDecimal(idx, new Decimal().set(10), 10) - assert(column.getDecimal(idx, 10, 0) != null) + + column.putDecimal(idx, new Decimal().set(30), 10) + reference += new Decimal().set(30) + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + assert(v == column.getDecimal(i, 10, 0), errMsg) + if (v == null) assert(column.isNullAt(i), errMsg) + } + + column.close() } - testVector("getBinary should return null for null slot", 4, BinaryType) { + testVector("Binary APIs", 4, BinaryType) { column => - assert(column.numNulls() == 0) + val reference = mutable.ArrayBuffer.empty[String] var idx = 0 - column.putNull(idx) - assert(column.getBinary(idx) == null) + column.putByteArray(idx, "Hello".getBytes(StandardCharsets.UTF_8)) + reference += "Hello" idx += 1 - column.putNull(idx) - assert(column.getBinary(idx) == null) - assert(column.numNulls() == 2) + column.putByteArray(idx, "World".getBytes(StandardCharsets.UTF_8)) + reference += "World" idx += 1 - column.putByteArray(idx, "Hello".getBytes(StandardCharsets.UTF_8)) - assert(column.getBinary(idx) != null) + + column.putNull(idx) + reference += null + idx += 1 + + column.putByteArray(idx, "abc".getBytes(StandardCharsets.UTF_8)) + reference += "abc" + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + if (v != null) { + assert(v == new String(column.getBinary(i)), errMsg) + } else { + assert(column.isNullAt(i), errMsg) + assert(column.getBinary(i) == null, errMsg) + } + } + + column.close() } } From 6d5f7ec3e6f25e683628370350cfb865aac29d65 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 1 Feb 2018 14:55:13 +0000 Subject: [PATCH 7/7] Remove null check. --- .../spark/sql/execution/vectorized/MutableColumnarRow.java | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index d4ba31ddf883b..4e4242fe8d9b9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -157,7 +157,6 @@ public ColumnarArray getArray(int ordinal) { @Override public ColumnarMap getMap(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getMap(rowId); }