Skip to content

Commit 60a4b41

Browse files
nesitorolethanh
authored andcommitted
Fix: Refactor to unify getting execution price method making it compatible with multiple payment methods.
1 parent 8bf9c51 commit 60a4b41

File tree

2 files changed

+26
-50
lines changed

2 files changed

+26
-50
lines changed

src/aleph/vm/orchestrator/payment.py

Lines changed: 23 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
from collections.abc import Iterable
44
from decimal import Decimal
5+
from typing import List
56

67
import aiohttp
78
from aleph_message.models import ItemHash, PaymentType
@@ -70,48 +71,9 @@ async def fetch_credit_balance_of_address(address: str) -> Decimal:
7071
return resp_data["credits"]
7172

7273

73-
async def fetch_execution_flow_price(item_hash: ItemHash) -> Decimal:
74-
"""Fetch the flow price of an execution from the reference API server."""
75-
async with aiohttp.ClientSession() as session:
76-
url = f"{settings.API_SERVER}/api/v0/price/{item_hash}"
77-
resp = await session.get(url)
78-
# Raise an error if the request failed
79-
resp.raise_for_status()
80-
81-
resp_data = await resp.json()
82-
required_flow: float = resp_data["required_tokens"]
83-
payment_type: str | None = resp_data["payment_type"]
84-
85-
if payment_type is None:
86-
msg = "Payment type must be specified in the message"
87-
raise ValueError(msg)
88-
elif payment_type != PaymentType.superfluid:
89-
msg = f"Payment type {payment_type} is not supported"
90-
raise ValueError(msg)
91-
92-
return Decimal(required_flow)
93-
94-
95-
async def fetch_execution_hold_price(item_hash: ItemHash) -> Decimal:
96-
"""Fetch the hold price of an execution from the reference API server."""
97-
async with aiohttp.ClientSession() as session:
98-
url = f"{settings.API_SERVER}/api/v0/price/{item_hash}"
99-
resp = await session.get(url)
100-
# Raise an error if the request failed
101-
resp.raise_for_status()
102-
103-
resp_data = await resp.json()
104-
required_hold: float = resp_data["required_tokens"]
105-
payment_type: str | None = resp_data["payment_type"]
106-
107-
if payment_type not in (None, PaymentType.hold):
108-
msg = f"Payment type {payment_type} is not supported"
109-
raise ValueError(msg)
110-
111-
return Decimal(required_hold)
112-
113-
114-
async def fetch_execution_credit_price(item_hash: ItemHash) -> Decimal:
74+
async def fetch_execution_price(
75+
item_hash: ItemHash, allowed_payments: List[PaymentType], payment_type_required: bool = True
76+
) -> Decimal:
11577
"""Fetch the credit price of an execution from the reference API server."""
11678
async with aiohttp.ClientSession() as session:
11779
url = f"{settings.API_SERVER}/api/v0/price/{item_hash}"
@@ -123,10 +85,15 @@ async def fetch_execution_credit_price(item_hash: ItemHash) -> Decimal:
12385
required_credits: float = resp_data["required_credits"] # Field not defined yet on API side.
12486
payment_type: str | None = resp_data["payment_type"]
12587

126-
if payment_type not in (None, PaymentType.credit):
127-
msg = f"Payment type {payment_type} is not supported"
88+
if payment_type_required and payment_type is None:
89+
msg = "Payment type must be specified in the message"
12890
raise ValueError(msg)
12991

92+
if payment_type:
93+
if payment_type not in allowed_payments:
94+
msg = f"Payment type {payment_type} is not supported"
95+
raise ValueError(msg)
96+
13097
return Decimal(required_credits)
13198

13299

@@ -178,17 +145,26 @@ async def get_stream(sender: str, receiver: str, chain: str) -> Decimal:
178145

179146
async def compute_required_balance(executions: Iterable[VmExecution]) -> Decimal:
180147
"""Get the balance required for the resources of the user from the messages and the pricing aggregate."""
181-
costs = await asyncio.gather(*(fetch_execution_hold_price(execution.vm_hash) for execution in executions))
148+
costs = await asyncio.gather(
149+
*(
150+
fetch_execution_price(execution.vm_hash, [PaymentType.hold], payment_type_required=False)
151+
for execution in executions
152+
)
153+
)
182154
return sum(costs, Decimal(0))
183155

184156

185157
async def compute_required_credit_balance(executions: Iterable[VmExecution]) -> Decimal:
186158
"""Get the balance required for the resources of the user from the messages and the pricing aggregate."""
187-
costs = await asyncio.gather(*(fetch_execution_credit_price(execution.vm_hash) for execution in executions))
159+
costs = await asyncio.gather(
160+
*(fetch_execution_price(execution.vm_hash, [PaymentType.credit]) for execution in executions)
161+
)
188162
return sum(costs, Decimal(0))
189163

190164

191165
async def compute_required_flow(executions: Iterable[VmExecution]) -> Decimal:
192166
"""Compute the flow required for a collection of executions, typically all executions from a specific address"""
193-
flows = await asyncio.gather(*(fetch_execution_flow_price(execution.vm_hash) for execution in executions))
167+
flows = await asyncio.gather(
168+
*(fetch_execution_price(execution.vm_hash, [PaymentType.superfluid]) for execution in executions)
169+
)
194170
return sum(flows, Decimal(0))

src/aleph/vm/orchestrator/views/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from aleph.vm.orchestrator.payment import (
3636
InvalidAddressError,
3737
InvalidChainError,
38-
fetch_execution_flow_price,
38+
fetch_execution_price,
3939
get_stream,
4040
)
4141
from aleph.vm.orchestrator.pubsub import PubSub
@@ -577,7 +577,7 @@ async def notify_allocation(request: web.Request):
577577
if have_gpu:
578578
logger.debug(f"GPU Instance {item_hash} not using PAYG")
579579
user_balance = await payment.fetch_balance_of_address(message.sender)
580-
hold_price = await payment.fetch_execution_hold_price(item_hash)
580+
hold_price = await payment.fetch_execution_price(item_hash, [PaymentType.hold], False)
581581
logger.debug(f"Address {message.sender} Balance: {user_balance}, Price: {hold_price}")
582582
if hold_price > user_balance:
583583
return web.HTTPPaymentRequired(
@@ -606,7 +606,7 @@ async def notify_allocation(request: web.Request):
606606
if not active_flow:
607607
raise web.HTTPPaymentRequired(reason="Empty payment stream for this instance")
608608

609-
required_flow: Decimal = await fetch_execution_flow_price(item_hash)
609+
required_flow: Decimal = await fetch_execution_price(item_hash, [PaymentType.superfluid])
610610
community_wallet = await get_community_wallet_address()
611611
required_crn_stream: Decimal
612612
required_community_stream: Decimal

0 commit comments

Comments
 (0)