Skip to content

Commit

Permalink
Merge pull request #407 from sontek/add_compression_serde
Browse files Browse the repository at this point in the history
add pluggable compression serde
  • Loading branch information
jogo authored Sep 12, 2022
2 parents 3f568ec + dc325b2 commit 6b85dea
Show file tree
Hide file tree
Showing 6 changed files with 443 additions and 36 deletions.
55 changes: 51 additions & 4 deletions pymemcache/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
import logging
from io import BytesIO
import pickle
import zlib

FLAG_BYTES = 0
FLAG_PICKLE = 1 << 0
FLAG_INTEGER = 1 << 1
FLAG_LONG = 1 << 2
FLAG_COMPRESSED = 1 << 3 # unused, to main compatibility with python-memcached
FLAG_COMPRESSED = 1 << 3
FLAG_TEXT = 1 << 4

# Pickle protocol version (highest available to runtime)
Expand Down Expand Up @@ -121,6 +122,55 @@ def deserialize(self, key, value, flags):
return python_memcache_deserializer(key, value, flags)


pickle_serde = PickleSerde()


class CompressedSerde:
"""
An object which implements the serialization/deserialization protocol for
:py:class:`pymemcache.client.base.Client` and its descendants with
configurable compression.
"""

def __init__(
self,
compress=zlib.compress,
decompress=zlib.decompress,
serde=pickle_serde,
# Discovered via the `test_optimal_compression_length` test.
min_compress_len=400,
):
self._serde = serde
self._compress = compress
self._decompress = decompress
self._min_compress_len = min_compress_len

def serialize(self, key, value):
value, flags = self._serde.serialize(key, value)

if len(value) > self._min_compress_len > 0:
old_value = value
value = self._compress(value)
# Don't use the compressed value if our end result is actually
# larger uncompressed.
if len(old_value) < len(value):
value = old_value
else:
flags |= FLAG_COMPRESSED

return value, flags

def deserialize(self, key, value, flags):
if flags & FLAG_COMPRESSED:
value = self._decompress(value)

value = self._serde.deserialize(key, value, flags)
return value


compressed_serde = CompressedSerde()


