diff --git a/cpp/src/arrow/ipc/dictionary.cc b/cpp/src/arrow/ipc/dictionary.cc index 4d4f60575b6..8a4f4b642e1 100644 --- a/cpp/src/arrow/ipc/dictionary.cc +++ b/cpp/src/arrow/ipc/dictionary.cc @@ -24,6 +24,7 @@ #include #include "arrow/array.h" +#include "arrow/array/concatenate.h" #include "arrow/extension_type.h" #include "arrow/record_batch.h" #include "arrow/status.h" @@ -142,6 +143,23 @@ Status DictionaryMemo::AddDictionary(int64_t id, return Status::OK(); } +Status DictionaryMemo::AddDictionaryDelta(int64_t id, + const std::shared_ptr& dictionary, + MemoryPool* pool) { + std::shared_ptr originalDict, combinedDict; + RETURN_NOT_OK(GetDictionary(id, &originalDict)); + ArrayVector dictsToCombine{originalDict, dictionary}; + ARROW_ASSIGN_OR_RAISE(combinedDict, Concatenate(dictsToCombine, pool)); + id_to_dictionary_[id] = combinedDict; + return Status::OK(); +} + +Status DictionaryMemo::AddOrReplaceDictionary(int64_t id, + const std::shared_ptr& dictionary) { + id_to_dictionary_[id] = dictionary; + return Status::OK(); +} + DictionaryMemo::DictionaryVector DictionaryMemo::dictionaries() const { // Sort dictionaries by ascending id. This ensures that, in the case // of nested dictionaries, inner dictionaries are emitted before outer diff --git a/cpp/src/arrow/ipc/dictionary.h b/cpp/src/arrow/ipc/dictionary.h index dc2c716559b..c8b347cf182 100644 --- a/cpp/src/arrow/ipc/dictionary.h +++ b/cpp/src/arrow/ipc/dictionary.h @@ -25,6 +25,7 @@ #include #include +#include "arrow/memory_pool.h" #include "arrow/status.h" #include "arrow/util/macros.h" #include "arrow/util/visibility.h" @@ -78,6 +79,15 @@ class ARROW_EXPORT DictionaryMemo { /// KeyError if that dictionary already exists Status AddDictionary(int64_t id, const std::shared_ptr& dictionary); + /// \brief Append a dictionary delta to the memo with a particular id. Returns + /// KeyError if that dictionary does not exists + Status AddDictionaryDelta(int64_t id, const std::shared_ptr& dictionary, + MemoryPool* pool); + + /// \brief Add a dictionary to the memo if it does not have one with the id, + /// otherwise, replace the dictionary with the new one. + Status AddOrReplaceDictionary(int64_t id, const std::shared_ptr& dictionary); + /// \brief The stored dictionaries, in ascending id order. DictionaryVector dictionaries() const; diff --git a/cpp/src/arrow/ipc/metadata_internal.cc b/cpp/src/arrow/ipc/metadata_internal.cc index 3c110dd50cc..dd642d5bd65 100644 --- a/cpp/src/arrow/ipc/metadata_internal.cc +++ b/cpp/src/arrow/ipc/metadata_internal.cc @@ -1209,7 +1209,7 @@ Result> WriteSparseTensorMessage( } Status WriteDictionaryMessage( - int64_t id, int64_t length, int64_t body_length, + int64_t id, bool is_delta, int64_t length, int64_t body_length, const std::shared_ptr& custom_metadata, const std::vector& nodes, const std::vector& buffers, const IpcWriteOptions& options, std::shared_ptr* out) { @@ -1217,7 +1217,8 @@ Status WriteDictionaryMessage( RecordBatchOffset record_batch; RETURN_NOT_OK( MakeRecordBatch(fbb, length, body_length, nodes, buffers, options, &record_batch)); - auto dictionary_batch = flatbuf::CreateDictionaryBatch(fbb, id, record_batch).Union(); + auto dictionary_batch = + flatbuf::CreateDictionaryBatch(fbb, id, record_batch, is_delta).Union(); return WriteFBMessage(fbb, flatbuf::MessageHeader::DictionaryBatch, dictionary_batch, body_length, custom_metadata) .Value(out); diff --git a/cpp/src/arrow/ipc/metadata_internal.h b/cpp/src/arrow/ipc/metadata_internal.h index b0da188363f..5c1a032042b 100644 --- a/cpp/src/arrow/ipc/metadata_internal.h +++ b/cpp/src/arrow/ipc/metadata_internal.h @@ -195,7 +195,8 @@ Status WriteFileFooter(const Schema& schema, const std::vector& dicti io::OutputStream* out); Status WriteDictionaryMessage( - const int64_t id, const int64_t length, const int64_t body_length, + const int64_t id, const bool is_delta, const int64_t length, + const int64_t body_length, const std::shared_ptr& custom_metadata, const std::vector& nodes, const std::vector& buffers, const IpcWriteOptions& options, std::shared_ptr* out); diff --git a/cpp/src/arrow/ipc/read_write_test.cc b/cpp/src/arrow/ipc/read_write_test.cc index 2c1bf1c73a2..374cf9deacb 100644 --- a/cpp/src/arrow/ipc/read_write_test.cc +++ b/cpp/src/arrow/ipc/read_write_test.cc @@ -1228,6 +1228,157 @@ TEST_P(TestFileFormat, RoundTrip) { TestZeroLengthRoundTrip(*GetParam(), options); } +Status MakeDictionaryBatch(std::shared_ptr* out) { + auto f0_type = arrow::dictionary(int32(), utf8()); + auto f1_type = arrow::dictionary(int8(), utf8()); + + auto dict = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]"); + + auto indices0 = ArrayFromJSON(int32(), "[1, 2, null, 0, 2, 0]"); + auto indices1 = ArrayFromJSON(int8(), "[0, 0, 2, 2, 1, 1]"); + + auto a0 = std::make_shared(f0_type, indices0, dict); + auto a1 = std::make_shared(f1_type, indices1, dict); + + // construct batch + auto schema = ::arrow::schema({field("dict1", f0_type), field("dict2", f1_type)}); + + *out = RecordBatch::Make(schema, 6, {a0, a1}); + return Status::OK(); +} + +// A utility that supports reading/writing record batches, +// and manually specifying dictionaries. +class DictionaryBatchHelper { + public: + explicit DictionaryBatchHelper(const Schema& schema) : schema_(schema) { + buffer_ = *AllocateResizableBuffer(0); + sink_.reset(new io::BufferOutputStream(buffer_)); + payload_writer_ = *internal::MakePayloadStreamWriter(sink_.get()); + } + + Status Start() { + RETURN_NOT_OK(payload_writer_->Start()); + + // write schema + IpcPayload payload; + RETURN_NOT_OK(GetSchemaPayload(schema_, IpcWriteOptions::Defaults(), + &dictionary_memo_, &payload)); + return payload_writer_->WritePayload(payload); + } + + Status WriteDictionary(int64_t dictionary_id, const std::shared_ptr& dictionary, + bool is_delta) { + IpcPayload payload; + RETURN_NOT_OK(GetDictionaryPayload(dictionary_id, is_delta, dictionary, + IpcWriteOptions::Defaults(), &payload)); + RETURN_NOT_OK(payload_writer_->WritePayload(payload)); + return Status::OK(); + } + + Status WriteBatchPayload(const RecordBatch& batch) { + // write record batch payload only + IpcPayload payload; + RETURN_NOT_OK(GetRecordBatchPayload(batch, IpcWriteOptions::Defaults(), &payload)); + return payload_writer_->WritePayload(payload); + } + + Status Close() { + RETURN_NOT_OK(payload_writer_->Close()); + return sink_->Close(); + } + + Status ReadBatch(std::shared_ptr* out_batch) { + auto buf_reader = std::make_shared(buffer_); + std::shared_ptr reader; + ARROW_ASSIGN_OR_RAISE( + reader, RecordBatchStreamReader::Open(buf_reader, IpcReadOptions::Defaults())) + return reader->ReadNext(out_batch); + } + + std::unique_ptr payload_writer_; + const Schema& schema_; + DictionaryMemo dictionary_memo_; + std::shared_ptr buffer_; + std::unique_ptr sink_; +}; + +TEST(TestDictionaryBatch, DictionaryDelta) { + std::shared_ptr in_batch; + std::shared_ptr out_batch; + ASSERT_OK(MakeDictionaryBatch(&in_batch)); + + auto dict1 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\"]"); + auto dict2 = ArrayFromJSON(utf8(), "[\"baz\"]"); + + DictionaryBatchHelper helper(*in_batch->schema()); + ASSERT_OK(helper.Start()); + + ASSERT_OK(helper.WriteDictionary(0L, dict1, /*is_delta=*/false)); + ASSERT_OK(helper.WriteDictionary(0L, dict2, /*is_delta=*/true)); + + ASSERT_OK(helper.WriteDictionary(1L, dict1, /*is_delta=*/false)); + ASSERT_OK(helper.WriteDictionary(1L, dict2, /*is_delta=*/true)); + + ASSERT_OK(helper.WriteBatchPayload(*in_batch)); + ASSERT_OK(helper.Close()); + + ASSERT_OK(helper.ReadBatch(&out_batch)); + + ASSERT_BATCHES_EQUAL(*in_batch, *out_batch); +} + +TEST(TestDictionaryBatch, DictionaryDeltaWithUnknownId) { + std::shared_ptr in_batch; + std::shared_ptr out_batch; + ASSERT_OK(MakeDictionaryBatch(&in_batch)); + + auto dict1 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\"]"); + auto dict2 = ArrayFromJSON(utf8(), "[\"baz\"]"); + + DictionaryBatchHelper helper(*in_batch->schema()); + ASSERT_OK(helper.Start()); + + ASSERT_OK(helper.WriteDictionary(0L, dict1, /*is_delta=*/false)); + ASSERT_OK(helper.WriteDictionary(0L, dict2, /*is_delta=*/true)); + + /* This delta dictionary does not have a base dictionary previously in stream */ + ASSERT_OK(helper.WriteDictionary(1L, dict2, /*is_delta=*/true)); + + ASSERT_OK(helper.WriteBatchPayload(*in_batch)); + ASSERT_OK(helper.Close()); + + ASSERT_RAISES(KeyError, helper.ReadBatch(&out_batch)); +} + +TEST(TestDictionaryBatch, DictionaryReplacement) { + std::shared_ptr in_batch; + std::shared_ptr out_batch; + ASSERT_OK(MakeDictionaryBatch(&in_batch)); + + auto dict = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]"); + auto dict1 = ArrayFromJSON(utf8(), "[\"foo1\", \"bar1\", \"baz1\"]"); + auto dict2 = ArrayFromJSON(utf8(), "[\"foo2\", \"bar2\", \"baz2\"]"); + + DictionaryBatchHelper helper(*in_batch->schema()); + ASSERT_OK(helper.Start()); + + // the old dictionaries will be overwritten by + // the new dictionaries with the same ids. + ASSERT_OK(helper.WriteDictionary(0L, dict1, /*is_delta=*/false)); + ASSERT_OK(helper.WriteDictionary(0L, dict, /*is_delta=*/false)); + + ASSERT_OK(helper.WriteDictionary(1L, dict2, /*is_delta=*/false)); + ASSERT_OK(helper.WriteDictionary(1L, dict, /*is_delta=*/false)); + + ASSERT_OK(helper.WriteBatchPayload(*in_batch)); + ASSERT_OK(helper.Close()); + + ASSERT_OK(helper.ReadBatch(&out_batch)); + + ASSERT_BATCHES_EQUAL(*in_batch, *out_batch); +} + TEST_P(TestStreamFormat, RoundTrip) { TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index d35236cc238..20cab041e1e 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -684,7 +684,10 @@ Status ReadDictionary(const Buffer& metadata, DictionaryMemo* dictionary_memo, return Status::Invalid("Dictionary record batch must only contain one field"); } auto dictionary = batch->column(0); - return dictionary_memo->AddDictionary(id, dictionary); + if (dictionary_batch->isDelta()) { + return dictionary_memo->AddDictionaryDelta(id, dictionary, options.memory_pool); + } + return dictionary_memo->AddOrReplaceDictionary(id, dictionary); } Status ParseDictionary(const Message& message, DictionaryMemo* dictionary_memo, @@ -698,8 +701,7 @@ Status ParseDictionary(const Message& message, DictionaryMemo* dictionary_memo, Status UpdateDictionaries(const Message& message, DictionaryMemo* dictionary_memo, const IpcReadOptions& options) { - // TODO(wesm): implement delta dictionaries - return Status::NotImplemented("Delta dictionaries not yet implemented"); + return ParseDictionary(message, dictionary_memo, options); } // ---------------------------------------------------------------------- @@ -735,23 +737,25 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader { return Status::OK(); } - ARROW_ASSIGN_OR_RAISE(std::unique_ptr message, - message_reader_->ReadNextMessage()); + std::unique_ptr message; + ARROW_ASSIGN_OR_RAISE(message, message_reader_->ReadNextMessage()); if (message == nullptr) { // End of stream *batch = nullptr; return Status::OK(); } - if (message->type() == MessageType::DICTIONARY_BATCH) { - return UpdateDictionaries(*message, &dictionary_memo_, options_); - } else { - CHECK_HAS_BODY(*message); - ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); - return ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_, - &dictionary_memo_, options_, reader.get()) - .Value(batch); + // continue to read other dictionaries, if any + while (message->type() == MessageType::DICTIONARY_BATCH) { + RETURN_NOT_OK(UpdateDictionaries(*message, &dictionary_memo_, options_)); + ARROW_ASSIGN_OR_RAISE(message, message_reader_->ReadNextMessage()); } + + CHECK_HAS_BODY(*message); + ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); + return ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_, + &dictionary_memo_, options_, reader.get()) + .Value(batch); } std::shared_ptr schema() const override { return out_schema_; } diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index 3587490b203..4db61364383 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -532,13 +532,14 @@ class RecordBatchSerializer { class DictionarySerializer : public RecordBatchSerializer { public: - DictionarySerializer(int64_t dictionary_id, int64_t buffer_start_offset, + DictionarySerializer(int64_t dictionary_id, bool is_delta, int64_t buffer_start_offset, const IpcWriteOptions& options, IpcPayload* out) : RecordBatchSerializer(buffer_start_offset, options, out), - dictionary_id_(dictionary_id) {} + dictionary_id_(dictionary_id), + is_delta_(is_delta) {} Status SerializeMetadata(int64_t num_rows) override { - return WriteDictionaryMessage(dictionary_id_, num_rows, out_->body_length, + return WriteDictionaryMessage(dictionary_id_, is_delta_, num_rows, out_->body_length, custom_metadata_, field_nodes_, buffer_meta_, options_, &out_->metadata); } @@ -552,6 +553,7 @@ class DictionarySerializer : public RecordBatchSerializer { private: int64_t dictionary_id_; + bool is_delta_; }; } // namespace internal @@ -600,9 +602,16 @@ Status GetSchemaPayload(const Schema& schema, const IpcWriteOptions& options, Status GetDictionaryPayload(int64_t id, const std::shared_ptr& dictionary, const IpcWriteOptions& options, IpcPayload* out) { + return GetDictionaryPayload(id, false, dictionary, options, out); +} + +Status GetDictionaryPayload(int64_t id, bool is_delta, + const std::shared_ptr& dictionary, + const IpcWriteOptions& options, IpcPayload* out) { out->type = MessageType::DICTIONARY_BATCH; // Frame of reference is 0, see ARROW-384 - internal::DictionarySerializer assembler(id, /*buffer_start_offset=*/0, options, out); + internal::DictionarySerializer assembler(id, is_delta, /*buffer_start_offset=*/0, + options, out); return assembler.Assemble(dictionary); } @@ -1213,6 +1222,19 @@ Result> OpenRecordBatchWriter( schema, options); } +Result> MakePayloadStreamWriter( + io::OutputStream* sink, const IpcWriteOptions& options) { + return ::arrow::internal::make_unique(sink, options); +} + +Result> MakePayloadFileWriter( + io::OutputStream* sink, const std::shared_ptr& schema, + const IpcWriteOptions& options, + const std::shared_ptr& metadata) { + return ::arrow::internal::make_unique(options, schema, + metadata, sink); +} + } // namespace internal // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h index 23bb17175b9..9fd782ae1af 100644 --- a/cpp/src/arrow/ipc/writer.h +++ b/cpp/src/arrow/ipc/writer.h @@ -297,6 +297,18 @@ ARROW_EXPORT Status GetDictionaryPayload(int64_t id, const std::shared_ptr& dictionary, const IpcWriteOptions& options, IpcPayload* payload); +/// \brief Compute IpcPayload for a dictionary +/// \param[in] id the dictionary id +/// \param[in] is_delta whether the dictionary is a delta dictionary +/// \param[in] dictionary the dictionary values +/// \param[in] options options for serialization +/// \param[out] payload the output IpcPayload +/// \return Status +ARROW_EXPORT +Status GetDictionaryPayload(int64_t id, bool is_delta, + const std::shared_ptr& dictionary, + const IpcWriteOptions& options, IpcPayload* payload); + /// \brief Compute IpcPayload for the given record batch /// \param[in] batch the RecordBatch that is being serialized /// \param[in] options options for serialization @@ -341,6 +353,29 @@ class ARROW_EXPORT IpcPayloadWriter { virtual Status Close() = 0; }; +/// Create a new IPC payload stream writer from stream sink. User is +/// responsible for closing the actual OutputStream. +/// +/// \param[in] sink output stream to write to +/// \param[in] options options for serialization +/// \return Result> +ARROW_EXPORT +Result> MakePayloadStreamWriter( + io::OutputStream* sink, const IpcWriteOptions& options = IpcWriteOptions::Defaults()); + +/// Create a new IPC payload file writer from stream sink. +/// +/// \param[in] sink output stream to write to +/// \param[in] schema the schema of the record batches to be written +/// \param[in] options options for serialization, optional +/// \param[in] metadata custom metadata for File Footer, optional +/// \return Status +ARROW_EXPORT +Result> MakePayloadFileWriter( + io::OutputStream* sink, const std::shared_ptr& schema, + const IpcWriteOptions& options = IpcWriteOptions::Defaults(), + const std::shared_ptr& metadata = NULLPTR); + /// Create a new RecordBatchWriter from IpcPayloadWriter and schema. /// /// \param[in] sink the IpcPayloadWriter to write to