22import logging
33from collections .abc import Iterable
44from decimal import Decimal
5+ from typing import List
56
67import aiohttp
78from aleph_message .models import ItemHash , PaymentType
@@ -70,48 +71,9 @@ async def fetch_credit_balance_of_address(address: str) -> Decimal:
7071 return resp_data ["credits" ]
7172
7273
73- async def fetch_execution_flow_price (item_hash : ItemHash ) -> Decimal :
74- """Fetch the flow price of an execution from the reference API server."""
75- async with aiohttp .ClientSession () as session :
76- url = f"{ settings .API_SERVER } /api/v0/price/{ item_hash } "
77- resp = await session .get (url )
78- # Raise an error if the request failed
79- resp .raise_for_status ()
80-
81- resp_data = await resp .json ()
82- required_flow : float = resp_data ["required_tokens" ]
83- payment_type : str | None = resp_data ["payment_type" ]
84-
85- if payment_type is None :
86- msg = "Payment type must be specified in the message"
87- raise ValueError (msg )
88- elif payment_type != PaymentType .superfluid :
89- msg = f"Payment type { payment_type } is not supported"
90- raise ValueError (msg )
91-
92- return Decimal (required_flow )
93-
94-
95- async def fetch_execution_hold_price (item_hash : ItemHash ) -> Decimal :
96- """Fetch the hold price of an execution from the reference API server."""
97- async with aiohttp .ClientSession () as session :
98- url = f"{ settings .API_SERVER } /api/v0/price/{ item_hash } "
99- resp = await session .get (url )
100- # Raise an error if the request failed
101- resp .raise_for_status ()
102-
103- resp_data = await resp .json ()
104- required_hold : float = resp_data ["required_tokens" ]
105- payment_type : str | None = resp_data ["payment_type" ]
106-
107- if payment_type not in (None , PaymentType .hold ):
108- msg = f"Payment type { payment_type } is not supported"
109- raise ValueError (msg )
110-
111- return Decimal (required_hold )
112-
113-
114- async def fetch_execution_credit_price (item_hash : ItemHash ) -> Decimal :
74+ async def fetch_execution_price (
75+ item_hash : ItemHash , allowed_payments : List [PaymentType ], payment_type_required : bool = True
76+ ) -> Decimal :
11577 """Fetch the credit price of an execution from the reference API server."""
11678 async with aiohttp .ClientSession () as session :
11779 url = f"{ settings .API_SERVER } /api/v0/price/{ item_hash } "
@@ -123,10 +85,15 @@ async def fetch_execution_credit_price(item_hash: ItemHash) -> Decimal:
12385 required_credits : float = resp_data ["required_credits" ] # Field not defined yet on API side.
12486 payment_type : str | None = resp_data ["payment_type" ]
12587
126- if payment_type not in ( None , PaymentType . credit ) :
127- msg = f "Payment type { payment_type } is not supported "
88+ if payment_type_required and payment_type is None :
89+ msg = "Payment type must be specified in the message "
12890 raise ValueError (msg )
12991
92+ if payment_type :
93+ if payment_type not in allowed_payments :
94+ msg = f"Payment type { payment_type } is not supported"
95+ raise ValueError (msg )
96+
13097 return Decimal (required_credits )
13198
13299
@@ -178,17 +145,26 @@ async def get_stream(sender: str, receiver: str, chain: str) -> Decimal:
178145
179146async def compute_required_balance (executions : Iterable [VmExecution ]) -> Decimal :
180147 """Get the balance required for the resources of the user from the messages and the pricing aggregate."""
181- costs = await asyncio .gather (* (fetch_execution_hold_price (execution .vm_hash ) for execution in executions ))
148+ costs = await asyncio .gather (
149+ * (
150+ fetch_execution_price (execution .vm_hash , [PaymentType .hold ], payment_type_required = False )
151+ for execution in executions
152+ )
153+ )
182154 return sum (costs , Decimal (0 ))
183155
184156
185157async def compute_required_credit_balance (executions : Iterable [VmExecution ]) -> Decimal :
186158 """Get the balance required for the resources of the user from the messages and the pricing aggregate."""
187- costs = await asyncio .gather (* (fetch_execution_credit_price (execution .vm_hash ) for execution in executions ))
159+ costs = await asyncio .gather (
160+ * (fetch_execution_price (execution .vm_hash , [PaymentType .credit ]) for execution in executions )
161+ )
188162 return sum (costs , Decimal (0 ))
189163
190164
191165async def compute_required_flow (executions : Iterable [VmExecution ]) -> Decimal :
192166 """Compute the flow required for a collection of executions, typically all executions from a specific address"""
193- flows = await asyncio .gather (* (fetch_execution_flow_price (execution .vm_hash ) for execution in executions ))
167+ flows = await asyncio .gather (
168+ * (fetch_execution_price (execution .vm_hash , [PaymentType .superfluid ]) for execution in executions )
169+ )
194170 return sum (flows , Decimal (0 ))
0 commit comments