Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/aleph/vm/controllers/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ async def execute_persistent_vm(config: Configuration):
assert isinstance(config.vm_configuration, VMConfiguration)
execution = MicroVM(
vm_id=config.vm_id,
vm_hash=config.vm_hash,
firecracker_bin_path=config.vm_configuration.firecracker_bin_path,
jailer_base_directory=config.settings.JAILER_BASE_DIR,
use_jailer=config.vm_configuration.use_jailer,
Expand All @@ -72,7 +73,7 @@ async def execute_persistent_vm(config: Configuration):
process = await execution.start(config.vm_configuration.config_file_path)
else:
assert isinstance(config.vm_configuration, QemuVMConfiguration)
execution = QemuVM(config.vm_configuration)
execution = QemuVM(config.vm_hash, config.vm_configuration)
process = await execution.start()

return execution, process
Expand All @@ -83,9 +84,6 @@ async def handle_persistent_vm(config: Configuration, execution: Union[MicroVM,
loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGTERM, execution.send_shutdown_message)

if config.settings.PRINT_SYSTEM_LOGS:
execution.start_printing_logs()

await process.wait()
logger.info(f"Process terminated with {process.returncode}")

Expand Down
1 change: 1 addition & 0 deletions src/aleph/vm/controllers/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class HypervisorType(str, Enum):

class Configuration(BaseModel):
vm_id: int
vm_hash: str
settings: Settings
vm_configuration: Union[QemuVMConfiguration, VMConfiguration]
hypervisor: HypervisorType = HypervisorType.firecracker
Expand Down
21 changes: 3 additions & 18 deletions src/aleph/vm/controllers/firecracker/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,13 @@ def __init__(

self.fvm = MicroVM(
vm_id=self.vm_id,
vm_hash=vm_hash,
firecracker_bin_path=settings.FIRECRACKER_PATH,
jailer_base_directory=settings.JAILER_BASE_DIR,
use_jailer=settings.USE_JAILER,
jailer_bin_path=settings.JAILER_PATH,
init_timeout=settings.INIT_TIMEOUT,
enable_log=enable_console,
)
if prepare_jailer:
self.fvm.prepare_jailer()
Expand Down Expand Up @@ -259,9 +261,6 @@ async def start(self):
await self.tap_interface.delete()
raise

if self.enable_console:
self.fvm.start_printing_logs()

await self.wait_for_init()
logger.debug(f"started fvm {self.vm_id}")
await self.load_configuration()
Expand All @@ -285,6 +284,7 @@ async def configure(self):

configuration = Configuration(
vm_id=self.vm_id,
vm_hash=self.vm_hash,
settings=settings,
vm_configuration=vm_configuration,
)
Expand Down Expand Up @@ -330,18 +330,3 @@ async def teardown(self):

async def create_snapshot(self) -> CompressedDiskVolumeSnapshot:
raise NotImplementedError()

def get_log_queue(self) -> asyncio.Queue:
queue: asyncio.Queue = asyncio.Queue(maxsize=1000)
# Limit the number of queues per VM

if len(self.fvm.log_queues) > 20:
logger.warning("Too many log queues, dropping the oldest one")
self.fvm.log_queues.pop(0)
self.fvm.log_queues.append(queue)
return queue

def unregister_queue(self, queue: asyncio.Queue):
if queue in self.fvm.log_queues:
self.fvm.log_queues.remove(queue)
queue.empty()
36 changes: 30 additions & 6 deletions src/aleph/vm/controllers/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,22 @@
from abc import ABC
from asyncio.subprocess import Process
from collections.abc import Coroutine
from typing import Any, Optional
from typing import Any, Callable, Optional

from aleph_message.models import ItemHash
from aleph_message.models.execution.environment import MachineResources

from aleph.vm.controllers.firecracker.snapshots import CompressedDiskVolumeSnapshot
from aleph.vm.network.interfaces import TapInterface
from aleph.vm.utils.logs import make_logs_queue

logger = logging.getLogger(__name__)


class AlephVmControllerInterface(ABC):
log_queues: list[asyncio.Queue] = []
_queue_cancellers: dict[asyncio.Queue, Callable] = {}

vm_id: int
"""id in the VMPool, attributed at execution"""
vm_hash: ItemHash
Expand Down Expand Up @@ -89,8 +93,28 @@ async def create_snapshot(self) -> CompressedDiskVolumeSnapshot:
"""Must be implement if self.support_snapshot is True"""
raise NotImplementedError()

async def get_log_queue(self) -> asyncio.Queue:
raise NotImplementedError()

async def unregister_queue(self, queue: asyncio.Queue):
raise NotImplementedError()
def get_log_queue(self) -> asyncio.Queue:
queue, canceller = make_logs_queue(self._journal_stdout_name, self._journal_stderr_name)
self._queue_cancellers[queue] = canceller
# Limit the number of queues per VM
# TODO : fix
if len(self.log_queues) > 20:
logger.warning("Too many log queues, dropping the oldest one")
self.unregister_queue(self.log_queues[1])
self.log_queues.append(queue)
return queue

def unregister_queue(self, queue: asyncio.Queue) -> None:
if queue in self.log_queues:
self._queue_cancellers[queue]()
del self._queue_cancellers[queue]
self.log_queues.remove(queue)
queue.empty()

@property
def _journal_stdout_name(self) -> str:
return f"vm-{self.vm_hash}-stdout"

@property
def _journal_stderr_name(self) -> str:
return f"vm-{self.vm_hash}-stderr"
8 changes: 5 additions & 3 deletions src/aleph/vm/controllers/qemu/QEMU.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,11 @@ import aiohttp
def on_message(content):
try:
msg = json.loads(content)
fd = sys.stderr if msg["type"] == "stderr" else sys.stdout
print("<", msg["message"], file=fd, end="")
if msg.get('status'):
print(msg)
else:
fd = sys.stderr if msg["type"] == "stderr" else sys.stdout
print("<", msg["message"], file=fd, end="")
except:
print("unable to parse", content)

Expand All @@ -125,7 +128,6 @@ async def tail_websocket(url):
break
elif msg.type == aiohttp.WSMsgType.ERROR:
print("Error", msg)
break


vm_hash = "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca"
Expand Down
109 changes: 8 additions & 101 deletions src/aleph/vm/controllers/qemu/instance.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import asyncio
import json
import logging
import shutil
import sys
from asyncio import Task
from asyncio.subprocess import Process
from pathlib import Path
from typing import Callable, Generic, Optional, TypedDict, TypeVar, Union
from typing import Generic, Optional, TypeVar, Union

import psutil
from aleph_message.models import ItemHash
from aleph_message.models.execution.environment import MachineResources
from aleph_message.models.execution.instance import RootfsVolume
from aleph_message.models.execution.volume import PersistentVolume, VolumePersistence
from systemd import journal

from aleph.vm.conf import settings
from aleph.vm.controllers.configuration import (
Expand Down Expand Up @@ -87,61 +84,10 @@ async def make_writable_volume(self, parent_image_path, volume: Union[Persistent
ConfigurationType = TypeVar("ConfigurationType")


class EntryDict(TypedDict):
SYSLOG_IDENTIFIER: str
MESSAGE: str


def make_logs_queue(stdout_identifier, stderr_identifier, skip_past=True) -> tuple[asyncio.Queue, Callable[[], None]]:
"""Create a queue which streams the logs for the process.

@param stdout_identifier: journald identifier for process stdout
@param stderr_identifier: journald identifier for process stderr
@param skip_past: Skip past history.
@return: queue and function to cancel the queue.

The consumer is required to call the queue cancel function when it's done consuming the queue.

Works by creating a journald reader, and using `add_reader` to call a callback when
data is available for reading.
In the callback we check the message type and fill the queue accordingly

For more information refer to the sd-journal(3) manpage
and systemd.journal module documentation.
"""
r = journal.Reader()
r.add_match(SYSLOG_IDENTIFIER=stdout_identifier)
r.add_match(SYSLOG_IDENTIFIER=stderr_identifier)
queue: asyncio.Queue = asyncio.Queue(maxsize=1000)

def _ready_for_read() -> None:
change_type = r.process() # reset fd status
if change_type != journal.APPEND:
return
entry: EntryDict
for entry in r:
log_type = "stdout" if entry["SYSLOG_IDENTIFIER"] == stdout_identifier else "stderr"
msg = entry["MESSAGE"]
asyncio.create_task(queue.put((log_type, msg)))

if skip_past:
r.seek_tail()

loop = asyncio.get_event_loop()
loop.add_reader(r.fileno(), _ready_for_read)

def do_cancel():
loop.remove_reader(r.fileno())
r.close()

return queue, do_cancel


class AlephQemuInstance(Generic[ConfigurationType], CloudInitMixin, AlephVmControllerInterface):
vm_id: int
vm_hash: ItemHash
resources: AlephQemuResources
enable_console: bool
enable_networking: bool
hardware_resources: MachineResources
tap_interface: Optional[TapInterface] = None
Expand All @@ -151,7 +97,6 @@ class AlephQemuInstance(Generic[ConfigurationType], CloudInitMixin, AlephVmContr
support_snapshot = False
qmp_socket_path = None
persistent = True
_queue_cancellers: dict[asyncio.Queue, Callable] = {}
controller_configuration: Configuration

def __repr__(self):
Expand All @@ -166,21 +111,18 @@ def __init__(
vm_hash: ItemHash,
resources: AlephQemuResources,
enable_networking: bool = False,
enable_console: Optional[bool] = None,
hardware_resources: MachineResources = MachineResources(),
tap_interface: Optional[TapInterface] = None,
):
self.vm_id = vm_id
self.vm_hash = vm_hash
self.resources = resources
if enable_console is None:
enable_console = settings.PRINT_SYSTEM_LOGS
self.enable_console = enable_console
self.enable_networking = enable_networking and settings.ALLOW_VM_NETWORKING
self.hardware_resources = hardware_resources
self.tap_interface = tap_interface
self.qemu_process = None

self.vm_hash = vm_hash

# TODO : wait for andress soltion for pid handling
def to_dict(self):
"""Dict representation of the virtual machine. Used to record resource usage and for JSON serialization."""
Expand Down Expand Up @@ -244,7 +186,11 @@ async def configure(self):
)

configuration = Configuration(
vm_id=self.vm_id, settings=settings, vm_configuration=vm_configuration, hypervisor=HypervisorType.qemu
vm_id=self.vm_id,
vm_hash=self.vm_hash,
settings=settings,
vm_configuration=vm_configuration,
hypervisor=HypervisorType.qemu,
)

save_controller_configuration(self.vm_hash, configuration)
Expand All @@ -256,14 +202,6 @@ def save_controller_configuration(self):
path.chmod(0o644)
return path

@property
def _journal_stdout_name(self) -> str:
return f"vm-{self.vm_hash}-stdout"

@property
def _journal_stderr_name(self) -> str:
return f"vm-{self.vm_hash}-stderr"

async def start(self):
# Start via systemd not here
raise NotImplementedError()
Expand Down Expand Up @@ -299,7 +237,6 @@ async def stop_guest_api(self):
pass

print_task: Optional[Task] = None
log_queues: list[asyncio.Queue] = []

async def teardown(self):
if self.print_task:
Expand All @@ -314,33 +251,3 @@ async def teardown(self):
def print_logs(self) -> None:
"""Print logs to our output for debugging"""
queue = self.get_log_queue()

async def print_logs():
try:
while True:
log_type, message = await queue.get()
fd = sys.stderr if log_type == "stderr" else sys.stdout
print(self, message, file=fd)
finally:
self.unregister_queue(queue)

loop = asyncio.get_running_loop()
self.print_task = loop.create_task(print_logs(), name=f"{self}-print-logs")

def get_log_queue(self) -> asyncio.Queue:
queue, canceller = make_logs_queue(self._journal_stdout_name, self._journal_stderr_name)
self._queue_cancellers[queue] = canceller
# Limit the number of queues per VM
# TODO : fix
if len(self.log_queues) > 20:
logger.warning("Too many log queues, dropping the oldest one")
self.unregister_queue(self.log_queues[1])
self.log_queues.append(queue)
return queue

def unregister_queue(self, queue: asyncio.Queue) -> None:
if queue in self.log_queues:
self._queue_cancellers[queue]()
del self._queue_cancellers[queue]
self.log_queues.remove(queue)
queue.empty()
Loading