diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index 05db0818..7dcad12e 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -90,29 +90,19 @@ def _check_key(key, allow_unicode_keys, key_prefix=b''): try: key = key.encode('ascii') except (UnicodeEncodeError, UnicodeDecodeError): - raise MemcacheIllegalInputError("Non-ASCII key: '%r'" % (key,)) + raise MemcacheIllegalInputError("Non-ASCII key: '%r'" % key) + key = key_prefix + key + parts = key.split() if len(key) > 250: - raise MemcacheIllegalInputError("Key is too long: '%r'" % (key,)) - - for c in bytearray(key): - if c == ord(b' '): - raise MemcacheIllegalInputError( - "Key contains space: '%r'" % (key,) - ) - elif c == ord(b'\n'): - raise MemcacheIllegalInputError( - "Key contains newline: '%r'" % (key,) - ) - elif c == ord(b'\00'): - raise MemcacheIllegalInputError( - "Key contains null character: '%r'" % (key,) - ) - elif c == ord(b'\r'): - raise MemcacheIllegalInputError( - "Key contains carriage return: '%r'" % (key,) - ) + raise MemcacheIllegalInputError("Key is too long: '%r'" % key) + # second statement catches leading or trailing whitespace + elif len(parts) > 1 or parts[0] != key: + raise MemcacheIllegalInputError("Key contains whitespace: '%r'" % key) + elif b'\00' in key: + raise MemcacheIllegalInputError("Key contains null: '%r'" % key) + return key @@ -765,6 +755,14 @@ def _fetch_cmd(self, name, keys, expect_cas): def _store_cmd(self, name, values, expire, noreply, cas=None): cmds = [] keys = [] + + extra = b'' + if cas is not None: + extra += b' ' + cas + if noreply: + extra += b' noreply' + expire = six.text_type(expire).encode('ascii') + for key, data in six.iteritems(values): # must be able to reliably map responses back to the original order keys.append(key) @@ -781,15 +779,9 @@ def _store_cmd(self, name, values, expire, noreply, cas=None): except UnicodeEncodeError as e: raise MemcacheIllegalInputError(str(e)) - extra = b'' - if cas is not None: - extra += b' ' + cas - if noreply: - extra += b' noreply' - cmds.append(name + b' ' + key + b' ' + six.text_type(flags).encode('ascii') + - b' ' + six.text_type(expire).encode('ascii') + + b' ' + expire + b' ' + six.text_type(len(data)).encode('ascii') + extra + b'\r\n' + data + b'\r\n')