Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions cpp/src/arrow/ipc/dictionary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <vector>

#include "arrow/array.h"
#include "arrow/array/concatenate.h"
#include "arrow/extension_type.h"
#include "arrow/record_batch.h"
#include "arrow/status.h"
Expand Down Expand Up @@ -142,6 +143,23 @@ Status DictionaryMemo::AddDictionary(int64_t id,
return Status::OK();
}

Status DictionaryMemo::AddDictionaryDelta(int64_t id,
const std::shared_ptr<Array>& dictionary,
MemoryPool* pool) {
std::shared_ptr<Array> originalDict, combinedDict;
RETURN_NOT_OK(GetDictionary(id, &originalDict));
ArrayVector dictsToCombine{originalDict, dictionary};
ARROW_ASSIGN_OR_RAISE(combinedDict, Concatenate(dictsToCombine, pool));
id_to_dictionary_[id] = combinedDict;
return Status::OK();
}

Status DictionaryMemo::AddOrReplaceDictionary(int64_t id,
const std::shared_ptr<Array>& dictionary) {
id_to_dictionary_[id] = dictionary;
return Status::OK();
}

DictionaryMemo::DictionaryVector DictionaryMemo::dictionaries() const {
// Sort dictionaries by ascending id. This ensures that, in the case
// of nested dictionaries, inner dictionaries are emitted before outer
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/ipc/dictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <utility>
#include <vector>

#include "arrow/memory_pool.h"
#include "arrow/status.h"
#include "arrow/util/macros.h"
#include "arrow/util/visibility.h"
Expand Down Expand Up @@ -78,6 +79,15 @@ class ARROW_EXPORT DictionaryMemo {
/// KeyError if that dictionary already exists
Status AddDictionary(int64_t id, const std::shared_ptr<Array>& dictionary);

/// \brief Append a dictionary delta to the memo with a particular id. Returns
/// KeyError if that dictionary does not exists
Status AddDictionaryDelta(int64_t id, const std::shared_ptr<Array>& dictionary,
MemoryPool* pool);

/// \brief Add a dictionary to the memo if it does not have one with the id,
/// otherwise, replace the dictionary with the new one.
Status AddOrReplaceDictionary(int64_t id, const std::shared_ptr<Array>& dictionary);

/// \brief The stored dictionaries, in ascending id order.
DictionaryVector dictionaries() const;

Expand Down
5 changes: 3 additions & 2 deletions cpp/src/arrow/ipc/metadata_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1209,15 +1209,16 @@ Result<std::shared_ptr<Buffer>> WriteSparseTensorMessage(
}

Status WriteDictionaryMessage(
int64_t id, int64_t length, int64_t body_length,
int64_t id, bool is_delta, int64_t length, int64_t body_length,
const std::shared_ptr<const KeyValueMetadata>& custom_metadata,
const std::vector<FieldMetadata>& nodes, const std::vector<BufferMetadata>& buffers,
const IpcWriteOptions& options, std::shared_ptr<Buffer>* out) {
FBB fbb;
RecordBatchOffset record_batch;
RETURN_NOT_OK(
MakeRecordBatch(fbb, length, body_length, nodes, buffers, options, &record_batch));
auto dictionary_batch = flatbuf::CreateDictionaryBatch(fbb, id, record_batch).Union();
auto dictionary_batch =
flatbuf::CreateDictionaryBatch(fbb, id, record_batch, is_delta).Union();
return WriteFBMessage(fbb, flatbuf::MessageHeader::DictionaryBatch, dictionary_batch,
body_length, custom_metadata)
.Value(out);
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/ipc/metadata_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ Status WriteFileFooter(const Schema& schema, const std::vector<FileBlock>& dicti
io::OutputStream* out);

Status WriteDictionaryMessage(
const int64_t id, const int64_t length, const int64_t body_length,
const int64_t id, const bool is_delta, const int64_t length,
const int64_t body_length,
const std::shared_ptr<const KeyValueMetadata>& custom_metadata,
const std::vector<FieldMetadata>& nodes, const std::vector<BufferMetadata>& buffers,
const IpcWriteOptions& options, std::shared_ptr<Buffer>* out);
Expand Down
151 changes: 151 additions & 0 deletions cpp/src/arrow/ipc/read_write_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,157 @@ TEST_P(TestFileFormat, RoundTrip) {
TestZeroLengthRoundTrip(*GetParam(), options);
}

Status MakeDictionaryBatch(std::shared_ptr<RecordBatch>* out) {
auto f0_type = arrow::dictionary(int32(), utf8());
auto f1_type = arrow::dictionary(int8(), utf8());

auto dict = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]");

auto indices0 = ArrayFromJSON(int32(), "[1, 2, null, 0, 2, 0]");
auto indices1 = ArrayFromJSON(int8(), "[0, 0, 2, 2, 1, 1]");

auto a0 = std::make_shared<DictionaryArray>(f0_type, indices0, dict);
auto a1 = std::make_shared<DictionaryArray>(f1_type, indices1, dict);

// construct batch
auto schema = ::arrow::schema({field("dict1", f0_type), field("dict2", f1_type)});

*out = RecordBatch::Make(schema, 6, {a0, a1});
return Status::OK();
}

// A utility that supports reading/writing record batches,
// and manually specifying dictionaries.
class DictionaryBatchHelper {
public:
explicit DictionaryBatchHelper(const Schema& schema) : schema_(schema) {
buffer_ = *AllocateResizableBuffer(0);
sink_.reset(new io::BufferOutputStream(buffer_));
payload_writer_ = *internal::MakePayloadStreamWriter(sink_.get());
}

Status Start() {
RETURN_NOT_OK(payload_writer_->Start());

// write schema
IpcPayload payload;
RETURN_NOT_OK(GetSchemaPayload(schema_, IpcWriteOptions::Defaults(),
&dictionary_memo_, &payload));
return payload_writer_->WritePayload(payload);
}

Status WriteDictionary(int64_t dictionary_id, const std::shared_ptr<Array>& dictionary,
bool is_delta) {
IpcPayload payload;
RETURN_NOT_OK(GetDictionaryPayload(dictionary_id, is_delta, dictionary,
IpcWriteOptions::Defaults(), &payload));
RETURN_NOT_OK(payload_writer_->WritePayload(payload));
return Status::OK();
}

Status WriteBatchPayload(const RecordBatch& batch) {
// write record batch payload only
IpcPayload payload;
RETURN_NOT_OK(GetRecordBatchPayload(batch, IpcWriteOptions::Defaults(), &payload));
return payload_writer_->WritePayload(payload);
}

Status Close() {
RETURN_NOT_OK(payload_writer_->Close());
return sink_->Close();
}

Status ReadBatch(std::shared_ptr<RecordBatch>* out_batch) {
auto buf_reader = std::make_shared<io::BufferReader>(buffer_);
std::shared_ptr<RecordBatchReader> reader;
ARROW_ASSIGN_OR_RAISE(
reader, RecordBatchStreamReader::Open(buf_reader, IpcReadOptions::Defaults()))
return reader->ReadNext(out_batch);
}

std::unique_ptr<internal::IpcPayloadWriter> payload_writer_;
const Schema& schema_;
DictionaryMemo dictionary_memo_;
std::shared_ptr<ResizableBuffer> buffer_;
std::unique_ptr<io::BufferOutputStream> sink_;
};

TEST(TestDictionaryBatch, DictionaryDelta) {
std::shared_ptr<RecordBatch> in_batch;
std::shared_ptr<RecordBatch> out_batch;
ASSERT_OK(MakeDictionaryBatch(&in_batch));

auto dict1 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\"]");
auto dict2 = ArrayFromJSON(utf8(), "[\"baz\"]");

DictionaryBatchHelper helper(*in_batch->schema());
ASSERT_OK(helper.Start());

ASSERT_OK(helper.WriteDictionary(0L, dict1, /*is_delta=*/false));
ASSERT_OK(helper.WriteDictionary(0L, dict2, /*is_delta=*/true));

ASSERT_OK(helper.WriteDictionary(1L, dict1, /*is_delta=*/false));
ASSERT_OK(helper.WriteDictionary(1L, dict2, /*is_delta=*/true));

ASSERT_OK(helper.WriteBatchPayload(*in_batch));
ASSERT_OK(helper.Close());

ASSERT_OK(helper.ReadBatch(&out_batch));

ASSERT_BATCHES_EQUAL(*in_batch, *out_batch);
}

TEST(TestDictionaryBatch, DictionaryDeltaWithUnknownId) {
std::shared_ptr<RecordBatch> in_batch;
std::shared_ptr<RecordBatch> out_batch;
ASSERT_OK(MakeDictionaryBatch(&in_batch));

auto dict1 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\"]");
auto dict2 = ArrayFromJSON(utf8(), "[\"baz\"]");

DictionaryBatchHelper helper(*in_batch->schema());
ASSERT_OK(helper.Start());

ASSERT_OK(helper.WriteDictionary(0L, dict1, /*is_delta=*/false));
ASSERT_OK(helper.WriteDictionary(0L, dict2, /*is_delta=*/true));

/* This delta dictionary does not have a base dictionary previously in stream */
ASSERT_OK(helper.WriteDictionary(1L, dict2, /*is_delta=*/true));

ASSERT_OK(helper.WriteBatchPayload(*in_batch));
ASSERT_OK(helper.Close());

ASSERT_RAISES(KeyError, helper.ReadBatch(&out_batch));
}

