Skip to content

Commit

Permalink
refactor: WebSocket
Browse files Browse the repository at this point in the history
  • Loading branch information
null2264 committed Jul 22, 2023
1 parent b84f4b8 commit d559f16
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 49 deletions.
3 changes: 3 additions & 0 deletions nexus/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from nexus.core import constants
from nexus.core.middleware import SessionMiddleware
from nexus.core.oauth import OAuth2Session
from nexus.core.websocket import WebSocketManager
from nexus.utils import cache


Expand Down Expand Up @@ -64,6 +65,8 @@ def __init__(self, context: Optional[zmq.asyncio.Context] = None, *args, **kwarg

self.initSockets()

self.websocketManager = WebSocketManager(self)

def initSockets(self) -> None:
self.connectReqSocket()
self.connectSubSocket()
Expand Down
63 changes: 14 additions & 49 deletions nexus/core/routes/ng/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import json
import traceback
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional

from fastapi import WebSocket, WebSocketDisconnect, WebSocketException, status
from fastapi.responses import HTMLResponse, Response
Expand Down Expand Up @@ -67,53 +67,18 @@ def generateResponse(doReload: bool = True) -> Response:
return resp


async def websocketSubcribeLoop(websocket: WebSocket, guildId: int):
try:
while True:
_, msg = await websocket.app.subSocket.recv_multipart()
decodedMsg = msg.decode()
if json.loads(decodedMsg).get("guildId") != guildId:
return
await websocket.send_text(f"{decodedMsg}")
except Exception as e:
print(e)


@router.websocket("/ws")
async def ws(websocket: WebSocket):
# Auth checker for WebSocket
scope = websocket.scope
scope["type"] = "http"
request = Request(scope=scope, receive=websocket._receive)
if not request.session.get("userId"):
scope["type"] = "websocket" # WebSocketException would raise error without this
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)

await websocket.accept()
task: Optional[asyncio.Task] = None
@router.websocket("/ws/{_type}/{_data}")
async def ws(websocket: WebSocket, _type: str, _data: Any):
app: "Nexus" = websocket.app

if _type not in ("guild",):
raise WebSocketException(status.WS_1003_UNSUPPORTED_DATA)

conn = await app.websocketManager.connect(websocket, type=_type, data=_data)
try:
while True:
msg = await websocket.receive_json()
_type = msg.get("t")
if _type == "ping":
await websocket.send_json({"t": "pong"})
elif _type == "guild":
if task:
continue

try:
id = int(msg["i"])
except ValueError:
await websocket.send_json(json.dumps({"e": "Invalid ID"}))
continue

task = asyncio.create_task(websocketSubcribeLoop(websocket, id))
await websocket.send_json({"i": id})
else:
await websocket.send_json({"o": f"{msg}"})
except Exception as e:
if task:
task.cancel()

if not isinstance(e, WebSocketDisconnect):
await websocket.close()
# Just to keep the connection alive
await websocket.receive_json()
await asyncio.sleep(1)
except WebSocketDisconnect:
app.websocketManager.disconnect(conn)
76 changes: 76 additions & 0 deletions nexus/core/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import asyncio
import json
from typing import TYPE_CHECKING, Generic, List, TypeVar

from fastapi import WebSocket, WebSocketException, status


if TYPE_CHECKING:
from nexus.core.api import Nexus


T = TypeVar("T")


class Connection(Generic[T]):
def __init__(self, type: str, data: T, websocket: WebSocket) -> None:
self.type: str = type
self.data: T = data
self.websocket: WebSocket = websocket


class WebSocketManager:
def __init__(self, app):
self.app: "Nexus" = app
self.activeConnections: List[Connection] = []

asyncio.create_task(self.updatePublishLoop())

async def handleGuildUpdate(self, decodedData: str):
data = json.loads(decodedData)

for conn in self.activeConnections:
if conn.type != "guild":
continue

if data["before"].get("guildId") != conn.data:
continue

await self.send(decodedData, conn.websocket)

async def updatePublishLoop(self):
if not self.app.subSocket:
return

try:
while True:
msgType, msg = await self.app.subSocket.recv_multipart()

if msgType.startswith(b"guild"):
if msgType.endswith(b"update"):
await self.handleGuildUpdate(msg.decode("utf-8"))
except Exception as e:
print(e)

async def connect(self, websocket: WebSocket, **kwargs):
if not self.app.subSocket:
raise WebSocketException(status.WS_1011_INTERNAL_ERROR, "Nexus is not connected to any sub sockets.")

# Auth checker for WebSocket
if not self.app.validateAuth(websocket.session.get("authToken", {})) or not websocket.session.get("userId"):
raise WebSocketException(status.WS_1008_POLICY_VIOLATION)

await websocket.accept()
conn = Connection(kwargs["type"], kwargs["data"], websocket)
self.activeConnections.append(conn)
return conn

def disconnect(self, connection: Connection):
self.activeConnections.remove(connection)

async def send(self, message: str, websocket: WebSocket):
await websocket.send_text(message)

async def broadcast(self, message: str):
for conn in self.activeConnections:
await self.send(message, conn.websocket)

0 comments on commit d559f16

Please sign in to comment.