|
| 1 | +import asyncio |
| 2 | +import Utils |
| 3 | +import websockets |
| 4 | +import functools |
| 5 | +from copy import deepcopy |
| 6 | +from typing import List, Any, Iterable |
| 7 | +from NetUtils import decode, encode, JSONtoTextParser, JSONMessagePart, NetworkItem |
| 8 | +from MultiServer import Endpoint |
| 9 | +from CommonClient import CommonContext, gui_enabled, ClientCommandProcessor, logger, get_base_parser |
| 10 | + |
| 11 | +DEBUG = False |
| 12 | + |
| 13 | + |
| 14 | +class AHITJSONToTextParser(JSONtoTextParser): |
| 15 | + def _handle_color(self, node: JSONMessagePart): |
| 16 | + return self._handle_text(node) # No colors for the in-game text |
| 17 | + |
| 18 | + |
| 19 | +class AHITCommandProcessor(ClientCommandProcessor): |
| 20 | + def _cmd_ahit(self): |
| 21 | + """Check AHIT Connection State""" |
| 22 | + if isinstance(self.ctx, AHITContext): |
| 23 | + logger.info(f"AHIT Status: {self.ctx.get_ahit_status()}") |
| 24 | + |
| 25 | + |
| 26 | +class AHITContext(CommonContext): |
| 27 | + command_processor = AHITCommandProcessor |
| 28 | + game = "A Hat in Time" |
| 29 | + |
| 30 | + def __init__(self, server_address, password): |
| 31 | + super().__init__(server_address, password) |
| 32 | + self.proxy = None |
| 33 | + self.proxy_task = None |
| 34 | + self.gamejsontotext = AHITJSONToTextParser(self) |
| 35 | + self.autoreconnect_task = None |
| 36 | + self.endpoint = None |
| 37 | + self.items_handling = 0b111 |
| 38 | + self.room_info = None |
| 39 | + self.connected_msg = None |
| 40 | + self.game_connected = False |
| 41 | + self.awaiting_info = False |
| 42 | + self.full_inventory: List[Any] = [] |
| 43 | + self.server_msgs: List[Any] = [] |
| 44 | + |
| 45 | + async def server_auth(self, password_requested: bool = False): |
| 46 | + if password_requested and not self.password: |
| 47 | + await super(AHITContext, self).server_auth(password_requested) |
| 48 | + |
| 49 | + await self.get_username() |
| 50 | + await self.send_connect() |
| 51 | + |
| 52 | + def get_ahit_status(self) -> str: |
| 53 | + if not self.is_proxy_connected(): |
| 54 | + return "Not connected to A Hat in Time" |
| 55 | + |
| 56 | + return "Connected to A Hat in Time" |
| 57 | + |
| 58 | + async def send_msgs_proxy(self, msgs: Iterable[dict]) -> bool: |
| 59 | + """ `msgs` JSON serializable """ |
| 60 | + if not self.endpoint or not self.endpoint.socket.open or self.endpoint.socket.closed: |
| 61 | + return False |
| 62 | + |
| 63 | + if DEBUG: |
| 64 | + logger.info(f"Outgoing message: {msgs}") |
| 65 | + |
| 66 | + await self.endpoint.socket.send(msgs) |
| 67 | + return True |
| 68 | + |
| 69 | + async def disconnect(self, allow_autoreconnect: bool = False): |
| 70 | + await super().disconnect(allow_autoreconnect) |
| 71 | + |
| 72 | + async def disconnect_proxy(self): |
| 73 | + if self.endpoint and not self.endpoint.socket.closed: |
| 74 | + await self.endpoint.socket.close() |
| 75 | + if self.proxy_task is not None: |
| 76 | + await self.proxy_task |
| 77 | + |
| 78 | + def is_connected(self) -> bool: |
| 79 | + return self.server and self.server.socket.open |
| 80 | + |
| 81 | + def is_proxy_connected(self) -> bool: |
| 82 | + return self.endpoint and self.endpoint.socket.open |
| 83 | + |
| 84 | + def on_print_json(self, args: dict): |
| 85 | + text = self.gamejsontotext(deepcopy(args["data"])) |
| 86 | + msg = {"cmd": "PrintJSON", "data": [{"text": text}], "type": "Chat"} |
| 87 | + self.server_msgs.append(encode([msg])) |
| 88 | + |
| 89 | + if self.ui: |
| 90 | + self.ui.print_json(args["data"]) |
| 91 | + else: |
| 92 | + text = self.jsontotextparser(args["data"]) |
| 93 | + logger.info(text) |
| 94 | + |
| 95 | + def update_items(self): |
| 96 | + # just to be safe - we might still have an inventory from a different room |
| 97 | + if not self.is_connected(): |
| 98 | + return |
| 99 | + |
| 100 | + self.server_msgs.append(encode([{"cmd": "ReceivedItems", "index": 0, "items": self.full_inventory}])) |
| 101 | + |
| 102 | + def on_package(self, cmd: str, args: dict): |
| 103 | + if cmd == "Connected": |
| 104 | + self.connected_msg = encode([args]) |
| 105 | + if self.awaiting_info: |
| 106 | + self.server_msgs.append(self.room_info) |
| 107 | + self.update_items() |
| 108 | + self.awaiting_info = False |
| 109 | + |
| 110 | + elif cmd == "ReceivedItems": |
| 111 | + if args["index"] == 0: |
| 112 | + self.full_inventory.clear() |
| 113 | + |
| 114 | + for item in args["items"]: |
| 115 | + self.full_inventory.append(NetworkItem(*item)) |
| 116 | + |
| 117 | + self.server_msgs.append(encode([args])) |
| 118 | + |
| 119 | + elif cmd == "RoomInfo": |
| 120 | + self.seed_name = args["seed_name"] |
| 121 | + self.room_info = encode([args]) |
| 122 | + |
| 123 | + else: |
| 124 | + if cmd != "PrintJSON": |
| 125 | + self.server_msgs.append(encode([args])) |
| 126 | + |
| 127 | + def run_gui(self): |
| 128 | + from kvui import GameManager |
| 129 | + |
| 130 | + class AHITManager(GameManager): |
| 131 | + logging_pairs = [ |
| 132 | + ("Client", "Archipelago") |
| 133 | + ] |
| 134 | + base_title = "Archipelago A Hat in Time Client" |
| 135 | + |
| 136 | + self.ui = AHITManager(self) |
| 137 | + self.ui_task = asyncio.create_task(self.ui.async_run(), name="UI") |
| 138 | + |
| 139 | + |
| 140 | +async def proxy(websocket, path: str = "/", ctx: AHITContext = None): |
| 141 | + ctx.endpoint = Endpoint(websocket) |
| 142 | + try: |
| 143 | + await on_client_connected(ctx) |
| 144 | + |
| 145 | + if ctx.is_proxy_connected(): |
| 146 | + async for data in websocket: |
| 147 | + if DEBUG: |
| 148 | + logger.info(f"Incoming message: {data}") |
| 149 | + |
| 150 | + for msg in decode(data): |
| 151 | + if msg["cmd"] == "Connect": |
| 152 | + # Proxy is connecting, make sure it is valid |
| 153 | + if msg["game"] != "A Hat in Time": |
| 154 | + logger.info("Aborting proxy connection: game is not A Hat in Time") |
| 155 | + await ctx.disconnect_proxy() |
| 156 | + break |
| 157 | + |
| 158 | + if ctx.seed_name: |
| 159 | + seed_name = msg.get("seed_name", "") |
| 160 | + if seed_name != "" and seed_name != ctx.seed_name: |
| 161 | + logger.info("Aborting proxy connection: seed mismatch from save file") |
| 162 | + logger.info(f"Expected: {ctx.seed_name}, got: {seed_name}") |
| 163 | + text = encode([{"cmd": "PrintJSON", |
| 164 | + "data": [{"text": "Connection aborted - save file to seed mismatch"}]}]) |
| 165 | + await ctx.send_msgs_proxy(text) |
| 166 | + await ctx.disconnect_proxy() |
| 167 | + break |
| 168 | + |
| 169 | + if ctx.connected_msg and ctx.is_connected(): |
| 170 | + await ctx.send_msgs_proxy(ctx.connected_msg) |
| 171 | + ctx.update_items() |
| 172 | + continue |
| 173 | + |
| 174 | + if not ctx.is_proxy_connected(): |
| 175 | + break |
| 176 | + |
| 177 | + await ctx.send_msgs([msg]) |
| 178 | + |
| 179 | + except Exception as e: |
| 180 | + if not isinstance(e, websockets.WebSocketException): |
| 181 | + logger.exception(e) |
| 182 | + finally: |
| 183 | + await ctx.disconnect_proxy() |
| 184 | + |
| 185 | + |
| 186 | +async def on_client_connected(ctx: AHITContext): |
| 187 | + if ctx.room_info and ctx.is_connected(): |
| 188 | + await ctx.send_msgs_proxy(ctx.room_info) |
| 189 | + else: |
| 190 | + ctx.awaiting_info = True |
| 191 | + |
| 192 | + |
| 193 | +async def proxy_loop(ctx: AHITContext): |
| 194 | + try: |
| 195 | + while not ctx.exit_event.is_set(): |
| 196 | + if len(ctx.server_msgs) > 0: |
| 197 | + for msg in ctx.server_msgs: |
| 198 | + await ctx.send_msgs_proxy(msg) |
| 199 | + |
| 200 | + ctx.server_msgs.clear() |
| 201 | + await asyncio.sleep(0.1) |
| 202 | + except Exception as e: |
| 203 | + logger.exception(e) |
| 204 | + logger.info("Aborting AHIT Proxy Client due to errors") |
| 205 | + |
| 206 | + |
| 207 | +def launch(): |
| 208 | + async def main(): |
| 209 | + parser = get_base_parser() |
| 210 | + args = parser.parse_args() |
| 211 | + |
| 212 | + ctx = AHITContext(args.connect, args.password) |
| 213 | + logger.info("Starting A Hat in Time proxy server") |
| 214 | + ctx.proxy = websockets.serve(functools.partial(proxy, ctx=ctx), |
| 215 | + host="localhost", port=11311, ping_timeout=999999, ping_interval=999999) |
| 216 | + ctx.proxy_task = asyncio.create_task(proxy_loop(ctx), name="ProxyLoop") |
| 217 | + |
| 218 | + if gui_enabled: |
| 219 | + ctx.run_gui() |
| 220 | + ctx.run_cli() |
| 221 | + |
| 222 | + await ctx.proxy |
| 223 | + await ctx.proxy_task |
| 224 | + await ctx.exit_event.wait() |
| 225 | + |
| 226 | + Utils.init_logging("AHITClient") |
| 227 | + # options = Utils.get_options() |
| 228 | + |
| 229 | + import colorama |
| 230 | + colorama.init() |
| 231 | + asyncio.run(main()) |
| 232 | + colorama.deinit() |
0 commit comments