Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions distributed/protocol/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
import logging
import random

from dask.context import _globals
from toolz import identity, partial

try:
import blosc
n = blosc.set_nthreads(2)
except ImportError:
blosc = False

from ..config import config
from ..utils import ignoring, ensure_bytes

Expand Down Expand Up @@ -92,8 +99,7 @@ def byte_sample(b, size, n):
return b''.join(map(ensure_bytes, parts))


def maybe_compress(payload, compression=default_compression, min_size=1e4,
sample_size=1e4, nsamples=5):
def maybe_compress(payload, min_size=1e4, sample_size=1e4, nsamples=5):
"""
Maybe compress payload

Expand All @@ -104,11 +110,13 @@ def maybe_compress(payload, compression=default_compression, min_size=1e4,
return the original
4. We return the compressed result
"""
compression = _globals.get('compression', default_compression)

if not compression:
return None, payload
if len(payload) < min_size:
return None, payload
if len(payload) > 2**31:
if len(payload) > 2**31: # Too large, compression libraries often fail
return None, payload

min_size = int(min_size)
Expand All @@ -118,11 +126,22 @@ def maybe_compress(payload, compression=default_compression, min_size=1e4,

# Compress a sample, return original if not very compressed
sample = byte_sample(payload, sample_size, nsamples)
if len(compress(sample)) > 0.9 * len(sample): # not very compressible
if len(compress(sample)) > 0.9 * len(sample): # sample not very compressible
return None, payload

compressed = compress(ensure_bytes(payload))
if len(compressed) > 0.9 * len(payload): # not very compressible
if type(payload) is memoryview:
nbytes = payload.itemsize * len(payload)
else:
nbytes = len(payload)

if blosc and type(payload) is memoryview:
compressed = blosc.compress(payload, typesize=payload.itemsize,
cname='lz4', clevel=5)
compression = 'blosc'
else:
compressed = compress(ensure_bytes(payload))

if len(compressed) > 0.9 * nbytes: # full data not very compressible
return None, payload
else:
return compression, compressed
26 changes: 2 additions & 24 deletions distributed/protocol/numpy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import print_function, division, absolute_import

import sys

import numpy as np
from numpy.lib import stride_tricks

Expand All @@ -11,12 +9,11 @@
except ImportError:
blosc = False

from .compression import byte_sample
from .utils import frame_split_size
from .serialize import register_serialization
from . import pickle

from ..utils import log_errors, ensure_bytes
from ..utils import log_errors


def itemsize(dt):
Expand All @@ -36,8 +33,6 @@ def serialize_numpy_ndarray(x):
frames = [pickle.dumps(x)]
return header, frames

size = itemsize(x.dtype)

if x.dtype.kind == 'V':
dt = x.dtype.descr
else:
Expand All @@ -59,25 +54,8 @@ def serialize_numpy_ndarray(x):
'shape': x.shape,
'strides': strides}

if blosc and x.nbytes > 1e5:
if x.nbytes > 1e5:
frames = frame_split_size([data])
if sys.version_info.major == 2:
frames = [ensure_bytes(frame) for frame in frames]

out = []
compression = []
for frame in frames:
sample = byte_sample(frame, 10000 // size * size, 5)
csample = blosc.compress(sample, typesize=size, cname='lz4', clevel=3)
if len(csample) < 0.8 * len(sample):
compressed = blosc.compress(frame, typesize=size, cname='lz4', clevel=5)
out.append(compressed)
compression.append('blosc')
else:
out.append(frame)
compression.append(None)
header['compression'] = compression
frames = out
else:
frames = [data]

Expand Down
19 changes: 15 additions & 4 deletions distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from distributed.protocol import (serialize, deserialize, decompress, dumps,
loads, to_serialize)
loads, to_serialize, msgpack)
from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE
from distributed.utils import tmpfile
from distributed.utils_test import slow, gen_cluster
Expand Down Expand Up @@ -115,10 +115,20 @@ def test_itemsize(dt, size):


def test_compress_numpy():
pytest.importorskip('lz4')
x = np.ones(10000000, dtype='i4')
compression, compressed = maybe_compress(x.data)
if compression:
assert len(compressed) < x.nbytes
frames = dumps({'x': to_serialize(x)})
assert sum(map(len, frames)) < x.nbytes

header = msgpack.loads(frames[2], encoding='utf8', use_list=False)
try:
import blosc
except ImportError:
pass
else:
assert all(c == 'blosc' for c in
header['headers'][('x',)]['compression'])



def test_compress_memoryview():
Expand All @@ -128,6 +138,7 @@ def test_compress_memoryview():
assert len(compressed) < len(mv)


@pytest.mark.skip
def test_dont_compress_uncompressable_data():
blosc = pytest.importorskip('blosc')
x = np.random.randint(0, 255, size=100000).astype('uint8')
Expand Down
30 changes: 24 additions & 6 deletions distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys

import dask
import pytest

from distributed.protocol import (loads, dumps, msgpack, maybe_compress,
Expand Down Expand Up @@ -63,20 +64,22 @@ def test_small_and_big():
def test_maybe_compress():
import zlib
payload = b'123'
assert maybe_compress(payload, None) == (None, payload)
assert maybe_compress(payload, 'zlib') == (None, payload)
with dask.set_options(compression=None):
assert maybe_compress(payload) == (None, payload)

assert maybe_compress(b'111', 'zlib') == (None, b'111')
with dask.set_options(compression='zlib'):
assert maybe_compress(payload) == (None, payload)
assert maybe_compress(b'111') == (None, b'111')

payload = b'0' * 10000
assert maybe_compress(payload, 'zlib') == ('zlib', zlib.compress(payload))
payload = b'0' * 10000
assert maybe_compress(payload) == ('zlib', zlib.compress(payload))


def test_maybe_compress_sample():
np = pytest.importorskip('numpy')
lz4 = pytest.importorskip('lz4')
payload = np.random.randint(0, 255, size=10000).astype('u1').tobytes()
fmt, compressed = maybe_compress(payload, 'lz4')
fmt, compressed = maybe_compress(payload)
assert fmt == None
assert compressed == payload

Expand Down Expand Up @@ -191,3 +194,18 @@ def test_dumps_loads_Serialized():

result3 = loads(frames2)
assert result == result3


def test_maybe_compress_memoryviews():
np = pytest.importorskip('numpy')
pytest.importorskip('lz4')
x = np.arange(1000000, dtype='int64')
compression, payload = maybe_compress(x.data)
try:
import blosc
except ImportError:
assert compression == 'lz4'
assert len(payload) < x.nbytes * 0.75
else:
assert compression == 'blosc'
assert len(payload) < x.nbytes / 10