diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py index 9e273ee8..6ec8d1f9 100644 --- a/pymemcache/client/hash.py +++ b/pymemcache/client/hash.py @@ -112,6 +112,11 @@ def __init__( self.encoding = encoding self.tls_context = tls_context + def _make_client_key(self, server): + if isinstance(server, (list, tuple)) and len(server) == 2: + return '%s:%s' % server + return server + def add_server(self, server, port=None): # To maintain backward compatibility, if a port is provided, assume # that server wasn't provided as a (host, port) tuple. @@ -120,16 +125,12 @@ def add_server(self, server, port=None): raise TypeError('Server must be a string when passing port.') server = (server, port) - if isinstance(server, six.string_types): - key = server - else: - key = '%s:%s' % server - _class = PooledClient if self.use_pooling else self.client_class client = _class(server, **self.default_kwargs) if self.use_pooling: client.client_class = self.client_class + key = self._make_client_key(server) self.clients[key] = client self.hasher.add_node(key) @@ -141,11 +142,7 @@ def remove_server(self, server, port=None): raise TypeError('Server must be a string when passing port.') server = (server, port) - if isinstance(server, six.string_types): - key = server - else: - key = '%s:%s' % server - + key = self._make_client_key(server) dead_time = time.time() self._failed_clients.pop(server) self._dead_clients[server] = dead_time @@ -181,8 +178,7 @@ def _get_client(self, key): return raise MemcacheError('All servers seem to be down right now') - client = self.clients[server] - return client + return self.clients[server] def _safely_run_func(self, client, func, default_val, *args, **kwargs): try: @@ -383,8 +379,7 @@ def set_many(self, values, *args, **kwargs): client_batches[client.server][key] = value for server, values in client_batches.items(): - client = self.clients['%s:%s' % server] - + client = self.clients[self._make_client_key(server)] failed += self._safely_run_set_many( client, values, *args, **kwargs ) @@ -406,7 +401,7 @@ def get_many(self, keys, gets=False, *args, **kwargs): client_batches[client.server].append(key) for server, keys in client_batches.items(): - client = self.clients['%s:%s' % server] + client = self.clients[self._make_client_key(server)] new_args = list(args) new_args.insert(0, keys) diff --git a/pymemcache/test/test_client_hash.py b/pymemcache/test/test_client_hash.py index 5dd4ec47..04b51230 100644 --- a/pymemcache/test/test_client_hash.py +++ b/pymemcache/test/test_client_hash.py @@ -39,6 +39,20 @@ def make_client(self, *mock_socket_values, **kwargs): return client + def make_unix_client(self, sockets, *mock_socket_values, **kwargs): + client = HashClient([], **kwargs) + + for socket_, vals in zip(sockets, mock_socket_values): + c = self.make_client_pool( + socket_, + vals, + **kwargs + ) + client.clients[socket_] = c + client.hasher.add_node(socket_) + + return client + def test_setup_client_without_pooling(self): client_class = 'pymemcache.client.hash.HashClient.client_class' with mock.patch(client_class) as internal_client: @@ -50,6 +64,30 @@ def test_setup_client_without_pooling(self): assert kwargs['timeout'] == 999 assert kwargs['key_prefix'] == 'foo_bar_baz' + def test_get_many_unix(self): + pid = os.getpid() + sockets = [ + '/tmp/pymemcache.1.%d' % pid, + '/tmp/pymemcache.2.%d' % pid, + ] + client = self.make_unix_client(sockets, *[ + [b'STORED\r\n', b'VALUE key3 0 6\r\nvalue2\r\nEND\r\n', ], + [b'STORED\r\n', b'VALUE key1 0 6\r\nvalue1\r\nEND\r\n', ], + ]) + + def get_clients(key): + if key == b'key3': + return client.clients['/tmp/pymemcache.1.%d' % pid] + else: + return client.clients['/tmp/pymemcache.2.%d' % pid] + + client._get_client = get_clients + + result = client.set(b'key1', b'value1', noreply=False) + result = client.set(b'key3', b'value2', noreply=False) + result = client.get_many([b'key1', b'key3']) + assert result == {b'key1': b'value1', b'key3': b'value2'} + def test_get_many_all_found(self): client = self.make_client(*[ [b'STORED\r\n', b'VALUE key3 0 6\r\nvalue2\r\nEND\r\n', ], @@ -284,6 +322,22 @@ def test_noreply_set_many(self): result = client.set_many(values, noreply=True) assert result == [] + def test_set_many_unix(self): + values = { + 'key1': 'value1', + 'key2': 'value2', + 'key3': 'value3' + } + + pid = os.getpid() + sockets = ['/tmp/pymemcache.%d' % pid] + client = self.make_unix_client(sockets, *[ + [b'STORED\r\n', b'NOT_STORED\r\n', b'STORED\r\n'], + ]) + + result = client.set_many(values, noreply=False) + assert result == ['key2'] + def test_server_encoding_pooled(self): """ test passed encoding from hash client to pooled clients