class LegacyWrappingSerde:
"""
This class defines how to wrap legacy de/serialization functions into a
Expand All @@ -141,6 +191,3 @@ def _default_serialize(self, key, value):

def _default_deserialize(self, key, value, flags):
return value


pickle_serde = PickleSerde()
8 changes: 5 additions & 3 deletions pymemcache/test/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,16 @@ def benchmark(count, func, *args, **kwargs):

@pytest.mark.benchmark()
def test_bench_get(request, client, pairs, count):
key, value = next(pairs)
key = "pymemcache_test:0"
value = pairs[key]
client.set(key, value)
benchmark(count, client.get, key)


@pytest.mark.benchmark()
def test_bench_set(request, client, pairs, count):
key, value = next(pairs.items())
key = "pymemcache_test:0"
value = pairs[key]
benchmark(count, client.set, key, value)


Expand All @@ -113,4 +115,4 @@ def test_bench_delete(request, client, pairs, count):
@pytest.mark.benchmark()
def test_bench_delete_multi(request, client, pairs, count):
# deleting missing key takes the same work client-side as real keys
benchmark(count, client.delete_multi, list(pairs))
benchmark(count, client.delete_multi, list(pairs.keys()))
220 changes: 220 additions & 0 deletions pymemcache/test/test_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from pymemcache.client.base import Client
from pymemcache.serde import (
CompressedSerde,
pickle_serde,
)

from faker import Faker

import pytest
import random
import string
import time
import zstd # type: ignore
import zlib

fake = Faker(["it_IT", "en_US", "ja_JP"])


def get_random_string(length):
letters = string.ascii_letters
chars = string.punctuation
digits = string.digits
total = letters + chars + digits
result_str = "".join(random.choice(total) for i in range(length))
return result_str


class CustomObject:
"""
Custom class for verifying serialization
"""

def __init__(self):
self.number = random.randint(0, 100)
self.string = fake.text()
self.object = fake.profile()


class CustomObjectValue:
def __init__(self, value):
self.value = value


def benchmark(count, func, *args, **kwargs):
start = time.time()

for _ in range(count):
result = func(*args, **kwargs)

duration = time.time() - start
print(str(duration))

return result


@pytest.fixture(scope="session")
def names():
names = []
for _ in range(15):
names.append(fake.name())

return names


@pytest.fixture(scope="session")
def paragraphs():
paragraphs = []
for _ in range(15):
paragraphs.append(fake.text())

return paragraphs


@pytest.fixture(scope="session")
def objects():
objects = []
for _ in range(15):
objects.append(CustomObject())

return objects


# Always run compression for the benchmarks
min_compress_len = 1

default_serde = CompressedSerde(min_compress_len=min_compress_len)

zlib_serde = CompressedSerde(
compress=lambda value: zlib.compress(value, 9),
decompress=lambda value: zlib.decompress(value),
min_compress_len=min_compress_len,
)

zstd_serde = CompressedSerde(
compress=lambda value: zstd.compress(value),
decompress=lambda value: zstd.decompress(value),
min_compress_len=min_compress_len,
)

serializers = [
None,
default_serde,
zlib_serde,
zstd_serde,
]
ids = ["none", "zlib ", "zlib9", "zstd "]


@pytest.mark.benchmark()
@pytest.mark.parametrize("serde", serializers, ids=ids)
def test_bench_compress_set_strings(count, host, port, serde, names):
client = Client((host, port), serde=serde, encoding="utf-8")

def test():
for index, name in enumerate(names):
key = f"name_{index}"
client.set(key, name)

benchmark(count, test)


@pytest.mark.benchmark()
@pytest.mark.parametrize("serde", serializers, ids=ids)
def test_bench_compress_get_strings(count, host, port, serde, names):
client = Client((host, port), serde=serde, encoding="utf-8")
for index, name in enumerate(names):
key = f"name_{index}"
client.set(key, name)

def test():
for index, _ in enumerate(names):
key = f"name_{index}"
client.get(key)

benchmark(count, test)


@pytest.mark.benchmark()
@pytest.mark.parametrize("serde", serializers, ids=ids)
def test_bench_compress_set_large_strings(count, host, port, serde, paragraphs):
client = Client((host, port), serde=serde, encoding="utf-8")

def test():
for index, p in enumerate(paragraphs):
key = f"paragraph_{index}"
client.set(key, p)

benchmark(count, test)


@pytest.mark.benchmark()
@pytest.mark.parametrize("serde", serializers, ids=ids)
def test_bench_compress_get_large_strings(count, host, port, serde, paragraphs):
client = Client((host, port), serde=serde, encoding="utf-8")
for index, p in enumerate(paragraphs):
key = f"paragraphs_{index}"
client.set(key, p)

def test():
for index, _ in enumerate(paragraphs):
key = f"paragraphs_{index}"
client.get(key)

benchmark(count, test)


@pytest.mark.benchmark()
@pytest.mark.parametrize("serde", serializers, ids=ids)
def test_bench_compress_set_objects(count, host, port, serde, objects):
client = Client((host, port), serde=serde, encoding="utf-8")

def test():
for index, o in enumerate(objects):
key = f"objects_{index}"
client.set(key, o)

benchmark(count, test)


@pytest.mark.benchmark()
@pytest.mark.parametrize("serde", serializers, ids=ids)
def test_bench_compress_get_objects(count, host, port, serde, objects):
client = Client((host, port), serde=serde, encoding="utf-8")
for index, o in enumerate(objects):
key = f"objects_{index}"
client.set(key, o)

def test():
for index, _ in enumerate(objects):
key = f"objects_{index}"
client.get(key)

benchmark(count, test)


@pytest.mark.benchmark()
def test_optimal_compression_length():
for length in range(5, 2000):
input_data = get_random_string(length)
start = len(input_data)

for index, serializer in enumerate(serializers[1:]):
name = ids[index + 1]
value, _ = serializer.serialize("foo", input_data)
end = len(value)
print(f"serializer={name}\t start={start}\t end={end}")


@pytest.mark.benchmark()
def test_optimal_compression_length_objects():
for length in range(5, 2000):
input_data = get_random_string(length)
obj = CustomObjectValue(input_data)
start = len(pickle_serde.serialize("foo", obj)[0])

for index, serializer in enumerate(serializers[1:]):
name = ids[index + 1]
value, _ = serializer.serialize("foo", obj)
end = len(value)
print(f"serializer={name}\t start={start}\t end={end}")
Loading

0 comments on commit 6b85dea

Please sign in to comment.