TEST(TestDictionaryBatch, DictionaryReplacement) {
std::shared_ptr<RecordBatch> in_batch;
std::shared_ptr<RecordBatch> out_batch;
ASSERT_OK(MakeDictionaryBatch(&in_batch));

auto dict = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]");
auto dict1 = ArrayFromJSON(utf8(), "[\"foo1\", \"bar1\", \"baz1\"]");
auto dict2 = ArrayFromJSON(utf8(), "[\"foo2\", \"bar2\", \"baz2\"]");

DictionaryBatchHelper helper(*in_batch->schema());
ASSERT_OK(helper.Start());

// the old dictionaries will be overwritten by
// the new dictionaries with the same ids.
ASSERT_OK(helper.WriteDictionary(0L, dict1, /*is_delta=*/false));
ASSERT_OK(helper.WriteDictionary(0L, dict, /*is_delta=*/false));

ASSERT_OK(helper.WriteDictionary(1L, dict2, /*is_delta=*/false));
ASSERT_OK(helper.WriteDictionary(1L, dict, /*is_delta=*/false));

ASSERT_OK(helper.WriteBatchPayload(*in_batch));
ASSERT_OK(helper.Close());

ASSERT_OK(helper.ReadBatch(&out_batch));

ASSERT_BATCHES_EQUAL(*in_batch, *out_batch);
}

