Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add global key prefix for keys set by Redis transporter #1349

Merged
merged 6 commits into from
Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Fernando Jorge Mota <[email protected]>
Flavio [FlaPer87] Percoco Premoli <[email protected]>
Florian Munz <[email protected]>
Franck Cuny <[email protected]>
Gábor Boros <[email protected]>
Germán M. Bravo <[email protected]>
Gregory Haskins <[email protected]>
Hank John <[email protected]>
Expand Down
125 changes: 124 additions & 1 deletion kombu/transport/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
* ``unacked_restore_limit``
* ``fanout_prefix``
* ``fanout_patterns``
* ``global_keyprefix``: (str) The global key prefix to be prepended to all keys
used by Kombu
* ``socket_timeout``
* ``socket_connect_timeout``
* ``socket_keepalive``
Expand All @@ -49,6 +51,7 @@
* ``priority_steps``
"""

import functools
import numbers
import socket
from bisect import bisect
Expand Down Expand Up @@ -179,6 +182,108 @@ def _after_fork_cleanup_channel(channel):
channel._after_fork()


class GlobalKeyPrefixMixin:
"""Mixin to provide common logic for global key prefixing.

Overriding all the methods used by Kombu with the same key prefixing logic
would be cumbersome and inefficient. Hence, we override the command
execution logic that is called by all commands.
"""

PREFIXED_SIMPLE_COMMANDS = [
"HDEL",
"HGET",
"HSET",
"LLEN",
"LPUSH",
"PUBLISH",
"SADD",
"SET",
"SMEMBERS",
"ZADD",
"ZREM",
"ZREVRANGEBYSCORE",
]

PREFIXED_COMPLEX_COMMANDS = {
"BRPOP": {"args_start": 0, "args_end": -1},
"EVALSHA": {"args_start": 2, "args_end": 3},
}

def _prefix_args(self, args):
args = list(args)
command = args.pop(0)

if command in self.PREFIXED_SIMPLE_COMMANDS:
args[0] = self.global_keyprefix + str(args[0])

if command in self.PREFIXED_COMPLEX_COMMANDS.keys():
args_start = self.PREFIXED_COMPLEX_COMMANDS[command]["args_start"]
args_end = self.PREFIXED_COMPLEX_COMMANDS[command]["args_end"]

pre_args = args[:args_start] if args_start > 0 else []

if args_end is not None:
post_args = args[args_end:]
elif args_end < 0:
post_args = args[len(args):]
else:
post_args = []

args = pre_args + [
self.global_keyprefix + str(arg)
for arg in args[args_start:args_end]
] + post_args

return [command, *args]

pomegranited marked this conversation as resolved.
Show resolved Hide resolved
def parse_response(self, connection, command_name, **options):
"""Parses a response from the Redis server.

Method wraps ``redis.parse_response()`` to remove prefixes of keys
returned by redis command.
"""
ret = super().parse_response(connection, command_name, **options)
if command_name == 'BRPOP' and ret:
key, value = ret
key = key[len(self.global_keyprefix):]
return key, value
return ret

def execute_command(self, *args, **kwargs):
return super().execute_command(*self._prefix_args(args), **kwargs)

def pipeline(self, transaction=True, shard_hint=None):
return PrefixedRedisPipeline(
self.connection_pool,
self.response_callbacks,
transaction,
shard_hint,
global_keyprefix=self.global_keyprefix,
)


class PrefixedStrictRedis(GlobalKeyPrefixMixin, redis.Redis):
"""Returns a ``StrictRedis`` client that prefixes the keys it uses."""

def __init__(self, *args, **kwargs):
self.global_keyprefix = kwargs.pop('global_keyprefix', '')
redis.Redis.__init__(self, *args, **kwargs)


class PrefixedRedisPipeline(GlobalKeyPrefixMixin, redis.client.Pipeline):
"""Custom Redis pipeline that takes global_keyprefix into consideration.

