diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index efe667ea..8d5456c6 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -14,8 +14,11 @@ on: jobs: build: strategy: + fail-fast: false matrix: - python-version: [ "3.9", "3.10", "3.11", "3.12" ] + python-version: [ "3.9", "3.10", "3.11" ] + # An issue with secp256k1 prevents Python 3.12 from working + # See https://github.com/baking-bad/pytezos/issues/370 runs-on: ubuntu-latest steps: diff --git a/README.md b/README.md index cfc7e1a4..3d2aea9c 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ $ pip install -e .[all] You can use the test env defined for hatch to run the tests: ```shell -$ hatch run test:run +$ hatch run testing:run ``` See `hatch env show` for more information about all the environments and their scripts. diff --git a/pyproject.toml b/pyproject.toml index 8a70e9c8..1070a7f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,14 +23,16 @@ classifiers = [ ] dependencies = [ "aiohttp>=3.8.3", - "aleph-message~=0.4.4", + "aleph-message>=0.4.7", "coincurve; python_version<\"3.11\"", "coincurve>=19.0.0; python_version>=\"3.11\"", "eth_abi>=4.0.0; python_version>=\"3.11\"", "eth_account>=0.4.0,<0.11.0", + "jwcrypto==1.5.6", "python-magic", "typer", "typing_extensions", + "aioresponses>=0.7.6" ] [project.optional-dependencies] @@ -122,6 +124,8 @@ dependencies = [ "pytest-cov==4.1.0", "pytest-mock==3.12.0", "pytest-asyncio==0.23.5", + "pytest-aiohttp==1.0.5", + "aioresponses==0.7.6", "fastapi", "httpx", "secp256k1", @@ -150,13 +154,13 @@ dependencies = [ [tool.hatch.envs.linting.scripts] typing = "mypy --config-file=pyproject.toml {args:} ./src/ ./tests/ ./examples/" style = [ - "ruff {args:.} ./src/ ./tests/ ./examples/", + "ruff check {args:.} ./src/ ./tests/ ./examples/", "black --check --diff {args:} ./src/ ./tests/ ./examples/", "isort --check-only --profile black {args:} ./src/ ./tests/ ./examples/", ] fmt = [ "black {args:} ./src/ ./tests/ ./examples/", - "ruff --fix {args:.} ./src/ ./tests/ ./examples/", + "ruff check --fix {args:.} ./src/ ./tests/ ./examples/", "isort --profile black {args:} ./src/ ./tests/ ./examples/", "style", ] diff --git a/src/aleph/sdk/chains/common.py b/src/aleph/sdk/chains/common.py index b73d6e41..0a90183c 100644 --- a/src/aleph/sdk/chains/common.py +++ b/src/aleph/sdk/chains/common.py @@ -170,10 +170,3 @@ def get_fallback_private_key(path: Optional[Path] = None) -> bytes: if not default_key_path.exists(): default_key_path.symlink_to(path) return private_key - - -def bytes_from_hex(hex_string: str) -> bytes: - if hex_string.startswith("0x"): - hex_string = hex_string[2:] - hex_string = bytes.fromhex(hex_string) - return hex_string diff --git a/src/aleph/sdk/chains/ethereum.py b/src/aleph/sdk/chains/ethereum.py index 124fbee7..b0fa5fbe 100644 --- a/src/aleph/sdk/chains/ethereum.py +++ b/src/aleph/sdk/chains/ethereum.py @@ -7,12 +7,8 @@ from eth_keys.exceptions import BadSignature as EthBadSignatureError from ..exceptions import BadSignatureError -from .common import ( - BaseAccount, - bytes_from_hex, - get_fallback_private_key, - get_public_key, -) +from ..utils import bytes_from_hex +from .common import BaseAccount, get_fallback_private_key, get_public_key class ETHAccount(BaseAccount): diff --git a/src/aleph/sdk/chains/substrate.py b/src/aleph/sdk/chains/substrate.py index 13795568..f4d18a0d 100644 --- a/src/aleph/sdk/chains/substrate.py +++ b/src/aleph/sdk/chains/substrate.py @@ -9,7 +9,8 @@ from ..conf import settings from ..exceptions import BadSignatureError -from .common import BaseAccount, bytes_from_hex, get_verification_buffer +from ..utils import bytes_from_hex +from .common import BaseAccount, get_verification_buffer logger = logging.getLogger(__name__) diff --git a/src/aleph/sdk/client/authenticated_http.py b/src/aleph/sdk/client/authenticated_http.py index 60d42b2b..6d44b526 100644 --- a/src/aleph/sdk/client/authenticated_http.py +++ b/src/aleph/sdk/client/authenticated_http.py @@ -30,6 +30,7 @@ from aleph_message.models.execution.environment import ( FunctionEnvironment, HypervisorType, + InstanceEnvironment, MachineResources, ) from aleph_message.models.execution.instance import RootfsVolume @@ -534,16 +535,17 @@ async def create_instance( timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT payment = payment or Payment(chain=Chain.ETH, type=PaymentType.hold) - hypervisor = hypervisor or HypervisorType.firecracker + + # Default to the QEMU hypervisor for instances. + selected_hypervisor: HypervisorType = hypervisor or HypervisorType.qemu content = InstanceContent( address=address, allow_amend=allow_amend, - environment=FunctionEnvironment( - reproducible=False, + environment=InstanceEnvironment( internet=internet, aleph_api=aleph_api, - hypervisor=hypervisor, + hypervisor=selected_hypervisor, ), variables=environment_variables, resources=MachineResources( diff --git a/src/aleph/sdk/client/vm_client.py b/src/aleph/sdk/client/vm_client.py new file mode 100644 index 00000000..4092851d --- /dev/null +++ b/src/aleph/sdk/client/vm_client.py @@ -0,0 +1,192 @@ +import datetime +import json +import logging +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +from urllib.parse import urlparse + +import aiohttp +from aleph_message.models import ItemHash +from eth_account.messages import encode_defunct +from jwcrypto import jwk + +from aleph.sdk.types import Account +from aleph.sdk.utils import ( + create_vm_control_payload, + sign_vm_control_payload, + to_0x_hex, +) + +logger = logging.getLogger(__name__) + + +class VmClient: + account: Account + ephemeral_key: jwk.JWK + node_url: str + pubkey_payload: Dict[str, Any] + pubkey_signature_header: str + session: aiohttp.ClientSession + + def __init__( + self, + account: Account, + node_url: str = "", + session: Optional[aiohttp.ClientSession] = None, + ): + self.account = account + self.ephemeral_key = jwk.JWK.generate(kty="EC", crv="P-256") + self.node_url = node_url + self.pubkey_payload = self._generate_pubkey_payload() + self.pubkey_signature_header = "" + self.session = session or aiohttp.ClientSession() + + def _generate_pubkey_payload(self) -> Dict[str, Any]: + return { + "pubkey": json.loads(self.ephemeral_key.export_public()), + "alg": "ECDSA", + "domain": self.node_domain, + "address": self.account.get_address(), + "expires": ( + datetime.datetime.utcnow() + datetime.timedelta(days=1) + ).isoformat() + + "Z", + } + + async def _generate_pubkey_signature_header(self) -> str: + pubkey_payload = json.dumps(self.pubkey_payload).encode("utf-8").hex() + signable_message = encode_defunct(hexstr=pubkey_payload) + buffer_to_sign = signable_message.body + + signed_message = await self.account.sign_raw(buffer_to_sign) + pubkey_signature = to_0x_hex(signed_message) + + return json.dumps( + { + "sender": self.account.get_address(), + "payload": pubkey_payload, + "signature": pubkey_signature, + "content": {"domain": self.node_domain}, + } + ) + + async def _generate_header( + self, vm_id: ItemHash, operation: str, method: str + ) -> Tuple[str, Dict[str, str]]: + payload = create_vm_control_payload( + vm_id, operation, domain=self.node_domain, method=method + ) + signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) + + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + headers = { + "X-SignedPubKey": self.pubkey_signature_header, + "X-SignedOperation": signed_operation, + } + + path = payload["path"] + return f"{self.node_url}{path}", headers + + @property + def node_domain(self) -> str: + domain = urlparse(self.node_url).hostname + if not domain: + raise Exception("Could not parse node domain") + return domain + + async def perform_operation( + self, vm_id: ItemHash, operation: str, method: str = "POST" + ) -> Tuple[Optional[int], str]: + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + url, header = await self._generate_header( + vm_id=vm_id, operation=operation, method=method + ) + + try: + async with self.session.request( + method=method, url=url, headers=header + ) as response: + response_text = await response.text() + return response.status, response_text + + except aiohttp.ClientError as e: + logger.error(f"HTTP error during operation {operation}: {str(e)}") + return None, str(e) + + async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]: + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + payload = create_vm_control_payload( + vm_id, "stream_logs", method="get", domain=self.node_domain + ) + signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) + path = payload["path"] + ws_url = f"{self.node_url}{path}" + + async with self.session.ws_connect(ws_url) as ws: + auth_message = { + "auth": { + "X-SignedPubKey": json.loads(self.pubkey_signature_header), + "X-SignedOperation": json.loads(signed_operation), + } + } + await ws.send_json(auth_message) + + async for msg in ws: # msg is of type aiohttp.WSMessage + if msg.type == aiohttp.WSMsgType.TEXT: + yield msg.data + elif msg.type == aiohttp.WSMsgType.ERROR: + break + + async def start_instance(self, vm_id: ItemHash) -> Tuple[int, str]: + return await self.notify_allocation(vm_id) + + async def stop_instance(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: + return await self.perform_operation(vm_id, "stop") + + async def reboot_instance(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: + return await self.perform_operation(vm_id, "reboot") + + async def erase_instance(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: + return await self.perform_operation(vm_id, "erase") + + async def expire_instance(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: + return await self.perform_operation(vm_id, "expire") + + async def notify_allocation(self, vm_id: ItemHash) -> Tuple[int, str]: + json_data = {"instance": vm_id} + + async with self.session.post( + f"{self.node_url}/control/allocation/notify", json=json_data + ) as session: + form_response_text = await session.text() + + return session.status, form_response_text + + async def manage_instance( + self, vm_id: ItemHash, operations: List[str] + ) -> Tuple[int, str]: + for operation in operations: + status, response = await self.perform_operation(vm_id, operation) + if status != 200 and status: + return status, response + return 200, "All operations completed successfully" + + async def close(self): + await self.session.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() diff --git a/src/aleph/sdk/client/vm_confidential_client.py b/src/aleph/sdk/client/vm_confidential_client.py new file mode 100644 index 00000000..a100de8c --- /dev/null +++ b/src/aleph/sdk/client/vm_confidential_client.py @@ -0,0 +1,216 @@ +import base64 +import json +import logging +import os +import tempfile +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import aiohttp +from aleph_message.models import ItemHash + +from aleph.sdk.client.vm_client import VmClient +from aleph.sdk.types import Account, SEVMeasurement +from aleph.sdk.utils import ( + compute_confidential_measure, + encrypt_secret_table, + get_vm_measure, + make_packet_header, + make_secret_table, + run_in_subprocess, +) + +logger = logging.getLogger(__name__) + + +class VmConfidentialClient(VmClient): + sevctl_path: Path + + def __init__( + self, + account: Account, + sevctl_path: Path, + node_url: str = "", + session: Optional[aiohttp.ClientSession] = None, + ): + super().__init__(account, node_url, session) + self.sevctl_path = sevctl_path + + async def get_certificates(self) -> Tuple[Optional[int], str]: + """ + Get platform confidential certificate + """ + + url = f"{self.node_url}/about/certificates" + try: + async with self.session.get(url) as response: + data = await response.read() + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + tmp_file.write(data) + return response.status, tmp_file.name + + except aiohttp.ClientError as e: + logger.error( + f"HTTP error getting node certificates on {self.node_url}: {str(e)}" + ) + return None, str(e) + + async def create_session( + self, vm_id: ItemHash, certificate_path: Path, policy: int + ) -> Path: + """ + Create new confidential session + """ + + current_path = Path().cwd() + args = [ + "session", + "--name", + str(vm_id), + str(certificate_path), + str(policy), + ] + try: + # TODO: Check command result + await self.sevctl_cmd(*args) + return current_path + except Exception as e: + raise ValueError(f"Session creation have failed, reason: {str(e)}") + + async def initialize(self, vm_id: ItemHash, session: Path, godh: Path) -> str: + """ + Initialize Confidential VM negociation passing the needed session files + """ + + session_file = session.read_bytes() + godh_file = godh.read_bytes() + params = { + "session": session_file, + "godh": godh_file, + } + return await self.perform_confidential_operation( + vm_id, "confidential/initialize", params=params + ) + + async def measurement(self, vm_id: ItemHash) -> SEVMeasurement: + """ + Fetch VM confidential measurement + """ + + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + status, text = await self.perform_operation( + vm_id, "confidential/measurement", method="GET" + ) + sev_mesurement = SEVMeasurement.parse_raw(text) + return sev_mesurement + + async def validate_measure( + self, sev_data: SEVMeasurement, tik_path: Path, firmware_hash: str + ) -> bool: + """ + Validate VM confidential measurement + """ + + tik = tik_path.read_bytes() + vm_measure, nonce = get_vm_measure(sev_data) + + expected_measure = compute_confidential_measure( + sev_info=sev_data.sev_info, + tik=tik, + expected_hash=firmware_hash, + nonce=nonce, + ).digest() + return expected_measure == vm_measure + + async def build_secret( + self, tek_path: Path, tik_path: Path, sev_data: SEVMeasurement, secret: str + ) -> Tuple[str, str]: + """ + Build disk secret to be injected on the confidential VM + """ + + tek = tek_path.read_bytes() + tik = tik_path.read_bytes() + + vm_measure, _ = get_vm_measure(sev_data) + + iv = os.urandom(16) + secret_table = make_secret_table(secret) + encrypted_secret_table = encrypt_secret_table( + secret_table=secret_table, tek=tek, iv=iv + ) + + packet_header = make_packet_header( + vm_measure=vm_measure, + encrypted_secret_table=encrypted_secret_table, + secret_table_size=len(secret_table), + tik=tik, + iv=iv, + ) + + encoded_packet_header = base64.b64encode(packet_header).decode() + encoded_secret = base64.b64encode(encrypted_secret_table).decode() + + return encoded_packet_header, encoded_secret + + async def inject_secret( + self, vm_id: ItemHash, packet_header: str, secret: str + ) -> Dict: + """ + Send the secret by the encrypted channel to boot up the VM + """ + + params = { + "packet_header": packet_header, + "secret": secret, + } + text = await self.perform_confidential_operation( + vm_id, "confidential/inject_secret", json=params + ) + + return json.loads(text) + + async def perform_confidential_operation( + self, + vm_id: ItemHash, + operation: str, + params: Optional[Dict[str, Any]] = None, + json=None, + ) -> str: + """ + Send confidential operations to the CRN passing the auth headers on each request + """ + + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + url, header = await self._generate_header( + vm_id=vm_id, operation=operation, method="post" + ) + + try: + async with self.session.post( + url, headers=header, data=params, json=json + ) as response: + response.raise_for_status() + response_text = await response.text() + return response_text + + except aiohttp.ClientError as e: + raise ValueError(f"HTTP error during operation {operation}: {str(e)}") + + async def sevctl_cmd(self, *args) -> bytes: + """ + Execute `sevctl` command with given arguments + """ + + return await run_in_subprocess( + [str(self.sevctl_path), *args], + check=True, + ) diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index 8d17f4d4..cf9e6fa8 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -2,6 +2,8 @@ from enum import Enum from typing import Dict, Protocol, TypeVar +from pydantic import BaseModel + __all__ = ("StorageEnum", "Account", "AccountFromPrivateKey", "GenericMessage") from aleph_message.models import AlephMessage @@ -20,6 +22,9 @@ class Account(Protocol): @abstractmethod async def sign_message(self, message: Dict) -> Dict: ... + @abstractmethod + async def sign_raw(self, buffer: bytes) -> bytes: ... + @abstractmethod def get_address(self) -> str: ... @@ -36,3 +41,26 @@ async def sign_raw(self, buffer: bytes) -> bytes: ... GenericMessage = TypeVar("GenericMessage", bound=AlephMessage) + + +class SEVInfo(BaseModel): + """ + An AMD SEV platform information. + """ + + enabled: bool + api_major: int + api_minor: int + build_id: int + policy: int + state: str + handle: int + + +class SEVMeasurement(BaseModel): + """ + A SEV measurement data get from Qemu measurement. + """ + + sev_info: SEVInfo + launch_measure: str diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index b1c04cdf..5c641d5c 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -1,14 +1,21 @@ +import asyncio +import base64 import errno import hashlib +import hmac +import json import logging import os +import subprocess from datetime import date, datetime, time from enum import Enum from pathlib import Path from shutil import make_archive from typing import ( Any, + Dict, Iterable, + List, Mapping, Optional, Protocol, @@ -18,15 +25,19 @@ Union, get_args, ) +from uuid import UUID from zipfile import BadZipFile, ZipFile -from aleph_message.models import MessageType +from aleph_message.models import ItemHash, MessageType from aleph_message.models.execution.program import Encoding from aleph_message.models.execution.volume import MachineVolume +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from jwcrypto.jwa import JWA from pydantic.json import pydantic_encoder from aleph.sdk.conf import settings -from aleph.sdk.types import GenericMessage +from aleph.sdk.types import GenericMessage, SEVInfo, SEVMeasurement logger = logging.getLogger(__name__) @@ -184,3 +195,186 @@ def parse_volume(volume_dict: Union[Mapping, MachineVolume]) -> MachineVolume: def compute_sha256(s: str) -> str: """Compute the SHA256 hash of a string.""" return hashlib.sha256(s.encode()).hexdigest() + + +def to_0x_hex(b: bytes) -> str: + return "0x" + bytes.hex(b) + + +def bytes_from_hex(hex_string: str) -> bytes: + if hex_string.startswith("0x"): + hex_string = hex_string[2:] + hex_string = bytes.fromhex(hex_string) + return hex_string + + +def create_vm_control_payload( + vm_id: ItemHash, operation: str, domain: str, method: str +) -> Dict[str, str]: + path = f"/control/machine/{vm_id}/{operation}" + payload = { + "time": datetime.utcnow().isoformat() + "Z", + "method": method.upper(), + "path": path, + "domain": domain, + } + return payload + + +def sign_vm_control_payload(payload: Dict[str, str], ephemeral_key) -> str: + payload_as_bytes = json.dumps(payload).encode("utf-8") + payload_signature = JWA.signing_alg("ES256").sign(ephemeral_key, payload_as_bytes) + signed_operation = json.dumps( + { + "payload": payload_as_bytes.hex(), + "signature": payload_signature.hex(), + } + ) + return signed_operation + + +async def run_in_subprocess( + command: List[str], check: bool = True, stdin_input: Optional[bytes] = None +) -> bytes: + """Run the specified command in a subprocess, returns the stdout of the process.""" + logger.debug(f"command: {' '.join(command)}") + + process = await asyncio.create_subprocess_exec( + *command, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate(input=stdin_input) + + if check and process.returncode: + logger.error( + f"Command failed with error code {process.returncode}:\n" + f" stdin = {stdin_input!r}\n" + f" command = {command}\n" + f" stdout = {stderr!r}" + ) + raise subprocess.CalledProcessError( + process.returncode, str(command), stderr.decode() + ) + + return stdout + + +def get_vm_measure(sev_data: SEVMeasurement) -> Tuple[bytes, bytes]: + launch_measure = base64.b64decode(sev_data.launch_measure) + vm_measure = launch_measure[0:32] + nonce = launch_measure[32:48] + return vm_measure, nonce + + +def compute_confidential_measure( + sev_info: SEVInfo, tik: bytes, expected_hash: str, nonce: bytes +) -> hmac.HMAC: + """ + Computes the SEV measurement using the CRN SEV data and local variables like the OVMF firmware hash, + and the session key generated. + """ + + h = hmac.new(tik, digestmod="sha256") + + ## + # calculated per section 6.5.2 + ## + h.update(bytes([0x04])) + h.update(sev_info.api_major.to_bytes(1, byteorder="little")) + h.update(sev_info.api_minor.to_bytes(1, byteorder="little")) + h.update(sev_info.build_id.to_bytes(1, byteorder="little")) + h.update(sev_info.policy.to_bytes(4, byteorder="little")) + + expected_hash_bytes = bytearray.fromhex(expected_hash) + h.update(expected_hash_bytes) + + h.update(nonce) + + return h + + +def make_secret_table(secret: str) -> bytearray: + """ + Makes the disk secret table to be sent to the Confidential CRN + """ + + ## + # Construct the secret table: two guids + 4 byte lengths plus string + # and zero terminator + # + # Secret layout is guid, len (4 bytes), data + # with len being the length from start of guid to end of data + # + # The table header covers the entire table then each entry covers + # only its local data + # + # our current table has the header guid with total table length + # followed by the secret guid with the zero terminated secret + ## + + # total length of table: header plus one entry with trailing \0 + length = 16 + 4 + 16 + 4 + len(secret) + 1 + # SEV-ES requires rounding to 16 + length = (length + 15) & ~15 + secret_table = bytearray(length) + + secret_table[0:16] = UUID("{1e74f542-71dd-4d66-963e-ef4287ff173b}").bytes_le + secret_table[16:20] = len(secret_table).to_bytes(4, byteorder="little") + secret_table[20:36] = UUID("{736869e5-84f0-4973-92ec-06879ce3da0b}").bytes_le + secret_table[36:40] = (16 + 4 + len(secret) + 1).to_bytes(4, byteorder="little") + secret_table[40 : 40 + len(secret)] = secret.encode() + + return secret_table + + +def encrypt_secret_table(secret_table: bytes, tek: bytes, iv: bytes) -> bytes: + """Encrypt the secret table with the TEK in CTR mode using a random IV""" + + # Initialize the cipher with AES algorithm and CTR mode + cipher = Cipher(algorithms.AES(tek), modes.CTR(iv), backend=default_backend()) + encryptor = cipher.encryptor() + + # Encrypt the secret table + encrypted_secret = encryptor.update(secret_table) + encryptor.finalize() + + return encrypted_secret + + +def make_packet_header( + vm_measure: bytes, + encrypted_secret_table: bytes, + secret_table_size: int, + tik: bytes, + iv: bytes, +) -> bytearray: + """ + Creates a packet header using the encrypted disk secret table to be sent to the Confidential CRN + """ + + ## + # ultimately needs to be an argument, but there's only + # compressed and no real use case + ## + flags = 0 + + ## + # Table 55. LAUNCH_SECRET Packet Header Buffer + ## + header = bytearray(52) + header[0:4] = flags.to_bytes(4, byteorder="little") + header[4:20] = iv + + h = hmac.new(tik, digestmod="sha256") + h.update(bytes([0x01])) + # FLAGS || IV + h.update(header[0:20]) + h.update(secret_table_size.to_bytes(4, byteorder="little")) + h.update(secret_table_size.to_bytes(4, byteorder="little")) + h.update(encrypted_secret_table) + h.update(vm_measure) + + header[20:52] = h.digest() + + return header diff --git a/src/aleph/sdk/wallets/ledger/ethereum.py b/src/aleph/sdk/wallets/ledger/ethereum.py index 2ecdc5d3..5dc40f03 100644 --- a/src/aleph/sdk/wallets/ledger/ethereum.py +++ b/src/aleph/sdk/wallets/ledger/ethereum.py @@ -9,7 +9,8 @@ from ledgereth.messages import sign_message from ledgereth.objects import LedgerAccount, SignedMessage -from ...chains.common import BaseAccount, bytes_from_hex, get_verification_buffer +from ...chains.common import BaseAccount, get_verification_buffer +from ...utils import bytes_from_hex class LedgerETHAccount(BaseAccount): diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py new file mode 100644 index 00000000..491da51a --- /dev/null +++ b/tests/unit/aleph_vm_authentication.py @@ -0,0 +1,290 @@ +# Keep datetime import as is as it allow patching in test +import datetime +import functools +import json +import logging +from collections.abc import Awaitable, Coroutine +from typing import Any, Callable, Dict, Literal, Optional, Union + +import cryptography.exceptions +import pydantic +from aiohttp import web +from eth_account import Account +from eth_account.messages import encode_defunct +from jwcrypto import jwk +from jwcrypto.jwa import JWA +from pydantic import BaseModel, ValidationError, root_validator, validator + +from aleph.sdk.utils import bytes_from_hex + +logger = logging.getLogger(__name__) + +DOMAIN_NAME = "localhost" + + +def is_token_still_valid(datestr: str) -> bool: + """ + Checks if a token has expired based on its expiry timestamp + """ + current_datetime = datetime.datetime.now(tz=datetime.timezone.utc) + expiry_datetime = datetime.datetime.fromisoformat(datestr.replace("Z", "+00:00")) + + return expiry_datetime > current_datetime + + +def verify_wallet_signature(signature: bytes, message: str, address: str) -> bool: + """ + Verifies a signature issued by a wallet + """ + enc_msg = encode_defunct(hexstr=message) + computed_address = Account.recover_message(enc_msg, signature=signature) + + return computed_address.lower() == address.lower() + + +class SignedPubKeyPayload(BaseModel): + """This payload is signed by the wallet of the user to authorize an ephemeral key to act on his behalf.""" + + pubkey: Dict[str, Any] + # {'pubkey': {'alg': 'ES256', 'crv': 'P-256', 'ext': True, 'key_ops': ['verify'], 'kty': 'EC', + # 'x': '4blJBYpltvQLFgRvLE-2H7dsMr5O0ImHkgOnjUbG2AU', 'y': '5VHnq_hUSogZBbVgsXMs0CjrVfMy4Pa3Uv2BEBqfrN4'} + # alg: Literal["ECDSA"] + address: str + expires: str + + @property + def json_web_key(self) -> jwk.JWK: + """Return the ephemeral public key as Json Web Key""" + + return jwk.JWK(**self.pubkey) + + +class SignedPubKeyHeader(BaseModel): + signature: bytes + payload: bytes + + @validator("signature") + def signature_must_be_hex(cls, value: bytes) -> bytes: + """Convert the signature from hexadecimal to bytes""" + + return bytes_from_hex(value.decode()) + + @validator("payload") + def payload_must_be_hex(cls, value: bytes) -> bytes: + """Convert the payload from hexadecimal to bytes""" + + return bytes_from_hex(value.decode()) + + @root_validator(pre=False, skip_on_failure=True) + def check_expiry(cls, values) -> Dict[str, bytes]: + """Check that the token has not expired""" + payload: bytes = values["payload"] + content = SignedPubKeyPayload.parse_raw(payload) + + if not is_token_still_valid(content.expires): + msg = "Token expired" + raise ValueError(msg) + + return values + + @root_validator(pre=False, skip_on_failure=True) + def check_signature(cls, values: Dict[str, bytes]) -> Dict[str, bytes]: + """Check that the signature is valid""" + signature: bytes = values["signature"] + payload: bytes = values["payload"] + content = SignedPubKeyPayload.parse_raw(payload) + + if not verify_wallet_signature(signature, payload.hex(), content.address): + msg = "Invalid signature" + raise ValueError(msg) + + return values + + @property + def content(self) -> SignedPubKeyPayload: + """Return the content of the header""" + return SignedPubKeyPayload.parse_raw(self.payload) + + +class SignedOperationPayload(BaseModel): + time: datetime.datetime + method: Union[Literal["POST"], Literal["GET"]] + domain: str + path: str + # body_sha256: str # disabled since there is no body + + @validator("time") + def time_is_current(cls, v: datetime.datetime) -> datetime.datetime: + """Check that the time is current and the payload is not a replay attack.""" + max_past = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta( + minutes=2 + ) + max_future = datetime.datetime.now( + tz=datetime.timezone.utc + ) + datetime.timedelta(minutes=2) + if v < max_past: + raise ValueError("Time is too far in the past") + if v > max_future: + raise ValueError("Time is too far in the future") + return v + + +class SignedOperation(BaseModel): + """This payload is signed by the ephemeral key authorized above.""" + + signature: bytes + payload: bytes + + @validator("signature") + def signature_must_be_hex(cls, value: str) -> bytes: + """Convert the signature from hexadecimal to bytes""" + + try: + if isinstance(value, bytes): + value = value.decode() + return bytes_from_hex(value) + except pydantic.ValidationError as error: + logger.warning(value) + raise error + + @validator("payload") + def payload_must_be_hex(cls, v) -> bytes: + """Convert the payload from hexadecimal to bytes""" + v = bytes.fromhex(v.decode()) + _ = SignedOperationPayload.parse_raw(v) + return v + + @property + def content(self) -> SignedOperationPayload: + """Return the content of the header""" + return SignedOperationPayload.parse_raw(self.payload) + + +def get_signed_pubkey(request: web.Request) -> SignedPubKeyHeader: + """Get the ephemeral public key that is signed by the wallet from the request headers.""" + signed_pubkey_header = request.headers.get("X-SignedPubKey") + + if not signed_pubkey_header: + raise web.HTTPBadRequest(reason="Missing X-SignedPubKey header") + + try: + return SignedPubKeyHeader.parse_raw(signed_pubkey_header) + + except KeyError as error: + logger.debug(f"Missing X-SignedPubKey header: {error}") + raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey fields") from error + + except json.JSONDecodeError as error: + raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey format") from error + + except ValueError as errors: + logging.debug(errors) + + for err in errors.args[0]: + if isinstance(err.exc, json.JSONDecodeError): + raise web.HTTPBadRequest( + reason="Invalid X-SignedPubKey format" + ) from errors + + if str(err.exc) == "Token expired": + raise web.HTTPUnauthorized(reason="Token expired") from errors + + if str(err.exc) == "Invalid signature": + raise web.HTTPUnauthorized(reason="Invalid signature") from errors + else: + raise errors + + +def get_signed_operation(request: web.Request) -> SignedOperation: + """Get the signed operation public key that is signed by the ephemeral key from the request headers.""" + try: + signed_operation = request.headers["X-SignedOperation"] + return SignedOperation.parse_raw(signed_operation) + except KeyError as error: + raise web.HTTPBadRequest(reason="Missing X-SignedOperation header") from error + except json.JSONDecodeError as error: + raise web.HTTPBadRequest(reason="Invalid X-SignedOperation format") from error + except ValidationError as error: + logger.debug(f"Invalid X-SignedOperation fields: {error}") + raise web.HTTPBadRequest(reason="Invalid X-SignedOperation fields") from error + + +def verify_signed_operation( + signed_operation: SignedOperation, signed_pubkey: SignedPubKeyHeader +) -> str: + """Verify that the operation is signed by the ephemeral key authorized by the wallet.""" + pubkey = signed_pubkey.content.json_web_key + + try: + JWA.signing_alg("ES256").verify( + pubkey, signed_operation.payload, signed_operation.signature + ) + logger.debug("Signature verified") + + return signed_pubkey.content.address + + except cryptography.exceptions.InvalidSignature as e: + logger.debug("Failing to validate signature for operation", e) + + raise web.HTTPUnauthorized(reason="Signature could not verified") + + +async def authenticate_jwk( + request: web.Request, domain_name: Optional[str] = DOMAIN_NAME +) -> str: + """Authenticate a request using the X-SignedPubKey and X-SignedOperation headers.""" + signed_pubkey = get_signed_pubkey(request) + signed_operation = get_signed_operation(request) + + if signed_operation.content.domain != domain_name: + logger.debug( + f"Invalid domain '{signed_operation.content.domain}' != '{domain_name}'" + ) + raise web.HTTPUnauthorized(reason="Invalid domain") + + if signed_operation.content.path != request.path: + logger.debug( + f"Invalid path '{signed_operation.content.path}' != '{request.path}'" + ) + raise web.HTTPUnauthorized(reason="Invalid path") + if signed_operation.content.method != request.method: + logger.debug( + f"Invalid method '{signed_operation.content.method}' != '{request.method}'" + ) + raise web.HTTPUnauthorized(reason="Invalid method") + return verify_signed_operation(signed_operation, signed_pubkey) + + +async def authenticate_websocket_message( + message, domain_name: Optional[str] = DOMAIN_NAME +) -> str: + """Authenticate a websocket message since JS cannot configure headers on WebSockets.""" + signed_pubkey = SignedPubKeyHeader.parse_obj(message["X-SignedPubKey"]) + signed_operation = SignedOperation.parse_obj(message["X-SignedOperation"]) + if signed_operation.content.domain != domain_name: + logger.debug( + f"Invalid domain '{signed_pubkey.content.domain}' != '{domain_name}'" + ) + raise web.HTTPUnauthorized(reason="Invalid domain") + return verify_signed_operation(signed_operation, signed_pubkey) + + +def require_jwk_authentication( + handler: Callable[[web.Request, str], Coroutine[Any, Any, web.StreamResponse]] +) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: + @functools.wraps(handler) + async def wrapper(request): + try: + authenticated_sender: str = await authenticate_jwk(request) + except web.HTTPException as e: + return web.json_response(data={"error": e.reason}, status=e.status) + except Exception as e: + # Unexpected make sure to log it + logging.exception(e) + raise + + # authenticated_sender is the authenticted wallet address of the requester (as a string) + response = await handler(request, authenticated_sender) + return response + + return wrapper diff --git a/tests/unit/test_asynchronous.py b/tests/unit/test_asynchronous.py index 0fa0df38..0f909408 100644 --- a/tests/unit/test_asynchronous.py +++ b/tests/unit/test_asynchronous.py @@ -157,7 +157,7 @@ async def test_create_instance_no_hypervisor(mock_session_with_post_success): hypervisor=None, ) - assert instance_message.content.environment.hypervisor == HypervisorType.firecracker + assert instance_message.content.environment.hypervisor == HypervisorType.qemu assert mock_session_with_post_success.http_session.post.assert_called_once assert isinstance(instance_message, InstanceMessage) diff --git a/tests/unit/test_vm_client.py b/tests/unit/test_vm_client.py new file mode 100644 index 00000000..7cc9a2c3 --- /dev/null +++ b/tests/unit/test_vm_client.py @@ -0,0 +1,297 @@ +from urllib.parse import urlparse + +import aiohttp +import pytest +from aiohttp import web +from aioresponses import aioresponses +from aleph_message.models import ItemHash +from yarl import URL + +from aleph.sdk.chains.ethereum import ETHAccount +from aleph.sdk.client.vm_client import VmClient + +from .aleph_vm_authentication import ( + SignedOperation, + SignedPubKeyHeader, + authenticate_jwk, + authenticate_websocket_message, + verify_signed_operation, +) + + +@pytest.mark.asyncio +async def test_notify_allocation(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post("http://localhost/control/allocation/notify", status=200) + await vm_client.notify_allocation(vm_id=vm_id) + assert len(m.requests) == 1 + assert ("POST", URL("http://localhost/control/allocation/notify")) in m.requests + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_perform_operation(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "reboot" + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/{operation}", + status=200, + payload="mock_response_text", + ) + + status, response_text = await vm_client.perform_operation(vm_id, operation) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_stop_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/stop", + status=200, + payload="mock_response_text", + ) + + status, response_text = await vm_client.stop_instance(vm_id) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_reboot_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/reboot", + status=200, + payload="mock_response_text", + ) + + status, response_text = await vm_client.reboot_instance(vm_id) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_erase_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/erase", + status=200, + payload="mock_response_text", + ) + + status, response_text = await vm_client.erase_instance(vm_id) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_expire_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/expire", + status=200, + payload="mock_response_text", + ) + + status, response_text = await vm_client.expire_instance(vm_id) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_get_logs(aiohttp_client): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + async def websocket_handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + await ws.send_str("mock_log_entry") + elif msg.type == aiohttp.WSMsgType.ERROR: + break + + return ws + + app = web.Application() + app.router.add_route( + "GET", "/control/machine/{vm_id}/stream_logs", websocket_handler + ) # Update route to match the URL + + client = await aiohttp_client(app) + + node_url = str(client.make_url("")).rstrip("/") + + vm_client = VmClient( + account=account, + node_url=node_url, + session=client.session, + ) + + logs = [] + async for log in vm_client.get_logs(vm_id): + logs.append(log) + if log == "mock_log_entry": + break + + assert logs == ["mock_log_entry"] + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_authenticate_jwk(aiohttp_client): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + async def test_authenticate_route(request): + address = await authenticate_jwk( + request, domain_name=urlparse(node_url).hostname + ) + assert vm_client.account.get_address() == address + return web.Response(text="ok") + + app = web.Application() + app.router.add_route( + "POST", f"/control/machine/{vm_id}/stop", test_authenticate_route + ) # Update route to match the URL + + client = await aiohttp_client(app) + + node_url = str(client.make_url("")).rstrip("/") + + vm_client = VmClient( + account=account, + node_url=node_url, + session=client.session, + ) + + status_code, response_text = await vm_client.stop_instance(vm_id) + assert status_code == 200, response_text + assert response_text == "ok" + + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_websocket_authentication(aiohttp_client): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + async def websocket_handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + first_message = await ws.receive_json() + credentials = first_message["auth"] + sender_address = await authenticate_websocket_message( + credentials, + domain_name=urlparse(node_url).hostname, + ) + + assert vm_client.account.get_address() == sender_address + await ws.send_str(sender_address) + + return ws + + app = web.Application() + app.router.add_route( + "GET", "/control/machine/{vm_id}/stream_logs", websocket_handler + ) # Update route to match the URL + + client = await aiohttp_client(app) + + node_url = str(client.make_url("")).rstrip("/") + + vm_client = VmClient( + account=account, + node_url=node_url, + session=client.session, + ) + + valid = False + + async for address in vm_client.get_logs(vm_id): + assert address == vm_client.account.get_address() + valid = True + + # this is done to ensure that the ws as runned at least once and avoid + # having silent errors + assert valid + + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_vm_client_generate_correct_authentication_headers(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + + path, headers = await vm_client._generate_header(vm_id, "reboot", method="post") + signed_pubkey = SignedPubKeyHeader.parse_raw(headers["X-SignedPubKey"]) + signed_operation = SignedOperation.parse_raw(headers["X-SignedOperation"]) + address = verify_signed_operation(signed_operation, signed_pubkey) + + assert vm_client.account.get_address() == address diff --git a/tests/unit/test_vm_confidential_client.py b/tests/unit/test_vm_confidential_client.py new file mode 100644 index 00000000..832871ff --- /dev/null +++ b/tests/unit/test_vm_confidential_client.py @@ -0,0 +1,216 @@ +import tempfile +from pathlib import Path +from unittest import mock +from unittest.mock import patch + +import aiohttp +import pytest +from aioresponses import aioresponses +from aleph_message.models import ItemHash + +from aleph.sdk.chains.ethereum import ETHAccount +from aleph.sdk.client.vm_confidential_client import VmConfidentialClient + + +@pytest.mark.asyncio +async def test_perform_confidential_operation(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/test" + + with aioresponses() as m: + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/{operation}", + status=200, + payload="mock_response_text", + ) + + response_text = await vm_client.perform_confidential_operation(vm_id, operation) + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_initialize_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/initialize" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file_bytes = Path(tmp_file.name).read_bytes() + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.post( + url, + status=200, + payload="mock_response_text", + ) + tmp_file_path = Path(tmp_file.name) + response_text = await vm_client.initialize( + vm_id, session=tmp_file_path, godh=tmp_file_path + ) + assert ( + response_text == '"mock_response_text"' + ) # ' ' cause by aioresponses + m.assert_called_once_with( + url, + method="POST", + data={ + "session": tmp_file_bytes, + "godh": tmp_file_bytes, + }, + json=None, + headers=headers, + ) + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_measurement_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/measurement" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.get( + url, + status=200, + payload=dict( + { + "sev_info": { + "enabled": True, + "api_major": 0, + "api_minor": 0, + "build_id": 0, + "policy": 0, + "state": "", + "handle": 0, + }, + "launch_measure": "test_measure", + } + ), + ) + measurement = await vm_client.measurement(vm_id) + assert measurement.launch_measure == "test_measure" + m.assert_called_once_with( + url, + method="GET", + headers=headers, + ) + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_inject_secret_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/inject_secret" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + test_secret = "test_secret" + packet_header = "test_packet_header" + + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.post( + url, + status=200, + payload="mock_response_text", + ) + response_text = await vm_client.inject_secret( + vm_id, secret=test_secret, packet_header=packet_header + ) + assert response_text == "mock_response_text" + m.assert_called_once_with( + url, + method="POST", + json={ + "secret": test_secret, + "packet_header": packet_header, + }, + headers=headers, + ) + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_create_session_command(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + node_url = "http://localhost" + sevctl_path = Path("/usr/bin/sevctl") + certificates_path = Path("/") + policy = 1 + + with mock.patch( + "aleph.sdk.client.vm_confidential_client.run_in_subprocess", + return_value=True, + ) as export_mock: + vm_client = VmConfidentialClient( + account=account, + sevctl_path=sevctl_path, + node_url=node_url, + session=aiohttp.ClientSession(), + ) + _ = await vm_client.create_session(vm_id, certificates_path, policy) + export_mock.assert_called_once_with( + [ + str(sevctl_path), + "session", + "--name", + str(vm_id), + str(certificates_path), + str(policy), + ], + check=True, + )