From 315b8312c3201f7856df77fa401d48ef4d0ac07e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E5=B3=B0?= Date: Sun, 11 Dec 2022 21:40:16 +0800 Subject: [PATCH 1/2] prevent potential memory leak --- .../dictionary/ListSubfieldEncoder.java | 58 ++++---- .../dictionary/StructSubfieldEncoder.java | 94 +++++++------ .../arrow/vector/TestDictionaryVector.java | 126 ++++++++++++++++++ 3 files changed, 212 insertions(+), 66 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ListSubfieldEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ListSubfieldEncoder.java index 00d7c8af179..f6a12a8f833 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ListSubfieldEncoder.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ListSubfieldEncoder.java @@ -85,20 +85,25 @@ public BaseListVector encodeListSubField(BaseListVector vector) { // clone list vector and initialize data vector BaseListVector encoded = cloneVector(vector, allocator); - encoded.initializeChildrenFromFields(Collections.singletonList(valueField)); - BaseIntVector indices = (BaseIntVector) getDataVector(encoded); - - ValueVector dataVector = getDataVector(vector); - for (int i = 0; i < valueCount; i++) { - if (!vector.isNull(i)) { - int start = vector.getElementStartIndex(i); - int end = vector.getElementEndIndex(i); - - DictionaryEncoder.buildIndexVector(dataVector, indices, hashTable, start, end); + try { + encoded.initializeChildrenFromFields(Collections.singletonList(valueField)); + BaseIntVector indices = (BaseIntVector) getDataVector(encoded); + + ValueVector dataVector = getDataVector(vector); + for (int i = 0; i < valueCount; i++) { + if (!vector.isNull(i)) { + int start = vector.getElementStartIndex(i); + int end = vector.getElementEndIndex(i); + + DictionaryEncoder.buildIndexVector(dataVector, indices, hashTable, start, end); + } } - } - return encoded; + return encoded; + } catch (Exception e) { + encoded.close(); + throw e; + } } /** @@ -132,24 +137,29 @@ public static BaseListVector decodeListSubField(BaseListVector vector, // clone list vector and initialize data vector BaseListVector decoded = cloneVector(vector, allocator); - Field dataVectorField = getDataVector(dictionaryVector).getField(); - decoded.initializeChildrenFromFields(Collections.singletonList(dataVectorField)); + try { + Field dataVectorField = getDataVector(dictionaryVector).getField(); + decoded.initializeChildrenFromFields(Collections.singletonList(dataVectorField)); - // get data vector - ValueVector dataVector = getDataVector(decoded); + // get data vector + ValueVector dataVector = getDataVector(decoded); - TransferPair transfer = getDataVector(dictionaryVector).makeTransferPair(dataVector); - BaseIntVector indices = (BaseIntVector) getDataVector(vector); + TransferPair transfer = getDataVector(dictionaryVector).makeTransferPair(dataVector); + BaseIntVector indices = (BaseIntVector) getDataVector(vector); - for (int i = 0; i < valueCount; i++) { + for (int i = 0; i < valueCount; i++) { - if (!vector.isNull(i)) { - int start = vector.getElementStartIndex(i); - int end = vector.getElementEndIndex(i); + if (!vector.isNull(i)) { + int start = vector.getElementStartIndex(i); + int end = vector.getElementEndIndex(i); - DictionaryEncoder.retrieveIndexVector(indices, transfer, dictionaryValueCount, start, end); + DictionaryEncoder.retrieveIndexVector(indices, transfer, dictionaryValueCount, start, end); + } } + return decoded; + } catch (Exception e) { + decoded.close(); + throw e; } - return decoded; } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java index d19e261490d..2f58067c2f2 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java @@ -118,23 +118,28 @@ public StructVector encode(StructVector vector, Map columnToDicti // clone list vector and initialize data vector StructVector encoded = cloneVector(vector, allocator); - encoded.initializeChildrenFromFields(childrenFields); - encoded.setValueCount(valueCount); - - for (int index = 0; index < childCount; index++) { - FieldVector childVector = getChildVector(vector, index); - FieldVector encodedChildVector = getChildVector(encoded, index); - Long dictionaryId = columnToDictionaryId.get(index); - if (dictionaryId != null) { - BaseIntVector indices = (BaseIntVector) encodedChildVector; - DictionaryEncoder.buildIndexVector(childVector, indices, dictionaryIdToHashTable.get(dictionaryId), - 0, valueCount); - } else { - childVector.makeTransferPair(encodedChildVector).splitAndTransfer(0, valueCount); + try { + encoded.initializeChildrenFromFields(childrenFields); + encoded.setValueCount(valueCount); + + for (int index = 0; index < childCount; index++) { + FieldVector childVector = getChildVector(vector, index); + FieldVector encodedChildVector = getChildVector(encoded, index); + Long dictionaryId = columnToDictionaryId.get(index); + if (dictionaryId != null) { + BaseIntVector indices = (BaseIntVector) encodedChildVector; + DictionaryEncoder.buildIndexVector(childVector, indices, dictionaryIdToHashTable.get(dictionaryId), + 0, valueCount); + } else { + childVector.makeTransferPair(encodedChildVector).splitAndTransfer(0, valueCount); + } } - } - return encoded; + return encoded; + } catch (Exception e) { + encoded.close(); + throw e; + } } /** @@ -167,36 +172,41 @@ public static StructVector decode(StructVector vector, // clone list vector and initialize child vectors StructVector decoded = cloneVector(vector, allocator); - List childFields = new ArrayList<>(); - for (int i = 0; i < childCount; i++) { - FieldVector childVector = getChildVector(vector, i); - Dictionary dictionary = getChildVectorDictionary(childVector, provider); - // childVector is not encoded. - if (dictionary == null) { - childFields.add(childVector.getField()); - } else { - childFields.add(dictionary.getVector().getField()); + try { + List childFields = new ArrayList<>(); + for (int i = 0; i < childCount; i++) { + FieldVector childVector = getChildVector(vector, i); + Dictionary dictionary = getChildVectorDictionary(childVector, provider); + // childVector is not encoded. + if (dictionary == null) { + childFields.add(childVector.getField()); + } else { + childFields.add(dictionary.getVector().getField()); + } } - } - decoded.initializeChildrenFromFields(childFields); - decoded.setValueCount(valueCount); - - for (int index = 0; index < childCount; index++) { - // get child vector - FieldVector childVector = getChildVector(vector, index); - FieldVector decodedChildVector = getChildVector(decoded, index); - Dictionary dictionary = getChildVectorDictionary(childVector, provider); - if (dictionary == null) { - childVector.makeTransferPair(decodedChildVector).splitAndTransfer(0, valueCount); - } else { - TransferPair transfer = dictionary.getVector().makeTransferPair(decodedChildVector); - BaseIntVector indices = (BaseIntVector) childVector; - - DictionaryEncoder.retrieveIndexVector(indices, transfer, valueCount, 0, valueCount); + decoded.initializeChildrenFromFields(childFields); + decoded.setValueCount(valueCount); + + for (int index = 0; index < childCount; index++) { + // get child vector + FieldVector childVector = getChildVector(vector, index); + FieldVector decodedChildVector = getChildVector(decoded, index); + Dictionary dictionary = getChildVectorDictionary(childVector, provider); + if (dictionary == null) { + childVector.makeTransferPair(decodedChildVector).splitAndTransfer(0, valueCount); + } else { + TransferPair transfer = dictionary.getVector().makeTransferPair(decodedChildVector); + BaseIntVector indices = (BaseIntVector) childVector; + + DictionaryEncoder.retrieveIndexVector(indices, transfer, valueCount, 0, valueCount); + } } - } - return decoded; + return decoded; + } catch (Exception e) { + decoded.close(); + throw e; + } } /** diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java index 7e188185bcf..f65e95d506a 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java @@ -916,6 +916,132 @@ public void testNoMemoryLeak() { assertEquals("decode memory leak", 0, allocator.getAllocatedMemory()); } + @Test + public void testListNoMemoryLeak() { + // Create a new value vector + try (final ListVector vector = ListVector.empty("vector", allocator); + final ListVector dictionaryVector = ListVector.empty("dict", allocator)) { + + UnionListWriter writer = vector.getWriter(); + writer.allocate(); + writeListVector(writer, new int[]{10, 20}); + writer.setValueCount(1); + + UnionListWriter dictWriter = dictionaryVector.getWriter(); + dictWriter.allocate(); + writeListVector(dictWriter, new int[]{10}); + dictionaryVector.setValueCount(1); + + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + ListSubfieldEncoder encoder = new ListSubfieldEncoder(dictionary, allocator); + + try (final ListVector encoded = (ListVector) encoder.encodeListSubField(vector)) { + fail("There should be an exception when encoding"); + } catch (Exception e) { + assertEquals("Dictionary encoding not defined for value:" + 20, e.getMessage()); + } + } + assertEquals("list encode memory leak", 0, allocator.getAllocatedMemory()); + + try (final ListVector indices = ListVector.empty("indices", allocator); + final ListVector dictionaryVector = ListVector.empty("dict", allocator)) { + + UnionListWriter writer = indices.getWriter(); + writer.allocate(); + writeListVector(writer, new int[]{3}); + writer.setValueCount(1); + + UnionListWriter dictWriter = dictionaryVector.getWriter(); + dictWriter.allocate(); + writeListVector(dictWriter, new int[]{10, 20}); + dictionaryVector.setValueCount(1); + + Dictionary dictionary = + new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + + try (final ValueVector decoded = ListSubfieldEncoder.decodeListSubField(indices, dictionary, allocator)) { + fail("There should be an exception when decoding"); + } catch (Exception e) { + assertEquals("Provided dictionary does not contain value for index 3", e.getMessage()); + } + } + assertEquals("list decode memory leak", 0, allocator.getAllocatedMemory()); + } + + @Test + public void testStructNoMemoryLeak() { + try (final StructVector vector = StructVector.empty("vector", allocator); + final VarCharVector dictVector1 = new VarCharVector("f0", allocator); + final VarCharVector dictVector2 = new VarCharVector("f1", allocator)) { + + vector.addOrGet("f0", FieldType.nullable(ArrowType.Utf8.INSTANCE), VarCharVector.class); + vector.addOrGet("f1", FieldType.nullable(ArrowType.Utf8.INSTANCE), VarCharVector.class); + + NullableStructWriter writer = vector.getWriter(); + writer.allocate(); + writeStructVector(writer, "aa", "baz"); + writer.setValueCount(1); + + DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); + setVector(dictVector1, + "aa".getBytes(StandardCharsets.UTF_8)); + setVector(dictVector2, + "foo".getBytes(StandardCharsets.UTF_8)); + + provider.put(new Dictionary(dictVector1, new DictionaryEncoding(1L, false, null))); + provider.put(new Dictionary(dictVector2, new DictionaryEncoding(2L, false, null))); + + StructSubfieldEncoder encoder = new StructSubfieldEncoder(allocator, provider); + Map columnToDictionaryId = new HashMap<>(); + columnToDictionaryId.put(0, 1L); + columnToDictionaryId.put(1, 2L); + + try (final StructVector encoded = (StructVector) encoder.encode(vector, columnToDictionaryId)) { + fail("There should be an exception when encoding"); + } catch (Exception e) { + assertEquals("Dictionary encoding not defined for value:" + "baz", e.getMessage()); + } + } + assertEquals("struct encode memory leak", 0, allocator.getAllocatedMemory()); + + try (final StructVector indices = StructVector.empty("indices", allocator); + final VarCharVector dictVector1 = new VarCharVector("f0", allocator); + final VarCharVector dictVector2 = new VarCharVector("f1", allocator)) { + + DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); + setVector(dictVector1, + "aa".getBytes(StandardCharsets.UTF_8)); + setVector(dictVector2, + "foo".getBytes(StandardCharsets.UTF_8)); + + provider.put(new Dictionary(dictVector1, new DictionaryEncoding(1L, false, null))); + provider.put(new Dictionary(dictVector2, new DictionaryEncoding(2L, false, null))); + + ArrowType int32 = new ArrowType.Int(32, true); + indices.addOrGet("f0", + new FieldType(true, int32, provider.lookup(1L).getEncoding()), + IntVector.class); + indices.addOrGet("f1", + new FieldType(true, int32, provider.lookup(2L).getEncoding()), + IntVector.class); + + NullableStructWriter writer = indices.getWriter(); + writer.allocate(); + writer.start(); + writer.integer("f0").writeInt(1); + writer.integer("f1").writeInt(3); + writer.end(); + writer.setValueCount(1); + + try (final StructVector decode = StructSubfieldEncoder.decode(indices, provider, allocator)) { + fail("There should be an exception when decoding"); + } catch (Exception e) { + assertEquals("Provided dictionary does not contain value for index 3", e.getMessage()); + } + } + assertEquals("struct decode memory leak", 0, allocator.getAllocatedMemory()); + } + private void testDictionary(Dictionary dictionary, ToIntBiFunction valGetter) { try (VarCharVector vector = new VarCharVector("vector", allocator)) { setVector(vector, "1", "3", "5", "7", "9"); From e1f9f16cfae6eed3a795885ace6b2dcd7bc6dac9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E5=B3=B0?= Date: Tue, 13 Dec 2022 01:12:53 +0800 Subject: [PATCH 2/2] review iter --- .../apache/arrow/vector/dictionary/DictionaryEncoder.java | 5 +++-- .../apache/arrow/vector/dictionary/ListSubfieldEncoder.java | 5 +++-- .../arrow/vector/dictionary/StructSubfieldEncoder.java | 5 +++-- .../java/org/apache/arrow/vector/TestDictionaryVector.java | 4 ++-- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java index 63ae41379e2..c44d106f536 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java @@ -20,6 +20,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.util.hash.ArrowBufHasher; import org.apache.arrow.memory.util.hash.SimpleHasher; +import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.BaseIntVector; import org.apache.arrow.vector.FieldVector; @@ -101,7 +102,7 @@ public static ValueVector decode(ValueVector indices, Dictionary dictionary, Buf decoded.setValueCount(count); return decoded; } catch (Exception e) { - transfer.getTo().close(); + AutoCloseables.close(e, transfer.getTo()); throw e; } } @@ -201,7 +202,7 @@ public ValueVector encode(ValueVector vector) { indices.setValueCount(vector.getValueCount()); return indices; } catch (Exception e) { - indices.close(); + AutoCloseables.close(e, indices); throw e; } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ListSubfieldEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ListSubfieldEncoder.java index f6a12a8f833..7f3514798d9 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ListSubfieldEncoder.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ListSubfieldEncoder.java @@ -22,6 +22,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.util.hash.ArrowBufHasher; import org.apache.arrow.memory.util.hash.SimpleHasher; +import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.BaseIntVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; @@ -101,7 +102,7 @@ public BaseListVector encodeListSubField(BaseListVector vector) { return encoded; } catch (Exception e) { - encoded.close(); + AutoCloseables.close(e, encoded); throw e; } } @@ -158,7 +159,7 @@ public static BaseListVector decodeListSubField(BaseListVector vector, } return decoded; } catch (Exception e) { - decoded.close(); + AutoCloseables.close(e, decoded); throw e; } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java index 2f58067c2f2..8500528a62b 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java @@ -25,6 +25,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.util.hash.ArrowBufHasher; import org.apache.arrow.memory.util.hash.SimpleHasher; +import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.BaseIntVector; import org.apache.arrow.vector.FieldVector; @@ -137,7 +138,7 @@ public StructVector encode(StructVector vector, Map columnToDicti return encoded; } catch (Exception e) { - encoded.close(); + AutoCloseables.close(e, encoded); throw e; } } @@ -204,7 +205,7 @@ public static StructVector decode(StructVector vector, return decoded; } catch (Exception e) { - decoded.close(); + AutoCloseables.close(e, decoded); throw e; } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java index f65e95d506a..501059733c6 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java @@ -938,7 +938,7 @@ public void testListNoMemoryLeak() { try (final ListVector encoded = (ListVector) encoder.encodeListSubField(vector)) { fail("There should be an exception when encoding"); } catch (Exception e) { - assertEquals("Dictionary encoding not defined for value:" + 20, e.getMessage()); + assertEquals("Dictionary encoding not defined for value:20", e.getMessage()); } } assertEquals("list encode memory leak", 0, allocator.getAllocatedMemory()); @@ -999,7 +999,7 @@ public void testStructNoMemoryLeak() { try (final StructVector encoded = (StructVector) encoder.encode(vector, columnToDictionaryId)) { fail("There should be an exception when encoding"); } catch (Exception e) { - assertEquals("Dictionary encoding not defined for value:" + "baz", e.getMessage()); + assertEquals("Dictionary encoding not defined for value:baz", e.getMessage()); } } assertEquals("struct encode memory leak", 0, allocator.getAllocatedMemory());