diff --git a/java/vector/src/main/codegen/templates/NullableValueVectors.java b/java/vector/src/main/codegen/templates/NullableValueVectors.java index 6a9ce65392f..0f38181b4c6 100644 --- a/java/vector/src/main/codegen/templates/NullableValueVectors.java +++ b/java/vector/src/main/codegen/templates/NullableValueVectors.java @@ -125,6 +125,11 @@ public final class ${className} extends BaseDataValueVector implements <#if type } + @Override + public BitVector getValidityVector() { + return bits; + } + @Override public List getFieldInnerVectors() { return innerVectors; @@ -420,7 +425,7 @@ public void copyFromSafe(int fromIndex, int thisIndex, ${valuesName} from){ mutator.fillEmpties(thisIndex); values.copyFromSafe(fromIndex, thisIndex, from); - bits.getMutator().setSafe(thisIndex, 1); + bits.getMutator().setSafeToOne(thisIndex); } public void copyFromSafe(int fromIndex, int thisIndex, ${className} from){ @@ -519,7 +524,7 @@ private Mutator(){ @Override public void setIndexDefined(int index){ - bits.getMutator().set(index, 1); + bits.getMutator().setToOne(index); } /** @@ -537,7 +542,7 @@ public void set(int index, <#if type.major == "VarLen">byte[]<#elseif (type.widt valuesMutator.set(i, emptyByteArray); } - bitsMutator.set(index, 1); + bitsMutator.setToOne(index); valuesMutator.set(index, value); <#if type.major == "VarLen">lastSet = index; } @@ -568,7 +573,7 @@ public void setSafe(int index, byte[] value, int start, int length) { <#else> fillEmpties(index); - bits.getMutator().setSafe(index, 1); + bits.getMutator().setSafeToOne(index); values.getMutator().setSafe(index, value, start, length); setCount++; <#if type.major == "VarLen">lastSet = index; @@ -581,7 +586,7 @@ public void setSafe(int index, ByteBuffer value, int start, int length) { <#else> fillEmpties(index); - bits.getMutator().setSafe(index, 1); + bits.getMutator().setSafeToOne(index); values.getMutator().setSafe(index, value, start, length); setCount++; <#if type.major == "VarLen">lastSet = index; @@ -620,7 +625,7 @@ public void set(int index, ${minor.class}Holder holder){ valuesMutator.set(i, emptyByteArray); } - bits.getMutator().set(index, 1); + bits.getMutator().setToOne(index); valuesMutator.set(index, holder); <#if type.major == "VarLen">lastSet = index; } @@ -670,7 +675,7 @@ public void setSafe(int index, ${minor.class}Holder value) { <#if type.major == "VarLen"> fillEmpties(index); - bits.getMutator().setSafe(index, 1); + bits.getMutator().setSafeToOne(index); values.getMutator().setSafe(index, value); setCount++; <#if type.major == "VarLen">lastSet = index; @@ -681,7 +686,7 @@ public void setSafe(int index, ${minor.javaType!type.javaType} value) { <#if type.major == "VarLen"> fillEmpties(index); - bits.getMutator().setSafe(index, 1); + bits.getMutator().setSafeToOne(index); values.getMutator().setSafe(index, value); setCount++; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java index 9beabcbe46b..d1e9abe5dd1 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java @@ -423,8 +423,8 @@ private Mutator() { * value to set (either 1 or 0) */ public final void set(int index, int value) { - int byteIndex = index >> 3; - int bitIndex = index & 7; + int byteIndex = byteIndex(index); + int bitIndex = bitIndex(index); byte currentByte = data.getByte(byteIndex); byte bitMask = (byte) (1L << bitIndex); if (value != 0) { @@ -432,10 +432,87 @@ public final void set(int index, int value) { } else { currentByte -= (bitMask & currentByte); } + data.setByte(byteIndex, currentByte); + } + /** + * Set the bit at the given index to 1. + * + * @param index position of the bit to set + */ + public final void setToOne(int index) { + int byteIndex = byteIndex(index); + int bitIndex = bitIndex(index); + byte currentByte = data.getByte(byteIndex); + byte bitMask = (byte) (1L << bitIndex); + currentByte |= bitMask; data.setByte(byteIndex, currentByte); } + /** + * set count bits to 1 in data starting at firstBitIndex + * @param data the buffer to set + * @param firstBitIndex the index of the first bit to set + * @param count the number of bits to set + */ + public void setRangeToOne(int firstBitIndex, int count) { + int starByteIndex = byteIndex(firstBitIndex); + final int lastBitIndex = firstBitIndex + count; + final int endByteIndex = byteIndex(lastBitIndex); + final int startByteBitIndex = bitIndex(firstBitIndex); + final int endBytebitIndex = bitIndex(lastBitIndex); + if (count < 8 && starByteIndex == endByteIndex) { + // handles the case where we don't have a first and a last byte + byte bitMask = 0; + for (int i = startByteBitIndex; i < endBytebitIndex; ++i) { + bitMask |= (byte) (1L << i); + } + byte currentByte = data.getByte(starByteIndex); + currentByte |= bitMask; + data.setByte(starByteIndex, currentByte); + } else { + // fill in first byte (if it's not full) + if (startByteBitIndex != 0) { + byte currentByte = data.getByte(starByteIndex); + final byte bitMask = (byte) (0xFFL << startByteBitIndex); + currentByte |= bitMask; + data.setByte(starByteIndex, currentByte); + ++ starByteIndex; + } + + // fill in one full byte at a time + for (int i = starByteIndex; i < endByteIndex; i++) { + data.setByte(i, 0xFF); + } + + // fill in the last byte (if it's not full) + if (endBytebitIndex != 0) { + final int byteIndex = byteIndex(lastBitIndex - endBytebitIndex); + byte currentByte = data.getByte(byteIndex); + final byte bitMask = (byte) (0xFFL >>> ((8 - endBytebitIndex) & 7)); + currentByte |= bitMask; + data.setByte(byteIndex, currentByte); + } + + } + } + + /** + * @param absoluteBitIndex the index of the bit in the buffer + * @return the index of the byte containing that bit + */ + private int byteIndex(int absoluteBitIndex) { + return absoluteBitIndex >> 3; + } + + /** + * @param absoluteBitIndex the index of the bit in the buffer + * @return the index of the bit inside the byte + */ + private int bitIndex(int absoluteBitIndex) { + return absoluteBitIndex & 7; + } + public final void set(int index, BitHolder holder) { set(index, holder.value); } @@ -451,6 +528,13 @@ public void setSafe(int index, int value) { set(index, value); } + public void setSafeToOne(int index) { + while(index >= getValueCapacity()) { + reAlloc(); + } + setToOne(index); + } + public void setSafe(int index, BitHolder holder) { while(index >= getValueCapacity()) { reAlloc(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/NullableVector.java b/java/vector/src/main/java/org/apache/arrow/vector/NullableVector.java index 0212b3c0d7b..b49e9167c25 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/NullableVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/NullableVector.java @@ -19,5 +19,7 @@ public interface NullableVector extends ValueVector { + BitVector getValidityVector(); + ValueVector getValuesVector(); } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java index b33919b2790..774b59e3683 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java @@ -30,6 +30,7 @@ import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.Field; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -364,6 +365,41 @@ public void testBitVector() { } } + @Test + public void testBitVectorRangeSetAllOnes() { + validateRange(1000, 0, 1000); + validateRange(1000, 0, 1); + validateRange(1000, 1, 2); + validateRange(1000, 5, 6); + validateRange(1000, 5, 10); + validateRange(1000, 5, 150); + validateRange(1000, 5, 27); + for (int i = 0; i < 8; i++) { + for (int j = 0; j < 8; j++) { + validateRange(1000, 10 + i, 27 + j); + validateRange(1000, i, j); + } + } + } + + private void validateRange(int length, int start, int count) { + String desc = "[" + start + ", " + (start + count) + ") "; + try (BitVector bitVector = new BitVector("bits", allocator)) { + bitVector.reset(); + bitVector.allocateNew(length); + bitVector.getMutator().setRangeToOne(start, count); + for (int i = 0; i < start; i++) { + Assert.assertEquals(desc + i, 0, bitVector.getAccessor().get(i)); + } + for (int i = start; i < start + count; i++) { + Assert.assertEquals(desc + i, 1, bitVector.getAccessor().get(i)); + } + for (int i = start + count; i < length; i++) { + Assert.assertEquals(desc + i, 0, bitVector.getAccessor().get(i)); + } + } + } + @Test public void testReAllocNullableFixedWidthVector() { // Create a new value vector for 1024 integers