Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ docs/build
continuous_integration/hdfs-initialized
.cache
.#*
.idea/
.pytest_cache/
5 changes: 5 additions & 0 deletions distributed/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@ def _register_keras():
@partial(register_serialization_lazy, "sparse")
def _register_sparse():
from . import sparse


@partial(register_serialization_lazy, "pyarrow")
def _register_arrow():
from . import arrow
53 changes: 53 additions & 0 deletions distributed/protocol/arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import print_function, division, absolute_import

from .serialize import register_serialization


def serialize_batch(batch):
import pyarrow as pa
sink = pa.BufferOutputStream()
writer = pa.RecordBatchStreamWriter(sink, batch.schema)
writer.write_batch(batch)
writer.close()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One improvement on the arrow side would be if RecordBatchStreamWriter was a context manager as that would avoid the need for an explicit close.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fair point. ARROW-2863

buf = sink.get_result()
header = {}
frames = [buf]
return header, frames


def deserialize_batch(header, frames):
import pyarrow as pa
blob = frames[0]
reader = pa.RecordBatchStreamReader(pa.BufferReader(blob))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened ARROW-2859 to see if we can get rid of this pa.BufferReader detail

return reader.read_next_batch()


def serialize_table(tbl):
import pyarrow as pa
sink = pa.BufferOutputStream()
writer = pa.RecordBatchStreamWriter(sink, tbl.schema)
writer.write_table(tbl)
writer.close()
buf = sink.get_result()
header = {}
frames = [buf]
return header, frames


def deserialize_table(header, frames):
import pyarrow as pa
blob = frames[0]
reader = pa.RecordBatchStreamReader(pa.BufferReader(blob))
return reader.read_all()


register_serialization(
'pyarrow.lib.RecordBatch',
serialize_batch,
deserialize_batch
)
register_serialization(
'pyarrow.lib.Table',
serialize_table,
deserialize_table
)
44 changes: 44 additions & 0 deletions distributed/protocol/tests/test_arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pandas as pd
import pytest

pa = pytest.importorskip('pyarrow')

from distributed.utils_test import gen_cluster
from distributed.protocol import deserialize, serialize
from distributed.protocol.serialize import class_serializers, typename


df = pd.DataFrame({'A': list('abc'), 'B': [1,2,3]})
tbl = pa.Table.from_pandas(df, preserve_index=False)
batch = pa.RecordBatch.from_pandas(df, preserve_index=False)


@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"])
def test_roundtrip(obj):
# Test that the serialize/deserialize functions actually
# work independent of distributed
header, frames = serialize(obj)
new_obj = deserialize(header, frames)
assert obj.equals(new_obj)


@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"])
def test_typename(obj):
# The typename used to register the custom serialization is hardcoded
# ensure that the typename hasn't changed
assert typename(type(obj)) in class_serializers


def echo(arg):
return arg


@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"])
def test_scatter(obj):
@gen_cluster(client=True)
def run_test(client, scheduler, worker1, worker2):
obj_fut = yield client.scatter(obj)
fut = client.submit(echo, obj_fut)
result = yield fut
assert obj.equals(result)
run_test()
5 changes: 4 additions & 1 deletion distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,10 @@ def nbytes(frame, _bytes_like=(bytes, bytearray)):
if isinstance(frame, _bytes_like):
return len(frame)
else:
return frame.nbytes
try:
return frame.nbytes
except AttributeError:
return len(frame)


def PeriodicCallback(callback, callback_time, io_loop=None):
Expand Down