Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions cpp/src/arrow/ipc/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,16 @@ Status GetTensorSize(const Tensor& tensor, int64_t* size) {

RecordBatchWriter::~RecordBatchWriter() {}

Status RecordBatchWriter::WriteRecordBatch(
const RecordBatch& batch,
const std::shared_ptr<const KeyValueMetadata>& custom_metadata) {
if (custom_metadata == nullptr) {
return WriteRecordBatch(batch);
}
return Status::NotImplemented(
"Write record batch with custom metadata not implemented");
}

Status RecordBatchWriter::WriteTable(const Table& table, int64_t max_chunksize) {
TableBatchReader reader(table);

Expand Down
5 changes: 1 addition & 4 deletions cpp/src/arrow/ipc/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,7 @@ class ARROW_EXPORT RecordBatchWriter {
/// \return Status
virtual Status WriteRecordBatch(
const RecordBatch& batch,
const std::shared_ptr<const KeyValueMetadata>& custom_metadata) {
return Status::NotImplemented(
"Write record batch with custom metadata not implemented");
}
const std::shared_ptr<const KeyValueMetadata>& custom_metadata);

/// \brief Write possibly-chunked table by creating sequence of record batches
/// \param[in] table table to write
Expand Down
6 changes: 5 additions & 1 deletion python/pyarrow/_flight.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1086,14 +1086,18 @@ cdef class MetadataRecordBatchWriter(_CRecordBatchWriter):
----------
batch : RecordBatch
"""
cdef:
shared_ptr[const CKeyValueMetadata] custom_metadata

# Override superclass method to use check_flight_status so we
# can generate FlightWriteSizeExceededError. We don't do this
# for write_table as callers who intend to handle the error
# and retry with a smaller batch should be working with
# individual batches to have control.

with nogil:
check_flight_status(
self._writer().WriteRecordBatch(deref(batch.batch)))
self._writer().WriteRecordBatch(deref(batch.batch), custom_metadata))

def write_table(self, Table table, max_chunksize=None, **kwargs):
"""
Expand Down
11 changes: 11 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,11 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
shared_ptr[CRecordBatch] Slice(int64_t offset)
shared_ptr[CRecordBatch] Slice(int64_t offset, int64_t length)

cdef cppclass CRecordBatchWithMetadata" arrow::RecordBatchWithMetadata":
shared_ptr[CRecordBatch] batch
# The struct in C++ does not actually have these two `const` qualifiers, but adding `const` gets Cython to not complain
const shared_ptr[const CKeyValueMetadata] custom_metadata

cdef cppclass CTable" arrow::Table":
CTable(const shared_ptr[CSchema]& schema,
const vector[shared_ptr[CChunkedArray]]& columns)
Expand Down Expand Up @@ -887,6 +892,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
cdef cppclass CRecordBatchReader" arrow::RecordBatchReader":
shared_ptr[CSchema] schema()
CStatus Close()
CResult[CRecordBatchWithMetadata] ReadNext()
CStatus ReadNext(shared_ptr[CRecordBatch]* batch)
CResult[shared_ptr[CTable]] ToTable()

Expand Down Expand Up @@ -1590,6 +1596,9 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
cdef cppclass CRecordBatchWriter" arrow::ipc::RecordBatchWriter":
CStatus Close()
CStatus WriteRecordBatch(const CRecordBatch& batch)
CStatus WriteRecordBatch(
const CRecordBatch& batch,
const shared_ptr[const CKeyValueMetadata]& metadata)
CStatus WriteTable(const CTable& table, int64_t max_chunksize)

CIpcWriteStats stats()
Expand Down Expand Up @@ -1625,6 +1634,8 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:

CResult[shared_ptr[CRecordBatch]] ReadRecordBatch(int i)

CResult[CRecordBatchWithMetadata] ReadRecordBatchWithCustomMetadata(int i)

CIpcReadStats stats()

CResult[shared_ptr[CRecordBatchWriter]] MakeStreamWriter(
Expand Down
97 changes: 95 additions & 2 deletions python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -477,17 +477,22 @@ cdef class _CRecordBatchWriter(_Weakrefable):
else:
raise ValueError(type(table_or_batch))

def write_batch(self, RecordBatch batch):
def write_batch(self, RecordBatch batch, custom_metadata=None):
"""
Write RecordBatch to stream.

Parameters
----------
batch : RecordBatch
custom_metadata : mapping or KeyValueMetadata
Keys and values must be string-like / coercible to bytes
"""
metadata = ensure_metadata(custom_metadata, allow_none=True)
c_meta = pyarrow_unwrap_metadata(metadata)

with nogil:
check_status(self.writer.get()
.WriteRecordBatch(deref(batch.batch)))
.WriteRecordBatch(deref(batch.batch), c_meta))

def write_table(self, Table table, max_chunksize=None):
"""
Expand Down Expand Up @@ -683,6 +688,46 @@ cdef class RecordBatchReader(_Weakrefable):

return pyarrow_wrap_batch(batch)

def read_next_batch_with_custom_metadata(self):
"""
Read next RecordBatch from the stream along with its custom metadata.

Raises
------
StopIteration:
At end of stream.

Returns
-------
batch : RecordBatch
custom_metadata : KeyValueMetadata
"""
cdef:
CRecordBatchWithMetadata batch_with_metadata

with nogil:
batch_with_metadata = GetResultValue(self.reader.get().ReadNext())

if batch_with_metadata.batch.get() == NULL:
raise StopIteration

return _wrap_record_batch_with_metadata(batch_with_metadata)

def iter_batches_with_custom_metadata(self):
"""
Iterate over record batches from the stream along with their custom
metadata.

Yields
------
RecordBatchWithMetadata
"""
while True:
try:
yield self.read_next_batch_with_custom_metadata()
except StopIteration:
return

def read_all(self):
"""
Read all record batches as a pyarrow.Table.
Expand Down Expand Up @@ -828,6 +873,27 @@ cdef class _RecordBatchFileWriter(_RecordBatchStreamWriter):
self.writer = GetResultValue(
MakeFileWriter(c_sink, schema.sp_schema, self.options))

_RecordBatchWithMetadata = namedtuple(
'RecordBatchWithMetadata',
('batch', 'custom_metadata'))


class RecordBatchWithMetadata(_RecordBatchWithMetadata):
"""RecordBatch with its custom metadata

Parameters
----------
batch : RecordBatch
custom_metadata : KeyValueMetadata
"""
__slots__ = ()


@staticmethod
cdef _wrap_record_batch_with_metadata(CRecordBatchWithMetadata c):
return RecordBatchWithMetadata(pyarrow_wrap_batch(c.batch),
pyarrow_wrap_metadata(c.custom_metadata))


cdef class _RecordBatchFileReader(_Weakrefable):
cdef:
Expand Down Expand Up @@ -904,6 +970,33 @@ cdef class _RecordBatchFileReader(_Weakrefable):
# time has passed
get_record_batch = get_batch

def get_batch_with_custom_metadata(self, int i):
"""
Read the record batch with the given index along with
its custom metadata

Parameters
----------
i : int
The index of the record batch in the IPC file.

Returns
-------
batch : RecordBatch
custom_metadata : KeyValueMetadata
"""
cdef:
CRecordBatchWithMetadata batch_with_metadata

if i < 0 or i >= self.num_record_batches:
raise ValueError('Batch number {0} out of range'.format(i))

with nogil:
batch_with_metadata = GetResultValue(
self.reader.get().ReadRecordBatchWithCustomMetadata(i))

return _wrap_record_batch_with_metadata(batch_with_metadata)

def read_all(self):
"""
Read all record batches as a pyarrow.Table
Expand Down
39 changes: 39 additions & 0 deletions python/pyarrow/tests/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,45 @@ def test_ipc_zero_copy_numpy():
assert_frame_equal(df, rdf)


@pytest.mark.pandas
@pytest.mark.parametrize("ipc_type", ["stream", "file"])
def test_batches_with_custom_metadata_roundtrip(ipc_type):
df = pd.DataFrame({'foo': [1.5]})

batch = pa.RecordBatch.from_pandas(df)
sink = pa.BufferOutputStream()

batch_count = 2
file_factory = {"stream": pa.ipc.new_stream,
"file": pa.ipc.new_file}[ipc_type]

with file_factory(sink, batch.schema) as writer:
for i in range(batch_count):
writer.write_batch(batch, custom_metadata={"batch_id": str(i)})
# write a batch without custom metadata
writer.write_batch(batch)

buffer = sink.getvalue()

if ipc_type == "stream":
with pa.ipc.open_stream(buffer) as reader:
batch_with_metas = list(reader.iter_batches_with_custom_metadata())
else:
with pa.ipc.open_file(buffer) as reader:
batch_with_metas = [reader.get_batch_with_custom_metadata(i)
for i in range(reader.num_record_batches)]

for i in range(batch_count):
assert batch_with_metas[i].batch.num_rows == 1
assert isinstance(
batch_with_metas[i].custom_metadata, pa.KeyValueMetadata)
assert batch_with_metas[i].custom_metadata == {"batch_id": str(i)}

# the last batch has no custom metadata
assert batch_with_metas[batch_count].batch.num_rows == 1
assert batch_with_metas[batch_count].custom_metadata is None


def test_ipc_stream_no_batches():
# ARROW-2307
table = pa.Table.from_arrays([pa.array([1, 2, 3, 4]),
Expand Down