Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,16 @@ public static ValueVector decode(ValueVector indices, Dictionary dictionary, Buf
// copy the dictionary values into the decoded vector
TransferPair transfer = dictionaryVector.getTransferPair(allocator);
transfer.getTo().allocateNewSafe();

BaseIntVector baseIntVector = (BaseIntVector) indices;
retrieveIndexVector(baseIntVector, transfer, dictionaryCount, 0, count);
ValueVector decoded = transfer.getTo();
decoded.setValueCount(count);
return decoded;
try {
BaseIntVector baseIntVector = (BaseIntVector) indices;
retrieveIndexVector(baseIntVector, transfer, dictionaryCount, 0, count);
ValueVector decoded = transfer.getTo();
decoded.setValueCount(count);
return decoded;
} catch (Exception e) {
transfer.getTo().close();
throw e;
}
}

/**
Expand Down Expand Up @@ -192,10 +196,14 @@ public ValueVector encode(ValueVector vector) {

BaseIntVector indices = (BaseIntVector) createdVector;
indices.allocateNew();

buildIndexVector(vector, indices, hashTable, 0, vector.getValueCount());
indices.setValueCount(vector.getValueCount());
return indices;
try {
buildIndexVector(vector, indices, hashTable, 0, vector.getValueCount());
indices.setValueCount(vector.getValueCount());
return indices;
} catch (Exception e) {
indices.close();
throw e;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,9 @@

package org.apache.arrow.vector;

import static org.apache.arrow.vector.TestUtils.newVarBinaryVector;
import static org.apache.arrow.vector.TestUtils.newVarCharVector;
import static org.apache.arrow.vector.TestUtils.*;
import static org.apache.arrow.vector.testing.ValueVectorDataPopulator.setVector;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
Expand Down Expand Up @@ -881,6 +877,45 @@ public void testEncodeStructSubFieldWithCertainColumns() {
}
}

@Test
public void testNoMemoryLeak() {
// test no memory leak when encode
try (final VarCharVector vector = newVarCharVector("foo", allocator);
final VarCharVector dictionaryVector = newVarCharVector("dict", allocator)) {

setVector(vector, zero, one, two);
setVector(dictionaryVector, zero, one);

Dictionary dictionary =
new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));

try (final ValueVector encoded = DictionaryEncoder.encode(vector, dictionary)) {
fail("There should be an exception when encoding");
} catch (Exception e) {
assertEquals("Dictionary encoding not defined for value:" + new Text(two), e.getMessage());
}
}
assertEquals("encode memory leak", 0, allocator.getAllocatedMemory());

// test no memory leak when decode
try (final IntVector indices = newVector(IntVector.class, "", Types.MinorType.INT, allocator);
final VarCharVector dictionaryVector = newVarCharVector("dict", allocator)) {

setVector(indices, 3);
setVector(dictionaryVector, zero, one);

Dictionary dictionary =
new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));

try (final ValueVector decoded = DictionaryEncoder.decode(indices, dictionary, allocator)) {
fail("There should be an exception when decoding");
} catch (Exception e) {
assertEquals("Provided dictionary does not contain value for index 3", e.getMessage());
}
}
assertEquals("decode memory leak", 0, allocator.getAllocatedMemory());
}

private void testDictionary(Dictionary dictionary, ToIntBiFunction<ValueVector, Integer> valGetter) {
try (VarCharVector vector = new VarCharVector("vector", allocator)) {
setVector(vector, "1", "3", "5", "7", "9");
Expand Down