diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java index 477b1511f84..10e37a5cb82 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java @@ -94,12 +94,16 @@ public static ValueVector decode(ValueVector indices, Dictionary dictionary, Buf // copy the dictionary values into the decoded vector TransferPair transfer = dictionaryVector.getTransferPair(allocator); transfer.getTo().allocateNewSafe(); - - BaseIntVector baseIntVector = (BaseIntVector) indices; - retrieveIndexVector(baseIntVector, transfer, dictionaryCount, 0, count); - ValueVector decoded = transfer.getTo(); - decoded.setValueCount(count); - return decoded; + try { + BaseIntVector baseIntVector = (BaseIntVector) indices; + retrieveIndexVector(baseIntVector, transfer, dictionaryCount, 0, count); + ValueVector decoded = transfer.getTo(); + decoded.setValueCount(count); + return decoded; + } catch (Exception e) { + transfer.getTo().close(); + throw e; + } } /** @@ -192,10 +196,14 @@ public ValueVector encode(ValueVector vector) { BaseIntVector indices = (BaseIntVector) createdVector; indices.allocateNew(); - - buildIndexVector(vector, indices, hashTable, 0, vector.getValueCount()); - indices.setValueCount(vector.getValueCount()); - return indices; + try { + buildIndexVector(vector, indices, hashTable, 0, vector.getValueCount()); + indices.setValueCount(vector.getValueCount()); + return indices; + } catch (Exception e) { + indices.close(); + throw e; + } } /** diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java index bc6cddf3674..7e188185bcf 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java @@ -17,13 +17,9 @@ package org.apache.arrow.vector; -import static org.apache.arrow.vector.TestUtils.newVarBinaryVector; -import static org.apache.arrow.vector.TestUtils.newVarCharVector; +import static org.apache.arrow.vector.TestUtils.*; import static org.apache.arrow.vector.testing.ValueVectorDataPopulator.setVector; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; import java.nio.charset.StandardCharsets; import java.util.Arrays; @@ -881,6 +877,45 @@ public void testEncodeStructSubFieldWithCertainColumns() { } } + @Test + public void testNoMemoryLeak() { + // test no memory leak when encode + try (final VarCharVector vector = newVarCharVector("foo", allocator); + final VarCharVector dictionaryVector = newVarCharVector("dict", allocator)) { + + setVector(vector, zero, one, two); + setVector(dictionaryVector, zero, one); + + Dictionary dictionary = + new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + + try (final ValueVector encoded = DictionaryEncoder.encode(vector, dictionary)) { + fail("There should be an exception when encoding"); + } catch (Exception e) { + assertEquals("Dictionary encoding not defined for value:" + new Text(two), e.getMessage()); + } + } + assertEquals("encode memory leak", 0, allocator.getAllocatedMemory()); + + // test no memory leak when decode + try (final IntVector indices = newVector(IntVector.class, "", Types.MinorType.INT, allocator); + final VarCharVector dictionaryVector = newVarCharVector("dict", allocator)) { + + setVector(indices, 3); + setVector(dictionaryVector, zero, one); + + Dictionary dictionary = + new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + + try (final ValueVector decoded = DictionaryEncoder.decode(indices, dictionary, allocator)) { + fail("There should be an exception when decoding"); + } catch (Exception e) { + assertEquals("Provided dictionary does not contain value for index 3", e.getMessage()); + } + } + assertEquals("decode memory leak", 0, allocator.getAllocatedMemory()); + } + private void testDictionary(Dictionary dictionary, ToIntBiFunction valGetter) { try (VarCharVector vector = new VarCharVector("vector", allocator)) { setVector(vector, "1", "3", "5", "7", "9");