diff --git a/.gitignore b/.gitignore index 7e110237d4b..7510e74bcbf 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ docs/build continuous_integration/hdfs-initialized .cache .#* +.idea/ +.pytest_cache/ diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index a6a9afaf324..01ac7e8464a 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -37,3 +37,8 @@ def _register_keras(): @partial(register_serialization_lazy, "sparse") def _register_sparse(): from . import sparse + + +@partial(register_serialization_lazy, "pyarrow") +def _register_arrow(): + from . import arrow diff --git a/distributed/protocol/arrow.py b/distributed/protocol/arrow.py new file mode 100644 index 00000000000..87c5d05c99f --- /dev/null +++ b/distributed/protocol/arrow.py @@ -0,0 +1,53 @@ +from __future__ import print_function, division, absolute_import + +from .serialize import register_serialization + + +def serialize_batch(batch): + import pyarrow as pa + sink = pa.BufferOutputStream() + writer = pa.RecordBatchStreamWriter(sink, batch.schema) + writer.write_batch(batch) + writer.close() + buf = sink.get_result() + header = {} + frames = [buf] + return header, frames + + +def deserialize_batch(header, frames): + import pyarrow as pa + blob = frames[0] + reader = pa.RecordBatchStreamReader(pa.BufferReader(blob)) + return reader.read_next_batch() + + +def serialize_table(tbl): + import pyarrow as pa + sink = pa.BufferOutputStream() + writer = pa.RecordBatchStreamWriter(sink, tbl.schema) + writer.write_table(tbl) + writer.close() + buf = sink.get_result() + header = {} + frames = [buf] + return header, frames + + +def deserialize_table(header, frames): + import pyarrow as pa + blob = frames[0] + reader = pa.RecordBatchStreamReader(pa.BufferReader(blob)) + return reader.read_all() + + +register_serialization( + 'pyarrow.lib.RecordBatch', + serialize_batch, + deserialize_batch +) +register_serialization( + 'pyarrow.lib.Table', + serialize_table, + deserialize_table +) diff --git a/distributed/protocol/tests/test_arrow.py b/distributed/protocol/tests/test_arrow.py new file mode 100644 index 00000000000..6f014bae323 --- /dev/null +++ b/distributed/protocol/tests/test_arrow.py @@ -0,0 +1,44 @@ +import pandas as pd +import pytest + +pa = pytest.importorskip('pyarrow') + +from distributed.utils_test import gen_cluster +from distributed.protocol import deserialize, serialize +from distributed.protocol.serialize import class_serializers, typename + + +df = pd.DataFrame({'A': list('abc'), 'B': [1,2,3]}) +tbl = pa.Table.from_pandas(df, preserve_index=False) +batch = pa.RecordBatch.from_pandas(df, preserve_index=False) + + +@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"]) +def test_roundtrip(obj): + # Test that the serialize/deserialize functions actually + # work independent of distributed + header, frames = serialize(obj) + new_obj = deserialize(header, frames) + assert obj.equals(new_obj) + + +@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"]) +def test_typename(obj): + # The typename used to register the custom serialization is hardcoded + # ensure that the typename hasn't changed + assert typename(type(obj)) in class_serializers + + +def echo(arg): + return arg + + +@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"]) +def test_scatter(obj): + @gen_cluster(client=True) + def run_test(client, scheduler, worker1, worker2): + obj_fut = yield client.scatter(obj) + fut = client.submit(echo, obj_fut) + result = yield fut + assert obj.equals(result) + run_test() diff --git a/distributed/utils.py b/distributed/utils.py index 53495ba60b4..666bb9dd26d 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1266,7 +1266,10 @@ def nbytes(frame, _bytes_like=(bytes, bytearray)): if isinstance(frame, _bytes_like): return len(frame) else: - return frame.nbytes + try: + return frame.nbytes + except AttributeError: + return len(frame) def PeriodicCallback(callback, callback_time, io_loop=None):