Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,30 @@ public static Field toMessageFormat(Field field, DictionaryProvider provider, Se
return field;
}
DictionaryEncoding encoding = field.getDictionary();
List<Field> children = field.getChildren();
List<Field> children;

List<Field> updatedChildren = new ArrayList<>(children.size());
for (Field child : children) {
updatedChildren.add(toMessageFormat(child, provider, dictionaryIdsUsed));
}

ArrowType type;
if (encoding == null) {
type = field.getType();
children = field.getChildren();
} else {
long id = encoding.getId();
Dictionary dictionary = provider.lookup(id);
if (dictionary == null) {
throw new IllegalArgumentException("Could not find dictionary with ID " + id);
}
type = dictionary.getVectorType();
children = dictionary.getVector().getField().getChildren();

dictionaryIdsUsed.add(id);
}

final List<Field> updatedChildren = new ArrayList<>(children.size());
for (Field child : children) {
updatedChildren.add(toMessageFormat(child, provider, dictionaryIdsUsed));
}

return new Field(field.getName(), new FieldType(field.isNullable(), type, encoding, field.getMetadata()),
updatedChildren);
}
Expand Down Expand Up @@ -115,8 +118,10 @@ public static Field toMemoryFormat(Field field, BufferAllocator allocator, Map<L
}

ArrowType type;
List<Field> fieldChildren = null;
if (encoding == null) {
type = field.getType();
fieldChildren = updatedChildren;
} else {
// re-type the field for in-memory format
type = encoding.getIndexType();
Expand All @@ -127,13 +132,14 @@ public static Field toMemoryFormat(Field field, BufferAllocator allocator, Map<L
if (!dictionaries.containsKey(encoding.getId())) {
// create a new dictionary vector for the values
String dictName = "DICT" + encoding.getId();
Field dictionaryField = new Field(dictName, new FieldType(false, field.getType(), null, null), children);
Field dictionaryField = new Field(dictName,
new FieldType(field.isNullable(), field.getType(), null, null), updatedChildren);
FieldVector dictionaryVector = dictionaryField.createVector(allocator);
dictionaries.put(encoding.getId(), new Dictionary(dictionaryVector, encoding));
}
}

return new Field(field.getName(), new FieldType(field.isNullable(), type, encoding, field.getMetadata()),
updatedChildren);
fieldChildren);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import static java.util.Arrays.asList;
import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt;
import static org.apache.arrow.vector.TestUtils.newVarCharVector;
import static org.apache.arrow.vector.TestUtils.newVector;
import static org.apache.arrow.vector.testing.ValueVectorDataPopulator.setVector;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

import java.io.ByteArrayInputStream;
Expand All @@ -41,6 +43,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

import org.apache.arrow.flatbuf.FieldNode;
Expand All @@ -55,11 +58,16 @@
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.NullVector;
import org.apache.arrow.vector.TestUtils;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.compare.Range;
import org.apache.arrow.vector.compare.RangeEqualsVisitor;
import org.apache.arrow.vector.compare.TypeEqualsVisitor;
import org.apache.arrow.vector.compare.VectorEqualsVisitor;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryEncoder;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
Expand All @@ -69,6 +77,7 @@
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
Expand All @@ -87,10 +96,12 @@ public class TestArrowReaderWriter {
private VarCharVector dictionaryVector1;
private VarCharVector dictionaryVector2;
private VarCharVector dictionaryVector3;
private StructVector dictionaryVector4;

private Dictionary dictionary1;
private Dictionary dictionary2;
private Dictionary dictionary3;
private Dictionary dictionary4;

private Schema schema;
private Schema encodedSchema;
Expand Down Expand Up @@ -119,20 +130,29 @@ public void init() {
"aa".getBytes(StandardCharsets.UTF_8),
"bb".getBytes(StandardCharsets.UTF_8),
"cc".getBytes(StandardCharsets.UTF_8));

dictionaryVector4 = newVector(StructVector.class, "D4", MinorType.STRUCT, allocator);
final Map<String, List<Integer>> dictionaryValues4 = new HashMap<>();
dictionaryValues4.put("a", Arrays.asList(1, 2, 3));
dictionaryValues4.put("b", Arrays.asList(4, 5, 6));
setVector(dictionaryVector4, dictionaryValues4);

dictionary1 = new Dictionary(dictionaryVector1,
new DictionaryEncoding(/*id=*/1L, /*ordered=*/false, /*indexType=*/null));
dictionary2 = new Dictionary(dictionaryVector2,
new DictionaryEncoding(/*id=*/2L, /*ordered=*/false, /*indexType=*/null));
dictionary3 = new Dictionary(dictionaryVector3,
new DictionaryEncoding(/*id=*/1L, /*ordered=*/false, /*indexType=*/null));
dictionary4 = new Dictionary(dictionaryVector4,
new DictionaryEncoding(/*id=*/3L, /*ordered=*/false, /*indexType=*/null));
}

@After
public void terminate() throws Exception {
dictionaryVector1.close();
dictionaryVector2.close();
dictionaryVector3.close();
dictionaryVector4.close();
allocator.close();
}

