diff --git a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VectorValueComparator.java b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VectorValueComparator.java index ed32e16ca26..d2c772ca8a8 100644 --- a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VectorValueComparator.java +++ b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VectorValueComparator.java @@ -41,6 +41,18 @@ public abstract class VectorValueComparator { */ protected int valueWidth; + + private boolean checkNullsOnCompare = true; + + /** + * This value is true by default and re-computed when vectors are attached to the comparator. If both vectors cannot + * contain nulls then this value is {@code false} and calls to {@code compare(i1, i2)} are short-circuited + * to {@code compareNotNull(i1, i2)} thereby speeding up comparisons resulting in faster sorts etc. + */ + public boolean checkNullsOnCompare() { + return this.checkNullsOnCompare; + } + /** * Constructor for variable-width vectors. */ @@ -76,6 +88,21 @@ public void attachVector(V vector) { public void attachVectors(V vector1, V vector2) { this.vector1 = vector1; this.vector2 = vector2; + + final boolean v1MayHaveNulls = mayHaveNulls(vector1); + final boolean v2MayHaveNulls = mayHaveNulls(vector2); + + this.checkNullsOnCompare = v1MayHaveNulls || v2MayHaveNulls; + } + + private boolean mayHaveNulls(V v) { + if (v.getValueCount() == 0) { + return true; + } + if (! v.getField().isNullable()) { + return false; + } + return v.getNullCount() > 0; } /** @@ -87,17 +114,19 @@ public void attachVectors(V vector1, V vector2) { * values are equal. */ public int compare(int index1, int index2) { - boolean isNull1 = vector1.isNull(index1); - boolean isNull2 = vector2.isNull(index2); - - if (isNull1 || isNull2) { - if (isNull1 && isNull2) { - return 0; - } else if (isNull1) { - // null is smaller - return -1; - } else { - return 1; + if (checkNullsOnCompare) { + boolean isNull1 = vector1.isNull(index1); + boolean isNull2 = vector2.isNull(index2); + + if (isNull1 || isNull2) { + if (isNull1 && isNull2) { + return 0; + } else if (isNull1) { + // null is smaller + return -1; + } else { + return 1; + } } } return compareNotNull(index1, index2); diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java index 2fbf598bf33..818bb60d116 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java @@ -19,6 +19,7 @@ import static org.apache.arrow.vector.complex.BaseRepeatedValueVector.OFFSET_WIDTH; import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import org.apache.arrow.memory.BufferAllocator; @@ -34,6 +35,7 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.testing.ValueVectorDataPopulator; import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.FieldType; import org.junit.After; import org.junit.Before; @@ -390,4 +392,76 @@ public void testCompareByte() { assertTrue(comparator.compare(7, 7) == 0); } } + + @Test + public void testCheckNullsOnCompareIsFalseForNonNullableVector() { + try (IntVector vec = new IntVector("not nullable", + FieldType.notNullable(new ArrowType.Int(32, false)), allocator)) { + + ValueVectorDataPopulator.setVector(vec, 1, 2, 3, 4); + + final VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + comparator.attachVector(vec); + + assertFalse(comparator.checkNullsOnCompare()); + } + } + + @Test + public void testCheckNullsOnCompareIsTrueForNullableVector() { + try (IntVector vec = new IntVector("nullable", FieldType.nullable( + new ArrowType.Int(32, false)), allocator); + IntVector vec2 = new IntVector("not-nullable", FieldType.notNullable( + new ArrowType.Int(32, false)), allocator) + ) { + + ValueVectorDataPopulator.setVector(vec, 1, null, 3, 4); + ValueVectorDataPopulator.setVector(vec2, 1, 2, 3, 4); + + final VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + comparator.attachVector(vec); + assertTrue(comparator.checkNullsOnCompare()); + + comparator.attachVectors(vec, vec2); + assertTrue(comparator.checkNullsOnCompare()); + } + } + + @Test + public void testCheckNullsOnCompareIsFalseWithNoNulls() { + try (IntVector vec = new IntVector("nullable", FieldType.nullable( + new ArrowType.Int(32, false)), allocator); + IntVector vec2 = new IntVector("also-nullable", FieldType.nullable( + new ArrowType.Int(32, false)), allocator) + ) { + + // no null values + ValueVectorDataPopulator.setVector(vec, 1, 2, 3, 4); + ValueVectorDataPopulator.setVector(vec2, 1, 2, 3, 4); + + final VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + comparator.attachVector(vec); + assertFalse(comparator.checkNullsOnCompare()); + + comparator.attachVectors(vec, vec2); + assertFalse(comparator.checkNullsOnCompare()); + } + } + + @Test + public void testCheckNullsOnCompareIsTrueWithEmptyVectors() { + try (IntVector vec = new IntVector("nullable", FieldType.nullable( + new ArrowType.Int(32, false)), allocator); + IntVector vec2 = new IntVector("also-nullable", FieldType.nullable( + new ArrowType.Int(32, false)), allocator) + ) { + + final VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); + comparator.attachVector(vec2); + assertTrue(comparator.checkNullsOnCompare()); + + comparator.attachVectors(vec, vec2); + assertTrue(comparator.checkNullsOnCompare()); + } + } }