|  | 
|  | 1 | +import datetime | 
|  | 2 | +import json | 
|  | 3 | +import logging | 
|  | 4 | +import sys | 
|  | 5 | +from typing import Any, Tuple | 
|  | 6 | + | 
|  | 7 | +import aiohttp | 
|  | 8 | +from eth_account import Account | 
|  | 9 | +from eth_account.messages import encode_defunct | 
|  | 10 | +from jwcrypto import jwk | 
|  | 11 | +from jwcrypto.jwa import JWA | 
|  | 12 | + | 
|  | 13 | +logger = logging.getLogger(__name__) | 
|  | 14 | + | 
|  | 15 | + | 
|  | 16 | +class VmClient: | 
|  | 17 | +    def __init__(self, account: Account, domain: str = ""): | 
|  | 18 | +        self.account: Account = account | 
|  | 19 | +        self.ephemeral_key = jwk.JWK.generate(kty="EC", crv="P-256") | 
|  | 20 | +        self.expected_domain = domain | 
|  | 21 | +        self.pubkey_payload = self._generate_pubkey_payload() | 
|  | 22 | +        self.pubkey_signature_header = None | 
|  | 23 | +        self.session = aiohttp.ClientSession() | 
|  | 24 | + | 
|  | 25 | +    def _generate_pubkey_payload(self): | 
|  | 26 | +        return { | 
|  | 27 | +            "pubkey": json.loads(self.ephemeral_key.export_public()), | 
|  | 28 | +            "alg": "ECDSA", | 
|  | 29 | +            "domain": self.expected_domain, | 
|  | 30 | +            "address": self.account.address, | 
|  | 31 | +            "expires": ( | 
|  | 32 | +                datetime.datetime.utcnow() + datetime.timedelta(days=1) | 
|  | 33 | +            ).isoformat() | 
|  | 34 | +            + "Z", | 
|  | 35 | +        } | 
|  | 36 | + | 
|  | 37 | +    def _generate_pubkey_signature_header(self): | 
|  | 38 | +        pubkey_payload = json.dumps(self.pubkey_payload).encode("utf-8").hex() | 
|  | 39 | +        signable_message = encode_defunct(hexstr=pubkey_payload) | 
|  | 40 | +        signed_message = self.account.sign_message(signable_message) | 
|  | 41 | +        pubkey_signature = self.to_0x_hex(signed_message.signature) | 
|  | 42 | +        return json.dumps( | 
|  | 43 | +            { | 
|  | 44 | +                "sender": self.account.address, | 
|  | 45 | +                "payload": pubkey_payload, | 
|  | 46 | +                "signature": pubkey_signature, | 
|  | 47 | +                "content": {"domain": self.expected_domain}, | 
|  | 48 | +            } | 
|  | 49 | +        ) | 
|  | 50 | + | 
|  | 51 | +    @staticmethod | 
|  | 52 | +    def to_0x_hex(b: bytes) -> str: | 
|  | 53 | +        return "0x" + bytes.hex(b) | 
|  | 54 | + | 
|  | 55 | +    @staticmethod | 
|  | 56 | +    def on_message(content): | 
|  | 57 | +        try: | 
|  | 58 | +            msg = json.loads(content) | 
|  | 59 | +            fd = sys.stderr if msg["type"] == "stderr" else sys.stdout | 
|  | 60 | +            logger.info(f"< {msg['message']}") | 
|  | 61 | +        except Exception as e: | 
|  | 62 | +            logger.error(f"Unable to parse content: {content}, Error: {str(e)}") | 
|  | 63 | + | 
|  | 64 | +    async def perform_operation(self, vm_id, operation): | 
|  | 65 | +        if self.pubkey_signature_header is None: | 
|  | 66 | +            self.pubkey_signature_header = self._generate_pubkey_signature_header() | 
|  | 67 | + | 
|  | 68 | +        hostname = f"https://{self.expected_domain}" | 
|  | 69 | +        path = f"/control/machine/{vm_id}/{operation}" | 
|  | 70 | + | 
|  | 71 | +        payload = { | 
|  | 72 | +            "time": datetime.datetime.utcnow().isoformat() + "Z", | 
|  | 73 | +            "method": "POST", | 
|  | 74 | +            "path": path, | 
|  | 75 | +        } | 
|  | 76 | +        payload_as_bytes = json.dumps(payload).encode("utf-8") | 
|  | 77 | +        headers = {"X-SignedPubKey": self.pubkey_signature_header} | 
|  | 78 | +        payload_signature = JWA.signing_alg("ES256").sign( | 
|  | 79 | +            self.ephemeral_key, payload_as_bytes | 
|  | 80 | +        ) | 
|  | 81 | +        headers["X-SignedOperation"] = json.dumps( | 
|  | 82 | +            { | 
|  | 83 | +                "payload": payload_as_bytes.hex(), | 
|  | 84 | +                "signature": payload_signature.hex(), | 
|  | 85 | +            } | 
|  | 86 | +        ) | 
|  | 87 | + | 
|  | 88 | +        url = f"{hostname}{path}" | 
|  | 89 | + | 
|  | 90 | +        try: | 
|  | 91 | +            async with self.session.post(url, headers=headers) as response: | 
|  | 92 | +                response_text = await response.text() | 
|  | 93 | +                return response.status, response_text | 
|  | 94 | +        except aiohttp.ClientError as e: | 
|  | 95 | +            logger.error(f"HTTP error during operation {operation}: {str(e)}") | 
|  | 96 | +            return None, str(e) | 
|  | 97 | + | 
|  | 98 | +    async def get_logs(self, vm_id): | 
|  | 99 | +        if self.pubkey_signature_header is None: | 
|  | 100 | +            self.pubkey_signature_header = self._generate_pubkey_signature_header() | 
|  | 101 | + | 
|  | 102 | +        ws_url = f"https://{self.expected_domain}/control/machine/{vm_id}/logs" | 
|  | 103 | + | 
|  | 104 | +        payload = { | 
|  | 105 | +            "time": datetime.datetime.utcnow().isoformat() + "Z", | 
|  | 106 | +            "method": "GET", | 
|  | 107 | +            "path": f"/control/machine/{vm_id}/logs", | 
|  | 108 | +        } | 
|  | 109 | +        payload_as_bytes = json.dumps(payload).encode("utf-8") | 
|  | 110 | +        headers = {"X-SignedPubKey": self.pubkey_signature_header} | 
|  | 111 | +        payload_signature = JWA.signing_alg("ES256").sign( | 
|  | 112 | +            self.ephemeral_key, payload_as_bytes | 
|  | 113 | +        ) | 
|  | 114 | +        headers["X-SignedOperation"] = json.dumps( | 
|  | 115 | +            { | 
|  | 116 | +                "payload": payload_as_bytes.hex(), | 
|  | 117 | +                "signature": payload_signature.hex(), | 
|  | 118 | +            } | 
|  | 119 | +        ) | 
|  | 120 | + | 
|  | 121 | +        try: | 
|  | 122 | +            async with aiohttp.ClientSession() as session: | 
|  | 123 | +                async with session.ws_connect(ws_url) as ws: | 
|  | 124 | +                    logger.error(f"Connecting to WebSocket URL: {ws_url}") | 
|  | 125 | + | 
|  | 126 | +                    auth_message = { | 
|  | 127 | +                        "auth": { | 
|  | 128 | +                            "X-SignedPubKey": headers["X-SignedPubKey"], | 
|  | 129 | +                            "X-SignedOperation": headers["X-SignedOperation"], | 
|  | 130 | +                        } | 
|  | 131 | +                    } | 
|  | 132 | +                    logger.error(f"Sending auth message: {auth_message}") | 
|  | 133 | +                    await ws.send_json(auth_message) | 
|  | 134 | +                    response = await ws.receive() | 
|  | 135 | +                    logger.error(response.data) | 
|  | 136 | +        except Exception as e: | 
|  | 137 | +            logger.error(f"error : {e}") | 
|  | 138 | + | 
|  | 139 | +    async def get_logs_as_text(self, vm_id): | 
|  | 140 | +        logs = [] | 
|  | 141 | + | 
|  | 142 | +        async def collect_logs(content): | 
|  | 143 | +            try: | 
|  | 144 | +                msg = json.loads(content) | 
|  | 145 | +                logs.append(msg["message"]) | 
|  | 146 | +            except Exception as e: | 
|  | 147 | +                logger.error(f"Unable to parse content: {content}, Error: {str(e)}") | 
|  | 148 | + | 
|  | 149 | +        original_on_message = self.on_message | 
|  | 150 | +        self.on_message = collect_logs | 
|  | 151 | + | 
|  | 152 | +        await self.get_logs(vm_id) | 
|  | 153 | + | 
|  | 154 | +        self.on_message = original_on_message | 
|  | 155 | +        return "\n".join(logs) | 
|  | 156 | + | 
|  | 157 | +    async def start_instance(self, vm_id): | 
|  | 158 | +        return await self.notify_allocation(vm_id) | 
|  | 159 | + | 
|  | 160 | +    async def stop_instance(self, vm_id): | 
|  | 161 | +        return await self.perform_operation(vm_id, "stop") | 
|  | 162 | + | 
|  | 163 | +    async def reboot_instance(self, vm_id): | 
|  | 164 | + | 
|  | 165 | +        return await self.perform_operation(vm_id, "reboot") | 
|  | 166 | + | 
|  | 167 | +    async def erase_instance(self, vm_id): | 
|  | 168 | +        return await self.perform_operation(vm_id, "erase") | 
|  | 169 | + | 
|  | 170 | +    async def expire_instance(self, vm_id): | 
|  | 171 | +        return await self.perform_operation(vm_id, "expire") | 
|  | 172 | + | 
|  | 173 | +    async def notify_allocation(self, vm_id) -> Tuple[Any, str]: | 
|  | 174 | +        json_data = {"instance": vm_id} | 
|  | 175 | +        async with self.session.post( | 
|  | 176 | +            f"https://{self.expected_domain}/control/allocation/notify", json=json_data | 
|  | 177 | +        ) as s: | 
|  | 178 | +            form_response_text = await s.text() | 
|  | 179 | +            return s.status, form_response_text | 
|  | 180 | + | 
|  | 181 | +    async def manage_instance(self, vm_id, operations): | 
|  | 182 | +        for operation in operations: | 
|  | 183 | +            logger.info(f"Performing operation: {operation}") | 
|  | 184 | +            status, response = await self.perform_operation(vm_id, operation) | 
|  | 185 | +            if status != 200: | 
|  | 186 | +                return status, response | 
|  | 187 | +        return | 
|  | 188 | + | 
|  | 189 | +    async def close(self): | 
|  | 190 | +        await self.session.close() | 
|  | 191 | + | 
|  | 192 | +    async def __aenter__(self): | 
|  | 193 | +        return self | 
|  | 194 | + | 
|  | 195 | +    async def __aexit__(self, exc_type, exc_value, traceback): | 
|  | 196 | +        await self.close() | 
0 commit comments