TEST_P(TestStreamFormat, RoundTrip) {
TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
Expand Down
30 changes: 17 additions & 13 deletions cpp/src/arrow/ipc/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,10 @@ Status ReadDictionary(const Buffer& metadata, DictionaryMemo* dictionary_memo,
return Status::Invalid("Dictionary record batch must only contain one field");
}
auto dictionary = batch->column(0);
return dictionary_memo->AddDictionary(id, dictionary);
if (dictionary_batch->isDelta()) {
return dictionary_memo->AddDictionaryDelta(id, dictionary, options.memory_pool);
}
return dictionary_memo->AddOrReplaceDictionary(id, dictionary);
}

Status ParseDictionary(const Message& message, DictionaryMemo* dictionary_memo,
Expand All @@ -698,8 +701,7 @@ Status ParseDictionary(const Message& message, DictionaryMemo* dictionary_memo,

Status UpdateDictionaries(const Message& message, DictionaryMemo* dictionary_memo,
const IpcReadOptions& options) {
// TODO(wesm): implement delta dictionaries
return Status::NotImplemented("Delta dictionaries not yet implemented");
return ParseDictionary(message, dictionary_memo, options);
}

// ----------------------------------------------------------------------
Expand Down Expand Up @@ -735,23 +737,25 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader {
return Status::OK();
}

ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message,
message_reader_->ReadNextMessage());
std::unique_ptr<Message> message;
ARROW_ASSIGN_OR_RAISE(message, message_reader_->ReadNextMessage());
if (message == nullptr) {
// End of stream
*batch = nullptr;
return Status::OK();
}

if (message->type() == MessageType::DICTIONARY_BATCH) {
return UpdateDictionaries(*message, &dictionary_memo_, options_);
} else {
CHECK_HAS_BODY(*message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
return ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_,
&dictionary_memo_, options_, reader.get())
.Value(batch);
// continue to read other dictionaries, if any
while (message->type() == MessageType::DICTIONARY_BATCH) {
RETURN_NOT_OK(UpdateDictionaries(*message, &dictionary_memo_, options_));
ARROW_ASSIGN_OR_RAISE(message, message_reader_->ReadNextMessage());
}

CHECK_HAS_BODY(*message);
ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
return ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_,
&dictionary_memo_, options_, reader.get())
.Value(batch);
}

