Skip to content

Commit

Permalink
Make pickle version for python_memcache_serializer adjustable
Browse files Browse the repository at this point in the history
It's unsafe to use the max pickle version when you are switching between
versions of python with different max versions.

Add a new function get_python_memcache_serializer_pickle_version that
returns a python_memcache_serializer with any pickle version.
  • Loading branch information
jogo committed Sep 7, 2018
1 parent 9e2a42c commit d77679a
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 8 deletions.
18 changes: 14 additions & 4 deletions pymemcache/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
import logging
from io import BytesIO
import six
Expand All @@ -30,16 +31,16 @@
FLAG_COMPRESSED = 1 << 3 # unused, to main compatibility with python-memcached
FLAG_TEXT = 1 << 4

# Pickle protocol version (-1 for highest available to runtime)
# Pickle protocol version (highest available to runtime)
# Warning with `0`: If somewhere in your value lies a slotted object,
# ie defines `__slots__`, even if you do not include it in your pickleable
# state via `__getstate__`, python will complain with something like:
# TypeError: a class that defines __slots__ without defining __getstate__
# cannot be pickled
PICKLE_VERSION = -1
DEFAULT_PICKLE_VERSION = pickle.HIGHEST_PROTOCOL


def python_memcache_serializer(key, value):
def _python_memcache_serializer(key, value, pickle_version=None):
flags = 0
value_type = type(value)

Expand All @@ -63,13 +64,22 @@ def python_memcache_serializer(key, value):
else:
flags |= FLAG_PICKLE
output = BytesIO()
pickler = pickle.Pickler(output, PICKLE_VERSION)
pickler = pickle.Pickler(output, pickle_version)
pickler.dump(value)
value = output.getvalue()

return value, flags


def get_python_memcache_serializer_pickle_version(pickle_version=None):
"""Return a serializer using a specific pickle version"""
return partial(_python_memcache_serializer, pickle_version=pickle_version)


python_memcache_serializer = partial(
_python_memcache_serializer, pickle_version=DEFAULT_PICKLE_VERSION)


def python_memcache_deserializer(key, value, flags):
if flags == 0:
return value
Expand Down
29 changes: 26 additions & 3 deletions pymemcache/test/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MemcacheClientError
)
from pymemcache.serde import (
get_python_memcache_serializer_pickle_version,
python_memcache_serializer,
python_memcache_deserializer
)
Expand Down Expand Up @@ -250,15 +251,15 @@ def _des(key, value, flags):
assert result == value


@pytest.mark.integration()
def test_serde_serialization(client_class, host, port, socket_module):
def serde_serialization_helper(client_class, host, port,
socket_module, serializer):
def check(value):
client.set(b'key', value, noreply=False)
result = client.get(b'key')
assert result == value
assert type(result) is type(value)

client = client_class((host, port), serializer=python_memcache_serializer,
client = client_class((host, port), serializer=serializer,
deserializer=python_memcache_deserializer,
socket_module=socket_module)
client.flush_all()
Expand All @@ -277,6 +278,28 @@ def check(value):
check(testdict)


@pytest.mark.integration()
def test_serde_serialization(client_class, host, port, socket_module):
serde_serialization_helper(client_class, host, port,
socket_module, python_memcache_serializer)


@pytest.mark.integration()
def test_serde_serialization0(client_class, host, port, socket_module):
serde_serialization_helper(
client_class, host, port,
socket_module,
get_python_memcache_serializer_pickle_version(pickle_version=0))


@pytest.mark.integration()
def test_serde_serialization2(client_class, host, port, socket_module):
serde_serialization_helper(
client_class, host, port,
socket_module,
get_python_memcache_serializer_pickle_version(pickle_version=2))


@pytest.mark.integration()
def test_errors(client_class, host, port, socket_module):
client = client_class((host, port), socket_module=socket_module)
Expand Down
26 changes: 25 additions & 1 deletion pymemcache/test/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from unittest import TestCase

from pymemcache.serde import (python_memcache_serializer,
get_python_memcache_serializer_pickle_version,
python_memcache_deserializer, FLAG_BYTES,
FLAG_PICKLE, FLAG_INTEGER, FLAG_LONG, FLAG_TEXT)
import pytest
import six
from six.moves import cPickle as pickle


class CustomInt(int):
Expand All @@ -20,9 +22,10 @@ class CustomInt(int):

@pytest.mark.unit()
class TestSerde(TestCase):
serializer = python_memcache_serializer

def check(self, value, expected_flags):
serialized, flags = python_memcache_serializer(b'key', value)
serialized, flags = self.serializer(b'key', value)
assert flags == expected_flags

# pymemcache stores values as byte strings, so we immediately the value
Expand Down Expand Up @@ -59,3 +62,24 @@ def test_pickleable(self):
def test_subtype(self):
# Subclass of a native type will be restored as the same type
self.check(CustomInt(123123), FLAG_PICKLE)


@pytest.mark.unit()
class TestSerdePickleVersion0(TestCase):
serializer = get_python_memcache_serializer_pickle_version(pickle_version=0)


@pytest.mark.unit()
class TestSerdePickleVersion1(TestCase):
serializer = get_python_memcache_serializer_pickle_version(pickle_version=1)


@pytest.mark.unit()
class TestSerdePickleVersion2(TestCase):
serializer = get_python_memcache_serializer_pickle_version(pickle_version=2)


@pytest.mark.unit()
class TestSerdePickleVersionHighest(TestCase):
serializer = get_python_memcache_serializer_pickle_version(
pickle_version=pickle.HIGHEST_PROTOCOL)

0 comments on commit d77679a

Please sign in to comment.