Skip to content

Commit

Permalink
refactor: Simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
null2264 committed Jul 19, 2023
1 parent 58ef200 commit 9d685d7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 54 deletions.
84 changes: 35 additions & 49 deletions nexus/core/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import os
import secrets
from logging import getLogger
from typing import Optional
from typing import TYPE_CHECKING, Optional

import zmq
import zmq.asyncio
Expand All @@ -16,18 +17,14 @@


class Nexus(FastAPI):
if TYPE_CHECKING:
reqSocket: zmq.asyncio.Socket
subSocket: Optional[zmq.asyncio.Socket]

def __init__(self, context: Optional[zmq.asyncio.Context] = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.context = context or zmq.asyncio.Context.instance()
self._reqSocket: Optional[zmq.asyncio.Socket] = None
self._subSocket: Optional[zmq.asyncio.Socket] = None
sockets = [self._reqSocket, self._subSocket]
timeout = 1000
for socket in sockets:
if not socket:
continue
socket.setsockopt(zmq.RCVTIMEO, timeout)
socket.setsockopt(zmq.IPV6, True)
self.subSocket = None

# Auth related stuff
self.clientId = int(os.getenv("DISCORD_CLIENT_ID", 0))
Expand Down Expand Up @@ -67,6 +64,29 @@ def __init__(self, context: Optional[zmq.asyncio.Context] = None, *args, **kwarg

self.logger = getLogger("uvicorn")

asyncio.create_task(self.__ainit__())

async def __ainit__(self) -> None:
dest = os.getenv("DASHBOARD_ZMQ_REQ")
if not dest:
raise RuntimeError("Nexus requires at least a request socket to function properly!")
self.reqSocket = self.context.socket(zmq.REQ)
self.reqSocket.setsockopt(zmq.IPV6, True)
self.reqSocket.setsockopt(zmq.RCVTIMEO, constants.REQUEST_TIMEOUT)
self.reqSocket.connect(f"tcp://{dest}")

subDest = os.getenv("DASHBOARD_ZMQ_SUB")
if subDest:
self.subSocket = self.context.socket(zmq.SUB)
self.subSocket.setsockopt(zmq.IPV6, True)
self.subSocket.setsockopt(zmq.SUBSCRIBE, b"guild.update")
self.subSocket.connect(f"tcp://{subDest}")

def reconnectReqSocket(self):
self.reqSocket.close(linger=0)
self.logger.info("Reconnecting to bot...")
self.reqSocket.connect(f"tcp://{os.getenv('DASHBOARD_ZMQ_REQ')}")

def getTokenUpdater(self, request: Optional[Request] = None):
if not request:
return None
Expand All @@ -90,52 +110,18 @@ def session(self, token=None, state=None, request: Optional[Request] = None) ->
tokenUpdater=self.getTokenUpdater(request),
)

def initRequestSocket(self):
self._reqSocket = self.context.socket(zmq.REQ)
self._reqSocket.setsockopt(zmq.RCVTIMEO, constants.REQUEST_TIMEOUT)
self._reqSocket.connect("tcp://" + os.getenv("DASHBOARD_ZMQ_REQ", "127.0.0.1:5556"))

def initSubscriptionSocket(self):
self._subSocket = self.context.socket(zmq.SUB)
self._subSocket.setsockopt(zmq.SUBSCRIBE, b"guild.update")
self._subSocket.connect("tcp://" + os.getenv("DASHBOARD_ZMQ_SUB", "127.0.0.1:5554"))

def initSockets(self):
self.initRequestSocket()
self.initSubscriptionSocket()

@property
def isZMQAvailable(self) -> bool:
return self._reqSocket is not None or self._subSocket is not None

def _getSocket(self, socket: str) -> zmq.asyncio.Socket:
_socket = getattr(self, f"_{socket}Socket", None)
if not _socket:
self.initSockets()
_socket = getattr(self, socket)

return _socket

@property
def reqSocket(self) -> zmq.asyncio.Socket:
return self._getSocket("req")

@property
def subSocket(self) -> zmq.asyncio.Socket:
return self._getSocket("sub")

def attachIsLoggedIn(self, response: Response):
response.set_cookie("loggedIn", "yes", domain=os.getenv("DASHBOARD_HOSTNAME"), max_age=31556926)

def detachIsLoggedIn(self, response: Response):
response.delete_cookie("loggedIn", domain=os.getenv("DASHBOARD_HOSTNAME"))

async def onStartup(self):
self.initSockets()
pass

def close(self):
async def closeSockets(self):
self.logger.info("Closing sockets...")
sockets = (self._reqSocket, self._subSocket)
sockets = (self.reqSocket, self.subSocket)
for socket in sockets:
if not socket:
continue
Expand All @@ -146,5 +132,5 @@ def close(self):
self.context.term()
self.logger.info("ZeroMQ has been closed")

def onShutdown(self):
self.close()
async def onShutdown(self):
await self.closeSockets()
7 changes: 2 additions & 5 deletions nexus/core/routes/ng/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import traceback
from typing import TYPE_CHECKING, Any, List, Optional, Union, overload

import zmq
import zmq # type: ignore
import zmq.asyncio
from fastapi import HTTPException
from fastapi.routing import APIRouter
Expand Down Expand Up @@ -61,10 +61,7 @@ async def requestBot(app: "Nexus", requestMessage: dict, userId: Optional[str] =
if retries >= constants.REQUEST_RETRIES:
raise HTTPException(502, str(e))

if app._reqSocket:
app._reqSocket.close(linger=0)
app.logger.info("Reconnecting to bot...")
app.initRequestSocket()
app.reconnectReqSocket()
retries += 1
app.logger.info("Retrying...")
continue # we let the loop retry the send request
Expand Down

0 comments on commit 9d685d7

Please sign in to comment.