Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.util.hash.ArrowBufHasher;
import org.apache.arrow.memory.util.hash.SimpleHasher;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.BaseIntVector;
import org.apache.arrow.vector.FieldVector;
Expand Down Expand Up @@ -101,7 +102,7 @@ public static ValueVector decode(ValueVector indices, Dictionary dictionary, Buf
decoded.setValueCount(count);
return decoded;
} catch (Exception e) {
transfer.getTo().close();
AutoCloseables.close(e, transfer.getTo());
throw e;
}
}
Expand Down Expand Up @@ -201,7 +202,7 @@ public ValueVector encode(ValueVector vector) {
indices.setValueCount(vector.getValueCount());
return indices;
} catch (Exception e) {
indices.close();
AutoCloseables.close(e, indices);
throw e;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.util.hash.ArrowBufHasher;
import org.apache.arrow.memory.util.hash.SimpleHasher;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.BaseIntVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.ValueVector;
Expand Down Expand Up @@ -85,20 +86,25 @@ public BaseListVector encodeListSubField(BaseListVector vector) {

// clone list vector and initialize data vector
BaseListVector encoded = cloneVector(vector, allocator);
encoded.initializeChildrenFromFields(Collections.singletonList(valueField));
BaseIntVector indices = (BaseIntVector) getDataVector(encoded);

ValueVector dataVector = getDataVector(vector);
for (int i = 0; i < valueCount; i++) {
if (!vector.isNull(i)) {
int start = vector.getElementStartIndex(i);
int end = vector.getElementEndIndex(i);

DictionaryEncoder.buildIndexVector(dataVector, indices, hashTable, start, end);
try {
encoded.initializeChildrenFromFields(Collections.singletonList(valueField));
BaseIntVector indices = (BaseIntVector) getDataVector(encoded);

ValueVector dataVector = getDataVector(vector);
for (int i = 0; i < valueCount; i++) {
if (!vector.isNull(i)) {
int start = vector.getElementStartIndex(i);
int end = vector.getElementEndIndex(i);

DictionaryEncoder.buildIndexVector(dataVector, indices, hashTable, start, end);
}
}
}

return encoded;
return encoded;
} catch (Exception e) {
AutoCloseables.close(e, encoded);
throw e;
}
}

/**
Expand Down Expand Up @@ -132,24 +138,29 @@ public static BaseListVector decodeListSubField(BaseListVector vector,

// clone list vector and initialize data vector
BaseListVector decoded = cloneVector(vector, allocator);
Field dataVectorField = getDataVector(dictionaryVector).getField();
decoded.initializeChildrenFromFields(Collections.singletonList(dataVectorField));
try {
Field dataVectorField = getDataVector(dictionaryVector).getField();
decoded.initializeChildrenFromFields(Collections.singletonList(dataVectorField));

// get data vector
ValueVector dataVector = getDataVector(decoded);
// get data vector
ValueVector dataVector = getDataVector(decoded);

TransferPair transfer = getDataVector(dictionaryVector).makeTransferPair(dataVector);
BaseIntVector indices = (BaseIntVector) getDataVector(vector);
TransferPair transfer = getDataVector(dictionaryVector).makeTransferPair(dataVector);
BaseIntVector indices = (BaseIntVector) getDataVector(vector);

for (int i = 0; i < valueCount; i++) {
for (int i = 0; i < valueCount; i++) {

if (!vector.isNull(i)) {
int start = vector.getElementStartIndex(i);
int end = vector.getElementEndIndex(i);
if (!vector.isNull(i)) {
int start = vector.getElementStartIndex(i);
int end = vector.getElementEndIndex(i);

DictionaryEncoder.retrieveIndexVector(indices, transfer, dictionaryValueCount, start, end);
DictionaryEncoder.retrieveIndexVector(indices, transfer, dictionaryValueCount, start, end);
}
}
return decoded;
} catch (Exception e) {
AutoCloseables.close(e, decoded);
throw e;
}
return decoded;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.util.hash.ArrowBufHasher;
import org.apache.arrow.memory.util.hash.SimpleHasher;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.BaseIntVector;
import org.apache.arrow.vector.FieldVector;
Expand Down Expand Up @@ -118,23 +119,28 @@ public StructVector encode(StructVector vector, Map<Integer, Long> columnToDicti

// clone list vector and initialize data vector
StructVector encoded = cloneVector(vector, allocator);
encoded.initializeChildrenFromFields(childrenFields);
encoded.setValueCount(valueCount);

for (int index = 0; index < childCount; index++) {
FieldVector childVector = getChildVector(vector, index);
FieldVector encodedChildVector = getChildVector(encoded, index);
Long dictionaryId = columnToDictionaryId.get(index);
if (dictionaryId != null) {
BaseIntVector indices = (BaseIntVector) encodedChildVector;
DictionaryEncoder.buildIndexVector(childVector, indices, dictionaryIdToHashTable.get(dictionaryId),
0, valueCount);
} else {
childVector.makeTransferPair(encodedChildVector).splitAndTransfer(0, valueCount);
try {
encoded.initializeChildrenFromFields(childrenFields);
encoded.setValueCount(valueCount);

for (int index = 0; index < childCount; index++) {
FieldVector childVector = getChildVector(vector, index);
FieldVector encodedChildVector = getChildVector(encoded, index);
Long dictionaryId = columnToDictionaryId.get(index);
if (dictionaryId != null) {
BaseIntVector indices = (BaseIntVector) encodedChildVector;
DictionaryEncoder.buildIndexVector(childVector, indices, dictionaryIdToHashTable.get(dictionaryId),
0, valueCount);
} else {
childVector.makeTransferPair(encodedChildVector).splitAndTransfer(0, valueCount);
}
}
}

return encoded;
return encoded;
} catch (Exception e) {
AutoCloseables.close(e, encoded);
throw e;
}
}

/**
Expand Down Expand Up @@ -167,36 +173,41 @@ public static StructVector decode(StructVector vector,

// clone list vector and initialize child vectors
StructVector decoded = cloneVector(vector, allocator);
List<Field> childFields = new ArrayList<>();
for (int i = 0; i < childCount; i++) {
FieldVector childVector = getChildVector(vector, i);
Dictionary dictionary = getChildVectorDictionary(childVector, provider);
// childVector is not encoded.
if (dictionary == null) {
childFields.add(childVector.getField());
} else {
childFields.add(dictionary.getVector().getField());
try {
List<Field> childFields = new ArrayList<>();
for (int i = 0; i < childCount; i++) {
FieldVector childVector = getChildVector(vector, i);
Dictionary dictionary = getChildVectorDictionary(childVector, provider);
// childVector is not encoded.
if (dictionary == null) {
childFields.add(childVector.getField());
} else {
childFields.add(dictionary.getVector().getField());
}
}
}
decoded.initializeChildrenFromFields(childFields);
decoded.setValueCount(valueCount);

for (int index = 0; index < childCount; index++) {
// get child vector
FieldVector childVector = getChildVector(vector, index);
FieldVector decodedChildVector = getChildVector(decoded, index);
Dictionary dictionary = getChildVectorDictionary(childVector, provider);
if (dictionary == null) {
childVector.makeTransferPair(decodedChildVector).splitAndTransfer(0, valueCount);
} else {
TransferPair transfer = dictionary.getVector().makeTransferPair(decodedChildVector);
BaseIntVector indices = (BaseIntVector) childVector;

DictionaryEncoder.retrieveIndexVector(indices, transfer, valueCount, 0, valueCount);
decoded.initializeChildrenFromFields(childFields);
decoded.setValueCount(valueCount);

for (int index = 0; index < childCount; index++) {
// get child vector
FieldVector childVector = getChildVector(vector, index);
FieldVector decodedChildVector = getChildVector(decoded, index);
Dictionary dictionary = getChildVectorDictionary(childVector, provider);
if (dictionary == null) {
childVector.makeTransferPair(decodedChildVector).splitAndTransfer(0, valueCount);
} else {
TransferPair transfer = dictionary.getVector().makeTransferPair(decodedChildVector);
BaseIntVector indices = (BaseIntVector) childVector;

DictionaryEncoder.retrieveIndexVector(indices, transfer, valueCount, 0, valueCount);
}
}
}

return decoded;
return decoded;
} catch (Exception e) {
AutoCloseables.close(e, decoded);
throw e;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,132 @@ public void testNoMemoryLeak() {
assertEquals("decode memory leak", 0, allocator.getAllocatedMemory());
}

@Test
public void testListNoMemoryLeak() {
// Create a new value vector
try (final ListVector vector = ListVector.empty("vector", allocator);
final ListVector dictionaryVector = ListVector.empty("dict", allocator)) {

UnionListWriter writer = vector.getWriter();
writer.allocate();
writeListVector(writer, new int[]{10, 20});
writer.setValueCount(1);

UnionListWriter dictWriter = dictionaryVector.getWriter();
dictWriter.allocate();
writeListVector(dictWriter, new int[]{10});
dictionaryVector.setValueCount(1);

Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
ListSubfieldEncoder encoder = new ListSubfieldEncoder(dictionary, allocator);

try (final ListVector encoded = (ListVector) encoder.encodeListSubField(vector)) {
fail("There should be an exception when encoding");
} catch (Exception e) {
assertEquals("Dictionary encoding not defined for value:20", e.getMessage());
}
}
assertEquals("list encode memory leak", 0, allocator.getAllocatedMemory());

try (final ListVector indices = ListVector.empty("indices", allocator);
final ListVector dictionaryVector = ListVector.empty("dict", allocator)) {

UnionListWriter writer = indices.getWriter();
writer.allocate();
writeListVector(writer, new int[]{3});
writer.setValueCount(1);

UnionListWriter dictWriter = dictionaryVector.getWriter();
dictWriter.allocate();
writeListVector(dictWriter, new int[]{10, 20});
dictionaryVector.setValueCount(1);

Dictionary dictionary =
new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));

try (final ValueVector decoded = ListSubfieldEncoder.decodeListSubField(indices, dictionary, allocator)) {
fail("There should be an exception when decoding");
} catch (Exception e) {
assertEquals("Provided dictionary does not contain value for index 3", e.getMessage());
}
}
assertEquals("list decode memory leak", 0, allocator.getAllocatedMemory());
}

@Test
public void testStructNoMemoryLeak() {
try (final StructVector vector = StructVector.empty("vector", allocator);
final VarCharVector dictVector1 = new VarCharVector("f0", allocator);
final VarCharVector dictVector2 = new VarCharVector("f1", allocator)) {

vector.addOrGet("f0", FieldType.nullable(ArrowType.Utf8.INSTANCE), VarCharVector.class);
vector.addOrGet("f1", FieldType.nullable(ArrowType.Utf8.INSTANCE), VarCharVector.class);

NullableStructWriter writer = vector.getWriter();
writer.allocate();
writeStructVector(writer, "aa", "baz");
writer.setValueCount(1);

DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider();
setVector(dictVector1,
"aa".getBytes(StandardCharsets.UTF_8));
setVector(dictVector2,
"foo".getBytes(StandardCharsets.UTF_8));

provider.put(new Dictionary(dictVector1, new DictionaryEncoding(1L, false, null)));
provider.put(new Dictionary(dictVector2, new DictionaryEncoding(2L, false, null)));

StructSubfieldEncoder encoder = new StructSubfieldEncoder(allocator, provider);
Map<Integer, Long> columnToDictionaryId = new HashMap<>();
columnToDictionaryId.put(0, 1L);
columnToDictionaryId.put(1, 2L);

try (final StructVector encoded = (StructVector) encoder.encode(vector, columnToDictionaryId)) {
fail("There should be an exception when encoding");
} catch (Exception e) {
assertEquals("Dictionary encoding not defined for value:baz", e.getMessage());
}
}
assertEquals("struct encode memory leak", 0, allocator.getAllocatedMemory());

try (final StructVector indices = StructVector.empty("indices", allocator);
final VarCharVector dictVector1 = new VarCharVector("f0", allocator);
final VarCharVector dictVector2 = new VarCharVector("f1", allocator)) {

DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider();
setVector(dictVector1,
"aa".getBytes(StandardCharsets.UTF_8));
setVector(dictVector2,
"foo".getBytes(StandardCharsets.UTF_8));

provider.put(new Dictionary(dictVector1, new DictionaryEncoding(1L, false, null)));
provider.put(new Dictionary(dictVector2, new DictionaryEncoding(2L, false, null)));

ArrowType int32 = new ArrowType.Int(32, true);
indices.addOrGet("f0",
new FieldType(true, int32, provider.lookup(1L).getEncoding()),
IntVector.class);
indices.addOrGet("f1",
new FieldType(true, int32, provider.lookup(2L).getEncoding()),
IntVector.class);

NullableStructWriter writer = indices.getWriter();
writer.allocate();
writer.start();
writer.integer("f0").writeInt(1);
writer.integer("f1").writeInt(3);
writer.end();
writer.setValueCount(1);

try (final StructVector decode = StructSubfieldEncoder.decode(indices, provider, allocator)) {
fail("There should be an exception when decoding");
} catch (Exception e) {
assertEquals("Provided dictionary does not contain value for index 3", e.getMessage());
}
}
assertEquals("struct decode memory leak", 0, allocator.getAllocatedMemory());
}

private void testDictionary(Dictionary dictionary, ToIntBiFunction<ValueVector, Integer> valGetter) {
try (VarCharVector vector = new VarCharVector("vector", allocator)) {
setVector(vector, "1", "3", "5", "7", "9");
Expand Down