diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index b89604e6fe1..585b86fd847 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -951,6 +951,16 @@ Status GetTensorSize(const Tensor& tensor, int64_t* size) { RecordBatchWriter::~RecordBatchWriter() {} +Status RecordBatchWriter::WriteRecordBatch( + const RecordBatch& batch, + const std::shared_ptr& custom_metadata) { + if (custom_metadata == nullptr) { + return WriteRecordBatch(batch); + } + return Status::NotImplemented( + "Write record batch with custom metadata not implemented"); +} + Status RecordBatchWriter::WriteTable(const Table& table, int64_t max_chunksize) { TableBatchReader reader(table); diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h index 6dc62f41761..9e18a213ba3 100644 --- a/cpp/src/arrow/ipc/writer.h +++ b/cpp/src/arrow/ipc/writer.h @@ -103,10 +103,7 @@ class ARROW_EXPORT RecordBatchWriter { /// \return Status virtual Status WriteRecordBatch( const RecordBatch& batch, - const std::shared_ptr& custom_metadata) { - return Status::NotImplemented( - "Write record batch with custom metadata not implemented"); - } + const std::shared_ptr& custom_metadata); /// \brief Write possibly-chunked table by creating sequence of record batches /// \param[in] table table to write diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index f46c1d3aa83..7feee8cf7b4 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -1086,14 +1086,18 @@ cdef class MetadataRecordBatchWriter(_CRecordBatchWriter): ---------- batch : RecordBatch """ + cdef: + shared_ptr[const CKeyValueMetadata] custom_metadata + # Override superclass method to use check_flight_status so we # can generate FlightWriteSizeExceededError. We don't do this # for write_table as callers who intend to handle the error # and retry with a smaller batch should be working with # individual batches to have control. + with nogil: check_flight_status( - self._writer().WriteRecordBatch(deref(batch.batch))) + self._writer().WriteRecordBatch(deref(batch.batch), custom_metadata)) def write_table(self, Table table, max_chunksize=None, **kwargs): """ diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 306128ce35c..b60452fd1d6 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -823,6 +823,11 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: shared_ptr[CRecordBatch] Slice(int64_t offset) shared_ptr[CRecordBatch] Slice(int64_t offset, int64_t length) + cdef cppclass CRecordBatchWithMetadata" arrow::RecordBatchWithMetadata": + shared_ptr[CRecordBatch] batch + # The struct in C++ does not actually have these two `const` qualifiers, but adding `const` gets Cython to not complain + const shared_ptr[const CKeyValueMetadata] custom_metadata + cdef cppclass CTable" arrow::Table": CTable(const shared_ptr[CSchema]& schema, const vector[shared_ptr[CChunkedArray]]& columns) @@ -887,6 +892,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CRecordBatchReader" arrow::RecordBatchReader": shared_ptr[CSchema] schema() CStatus Close() + CResult[CRecordBatchWithMetadata] ReadNext() CStatus ReadNext(shared_ptr[CRecordBatch]* batch) CResult[shared_ptr[CTable]] ToTable() @@ -1590,6 +1596,9 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: cdef cppclass CRecordBatchWriter" arrow::ipc::RecordBatchWriter": CStatus Close() CStatus WriteRecordBatch(const CRecordBatch& batch) + CStatus WriteRecordBatch( + const CRecordBatch& batch, + const shared_ptr[const CKeyValueMetadata]& metadata) CStatus WriteTable(const CTable& table, int64_t max_chunksize) CIpcWriteStats stats() @@ -1625,6 +1634,8 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: CResult[shared_ptr[CRecordBatch]] ReadRecordBatch(int i) + CResult[CRecordBatchWithMetadata] ReadRecordBatchWithCustomMetadata(int i) + CIpcReadStats stats() CResult[shared_ptr[CRecordBatchWriter]] MakeStreamWriter( diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index ee5f3454762..9b13e71dde9 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -477,17 +477,22 @@ cdef class _CRecordBatchWriter(_Weakrefable): else: raise ValueError(type(table_or_batch)) - def write_batch(self, RecordBatch batch): + def write_batch(self, RecordBatch batch, custom_metadata=None): """ Write RecordBatch to stream. Parameters ---------- batch : RecordBatch + custom_metadata : mapping or KeyValueMetadata + Keys and values must be string-like / coercible to bytes """ + metadata = ensure_metadata(custom_metadata, allow_none=True) + c_meta = pyarrow_unwrap_metadata(metadata) + with nogil: check_status(self.writer.get() - .WriteRecordBatch(deref(batch.batch))) + .WriteRecordBatch(deref(batch.batch), c_meta)) def write_table(self, Table table, max_chunksize=None): """ @@ -683,6 +688,46 @@ cdef class RecordBatchReader(_Weakrefable): return pyarrow_wrap_batch(batch) + def read_next_batch_with_custom_metadata(self): + """ + Read next RecordBatch from the stream along with its custom metadata. + + Raises + ------ + StopIteration: + At end of stream. + + Returns + ------- + batch : RecordBatch + custom_metadata : KeyValueMetadata + """ + cdef: + CRecordBatchWithMetadata batch_with_metadata + + with nogil: + batch_with_metadata = GetResultValue(self.reader.get().ReadNext()) + + if batch_with_metadata.batch.get() == NULL: + raise StopIteration + + return _wrap_record_batch_with_metadata(batch_with_metadata) + + def iter_batches_with_custom_metadata(self): + """ + Iterate over record batches from the stream along with their custom + metadata. + + Yields + ------ + RecordBatchWithMetadata + """ + while True: + try: + yield self.read_next_batch_with_custom_metadata() + except StopIteration: + return + def read_all(self): """ Read all record batches as a pyarrow.Table. @@ -828,6 +873,27 @@ cdef class _RecordBatchFileWriter(_RecordBatchStreamWriter): self.writer = GetResultValue( MakeFileWriter(c_sink, schema.sp_schema, self.options)) +_RecordBatchWithMetadata = namedtuple( + 'RecordBatchWithMetadata', + ('batch', 'custom_metadata')) + + +class RecordBatchWithMetadata(_RecordBatchWithMetadata): + """RecordBatch with its custom metadata + + Parameters + ---------- + batch : RecordBatch + custom_metadata : KeyValueMetadata + """ + __slots__ = () + + +@staticmethod +cdef _wrap_record_batch_with_metadata(CRecordBatchWithMetadata c): + return RecordBatchWithMetadata(pyarrow_wrap_batch(c.batch), + pyarrow_wrap_metadata(c.custom_metadata)) + cdef class _RecordBatchFileReader(_Weakrefable): cdef: @@ -904,6 +970,33 @@ cdef class _RecordBatchFileReader(_Weakrefable): # time has passed get_record_batch = get_batch + def get_batch_with_custom_metadata(self, int i): + """ + Read the record batch with the given index along with + its custom metadata + + Parameters + ---------- + i : int + The index of the record batch in the IPC file. + + Returns + ------- + batch : RecordBatch + custom_metadata : KeyValueMetadata + """ + cdef: + CRecordBatchWithMetadata batch_with_metadata + + if i < 0 or i >= self.num_record_batches: + raise ValueError('Batch number {0} out of range'.format(i)) + + with nogil: + batch_with_metadata = GetResultValue( + self.reader.get().ReadRecordBatchWithCustomMetadata(i)) + + return _wrap_record_batch_with_metadata(batch_with_metadata) + def read_all(self): """ Read all record batches as a pyarrow.Table diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index d9abe987ae6..0df302a8de7 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -946,6 +946,45 @@ def test_ipc_zero_copy_numpy(): assert_frame_equal(df, rdf) +@pytest.mark.pandas +@pytest.mark.parametrize("ipc_type", ["stream", "file"]) +def test_batches_with_custom_metadata_roundtrip(ipc_type): + df = pd.DataFrame({'foo': [1.5]}) + + batch = pa.RecordBatch.from_pandas(df) + sink = pa.BufferOutputStream() + + batch_count = 2 + file_factory = {"stream": pa.ipc.new_stream, + "file": pa.ipc.new_file}[ipc_type] + + with file_factory(sink, batch.schema) as writer: + for i in range(batch_count): + writer.write_batch(batch, custom_metadata={"batch_id": str(i)}) + # write a batch without custom metadata + writer.write_batch(batch) + + buffer = sink.getvalue() + + if ipc_type == "stream": + with pa.ipc.open_stream(buffer) as reader: + batch_with_metas = list(reader.iter_batches_with_custom_metadata()) + else: + with pa.ipc.open_file(buffer) as reader: + batch_with_metas = [reader.get_batch_with_custom_metadata(i) + for i in range(reader.num_record_batches)] + + for i in range(batch_count): + assert batch_with_metas[i].batch.num_rows == 1 + assert isinstance( + batch_with_metas[i].custom_metadata, pa.KeyValueMetadata) + assert batch_with_metas[i].custom_metadata == {"batch_id": str(i)} + + # the last batch has no custom metadata + assert batch_with_metas[batch_count].batch.num_rows == 1 + assert batch_with_metas[batch_count].custom_metadata is None + + def test_ipc_stream_no_batches(): # ARROW-2307 table = pa.Table.from_arrays([pa.array([1, 2, 3, 4]),