Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed HashClient.{get,set}_many() with UNIX sockets. #315

Merged
merged 1 commit into from
May 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions pymemcache/client/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ def __init__(
self.encoding = encoding
self.tls_context = tls_context

def _make_client_key(self, server):
if isinstance(server, (list, tuple)) and len(server) == 2:
return '%s:%s' % server
return server

def add_server(self, server, port=None):
# To maintain backward compatibility, if a port is provided, assume
# that server wasn't provided as a (host, port) tuple.
Expand All @@ -120,16 +125,12 @@ def add_server(self, server, port=None):
raise TypeError('Server must be a string when passing port.')
server = (server, port)

if isinstance(server, six.string_types):
key = server
else:
key = '%s:%s' % server

_class = PooledClient if self.use_pooling else self.client_class
client = _class(server, **self.default_kwargs)
if self.use_pooling:
client.client_class = self.client_class

key = self._make_client_key(server)
self.clients[key] = client
self.hasher.add_node(key)

Expand All @@ -141,11 +142,7 @@ def remove_server(self, server, port=None):
raise TypeError('Server must be a string when passing port.')
server = (server, port)

if isinstance(server, six.string_types):
key = server
else:
key = '%s:%s' % server

key = self._make_client_key(server)
dead_time = time.time()
self._failed_clients.pop(server)
self._dead_clients[server] = dead_time
Expand Down Expand Up @@ -181,8 +178,7 @@ def _get_client(self, key):
return
raise MemcacheError('All servers seem to be down right now')

client = self.clients[server]
return client
return self.clients[server]

def _safely_run_func(self, client, func, default_val, *args, **kwargs):
try:
Expand Down Expand Up @@ -383,8 +379,7 @@ def set_many(self, values, *args, **kwargs):
client_batches[client.server][key] = value

for server, values in client_batches.items():
client = self.clients['%s:%s' % server]

client = self.clients[self._make_client_key(server)]
failed += self._safely_run_set_many(
client, values, *args, **kwargs
)
Expand All @@ -406,7 +401,7 @@ def get_many(self, keys, gets=False, *args, **kwargs):
client_batches[client.server].append(key)

for server, keys in client_batches.items():
client = self.clients['%s:%s' % server]
client = self.clients[self._make_client_key(server)]
new_args = list(args)
new_args.insert(0, keys)

Expand Down
54 changes: 54 additions & 0 deletions pymemcache/test/test_client_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ def make_client(self, *mock_socket_values, **kwargs):

return client

def make_unix_client(self, sockets, *mock_socket_values, **kwargs):
client = HashClient([], **kwargs)

for socket_, vals in zip(sockets, mock_socket_values):
c = self.make_client_pool(
socket_,
vals,
**kwargs
)
client.clients[socket_] = c
client.hasher.add_node(socket_)

return client

def test_setup_client_without_pooling(self):
client_class = 'pymemcache.client.hash.HashClient.client_class'
with mock.patch(client_class) as internal_client:
Expand All @@ -50,6 +64,30 @@ def test_setup_client_without_pooling(self):
assert kwargs['timeout'] == 999
assert kwargs['key_prefix'] == 'foo_bar_baz'

def test_get_many_unix(self):
pid = os.getpid()
sockets = [
'/tmp/pymemcache.1.%d' % pid,
'/tmp/pymemcache.2.%d' % pid,
]
client = self.make_unix_client(sockets, *[
[b'STORED\r\n', b'VALUE key3 0 6\r\nvalue2\r\nEND\r\n', ],
[b'STORED\r\n', b'VALUE key1 0 6\r\nvalue1\r\nEND\r\n', ],
])

def get_clients(key):
if key == b'key3':
return client.clients['/tmp/pymemcache.1.%d' % pid]
else:
return client.clients['/tmp/pymemcache.2.%d' % pid]

client._get_client = get_clients

result = client.set(b'key1', b'value1', noreply=False)
result = client.set(b'key3', b'value2', noreply=False)
result = client.get_many([b'key1', b'key3'])
assert result == {b'key1': b'value1', b'key3': b'value2'}

def test_get_many_all_found(self):
client = self.make_client(*[
[b'STORED\r\n', b'VALUE key3 0 6\r\nvalue2\r\nEND\r\n', ],
Expand Down Expand Up @@ -284,6 +322,22 @@ def test_noreply_set_many(self):
result = client.set_many(values, noreply=True)
assert result == []

def test_set_many_unix(self):
values = {
'key1': 'value1',
'key2': 'value2',
'key3': 'value3'
}

pid = os.getpid()
sockets = ['/tmp/pymemcache.%d' % pid]
client = self.make_unix_client(sockets, *[
[b'STORED\r\n', b'NOT_STORED\r\n', b'STORED\r\n'],
])

result = client.set_many(values, noreply=False)
assert result == ['key2']

def test_server_encoding_pooled(self):
"""
test passed encoding from hash client to pooled clients
Expand Down