diff --git a/src/aleph/vm/conf.py b/src/aleph/vm/conf.py index a5e2ae6ce..f33e02f6c 100644 --- a/src/aleph/vm/conf.py +++ b/src/aleph/vm/conf.py @@ -282,11 +282,8 @@ class Settings(BaseSettings): # Settings to get from the network aggregates SETTINGS_AGGREGATE_ADDRESS: str = "0xFba561a84A537fCaa567bb7A2257e7142701ae2A" - COMMUNITY_WALLET_ADDRESS: str | None = None - COMPATIBLE_GPUS: List[dict[str, str]] = [] # Tests on programs - FAKE_DATA_PROGRAM: Path | None = None BENCHMARK_FAKE_DATA_PROGRAM = Path(abspath(join(__file__, "../../../../examples/example_fastapi"))) diff --git a/src/aleph/vm/orchestrator/tasks.py b/src/aleph/vm/orchestrator/tasks.py index 75fff2364..9819ffcf2 100644 --- a/src/aleph/vm/orchestrator/tasks.py +++ b/src/aleph/vm/orchestrator/tasks.py @@ -4,6 +4,7 @@ import math import time from collections.abc import AsyncIterable +from decimal import Decimal from typing import TypeVar import aiohttp @@ -19,6 +20,10 @@ from yarl import URL from aleph.vm.conf import settings +from aleph.vm.orchestrator.utils import ( + get_community_wallet_address, + is_after_community_wallet_start, +) from aleph.vm.pool import VmPool from aleph.vm.utils import create_task_log_exceptions @@ -35,6 +40,7 @@ logger = logging.getLogger(__name__) Value = TypeVar("Value") +COMMUNITY_STREAM_RATIO = Decimal(0.2) async def retry_generator(generator: AsyncIterable[Value], max_seconds: int = 8) -> AsyncIterable[Value]: @@ -154,6 +160,7 @@ async def monitor_payments(app: web.Application): try: logger.debug("Monitoring balances task running") await check_payment(pool) + logger.debug("Monitoring balances task ended") except Exception as e: # Catch all exceptions as to never stop the task. logger.warning(f"check_payment failed {e}", exc_info=True) @@ -191,31 +198,62 @@ async def check_payment(pool: VmPool): logger.debug(f"Stopping {last_execution} due to insufficient balance") await pool.stop_vm(last_execution.vm_hash) required_balance = await compute_required_balance(executions) + community_wallet = await get_community_wallet_address() + if not community_wallet: + logger.error("Monitor payment ERROR: No community wallet set. Cannot check community payment") # Check if the balance held in the wallet is sufficient stream tier resources for sender, chains in pool.get_executions_by_sender(payment_type=PaymentType.superfluid).items(): for chain, executions in chains.items(): try: stream = await get_stream(sender=sender, receiver=settings.PAYMENT_RECEIVER_ADDRESS, chain=chain) + logger.debug( - f"Get stream flow from Sender {sender} to Receiver {settings.PAYMENT_RECEIVER_ADDRESS} of {stream}" + f"Stream flow from {sender} to {settings.PAYMENT_RECEIVER_ADDRESS} = {stream} {chain.value}" ) + except ValueError as error: + logger.error(f"Error found getting stream for chain {chain} and sender {sender}: {error}") + continue + try: + community_stream = await get_stream(sender=sender, receiver=community_wallet, chain=chain) + logger.debug(f"Stream flow from {sender} to {community_wallet} (community) : {stream} {chain}") + except ValueError as error: logger.error(f"Error found getting stream for chain {chain} and sender {sender}: {error}") continue - required_stream = await compute_required_flow(executions) - logger.debug(f"Required stream for Sender {sender} executions: {required_stream}") - # Stop executions until the required stream is reached - while (stream + settings.PAYMENT_BUFFER) < required_stream: - try: - last_execution = executions.pop(-1) - except IndexError: # Empty list - logger.debug("No execution can be maintained due to insufficient stream") + while executions: + executions_with_community = [ + execution + for execution in executions + if await is_after_community_wallet_start(execution.times.started_at) + ] + + required_stream = await compute_required_flow(executions_with_community) + executions_without_community = [ + execution + for execution in executions + if not await is_after_community_wallet_start(execution.times.started_at) + ] + logger.info("flow community %s", executions_with_community) + logger.info("flow without community %s", executions_without_community) + required_stream_without_community = await compute_required_flow(executions_without_community) + + required_crn_stream = required_stream * (1 - COMMUNITY_STREAM_RATIO) + required_stream_without_community + required_community_stream = required_stream * COMMUNITY_STREAM_RATIO + logger.debug( + f"Stream for senders {sender} {len(executions)} executions. CRN : {stream} / {required_crn_stream}." + f"Community: {community_stream} / {required_community_stream}" + ) + # Can pay all executions + if (stream + settings.PAYMENT_BUFFER) > required_crn_stream and ( + community_stream + settings.PAYMENT_BUFFER + ) > required_community_stream: break - logger.debug(f"Stopping {last_execution} due to insufficient stream") + # Stop executions until the required stream is reached + last_execution = executions.pop(-1) + logger.info(f"Stopping {last_execution} of {sender} due to insufficient stream") await pool.stop_vm(last_execution.vm_hash) - required_stream = await compute_required_flow(executions) async def start_payment_monitoring_task(app: web.Application): diff --git a/src/aleph/vm/orchestrator/utils.py b/src/aleph/vm/orchestrator/utils.py index 99350946b..17dcbca03 100644 --- a/src/aleph/vm/orchestrator/utils.py +++ b/src/aleph/vm/orchestrator/utils.py @@ -1,11 +1,25 @@ -from typing import Any +from datetime import datetime, timedelta, timezone +from logging import getLogger +from typing import Any, TypedDict import aiohttp from aleph.vm.conf import settings +logger = getLogger(__name__) -async def fetch_aggregate_settings() -> dict[str, Any] | None: + +class AggregateSettingsDict(TypedDict): + compatible_gpus: list[Any] + community_wallet_address: str + community_wallet_timestamp: int + + +LAST_AGGREGATE_SETTINGS: AggregateSettingsDict | None = None +LAST_AGGREGATE_SETTINGS_FETCHED_AT: datetime | None = None + + +async def fetch_aggregate_settings() -> AggregateSettingsDict | None: """ Get the settings Aggregate dict from the PyAleph API Aggregate. @@ -17,6 +31,7 @@ async def fetch_aggregate_settings() -> dict[str, Any] | None: """ async with aiohttp.ClientSession() as session: url = f"{settings.API_SERVER}/api/v0/aggregates/{settings.SETTINGS_AGGREGATE_ADDRESS}.json?keys=settings" + logger.info(f"Fetching settings aggregate from {url}") resp = await session.get(url) # Raise an error if the request failed @@ -27,7 +42,61 @@ async def fetch_aggregate_settings() -> dict[str, Any] | None: async def update_aggregate_settings(): - aggregate_settings = await fetch_aggregate_settings() - if aggregate_settings: - settings.COMPATIBLE_GPUS = aggregate_settings["compatible_gpus"] - settings.COMMUNITY_WALLET_ADDRESS = aggregate_settings["community_wallet_address"] + global LAST_AGGREGATE_SETTINGS # noqa: PLW0603 + global LAST_AGGREGATE_SETTINGS_FETCHED_AT # noqa: PLW0603 + + LAST_AGGREGATE_SETTINGS = await fetch_aggregate_settings() + if ( + not LAST_AGGREGATE_SETTINGS + or LAST_AGGREGATE_SETTINGS_FETCHED_AT + and datetime.now(tz=timezone.utc) - LAST_AGGREGATE_SETTINGS_FETCHED_AT > timedelta(minutes=1) + ): + try: + aggregate = await fetch_aggregate_settings() + LAST_AGGREGATE_SETTINGS = aggregate + LAST_AGGREGATE_SETTINGS_FETCHED_AT = datetime.now(tz=timezone.utc) + + except Exception: + logger.exception("Failed to fetch aggregate settings") + + +async def get_aggregate_settings() -> AggregateSettingsDict | None: + """The settings aggregate is a special aggregate used to share some common settings for VM setup + + Ensure the cached version is up to date and return it""" + await update_aggregate_settings() + + if not LAST_AGGREGATE_SETTINGS: + logger.error("No setting aggregate") + return LAST_AGGREGATE_SETTINGS + + +async def get_community_wallet_address() -> str | None: + setting_aggr = await get_aggregate_settings() + return setting_aggr and setting_aggr.get("community_wallet_address") + + +async def get_community_wallet_start() -> datetime: + """Community wallet start time. + + After this timestamp. New PAYG must include a payment to the community wallet""" + setting_aggr = await get_aggregate_settings() + if setting_aggr is None or "community_wallet_timestamp" not in setting_aggr: + return datetime.now(tz=timezone.utc) + timestamp = setting_aggr["community_wallet_timestamp"] + start_datetime = datetime.fromtimestamp(timestamp, tz=timezone.utc) + return start_datetime + + +async def is_after_community_wallet_start(dt: datetime | None = None) -> bool: + """Community wallet start time""" + if not dt: + dt = datetime.now(tz=timezone.utc) + start_dt = await get_community_wallet_start() + return dt > start_dt + + +def get_compatible_gpus() -> list[Any]: + if not LAST_AGGREGATE_SETTINGS: + return [] + return LAST_AGGREGATE_SETTINGS["compatible_gpus"] diff --git a/src/aleph/vm/orchestrator/views/__init__.py b/src/aleph/vm/orchestrator/views/__init__.py index 899a038f8..6e9460d1c 100644 --- a/src/aleph/vm/orchestrator/views/__init__.py +++ b/src/aleph/vm/orchestrator/views/__init__.py @@ -1,5 +1,4 @@ import binascii -import contextlib import logging from decimal import Decimal from hashlib import sha256 @@ -8,7 +7,6 @@ from pathlib import Path from secrets import compare_digest from string import Template -from typing import Optional import aiodns import aiohttp @@ -26,7 +24,7 @@ from aleph.vm.controllers.firecracker.program import FileTooLargeError from aleph.vm.hypervisors.firecracker.microvm import MicroVMFailedInitError from aleph.vm.orchestrator import payment, status -from aleph.vm.orchestrator.chain import STREAM_CHAINS, ChainInfo +from aleph.vm.orchestrator.chain import STREAM_CHAINS from aleph.vm.orchestrator.custom_logs import set_vm_for_logging from aleph.vm.orchestrator.messages import try_get_message from aleph.vm.orchestrator.metrics import get_execution_records @@ -39,6 +37,12 @@ from aleph.vm.orchestrator.pubsub import PubSub from aleph.vm.orchestrator.resources import Allocation, VMNotification from aleph.vm.orchestrator.run import run_code_on_request, start_persistent_vm +from aleph.vm.orchestrator.tasks import COMMUNITY_STREAM_RATIO +from aleph.vm.orchestrator.utils import ( + get_community_wallet_address, + is_after_community_wallet_start, + update_aggregate_settings, +) from aleph.vm.orchestrator.views.host_status import ( check_dns_ipv4, check_dns_ipv6, @@ -468,6 +472,7 @@ async def update_allocations(request: web.Request): @cors_allow_all async def notify_allocation(request: web.Request): """Notify instance allocation, only used for Pay as you Go feature""" + await update_aggregate_settings() try: data = await request.json() vm_notification = VMNotification.parse_obj(data) @@ -526,16 +531,44 @@ async def notify_allocation(request: web.Request): raise web.HTTPPaymentRequired(reason="Empty payment stream for this instance") required_flow: Decimal = await fetch_execution_flow_price(item_hash) - - if active_flow < required_flow: + community_wallet = await get_community_wallet_address() + required_crn_stream: Decimal + required_community_stream: Decimal + if await is_after_community_wallet_start() and community_wallet: + required_crn_stream = required_flow * (1 - COMMUNITY_STREAM_RATIO) + required_community_stream = required_flow * COMMUNITY_STREAM_RATIO + else: # No community wallet payment + required_crn_stream = required_flow + required_community_stream = Decimal(0) + + if active_flow < (required_crn_stream - settings.PAYMENT_BUFFER): active_flow_per_month = active_flow * 60 * 60 * 24 * (Decimal("30.41666666666923904761904784")) - required_flow_per_month = required_flow * 60 * 60 * 24 * Decimal("30.41666666666923904761904784") + required_flow_per_month = required_crn_stream * 60 * 60 * 24 * Decimal("30.41666666666923904761904784") return web.HTTPPaymentRequired( reason="Insufficient payment stream", text="Insufficient payment stream for this instance\n\n" - f"Required: {required_flow_per_month} / month (flow = {required_flow})\n" + f"Required: {required_flow_per_month} / month (flow = {required_crn_stream})\n" f"Present: {active_flow_per_month} / month (flow = {active_flow})", ) + + if community_wallet and required_community_stream: + community_flow: Decimal = await get_stream( + sender=message.sender, + receiver=community_wallet, + chain=message.content.payment.chain, + ) + if community_flow < (required_community_stream - settings.PAYMENT_BUFFER): + active_flow_per_month = community_flow * 60 * 60 * 24 * (Decimal("30.41666666666923904761904784")) + required_flow_per_month = ( + required_community_stream * 60 * 60 * 24 * Decimal("30.41666666666923904761904784") + ) + return web.HTTPPaymentRequired( + reason="Insufficient payment stream to community", + text="Insufficient payment stream for community \n\n" + f"Required: {required_flow_per_month} / month (flow = {required_community_stream})\n" + f"Present: {active_flow_per_month} / month (flow = {community_flow})\n" + f"Address: {community_wallet}", + ) else: return web.HTTPBadRequest(reason="Invalid payment method") diff --git a/src/aleph/vm/resources.py b/src/aleph/vm/resources.py index ea03aac9a..4776c254c 100644 --- a/src/aleph/vm/resources.py +++ b/src/aleph/vm/resources.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Extra, Field from aleph.vm.conf import settings +from aleph.vm.orchestrator.utils import get_compatible_gpus class HostGPU(BaseModel): @@ -60,7 +61,7 @@ def is_gpu_device_class(device_class: str) -> bool: def get_gpu_model(device_id: str) -> bool | None: """Returns a GPU model name if it's found from the compatible ones.""" - model_gpu_set = {gpu["device_id"]: gpu["model"] for gpu in settings.COMPATIBLE_GPUS} + model_gpu_set = {gpu["device_id"]: gpu["model"] for gpu in get_compatible_gpus()} try: return model_gpu_set[device_id] except KeyError: @@ -69,7 +70,7 @@ def get_gpu_model(device_id: str) -> bool | None: def is_gpu_compatible(device_id: str) -> bool: """Checks if a GPU is compatible based on vendor and model IDs.""" - compatible_gpu_set = {gpu["device_id"] for gpu in settings.COMPATIBLE_GPUS} + compatible_gpu_set = {gpu["device_id"] for gpu in get_compatible_gpus()} return device_id in compatible_gpu_set diff --git a/tests/supervisor/test_checkpayment.py b/tests/supervisor/test_checkpayment.py new file mode 100644 index 000000000..3671114de --- /dev/null +++ b/tests/supervisor/test_checkpayment.py @@ -0,0 +1,226 @@ +import asyncio + +import pytest +from aleph_message.models import Chain, InstanceContent, PaymentType +from aleph_message.status import MessageStatus + +from aleph.vm.conf import settings +from aleph.vm.models import VmExecution +from aleph.vm.orchestrator.tasks import check_payment +from aleph.vm.pool import VmPool + + +@pytest.fixture() +def fake_instance_content(): + fake = { + "address": "0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9", + "time": 1713874241.800818, + "allow_amend": False, + "metadata": None, + "authorized_keys": None, + "variables": None, + "environment": {"reproducible": False, "internet": True, "aleph_api": True, "shared_cache": False}, + "resources": {"vcpus": 1, "memory": 256, "seconds": 30, "published_ports": None}, + "payment": {"type": "superfluid", "chain": "BASE"}, + "requirements": None, + "replaces": None, + "rootfs": { + "parent": {"ref": "63f07193e6ee9d207b7d1fcf8286f9aee34e6f12f101d2ec77c1229f92964696"}, + "ref": "63f07193e6ee9d207b7d1fcf8286f9aee34e6f12f101d2ec77c1229f92964696", + "use_latest": True, + "comment": "", + "persistence": "host", + "size_mib": 1000, + }, + } + + return fake + + +@pytest.mark.asyncio +async def test_enough_flow(mocker, fake_instance_content): + """Execution with community flow + + Cost 500 + Community 100 + CRN 400 + Both Flow are 500. + Should not stop + + """ + mocker.patch.object(settings, "ALLOW_VM_NETWORKING", False) + mocker.patch.object(settings, "PAYMENT_RECEIVER_ADDRESS", "0xD39C335404a78E0BDCf6D50F29B86EFd57924288") + mock_community_wallet_address = "0x23C7A99d7AbebeD245d044685F1893aeA4b5Da90" + mocker.patch("aleph.vm.orchestrator.tasks.get_community_wallet_address", return_value=mock_community_wallet_address) + mocker.patch("aleph.vm.orchestrator.tasks.is_after_community_wallet_start", return_value=True) + + loop = asyncio.get_event_loop() + pool = VmPool(loop=loop) + mocker.patch("aleph.vm.orchestrator.tasks.get_stream", return_value=400, autospec=True) + mocker.patch("aleph.vm.orchestrator.tasks.get_message_status", return_value=MessageStatus.PROCESSED) + + async def compute_required_flow(executions): + return 500 * len(executions) + + mocker.patch("aleph.vm.orchestrator.tasks.compute_required_flow", compute_required_flow) + message = InstanceContent.parse_obj(fake_instance_content) + + hash = "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca" + + mocker.patch.object(VmExecution, "is_running", new=True) + mocker.patch.object(VmExecution, "stop", new=mocker.AsyncMock(return_value=False)) + + execution = VmExecution( + vm_hash=hash, + message=message, + original=message, + persistent=False, + snapshot_manager=None, + systemd_manager=None, + ) + assert execution.times.started_at is None + + pool.executions = {hash: execution} + + executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid) + assert len(executions_by_sender) == 1 + assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}} + + await check_payment(pool=pool) + assert pool.executions == {hash: execution} + execution.stop.assert_not_called() + + +@pytest.mark.asyncio +async def test_enough_flow_not_community(mocker, fake_instance_content): + """Execution without community flow + + Cost 500 + Community 0 + CRN 500 + Both Flow are 500. + Should not stop + + """ + mocker.patch.object(settings, "ALLOW_VM_NETWORKING", False) + mocker.patch.object(settings, "PAYMENT_RECEIVER_ADDRESS", "0xD39C335404a78E0BDCf6D50F29B86EFd57924288") + mock_community_wallet_address = "0x23C7A99d7AbebeD245d044685F1893aeA4b5Da90" + mocker.patch("aleph.vm.orchestrator.tasks.get_community_wallet_address", return_value=mock_community_wallet_address) + mocker.patch("aleph.vm.orchestrator.tasks.is_after_community_wallet_start", return_value=False) + + loop = asyncio.get_event_loop() + pool = VmPool(loop=loop) + mocker.patch("aleph.vm.orchestrator.tasks.get_stream", return_value=500, autospec=True) + mocker.patch("aleph.vm.orchestrator.tasks.get_message_status", return_value=MessageStatus.PROCESSED) + + async def compute_required_flow(executions): + return 500 * len(executions) + + mocker.patch("aleph.vm.orchestrator.tasks.compute_required_flow", compute_required_flow) + message = InstanceContent.parse_obj(fake_instance_content) + + hash = "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca" + + mocker.patch.object(VmExecution, "is_running", new=True) + mocker.patch.object(VmExecution, "stop", new=mocker.AsyncMock(return_value=False)) + + execution = VmExecution( + vm_hash=hash, + message=message, + original=message, + persistent=False, + snapshot_manager=None, + systemd_manager=None, + ) + assert execution.times.started_at is None + + pool.executions = {hash: execution} + + executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid) + assert len(executions_by_sender) == 1 + assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}} + + await check_payment(pool=pool) + assert pool.executions == {hash: execution} + execution.stop.assert_not_called() + + +@pytest.mark.asyncio +async def test_not_enough_flow(mocker, fake_instance_content): + mocker.patch.object(settings, "ALLOW_VM_NETWORKING", False) + mocker.patch.object(settings, "PAYMENT_RECEIVER_ADDRESS", "0xD39C335404a78E0BDCf6D50F29B86EFd57924288") + mock_community_wallet_address = "0x23C7A99d7AbebeD245d044685F1893aeA4b5Da90" + mocker.patch("aleph.vm.orchestrator.tasks.get_community_wallet_address", return_value=mock_community_wallet_address) + + loop = asyncio.get_event_loop() + pool = VmPool(loop=loop) + mocker.patch("aleph.vm.orchestrator.tasks.get_stream", return_value=2, autospec=True) + mocker.patch("aleph.vm.orchestrator.tasks.get_message_status", return_value=MessageStatus.PROCESSED) + mocker.patch("aleph.vm.orchestrator.tasks.compute_required_flow", return_value=5) + message = InstanceContent.parse_obj(fake_instance_content) + + mocker.patch.object(VmExecution, "is_running", new=True) + mocker.patch.object(VmExecution, "stop", new=mocker.AsyncMock(return_value=False)) + hash = "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca" + execution = VmExecution( + vm_hash=hash, + message=message, + original=message, + persistent=False, + snapshot_manager=None, + systemd_manager=None, + ) + + pool.executions = {hash: execution} + + executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid) + assert len(executions_by_sender) == 1 + assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}} + + await check_payment(pool=pool) + + execution.stop.assert_called_with() + + +@pytest.mark.asyncio +async def test_not_enough_community_flow(mocker, fake_instance_content): + mocker.patch.object(settings, "ALLOW_VM_NETWORKING", False) + mocker.patch.object(settings, "PAYMENT_RECEIVER_ADDRESS", "0xD39C335404a78E0BDCf6D50F29B86EFd57924288") + + loop = asyncio.get_event_loop() + pool = VmPool(loop=loop) + mock_community_wallet_address = "0x23C7A99d7AbebeD245d044685F1893aeA4b5Da90" + + async def get_stream(sender, receiver, chain): + if receiver == mock_community_wallet_address: + return 0 + elif receiver == settings.PAYMENT_RECEIVER_ADDRESS: + return 10 + + mocker.patch("aleph.vm.orchestrator.tasks.get_stream", new=get_stream) + mocker.patch("aleph.vm.orchestrator.tasks.get_community_wallet_address", return_value=mock_community_wallet_address) + mocker.patch("aleph.vm.orchestrator.tasks.get_message_status", return_value=MessageStatus.PROCESSED) + mocker.patch("aleph.vm.orchestrator.tasks.compute_required_flow", return_value=5) + message = InstanceContent.parse_obj(fake_instance_content) + + mocker.patch.object(VmExecution, "is_running", new=True) + mocker.patch.object(VmExecution, "stop", new=mocker.AsyncMock(return_value=False)) + hash = "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca" + execution = VmExecution( + vm_hash=hash, + message=message, + original=message, + persistent=False, + snapshot_manager=None, + systemd_manager=None, + ) + + pool.executions = {hash: execution} + + executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid) + assert len(executions_by_sender) == 1 + assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}} + + await check_payment(pool=pool) + + execution.stop.assert_called_with()