diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/ApproxEqualsVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/ApproxEqualsVisitor.java index 6e74c212116..4b4000cef8e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/compare/ApproxEqualsVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/ApproxEqualsVisitor.java @@ -69,8 +69,14 @@ public void setDoubleDiffFunction(DiffFunctionDouble doubleDiffFunction) { @Override public Boolean visit(BaseFixedWidthVector left, Range range) { if (left instanceof Float4Vector) { + if (!validate(left)) { + return false; + } return float4ApproxEquals(range); } else if (left instanceof Float8Vector) { + if (!validate(left)) { + return false; + } return float8ApproxEquals(range); } else { return super.visit(left, range); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java index 5d43031ffb5..d6c9ac7b4f3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java @@ -58,7 +58,11 @@ public RangeEqualsVisitor(ValueVector left, ValueVector right, boolean isTypeChe Preconditions.checkArgument(right != null, "right vector cannot be null"); - // types cannot change for a visitor instance. so, the check is done only once. + // type usually checks only once unless the left vector is changed. + checkType(); + } + + private void checkType() { if (!isTypeCheckNeeded) { typeCompareResult = true; } else if (left == right) { @@ -68,6 +72,17 @@ public RangeEqualsVisitor(ValueVector left, ValueVector right, boolean isTypeChe } } + /** + * Validate the passed left vector, if it is changed, reset and check type. + */ + protected boolean validate(ValueVector left) { + if (left != this.left) { + this.left = left; + checkType(); + } + return typeCompareResult; + } + /** * Constructs a new instance. * @@ -79,7 +94,7 @@ public RangeEqualsVisitor(ValueVector left, ValueVector right) { } /** - * Check range equals without passing IN param in VectorVisitor. + * Check range equals. */ public boolean rangeEquals(Range range) { if (!typeCompareResult) { @@ -107,42 +122,59 @@ public ValueVector getRight() { return right; } - public boolean isTypeCheckNeeded() { - return isTypeCheckNeeded; - } - @Override public Boolean visit(BaseFixedWidthVector left, Range range) { + if (!validate(left)) { + return false; + } return compareBaseFixedWidthVectors(range); } @Override public Boolean visit(BaseVariableWidthVector left, Range range) { + if (!validate(left)) { + return false; + } return compareBaseVariableWidthVectors(range); } @Override public Boolean visit(ListVector left, Range range) { + if (!validate(left)) { + return false; + } return compareListVectors(range); } @Override public Boolean visit(FixedSizeListVector left, Range range) { + if (!validate(left)) { + return false; + } return compareFixedSizeListVectors(range); } @Override public Boolean visit(NonNullableStructVector left, Range range) { + if (!validate(left)) { + return false; + } return compareStructVectors(range); } @Override public Boolean visit(UnionVector left, Range range) { + if (!validate(left)) { + return false; + } return compareUnionVectors(range); } @Override public Boolean visit(ZeroVector left, Range range) { + if (!validate(left)) { + return false; + } return true; } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java index 04d73e231cd..6cfd70ddc77 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java @@ -86,6 +86,31 @@ public void testIntVectorEqualsWithNull() { } } + @Test + public void testEqualsWithTypeChange() { + try (final IntVector vector1 = new IntVector("intVector1", allocator); + final IntVector vector2 = new IntVector("intVector2", allocator); + final BigIntVector vector3 = new BigIntVector("bigIntVector", allocator)) { + + vector1.allocateNew(2); + vector1.setValueCount(2); + vector2.allocateNew(2); + vector2.setValueCount(2); + + vector1.setSafe(0, 1); + vector1.setSafe(1, 2); + + vector2.setSafe(0, 1); + vector2.setSafe(1, 2); + + RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector1, vector2); + Range range = new Range(0, 0, 2); + assertTrue(vector1.accept(visitor, range)); + // visitor left vector changed, will reset and check type again + assertFalse(vector3.accept(visitor, range)); + } + } + @Test public void testBaseFixedWidthVectorRangeEqual() { try (final IntVector vector1 = new IntVector("int", allocator);