diff --git a/distributed/protocol/compression.py b/distributed/protocol/compression.py index fd1763e5131..c4dd80d22be 100644 --- a/distributed/protocol/compression.py +++ b/distributed/protocol/compression.py @@ -8,8 +8,15 @@ import logging import random +from dask.context import _globals from toolz import identity, partial +try: + import blosc + n = blosc.set_nthreads(2) +except ImportError: + blosc = False + from ..config import config from ..utils import ignoring, ensure_bytes @@ -92,8 +99,7 @@ def byte_sample(b, size, n): return b''.join(map(ensure_bytes, parts)) -def maybe_compress(payload, compression=default_compression, min_size=1e4, - sample_size=1e4, nsamples=5): +def maybe_compress(payload, min_size=1e4, sample_size=1e4, nsamples=5): """ Maybe compress payload @@ -104,11 +110,13 @@ def maybe_compress(payload, compression=default_compression, min_size=1e4, return the original 4. We return the compressed result """ + compression = _globals.get('compression', default_compression) + if not compression: return None, payload if len(payload) < min_size: return None, payload - if len(payload) > 2**31: + if len(payload) > 2**31: # Too large, compression libraries often fail return None, payload min_size = int(min_size) @@ -118,11 +126,22 @@ def maybe_compress(payload, compression=default_compression, min_size=1e4, # Compress a sample, return original if not very compressed sample = byte_sample(payload, sample_size, nsamples) - if len(compress(sample)) > 0.9 * len(sample): # not very compressible + if len(compress(sample)) > 0.9 * len(sample): # sample not very compressible return None, payload - compressed = compress(ensure_bytes(payload)) - if len(compressed) > 0.9 * len(payload): # not very compressible + if type(payload) is memoryview: + nbytes = payload.itemsize * len(payload) + else: + nbytes = len(payload) + + if blosc and type(payload) is memoryview: + compressed = blosc.compress(payload, typesize=payload.itemsize, + cname='lz4', clevel=5) + compression = 'blosc' + else: + compressed = compress(ensure_bytes(payload)) + + if len(compressed) > 0.9 * nbytes: # full data not very compressible return None, payload else: return compression, compressed diff --git a/distributed/protocol/numpy.py b/distributed/protocol/numpy.py index 7960c9ba322..6827cda7875 100644 --- a/distributed/protocol/numpy.py +++ b/distributed/protocol/numpy.py @@ -1,7 +1,5 @@ from __future__ import print_function, division, absolute_import -import sys - import numpy as np from numpy.lib import stride_tricks @@ -11,12 +9,11 @@ except ImportError: blosc = False -from .compression import byte_sample from .utils import frame_split_size from .serialize import register_serialization from . import pickle -from ..utils import log_errors, ensure_bytes +from ..utils import log_errors def itemsize(dt): @@ -36,8 +33,6 @@ def serialize_numpy_ndarray(x): frames = [pickle.dumps(x)] return header, frames - size = itemsize(x.dtype) - if x.dtype.kind == 'V': dt = x.dtype.descr else: @@ -59,25 +54,8 @@ def serialize_numpy_ndarray(x): 'shape': x.shape, 'strides': strides} - if blosc and x.nbytes > 1e5: + if x.nbytes > 1e5: frames = frame_split_size([data]) - if sys.version_info.major == 2: - frames = [ensure_bytes(frame) for frame in frames] - - out = [] - compression = [] - for frame in frames: - sample = byte_sample(frame, 10000 // size * size, 5) - csample = blosc.compress(sample, typesize=size, cname='lz4', clevel=3) - if len(csample) < 0.8 * len(sample): - compressed = blosc.compress(frame, typesize=size, cname='lz4', clevel=5) - out.append(compressed) - compression.append('blosc') - else: - out.append(frame) - compression.append(None) - header['compression'] = compression - frames = out else: frames = [data] diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 47373f54852..507abe91609 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -6,7 +6,7 @@ import pytest from distributed.protocol import (serialize, deserialize, decompress, dumps, - loads, to_serialize) + loads, to_serialize, msgpack) from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE from distributed.utils import tmpfile from distributed.utils_test import slow, gen_cluster @@ -115,10 +115,20 @@ def test_itemsize(dt, size): def test_compress_numpy(): + pytest.importorskip('lz4') x = np.ones(10000000, dtype='i4') - compression, compressed = maybe_compress(x.data) - if compression: - assert len(compressed) < x.nbytes + frames = dumps({'x': to_serialize(x)}) + assert sum(map(len, frames)) < x.nbytes + + header = msgpack.loads(frames[2], encoding='utf8', use_list=False) + try: + import blosc + except ImportError: + pass + else: + assert all(c == 'blosc' for c in + header['headers'][('x',)]['compression']) + def test_compress_memoryview(): @@ -128,6 +138,7 @@ def test_compress_memoryview(): assert len(compressed) < len(mv) +@pytest.mark.skip def test_dont_compress_uncompressable_data(): blosc = pytest.importorskip('blosc') x = np.random.randint(0, 255, size=100000).astype('uint8') diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index 3e31029ba65..9a57f7131a8 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -2,6 +2,7 @@ import sys +import dask import pytest from distributed.protocol import (loads, dumps, msgpack, maybe_compress, @@ -63,20 +64,22 @@ def test_small_and_big(): def test_maybe_compress(): import zlib payload = b'123' - assert maybe_compress(payload, None) == (None, payload) - assert maybe_compress(payload, 'zlib') == (None, payload) + with dask.set_options(compression=None): + assert maybe_compress(payload) == (None, payload) - assert maybe_compress(b'111', 'zlib') == (None, b'111') + with dask.set_options(compression='zlib'): + assert maybe_compress(payload) == (None, payload) + assert maybe_compress(b'111') == (None, b'111') - payload = b'0' * 10000 - assert maybe_compress(payload, 'zlib') == ('zlib', zlib.compress(payload)) + payload = b'0' * 10000 + assert maybe_compress(payload) == ('zlib', zlib.compress(payload)) def test_maybe_compress_sample(): np = pytest.importorskip('numpy') lz4 = pytest.importorskip('lz4') payload = np.random.randint(0, 255, size=10000).astype('u1').tobytes() - fmt, compressed = maybe_compress(payload, 'lz4') + fmt, compressed = maybe_compress(payload) assert fmt == None assert compressed == payload @@ -191,3 +194,18 @@ def test_dumps_loads_Serialized(): result3 = loads(frames2) assert result == result3 + + +def test_maybe_compress_memoryviews(): + np = pytest.importorskip('numpy') + pytest.importorskip('lz4') + x = np.arange(1000000, dtype='int64') + compression, payload = maybe_compress(x.data) + try: + import blosc + except ImportError: + assert compression == 'lz4' + assert len(payload) < x.nbytes * 0.75 + else: + assert compression == 'blosc' + assert len(payload) < x.nbytes / 10