std::shared_ptr<Schema> schema() const override { return out_schema_; }
Expand Down
30 changes: 26 additions & 4 deletions cpp/src/arrow/ipc/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -532,13 +532,14 @@ class RecordBatchSerializer {

class DictionarySerializer : public RecordBatchSerializer {
public:
DictionarySerializer(int64_t dictionary_id, int64_t buffer_start_offset,
DictionarySerializer(int64_t dictionary_id, bool is_delta, int64_t buffer_start_offset,
const IpcWriteOptions& options, IpcPayload* out)
: RecordBatchSerializer(buffer_start_offset, options, out),
dictionary_id_(dictionary_id) {}
dictionary_id_(dictionary_id),
is_delta_(is_delta) {}

Status SerializeMetadata(int64_t num_rows) override {
return WriteDictionaryMessage(dictionary_id_, num_rows, out_->body_length,
return WriteDictionaryMessage(dictionary_id_, is_delta_, num_rows, out_->body_length,
custom_metadata_, field_nodes_, buffer_meta_, options_,
&out_->metadata);
}
Expand All @@ -552,6 +553,7 @@ class DictionarySerializer : public RecordBatchSerializer {

private:
int64_t dictionary_id_;
bool is_delta_;
};

} // namespace internal
Expand Down Expand Up @@ -600,9 +602,16 @@ Status GetSchemaPayload(const Schema& schema, const IpcWriteOptions& options,

Status GetDictionaryPayload(int64_t id, const std::shared_ptr<Array>& dictionary,
const IpcWriteOptions& options, IpcPayload* out) {
return GetDictionaryPayload(id, false, dictionary, options, out);
}

Status GetDictionaryPayload(int64_t id, bool is_delta,
const std::shared_ptr<Array>& dictionary,
const IpcWriteOptions& options, IpcPayload* out) {
out->type = MessageType::DICTIONARY_BATCH;
// Frame of reference is 0, see ARROW-384
internal::DictionarySerializer assembler(id, /*buffer_start_offset=*/0, options, out);
internal::DictionarySerializer assembler(id, is_delta, /*buffer_start_offset=*/0,
options, out);
return assembler.Assemble(dictionary);
}

Expand Down Expand Up @@ -1213,6 +1222,19 @@ Result<std::unique_ptr<RecordBatchWriter>> OpenRecordBatchWriter(
schema, options);
}

Result<std::unique_ptr<IpcPayloadWriter>> MakePayloadStreamWriter(
io::OutputStream* sink, const IpcWriteOptions& options) {
return ::arrow::internal::make_unique<internal::PayloadStreamWriter>(sink, options);
}

Result<std::unique_ptr<IpcPayloadWriter>> MakePayloadFileWriter(
io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
const IpcWriteOptions& options,
const std::shared_ptr<const KeyValueMetadata>& metadata) {
return ::arrow::internal::make_unique<internal::PayloadFileWriter>(options, schema,
metadata, sink);
}

} // namespace internal

// ----------------------------------------------------------------------
Expand Down
Loading