diff --git a/nexus/core/api.py b/nexus/core/api.py index fa4523a..3880e15 100644 --- a/nexus/core/api.py +++ b/nexus/core/api.py @@ -1,7 +1,8 @@ +import asyncio import os import secrets from logging import getLogger -from typing import Optional +from typing import TYPE_CHECKING, Optional import zmq import zmq.asyncio @@ -16,18 +17,14 @@ class Nexus(FastAPI): + if TYPE_CHECKING: + reqSocket: zmq.asyncio.Socket + subSocket: Optional[zmq.asyncio.Socket] + def __init__(self, context: Optional[zmq.asyncio.Context] = None, *args, **kwargs): super().__init__(*args, **kwargs) self.context = context or zmq.asyncio.Context.instance() - self._reqSocket: Optional[zmq.asyncio.Socket] = None - self._subSocket: Optional[zmq.asyncio.Socket] = None - sockets = [self._reqSocket, self._subSocket] - timeout = 1000 - for socket in sockets: - if not socket: - continue - socket.setsockopt(zmq.RCVTIMEO, timeout) - socket.setsockopt(zmq.IPV6, True) + self.subSocket = None # Auth related stuff self.clientId = int(os.getenv("DISCORD_CLIENT_ID", 0)) @@ -67,6 +64,29 @@ def __init__(self, context: Optional[zmq.asyncio.Context] = None, *args, **kwarg self.logger = getLogger("uvicorn") + asyncio.create_task(self.__ainit__()) + + async def __ainit__(self) -> None: + dest = os.getenv("DASHBOARD_ZMQ_REQ") + if not dest: + raise RuntimeError("Nexus requires at least a request socket to function properly!") + self.reqSocket = self.context.socket(zmq.REQ) + self.reqSocket.setsockopt(zmq.IPV6, True) + self.reqSocket.setsockopt(zmq.RCVTIMEO, constants.REQUEST_TIMEOUT) + self.reqSocket.connect(f"tcp://{dest}") + + subDest = os.getenv("DASHBOARD_ZMQ_SUB") + if subDest: + self.subSocket = self.context.socket(zmq.SUB) + self.subSocket.setsockopt(zmq.IPV6, True) + self.subSocket.setsockopt(zmq.SUBSCRIBE, b"guild.update") + self.subSocket.connect(f"tcp://{subDest}") + + def reconnectReqSocket(self): + self.reqSocket.close(linger=0) + self.logger.info("Reconnecting to bot...") + self.reqSocket.connect(f"tcp://{os.getenv('DASHBOARD_ZMQ_REQ')}") + def getTokenUpdater(self, request: Optional[Request] = None): if not request: return None @@ -90,40 +110,6 @@ def session(self, token=None, state=None, request: Optional[Request] = None) -> tokenUpdater=self.getTokenUpdater(request), ) - def initRequestSocket(self): - self._reqSocket = self.context.socket(zmq.REQ) - self._reqSocket.setsockopt(zmq.RCVTIMEO, constants.REQUEST_TIMEOUT) - self._reqSocket.connect("tcp://" + os.getenv("DASHBOARD_ZMQ_REQ", "127.0.0.1:5556")) - - def initSubscriptionSocket(self): - self._subSocket = self.context.socket(zmq.SUB) - self._subSocket.setsockopt(zmq.SUBSCRIBE, b"guild.update") - self._subSocket.connect("tcp://" + os.getenv("DASHBOARD_ZMQ_SUB", "127.0.0.1:5554")) - - def initSockets(self): - self.initRequestSocket() - self.initSubscriptionSocket() - - @property - def isZMQAvailable(self) -> bool: - return self._reqSocket is not None or self._subSocket is not None - - def _getSocket(self, socket: str) -> zmq.asyncio.Socket: - _socket = getattr(self, f"_{socket}Socket", None) - if not _socket: - self.initSockets() - _socket = getattr(self, socket) - - return _socket - - @property - def reqSocket(self) -> zmq.asyncio.Socket: - return self._getSocket("req") - - @property - def subSocket(self) -> zmq.asyncio.Socket: - return self._getSocket("sub") - def attachIsLoggedIn(self, response: Response): response.set_cookie("loggedIn", "yes", domain=os.getenv("DASHBOARD_HOSTNAME"), max_age=31556926) @@ -131,11 +117,11 @@ def detachIsLoggedIn(self, response: Response): response.delete_cookie("loggedIn", domain=os.getenv("DASHBOARD_HOSTNAME")) async def onStartup(self): - self.initSockets() + pass - def close(self): + async def closeSockets(self): self.logger.info("Closing sockets...") - sockets = (self._reqSocket, self._subSocket) + sockets = (self.reqSocket, self.subSocket) for socket in sockets: if not socket: continue @@ -146,5 +132,5 @@ def close(self): self.context.term() self.logger.info("ZeroMQ has been closed") - def onShutdown(self): - self.close() + async def onShutdown(self): + await self.closeSockets() diff --git a/nexus/core/routes/ng/meta.py b/nexus/core/routes/ng/meta.py index 2b7e323..a9fbd68 100644 --- a/nexus/core/routes/ng/meta.py +++ b/nexus/core/routes/ng/meta.py @@ -7,7 +7,7 @@ import traceback from typing import TYPE_CHECKING, Any, List, Optional, Union, overload -import zmq +import zmq # type: ignore import zmq.asyncio from fastapi import HTTPException from fastapi.routing import APIRouter @@ -61,10 +61,7 @@ async def requestBot(app: "Nexus", requestMessage: dict, userId: Optional[str] = if retries >= constants.REQUEST_RETRIES: raise HTTPException(502, str(e)) - if app._reqSocket: - app._reqSocket.close(linger=0) - app.logger.info("Reconnecting to bot...") - app.initRequestSocket() + app.reconnectReqSocket() retries += 1 app.logger.info("Retrying...") continue # we let the loop retry the send request