Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async redis subscription inside websocket doesn't shut down properly #2523

Open
wholmen opened this issue Dec 20, 2022 · 4 comments
Open

Async redis subscription inside websocket doesn't shut down properly #2523

wholmen opened this issue Dec 20, 2022 · 4 comments
Labels

Comments

@wholmen
Copy link

wholmen commented Dec 20, 2022

Version: 4.4.0

Platform: Python 3.11.0, ubuntu

Description:

I am using async redis to subscribe to a topic within a fastapi websocket connection. It works fine, but I cannot make it shut down properly. i.e. when the server shuts down, the await get_message is not stopping.

I have tried to use SIGTERM, but it seems the reader-function doesn't yield to SIGTERM, because it gives error

INFO:     Waiting for background tasks to complete. (CTRL+C to force quit)

Before handling the SIGTERM command.

Does anyone have a suggestion for where to start looking?

@router.websocket("/ws")
async def websocket_labtest(websocket: WebSocket):
    await websocket.accept()

    async_redis_client: aioredis.client.Redis = aioredis.from_url(
        url=settings.redis_url, port=settings.redis_port, password=settings.redis_password, decode_responses=True
    )
    channelname: str = f"channel:{events.LabtestAddedEvent.__name__}"
    redis_channel = async_redis_client.pubsub()

    async def reader(channel: aioredis.client.PubSub):
        while True:
            try:
                async with async_timeout.timeout(1):
                    message = await channel.get_message(ignore_subscribe_messages=True)
                    if message is not None:
                        response = json.loads(message["data"])
                        await websocket.send_json(response)
                    await asyncio.sleep(0.01)

            except WebSocketDisconnect:
                print("hit websocket disconnect")
            except asyncio.TimeoutError:
                pass
            except asyncio.CancelledError:
                print("Cancelled")
                break
            except aioredis.PubSubError:
                print("Pubsub error")
                break

    async with redis_channel as p:
        await p.subscribe(channelname)
        await reader(p)
        await p.unsubscribe(channelname)

    await redis_channel.close()
@Andrew-Chen-Wang
Copy link
Contributor

Andrew-Chen-Wang commented Mar 17, 2023

The problem is the block in parse_response. Current solution is to pass a small timeout in get_message

Still not working. Anyone have a solution? This seems to make PubSub completely unusable if asyncio.CancelledError isn't raised

@tomer555
Copy link

+1

@Andrew-Chen-Wang
Copy link
Contributor

Andrew-Chen-Wang commented Jun 25, 2023

This has been working for me. There are some random variables like "readers"; just copy pasted from a project so parse thru what you need:

import asyncio
from collections.abc import Callable, Coroutine
from typing import Literal

from redis.asyncio.client import PubSub

from app.utils.redis.subscriber.typing import ChatReadersT, ReadersT, ReaderT


async def reader(channel: redis.client.PubSub, readers: ReadersT):
    try:
        while True:
            # https://github.com/redis/redis-py/issues/2523
            message: ChannelMessage = await channel.get_message(
                ignore_subscribe_messages=True, timeout=0.5
            )
            if message is None:
                continue
            _channel = message["channel"].decode().split(":", 1)[1]
            if ":" not in _channel:
                _channel = int(_channel)
            wsr = readers.get().get(_channel, {}).items()
            data: dict = orjson.loads(message["data"])
            reader_id = data.pop("id", None)
            data = data["data"]
            [w.messages.put_nowait(data) for k, w in wsr if k != reader_id]
    except RedisConnectionError:
        pass
    finally:
        pass


class BaseRedisConnection:
    def __init__(
        self,
        *,
        channel: str,
        include_wildcard: bool = True,
        reader: Callable,
        subscription_type: Literal["psubscribe", "subscribe"] = "psubscribe",
    ):
        """
        Handler for PubSub connection

        :param channel: PubSub channel name
        :param include_wildcard: Whether to make PubSub channel name include a wildcard.
        Applicable only to when subscription_type is "psubscribe"
        :param reader: An infinite loop callable that reads from the PubSub channel
        :param subscription_type: the type of PubSub subscription to use
        """
        self.reader_task: asyncio.Task | None = None
        self.r = None
        self.pubsub: PubSub | None = None
        self.reader = reader
        self.subscription_type = subscription_type
        if subscription_type == "psubscribe" and include_wildcard:
            self.channel = f"{channel}:*"
        else:
            self.channel = channel

    async def create_reader(self, pubsub: PubSub) -> Coroutine:
        raise NotImplementedError("create_reader() must be implemented")

    async def start(self):
        self.pubsub = self.r.pubsub()
        await getattr(self.pubsub, self.subscription_type)(self.channel)
        self.reader_task = asyncio.create_task(await self.create_reader(self.pubsub))
        await self.reader_task

    async def close(self):
        await self.pubsub.close()
        if self.reader_task is not None:
            self.reader_task.cancel()


class RedisConnection(BaseRedisConnection):
    def __init__(
        self,
        *,
        channel: str,
        include_wildcard: bool = True,
        subscription_type: Literal["psubscribe", "subscribe"],
        reader: ReaderT,
        readers: ReadersT | ChatReadersT,
    ):
        super().__init__(
            channel=channel,
            include_wildcard=include_wildcard,
            subscription_type=subscription_type,
            reader=reader,
        )
        self.readers = readers

    async def create_reader(self, pubsub: PubSub) -> Coroutine:
        return self.reader(pubsub, self.readers)

@nextmat
Copy link

nextmat commented Jan 17, 2024

What worked for me is scheduling the reader method with create_task to start it:

reader_task = asyncio.create_task(reader(# ..args))

Then I have a lifespan method to make sure it gets shut down:

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Ensure redis gets closed cleanly when shutting down"""
    yield
    reader_task.cancel()

Register it like this:

app = FastAPI(
    lifespan=lifespan,
    # ...
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants