11import pickle
2- from typing import Dict , Optional , TypeVar , Union
2+ from typing import Dict , Optional , Type , TypeVar , Union
33
4- from redis .asyncio import ConnectionPool , Redis
4+ from redis .asyncio import Redis , RedisCluster
55from taskiq import AsyncResultBackend
66from taskiq .abc .result_backend import TaskiqResult
77
@@ -23,11 +23,15 @@ def __init__(
2323 keep_results : bool = True ,
2424 result_ex_time : Optional [int ] = None ,
2525 result_px_time : Optional [int ] = None ,
26+ * ,
27+ redis_cls : Union [Type [Redis ], Type [RedisCluster ], None ] = None ,
2628 ) -> None :
2729 """
2830 Constructs a new result backend.
2931
3032 :param redis_url: url to redis.
33+ :param redis_cls: async redis class, should be either redis.asyncio.Redis
34+ or redis.asyncio.RedisCluster.
3135 :param keep_results: flag to not remove results from Redis after reading.
3236 :param result_ex_time: expire time in seconds for result.
3337 :param result_px_time: expire time in milliseconds for result.
@@ -37,7 +41,10 @@ def __init__(
3741 :raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
3842 and result_px_time are equal zero.
3943 """
40- self .redis_pool = ConnectionPool .from_url (redis_url )
44+ if redis_cls is None :
45+ redis_cls = Redis
46+
47+ self .redis = redis_cls .from_url (redis_url )
4148 self .keep_results = keep_results
4249 self .result_ex_time = result_ex_time
4350 self .result_px_time = result_px_time
@@ -58,11 +65,6 @@ def __init__(
5865 "Choose either result_ex_time or result_px_time." ,
5966 )
6067
61- async def shutdown (self ) -> None :
62- """Closes redis connection."""
63- await self .redis_pool .disconnect ()
64- await super ().shutdown ()
65-
6668 async def set_result (
6769 self ,
6870 task_id : str ,
@@ -86,8 +88,7 @@ async def set_result(
8688 elif self .result_px_time :
8789 redis_set_params ["px" ] = self .result_px_time
8890
89- async with Redis (connection_pool = self .redis_pool ) as redis :
90- await redis .set (** redis_set_params ) # type: ignore
91+ await self .redis .set (** redis_set_params ) # type: ignore
9192
9293 async def is_result_ready (self , task_id : str ) -> bool :
9394 """
@@ -97,8 +98,7 @@ async def is_result_ready(self, task_id: str) -> bool:
9798
9899 :returns: True if the result is ready else False.
99100 """
100- async with Redis (connection_pool = self .redis_pool ) as redis :
101- return bool (await redis .exists (task_id ))
101+ return bool (await self .redis .exists (task_id ))
102102
103103 async def get_result (
104104 self ,
@@ -113,15 +113,14 @@ async def get_result(
113113 :raises ResultIsMissingError: if there is no result when trying to get it.
114114 :return: task's return value.
115115 """
116- async with Redis (connection_pool = self .redis_pool ) as redis :
117- if self .keep_results :
118- result_value = await redis .get (
119- name = task_id ,
120- )
121- else :
122- result_value = await redis .getdel (
123- name = task_id ,
124- )
116+ if self .keep_results :
117+ result_value = await self .redis .get (
118+ name = task_id ,
119+ )
120+ else :
121+ result_value = await self .redis .getdel (
122+ name = task_id ,
123+ )
125124
126125 if result_value is None :
127126 raise ResultIsMissingError
0 commit comments