As the ``PrefixedStrictRedis`` client uses the `global_keyprefix` to prefix
the keys it uses, the pipeline called by the client must be able to prefix
the keys as well.
"""

def __init__(self, *args, **kwargs):
self.global_keyprefix = kwargs.pop('global_keyprefix', '')
redis.client.Pipeline.__init__(self, *args, **kwargs)
thedrow marked this conversation as resolved.
Show resolved Hide resolved


class QoS(virtual.QoS):
"""Redis Ack Emulation."""

Expand Down Expand Up @@ -485,6 +590,11 @@ class Channel(virtual.Channel):
#: Disable for backwards compatibility with Kombu 3.x.
fanout_patterns = True

#: The global key prefix will be prepended to all keys used
#: by Kombu, which can be useful when a redis database is shared
#: by different users. By default, no prefix is prepended.
global_keyprefix = ''
gabor-boros marked this conversation as resolved.
Show resolved Hide resolved

#: Order in which we consume from queues.
#:
#: Can be either string alias, or a cycle strategy class
Expand Down Expand Up @@ -526,6 +636,7 @@ class Channel(virtual.Channel):
'unacked_restore_limit',
'fanout_prefix',
'fanout_patterns',
'global_keyprefix',
'socket_timeout',
'socket_connect_timeout',
'socket_keepalive',
Expand Down Expand Up @@ -769,7 +880,12 @@ def _brpop_start(self, timeout=1):
keys = [self._q_for_pri(queue, pri) for pri in self.priority_steps
for queue in queues] + [timeout or 0]
self._in_poll = self.client.connection
self.client.connection.send_command('BRPOP', *keys)

command_args = ['BRPOP', *keys]
if self.global_keyprefix:
command_args = self.client._prefix_args(command_args)

self.client.connection.send_command(*command_args)

def _brpop_read(self, **options):
try:
Expand Down Expand Up @@ -1025,6 +1141,13 @@ def _get_client(self):
raise VersionMismatch(
'Redis transport requires redis-py versions 3.2.0 or later. '
'You have {0.__version__}'.format(redis))

if self.global_keyprefix:
return functools.partial(
PrefixedStrictRedis,
global_keyprefix=self.global_keyprefix,
)

return redis.StrictRedis

@contextmanager
Expand Down
94 changes: 93 additions & 1 deletion t/unit/transport/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def test_rotate_cycle_ValueError(self):
def test_get_client(self):
import redis as R
KombuRedis = redis.Channel._get_client(self.channel)
assert KombuRedis
assert isinstance(KombuRedis(), R.StrictRedis)

Rv = getattr(R, 'VERSION', None)
try:
Expand All @@ -757,6 +757,12 @@ def test_get_client(self):
if Rv is not None:
R.VERSION = Rv

def test_get_prefixed_client(self):
from kombu.transport.redis import PrefixedStrictRedis
self.channel.global_keyprefix = "test_"
PrefixedRedis = redis.Channel._get_client(self.channel)
assert isinstance(PrefixedRedis(), PrefixedStrictRedis)

def test_get_response_error(self):
from redis.exceptions import ResponseError
assert redis.Channel._get_response_error(self.channel) is ResponseError
Expand Down Expand Up @@ -926,6 +932,43 @@ def test_sep_transport_option(self):
('celery', '', 'celery'),
]

@patch("redis.StrictRedis.execute_command")
def test_global_keyprefix(self, mock_execute_command):
from kombu.transport.redis import PrefixedStrictRedis

with Connection(transport=Transport) as conn:
client = PrefixedStrictRedis(global_keyprefix='foo_')

channel = conn.channel()
channel._create_client = Mock()
channel._create_client.return_value = client

body = {'hello': 'world'}
channel._put_fanout('exchange', body, '')
mock_execute_command.assert_called_with(
'PUBLISH',
'foo_/{db}.exchange',
dumps(body)
)

@patch("redis.StrictRedis.execute_command")
def test_global_keyprefix_queue_bind(self, mock_execute_command):
from kombu.transport.redis import PrefixedStrictRedis

with Connection(transport=Transport) as conn:
client = PrefixedStrictRedis(global_keyprefix='foo_')

channel = conn.channel()
channel._create_client = Mock()
channel._create_client.return_value = client

channel._queue_bind('default', '', None, 'queue')
mock_execute_command.assert_called_with(
'SADD',
'foo__kombu.binding.default',
'\x06\x16\x06\x16queue'
)


class test_Redis:

Expand Down Expand Up @@ -1500,3 +1543,52 @@ def test_sentinel_with_ssl(self):
from kombu.transport.redis import SentinelManagedSSLConnection
assert (params['connection_class'] is
SentinelManagedSSLConnection)


class test_GlobalKeyPrefixMixin:

from kombu.transport.redis import GlobalKeyPrefixMixin

global_keyprefix = "prefix_"
mixin = GlobalKeyPrefixMixin()
mixin.global_keyprefix = global_keyprefix

def test_prefix_simple_args(self):
for command in self.mixin.PREFIXED_SIMPLE_COMMANDS:
prefixed_args = self.mixin._prefix_args([command, "fake_key"])
assert prefixed_args == [
command,
f"{self.global_keyprefix}fake_key"
]

def test_prefix_brpop_args(self):
prefixed_args = self.mixin._prefix_args([
"BRPOP",
"fake_key",
"fake_key2",
"not_prefixed"
])

assert prefixed_args == [
"BRPOP",
f"{self.global_keyprefix}fake_key",
f"{self.global_keyprefix}fake_key2",
"not_prefixed",
]

def test_prefix_evalsha_args(self):
prefixed_args = self.mixin._prefix_args([
"EVALSHA",
"not_prefixed",
"not_prefixed",
"fake_key",
"not_prefixed",
])

assert prefixed_args == [
"EVALSHA",
"not_prefixed",
"not_prefixed",
f"{self.global_keyprefix}fake_key",
"not_prefixed",
]