diff --git a/pymemcache/serde.py b/pymemcache/serde.py index 23cbd7b9..f3cf55ac 100644 --- a/pymemcache/serde.py +++ b/pymemcache/serde.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import logging from io import BytesIO import six @@ -30,16 +31,16 @@ FLAG_COMPRESSED = 1 << 3 # unused, to main compatibility with python-memcached FLAG_TEXT = 1 << 4 -# Pickle protocol version (-1 for highest available to runtime) +# Pickle protocol version (highest available to runtime) # Warning with `0`: If somewhere in your value lies a slotted object, # ie defines `__slots__`, even if you do not include it in your pickleable # state via `__getstate__`, python will complain with something like: # TypeError: a class that defines __slots__ without defining __getstate__ # cannot be pickled -PICKLE_VERSION = -1 +DEFAULT_PICKLE_VERSION = pickle.HIGHEST_PROTOCOL -def python_memcache_serializer(key, value): +def _python_memcache_serializer(key, value, pickle_version=None): flags = 0 value_type = type(value) @@ -63,13 +64,22 @@ def python_memcache_serializer(key, value): else: flags |= FLAG_PICKLE output = BytesIO() - pickler = pickle.Pickler(output, PICKLE_VERSION) + pickler = pickle.Pickler(output, pickle_version) pickler.dump(value) value = output.getvalue() return value, flags +def get_python_memcache_serializer_pickle_version(pickle_version=None): + """Return a serializer using a specific pickle version""" + return partial(_python_memcache_serializer, pickle_version=pickle_version) + + +python_memcache_serializer = partial( + _python_memcache_serializer, pickle_version=DEFAULT_PICKLE_VERSION) + + def python_memcache_deserializer(key, value, flags): if flags == 0: return value diff --git a/pymemcache/test/test_integration.py b/pymemcache/test/test_integration.py index 4a827cdf..c38e33fd 100644 --- a/pymemcache/test/test_integration.py +++ b/pymemcache/test/test_integration.py @@ -26,6 +26,7 @@ MemcacheClientError ) from pymemcache.serde import ( + get_python_memcache_serializer_pickle_version, python_memcache_serializer, python_memcache_deserializer ) @@ -250,15 +251,15 @@ def _des(key, value, flags): assert result == value -@pytest.mark.integration() -def test_serde_serialization(client_class, host, port, socket_module): +def serde_serialization_helper(client_class, host, port, + socket_module, serializer): def check(value): client.set(b'key', value, noreply=False) result = client.get(b'key') assert result == value assert type(result) is type(value) - client = client_class((host, port), serializer=python_memcache_serializer, + client = client_class((host, port), serializer=serializer, deserializer=python_memcache_deserializer, socket_module=socket_module) client.flush_all() @@ -277,6 +278,28 @@ def check(value): check(testdict) +@pytest.mark.integration() +def test_serde_serialization(client_class, host, port, socket_module): + serde_serialization_helper(client_class, host, port, + socket_module, python_memcache_serializer) + + +@pytest.mark.integration() +def test_serde_serialization0(client_class, host, port, socket_module): + serde_serialization_helper( + client_class, host, port, + socket_module, + get_python_memcache_serializer_pickle_version(pickle_version=0)) + + +@pytest.mark.integration() +def test_serde_serialization2(client_class, host, port, socket_module): + serde_serialization_helper( + client_class, host, port, + socket_module, + get_python_memcache_serializer_pickle_version(pickle_version=2)) + + @pytest.mark.integration() def test_errors(client_class, host, port, socket_module): client = client_class((host, port), socket_module=socket_module) diff --git a/pymemcache/test/test_serde.py b/pymemcache/test/test_serde.py index 04c8e079..538e177b 100644 --- a/pymemcache/test/test_serde.py +++ b/pymemcache/test/test_serde.py @@ -2,10 +2,12 @@ from unittest import TestCase from pymemcache.serde import (python_memcache_serializer, + get_python_memcache_serializer_pickle_version, python_memcache_deserializer, FLAG_BYTES, FLAG_PICKLE, FLAG_INTEGER, FLAG_LONG, FLAG_TEXT) import pytest import six +from six.moves import cPickle as pickle class CustomInt(int): @@ -20,9 +22,10 @@ class CustomInt(int): @pytest.mark.unit() class TestSerde(TestCase): + serializer = python_memcache_serializer def check(self, value, expected_flags): - serialized, flags = python_memcache_serializer(b'key', value) + serialized, flags = self.serializer(b'key', value) assert flags == expected_flags # pymemcache stores values as byte strings, so we immediately the value @@ -59,3 +62,24 @@ def test_pickleable(self): def test_subtype(self): # Subclass of a native type will be restored as the same type self.check(CustomInt(123123), FLAG_PICKLE) + + +@pytest.mark.unit() +class TestSerdePickleVersion0(TestCase): + serializer = get_python_memcache_serializer_pickle_version(pickle_version=0) + + +@pytest.mark.unit() +class TestSerdePickleVersion1(TestCase): + serializer = get_python_memcache_serializer_pickle_version(pickle_version=1) + + +@pytest.mark.unit() +class TestSerdePickleVersion2(TestCase): + serializer = get_python_memcache_serializer_pickle_version(pickle_version=2) + + +@pytest.mark.unit() +class TestSerdePickleVersionHighest(TestCase): + serializer = get_python_memcache_serializer_pickle_version( + pickle_version=pickle.HIGHEST_PROTOCOL)