From aba7f3ffd0fe76f0a6068c05133b5bb2fef94909 Mon Sep 17 00:00:00 2001 From: Joe Gordon Date: Wed, 17 Aug 2022 10:01:18 -0700 Subject: [PATCH] Start to add type hints First pass at adding some type hints to pymemcache to make it easier to develop against etc. --- pymemcache/client/base.py | 178 +++++++++++++++++++++++++------------- pymemcache/client/hash.py | 2 +- pymemcache/serde.py | 10 +-- pyproject.toml | 1 + 4 files changed, 126 insertions(+), 65 deletions(-) diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index 3d4234c1..ef6bcacb 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -16,20 +16,18 @@ from functools import partial import platform import socket -from typing import Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Callable, Iterable from pymemcache import pool - -from pymemcache.serde import LegacyWrappingSerde from pymemcache.exceptions import ( MemcacheClientError, - MemcacheUnknownCommandError, MemcacheIllegalInputError, MemcacheServerError, - MemcacheUnknownError, MemcacheUnexpectedCloseError, + MemcacheUnknownCommandError, + MemcacheUnknownError, ) - +from pymemcache.serde import LegacyWrappingSerde RECV_SIZE = 4096 VALID_STORE_RESULTS = { @@ -53,23 +51,24 @@ } ServerSpec = Union[Tuple[str, int], str] +Key = Union[bytes, str] # Some of the values returned by the "stats" command # need mapping into native Python types -def _parse_bool_int(value): +def _parse_bool_int(value: bytes) -> bool: return int(value) != 0 -def _parse_bool_string_is_yes(value): +def _parse_bool_string_is_yes(value: bytes) -> bool: return value == b"yes" -def _parse_float(value): +def _parse_float(value: bytes) -> float: return float(value.replace(b":", b".")) -def _parse_hex(value): +def _parse_hex(value: bytes) -> int: return int(value, 8) @@ -96,7 +95,9 @@ def _parse_hex(value): # Common helper functions. -def check_key_helper(key, allow_unicode_keys, key_prefix=b""): +def check_key_helper( + key: Key, allow_unicode_keys: bool, key_prefix: bytes = b"" +) -> bytes: """Checks key and add key_prefix.""" if allow_unicode_keys: if isinstance(key, str): @@ -160,7 +161,7 @@ class KeepaliveOpts: __slots__ = ("idle", "intvl", "cnt") - def __init__(self, idle=1, intvl=1, cnt=5): + def __init__(self, idle: int = 1, intvl: int = 1, cnt: int = 5) -> None: if idle < 1: raise ValueError("The idle parameter must be greater or equal to 1.") self.idle = idle @@ -275,15 +276,15 @@ def __init__( serializer=None, deserializer=None, connect_timeout=None, - timeout=None, - no_delay=False, - ignore_exc=False, + timeout: Optional[float] = None, + no_delay: bool = False, + ignore_exc: bool = False, socket_module=socket, - socket_keepalive=None, - key_prefix=b"", + socket_keepalive: Optional[KeepaliveOpts] = None, + key_prefix: bytes = b"", default_noreply=True, - allow_unicode_keys=False, - encoding="ascii", + allow_unicode_keys: bool = False, + encoding: str = "ascii", tls_context=None, ): """ @@ -354,7 +355,7 @@ def __init__( "KeepaliveOpts object. That's the only supported type " "of structure." ) - self.sock = None + self.sock: Optional[socket.socket] = None if isinstance(key_prefix, str): key_prefix = key_prefix.encode("ascii") if not isinstance(key_prefix, bytes): @@ -365,13 +366,13 @@ def __init__( self.encoding = encoding self.tls_context = tls_context - def check_key(self, key): + def check_key(self, key: Key) -> bytes: """Checks key and add key_prefix.""" return check_key_helper( key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix ) - def _connect(self): + def _connect(self) -> None: self.close() s = self.socket_module @@ -426,7 +427,7 @@ def _connect(self): self.sock = sock - def close(self): + def close(self) -> None: """Close the connection to memcached, if it is open. The next call to a method that requires a connection will re-open it.""" if self.sock is not None: @@ -439,7 +440,14 @@ def close(self): disconnect_all = close - def set(self, key, value, expire=0, noreply=None, flags=None): + def set( + self, + key: Key, + value: Any, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ) -> Optional[bool]: """ The memcached "set" command. @@ -460,9 +468,17 @@ def set(self, key, value, expire=0, noreply=None, flags=None): """ if noreply is None: noreply = self.default_noreply + # Optional because _store_cmd lookup in STORE_RESULTS_VALUE can return None in some cases. + # TODO: refactor to fix return self._store_cmd(b"set", {key: value}, expire, noreply, flags=flags)[key] - def set_many(self, values, expire=0, noreply=None, flags=None): + def set_many( + self, + values: Dict[Key, Any], + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ) -> List[Key]: """ A convenience function for setting multiple values. @@ -487,7 +503,14 @@ def set_many(self, values, expire=0, noreply=None, flags=None): set_multi = set_many - def add(self, key, value, expire=0, noreply=None, flags=None): + def add( + self, + key: Key, + value: Any, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): """ The memcached "add" command. @@ -608,7 +631,7 @@ def cas(self, key, value, cas, expire=0, noreply=False, flags=None): b"cas", {key: value}, expire, noreply, flags=flags, cas=cas )[key] - def get(self, key, default=None): + def get(self, key: Key, default: Optional[Any] = None): """ The memcached "get" command, but only for one key, as a convenience. @@ -673,7 +696,7 @@ def gets_many(self, keys): return self._fetch_cmd(b"gets", keys, True) - def delete(self, key, noreply=None): + def delete(self, key: Key, noreply=None): """ The memcached "delete" command. @@ -698,7 +721,7 @@ def delete(self, key, noreply=None): return True return results[0] == b"DELETED" - def delete_many(self, keys, noreply=None): + def delete_many(self, keys: Iterable[Key], noreply: Optional[bool] = None) -> bool: """ A convenience function to delete multiple keys. @@ -732,7 +755,9 @@ def delete_many(self, keys, noreply=None): delete_multi = delete_many - def incr(self, key, value, noreply=False): + def incr( + self, key: Key, value: int, noreply: Optional[bool] = False + ) -> Optional[int]: """ The memcached "incr" command. @@ -746,8 +771,8 @@ def incr(self, key, value, noreply=False): value of the key, or None if the key wasn't found. """ key = self.check_key(key) - value = self._check_integer(value, "value") - cmd = b"incr " + key + b" " + value + val = self._check_integer(value, "value") + cmd = b"incr " + key + b" " + val if noreply: cmd += b" noreply" cmd += b"\r\n" @@ -758,7 +783,9 @@ def incr(self, key, value, noreply=False): return None return int(results[0]) - def decr(self, key, value, noreply=False): + def decr( + self, key: Key, value: int, noreply: Optional[bool] = False + ) -> Optional[int]: """ The memcached "decr" command. @@ -772,8 +799,8 @@ def decr(self, key, value, noreply=False): value of the key, or None if the key wasn't found. """ key = self.check_key(key) - value = self._check_integer(value, "value") - cmd = b"decr " + key + b" " + value + val = self._check_integer(value, "value") + cmd = b"decr " + key + b" " + val if noreply: cmd += b" noreply" cmd += b"\r\n" @@ -784,7 +811,7 @@ def decr(self, key, value, noreply=False): return None return int(results[0]) - def touch(self, key, expire=0, noreply=None): + def touch(self, key: Key, expire: int = 0, noreply: Optional[bool] = None) -> bool: """ The memcached "touch" command. @@ -802,8 +829,8 @@ def touch(self, key, expire=0, noreply=None): if noreply is None: noreply = self.default_noreply key = self.check_key(key) - expire = self._check_integer(expire, "expire") - cmd = b"touch " + key + b" " + expire + expire_bytes = self._check_integer(expire, "expire") + cmd = b"touch " + key + b" " + expire_bytes if noreply: cmd += b" noreply" cmd += b"\r\n" @@ -914,7 +941,7 @@ def flush_all(self, delay=0, noreply=None): return True return results[0] == b"OK" - def quit(self): + def quit(self) -> None: """ The memcached "quit" command. @@ -964,7 +991,7 @@ def _raise_errors(self, line, name): error = line[line.find(b" ") + 1 :] raise MemcacheServerError(error) - def _check_integer(self, value, name): + def _check_integer(self, value: int, name: str) -> bytes: """Check that a value is an integer and encode it as a binary string""" if not isinstance(value, int): raise MemcacheIllegalInputError( @@ -973,7 +1000,7 @@ def _check_integer(self, value, name): return str(value).encode(self.encoding) - def _check_cas(self, cas): + def _check_cas(self, cas: Union[int, str, bytes]) -> bytes: """Check that a value is a valid input for 'cas' -- either an int or a string containing only 0-9 @@ -997,7 +1024,14 @@ def _check_cas(self, cas): return cas - def _extract_value(self, expect_cas, line, buf, remapped_keys, prefixed_keys): + def _extract_value( + self, + expect_cas: bool, + line: bytes, + buf: bytes, + remapped_keys, + prefixed_keys: List[bytes], + ): """ This function is abstracted from _fetch_cmd to support different ways of value extraction. In order to use this feature, _extract_value needs @@ -1009,7 +1043,7 @@ def _extract_value(self, expect_cas, line, buf, remapped_keys, prefixed_keys): try: _, key, flags, size = line.split() except Exception as e: - raise ValueError(f"Unable to parse line {line}: {e}") + raise ValueError(f"Unable to parse line {line!r}: {e}") value = None try: @@ -1025,7 +1059,7 @@ def _extract_value(self, expect_cas, line, buf, remapped_keys, prefixed_keys): else: return key, value, buf - def _fetch_cmd(self, name, keys, expect_cas): + def _fetch_cmd(self, name: bytes, keys: Iterable[Key], expect_cas: bool): prefixed_keys = [self.check_key(k) for k in keys] remapped_keys = dict(zip(prefixed_keys, keys)) @@ -1039,11 +1073,14 @@ def _fetch_cmd(self, name, keys, expect_cas): if self.sock is None: self._connect() + # For typing + assert self.sock is not None + self.sock.sendall(cmd) buf = b"" line = None - result = {} + result: Dict[bytes, bytes] = {} while True: try: buf, line = _readline(self.sock, buf) @@ -1073,7 +1110,15 @@ def _fetch_cmd(self, name, keys, expect_cas): return {} raise - def _store_cmd(self, name, values, expire, noreply, flags=None, cas=None): + def _store_cmd( + self, + name: bytes, + values: Dict[Key, Any], + expire: int, + noreply: bool, + flags: Optional[int] = None, + cas: Optional[bytes] = None, + ) -> Dict[Key, Optional[bool]]: cmds = [] keys = [] @@ -1082,7 +1127,7 @@ def _store_cmd(self, name, values, expire, noreply, flags=None, cas=None): extra += b" " + cas if noreply: extra += b" noreply" - expire = self._check_integer(expire, "expire") + expire_bytes = self._check_integer(expire, "expire") for key, data in values.items(): # must be able to reliably map responses back to the original order @@ -1111,7 +1156,7 @@ def _store_cmd(self, name, values, expire, noreply, flags=None, cas=None): + b" " + str(data_flags).encode(self.encoding) + b" " - + expire + + expire_bytes + b" " + str(len(data)).encode(self.encoding) + extra @@ -1123,6 +1168,9 @@ def _store_cmd(self, name, values, expire, noreply, flags=None, cas=None): if self.sock is None: self._connect() + # For typing + assert self.sock is not None + try: self.sock.sendall(b"".join(cmds)) if noreply: @@ -1148,10 +1196,17 @@ def _store_cmd(self, name, values, expire, noreply, flags=None, cas=None): self.close() raise - def _misc_cmd(self, cmds, cmd_name, noreply, end_tokens=None): + def _misc_cmd( + self, + cmds: Iterable[bytes], + cmd_name: bytes, + noreply: Optional[bool], + end_tokens=None, + ) -> List[bytes]: # If no end_tokens have been given, just assume standard memcached # operations, which end in "\r\n", use regular code for that. + _reader: Callable[[socket.socket, bytes], Tuple[bytes, bytes]] if end_tokens: _reader = partial(_readsegment, end_tokens=end_tokens) else: @@ -1160,6 +1215,9 @@ def _misc_cmd(self, cmds, cmd_name, noreply, end_tokens=None): if self.sock is None: self._connect() + # For typing + assert self.sock is not None + try: self.sock.sendall(b"".join(cmds)) @@ -1236,7 +1294,7 @@ def __init__( max_pool_size=None, pool_idle_timeout=0, lock_generator=None, - default_noreply=True, + default_noreply: bool = True, allow_unicode_keys=False, encoding="ascii", tls_context=None, @@ -1266,7 +1324,7 @@ def __init__( self.encoding = encoding self.tls_context = tls_context - def check_key(self, key): + def check_key(self, key: Key) -> bytes: """Checks key and add key_prefix.""" return check_key_helper( key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix @@ -1443,7 +1501,7 @@ def __delitem__(self, key): self.delete(key, noreply=True) -def _readline(sock, buf): +def _readline(sock: socket.socket, buf: bytes) -> Tuple[bytes, bytes]: """Read line of text from the socket. Read a line of text (delimited by "\r\n") from the socket, and @@ -1452,18 +1510,18 @@ def _readline(sock, buf): Args: sock: Socket object, should be connected. - buf: String, zero or more characters, returned from an earlier - call to _readline or _readvalue (pass an empty string on the + buf: Bytes, zero or more characters, returned from an earlier + call to _readline or _readvalue (pass an empty byte string on the first call). Returns: A tuple of (buf, line) where line is the full line read from the socket (minus the "\r\n" characters) and buf is any trailing characters read after the "\r\n" was found (which may be an empty - string). + byte string). """ - chunks = [] + chunks: List[bytes] = [] last_char = b"" while True: @@ -1494,7 +1552,7 @@ def _readline(sock, buf): raise MemcacheUnexpectedCloseError() -def _readvalue(sock, buf, size): +def _readvalue(sock, buf, size: int): """Read specified amount of bytes from the socket. Read size bytes, followed by the "\r\n" characters, from the socket, @@ -1539,7 +1597,9 @@ def _readvalue(sock, buf, size): return buf[rlen:], b"".join(chunks) -def _readsegment(sock, buf, end_tokens): +def _readsegment( + sock: socket.socket, buf: bytes, end_tokens: bytes +) -> Tuple[bytes, bytes]: """Read a segment from the socket. Read a segment from the socket, up to the first end_token sub-string/bytes, @@ -1575,7 +1635,7 @@ def _readsegment(sock, buf, end_tokens): raise MemcacheUnexpectedCloseError() -def _recv(sock, size): +def _recv(sock: socket.socket, size: int) -> bytes: """sock.recv() with retry on EINTR""" while True: try: diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py index db441086..57b07dbd 100644 --- a/pymemcache/client/hash.py +++ b/pymemcache/client/hash.py @@ -154,7 +154,7 @@ def remove_server(self, server, port=None): self._dead_clients[server] = dead_time self.hasher.remove_node(key) - def _retry_dead(self): + def _retry_dead(self) -> None: current_time = time.time() ldc = self._last_dead_check_time # We have reached the retry timeout diff --git a/pymemcache/serde.py b/pymemcache/serde.py index 6e77766a..42ec922f 100644 --- a/pymemcache/serde.py +++ b/pymemcache/serde.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial import logging -from io import BytesIO import pickle import zlib +from functools import partial +from io import BytesIO FLAG_BYTES = 0 FLAG_PICKLE = 1 << 0 @@ -61,7 +61,7 @@ def _python_memcache_serializer(key, value, pickle_version=None): return value, flags -def get_python_memcache_serializer(pickle_version=DEFAULT_PICKLE_VERSION): +def get_python_memcache_serializer(pickle_version: int = DEFAULT_PICKLE_VERSION): """Return a serializer using a specific pickle version""" return partial(_python_memcache_serializer, pickle_version=pickle_version) @@ -112,7 +112,7 @@ class PickleSerde: for :py:class:`pymemcache.client.base.Client` """ - def __init__(self, pickle_version=DEFAULT_PICKLE_VERSION): + def __init__(self, pickle_version: int = DEFAULT_PICKLE_VERSION) -> None: self._serialize_func = get_python_memcache_serializer(pickle_version) def serialize(self, key, value): @@ -182,7 +182,7 @@ class LegacyWrappingSerde: case that they are missing. """ - def __init__(self, serializer_func, deserializer_func): + def __init__(self, serializer_func, deserializer_func) -> None: self.serialize = serializer_func or self._default_serialize self.deserialize = deserializer_func or self._default_deserialize diff --git a/pyproject.toml b/pyproject.toml index 8d26def5..d08806ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,3 +3,4 @@ target-version = ['py37', 'py38', 'py39', 'py310'] [tool.mypy] python_version = 3.7 +ignore_missing_imports = true