Expand Down Expand Up @@ -305,6 +325,82 @@ public void testWriteReadWithDictionaries() throws IOException {
}
}

@Test
public void testWriteReadWithStructDictionaries() throws IOException {
DictionaryProvider.MapDictionaryProvider provider =
new DictionaryProvider.MapDictionaryProvider();
provider.put(dictionary4);

try (final StructVector vector =
newVector(StructVector.class, "D4", MinorType.STRUCT, allocator)) {
final Map<String, List<Integer>> values = new HashMap<>();
// Index: 0, 2, 1, 2, 1, 0, 0
values.put("a", Arrays.asList(1, 3, 2, 3, 2, 1, 1));
values.put("b", Arrays.asList(4, 6, 5, 6, 5, 4, 4));
setVector(vector, values);
FieldVector encodedVector = (FieldVector) DictionaryEncoder.encode(vector, dictionary4);

List<Field> fields = Arrays.asList(encodedVector.getField());
List<FieldVector> vectors = Collections2.asImmutableList(encodedVector);
try (
VectorSchemaRoot root =
new VectorSchemaRoot(fields, vectors, encodedVector.getValueCount());
ByteArrayOutputStream out = new ByteArrayOutputStream();
ArrowFileWriter writer = new ArrowFileWriter(root, provider, newChannel(out));) {

writer.start();
writer.writeBatch();
writer.end();

try (
SeekableReadChannel channel = new SeekableReadChannel(
new ByteArrayReadableSeekableByteChannel(out.toByteArray()));
ArrowFileReader reader = new ArrowFileReader(channel, allocator)) {
final VectorSchemaRoot readRoot = reader.getVectorSchemaRoot();
final Schema readSchema = readRoot.getSchema();
assertEquals(root.getSchema(), readSchema);
assertEquals(1, reader.getDictionaryBlocks().size());
assertEquals(1, reader.getRecordBlocks().size());

reader.loadNextBatch();
assertEquals(1, readRoot.getFieldVectors().size());
assertEquals(1, reader.getDictionaryVectors().size());

// Read the encoded vector and check it
final FieldVector readEncoded = readRoot.getVector(0);
assertEquals(encodedVector.getValueCount(), readEncoded.getValueCount());
assertTrue(new RangeEqualsVisitor(encodedVector, readEncoded)
.rangeEquals(new Range(0, 0, encodedVector.getValueCount())));

// Read the dictionary
final Map<Long, Dictionary> readDictionaryMap = reader.getDictionaryVectors();
final Dictionary readDictionary =
readDictionaryMap.get(readEncoded.getField().getDictionary().getId());
assertNotNull(readDictionary);

// Assert the dictionary vector is correct
final FieldVector readDictionaryVector = readDictionary.getVector();
assertEquals(dictionaryVector4.getValueCount(), readDictionaryVector.getValueCount());
final BiFunction<ValueVector, ValueVector, Boolean> typeComparatorIgnoreName =
(v1, v2) -> new TypeEqualsVisitor(v1, false, true).equals(v2);
assertTrue("Dictionary vectors are not equal",
new RangeEqualsVisitor(dictionaryVector4, readDictionaryVector,
typeComparatorIgnoreName)
.rangeEquals(new Range(0, 0, dictionaryVector4.getValueCount())));

// Assert the decoded vector is correct
try (final ValueVector readVector =
DictionaryEncoder.decode(readEncoded, readDictionary)) {
assertEquals(vector.getValueCount(), readVector.getValueCount());
assertTrue("Decoded vectors are not equal",
new RangeEqualsVisitor(vector, readVector, typeComparatorIgnoreName)
.rangeEquals(new Range(0, 0, vector.getValueCount())));
}
}
}
}
}

@Test
public void testEmptyStreamInFileIPC() throws IOException {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
Expand Down Expand Up @@ -60,8 +62,10 @@
import org.apache.arrow.vector.complex.FixedSizeListVector;
import org.apache.arrow.vector.complex.LargeListVector;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.holders.IntervalDayHolder;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.arrow.vector.types.pojo.FieldType;

/**
Expand Down Expand Up @@ -673,4 +677,32 @@ public static void setVector(FixedSizeListVector vector, List<Integer>... values
dataVector.setValueCount(curPos);
vector.setValueCount(values.length);
}

/**
* Populate values for {@link StructVector}.
*/
public static void setVector(StructVector vector, Map<String, List<Integer>> values) {
vector.allocateNewSafe();

int valueCount = 0;
for (final Entry<String, List<Integer>> entry : values.entrySet()) {
// Add the child
final IntVector child = vector.addOrGet(entry.getKey(),
FieldType.nullable(MinorType.INT.getType()), IntVector.class);

// Write the values to the child
child.allocateNew();
final List<Integer> v = entry.getValue();
for (int i = 0; i < v.size(); i++) {
if (v.get(i) != null) {
child.set(i, v.get(i));
vector.setIndexDefined(i);
} else {
child.setNull(i);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we also need to set the value count for the child vector

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. The struct vector sets the value count of its children in #setValueCount.

valueCount = Math.max(valueCount, v.size());
}
vector.setValueCount(valueCount);
}
}