Skip to content

Commit 8bf9c51

Browse files
Andres D. Molinsolethanh
authored andcommitted
Feature: Implement credits payment method monitoring, same as PAYG.
1 parent c79a8cc commit 8bf9c51

File tree

5 files changed

+104
-23
lines changed

5 files changed

+104
-23
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ dependencies = [
3636
"aioredis==1.3.1",
3737
"aiosqlite==0.19",
3838
"alembic==1.13.1",
39-
"aleph-message~=1.0.1",
39+
# "aleph-message~=1.0.1",
40+
"aleph-message @ git+https://github.com/aleph-im/aleph-message@andres-feature-implement_credits_payment",
4041
"aleph-superfluid~=0.2.1",
4142
"dbus-python==1.3.2",
4243
"eth-account~=0.10",

src/aleph/vm/orchestrator/payment.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,32 @@ async def fetch_balance_of_address(address: str) -> Decimal:
4444
return resp_data["balance"]
4545

4646

47+
async def fetch_credit_balance_of_address(address: str) -> Decimal:
48+
"""
49+
Get the balance of the user from the PyAleph API.
50+
51+
API Endpoint:
52+
GET /api/v0/addresses/{address}/balance
53+
54+
For more details, see the PyAleph API documentation:
55+
https://github.com/aleph-im/pyaleph/blob/master/src/aleph/web/controllers/routes.py#L62
56+
"""
57+
58+
async with aiohttp.ClientSession() as session:
59+
url = f"{settings.API_SERVER}/api/v0/addresses/{address}/credit_balance"
60+
resp = await session.get(url)
61+
62+
# Consider the balance as null if the address is not found
63+
if resp.status == 404:
64+
return Decimal(0)
65+
66+
# Raise an error if the request failed
67+
resp.raise_for_status()
68+
69+
resp_data = await resp.json()
70+
return resp_data["credits"]
71+
72+
4773
async def fetch_execution_flow_price(item_hash: ItemHash) -> Decimal:
4874
"""Fetch the flow price of an execution from the reference API server."""
4975
async with aiohttp.ClientSession() as session:
@@ -85,6 +111,25 @@ async def fetch_execution_hold_price(item_hash: ItemHash) -> Decimal:
85111
return Decimal(required_hold)
86112

87113

114+
async def fetch_execution_credit_price(item_hash: ItemHash) -> Decimal:
115+
"""Fetch the credit price of an execution from the reference API server."""
116+
async with aiohttp.ClientSession() as session:
117+
url = f"{settings.API_SERVER}/api/v0/price/{item_hash}"
118+
resp = await session.get(url)
119+
# Raise an error if the request failed
120+
resp.raise_for_status()
121+
122+
resp_data = await resp.json()
123+
required_credits: float = resp_data["required_credits"] # Field not defined yet on API side.
124+
payment_type: str | None = resp_data["payment_type"]
125+
126+
if payment_type not in (None, PaymentType.credit):
127+
msg = f"Payment type {payment_type} is not supported"
128+
raise ValueError(msg)
129+
130+
return Decimal(required_credits)
131+
132+
88133
class InvalidAddressError(ValueError):
89134
"""The blockchain address could not be parsed."""
90135

@@ -137,6 +182,12 @@ async def compute_required_balance(executions: Iterable[VmExecution]) -> Decimal
137182
return sum(costs, Decimal(0))
138183

139184

185+
async def compute_required_credit_balance(executions: Iterable[VmExecution]) -> Decimal:
186+
"""Get the balance required for the resources of the user from the messages and the pricing aggregate."""
187+
costs = await asyncio.gather(*(fetch_execution_credit_price(execution.vm_hash) for execution in executions))
188+
return sum(costs, Decimal(0))
189+
190+
140191
async def compute_required_flow(executions: Iterable[VmExecution]) -> Decimal:
141192
"""Compute the flow required for a collection of executions, typically all executions from a specific address"""
142193
flows = await asyncio.gather(*(fetch_execution_flow_price(execution.vm_hash) for execution in executions))

src/aleph/vm/orchestrator/tasks.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
from .messages import get_message_status
3232
from .payment import (
3333
compute_required_balance,
34+
compute_required_credit_balance,
3435
compute_required_flow,
3536
fetch_balance_of_address,
37+
fetch_credit_balance_of_address,
3638
get_stream,
3739
)
3840
from .pubsub import PubSub
@@ -187,44 +189,71 @@ async def check_payment(pool: VmPool):
187189
pool.forget_vm(vm_hash)
188190

189191
# Check if the balance held in the wallet is sufficient holder tier resources (Not do it yet)
190-
for sender, chains in pool.get_executions_by_sender(payment_type=PaymentType.hold).items():
192+
for execution_address, chains in pool.get_executions_by_address(payment_type=PaymentType.hold).items():
191193
for chain, executions in chains.items():
192194
executions = [execution for execution in executions if execution.is_confidential]
193195
if not executions:
194196
continue
195-
balance = await fetch_balance_of_address(sender)
197+
balance = await fetch_balance_of_address(execution_address)
196198

197199
# Stop executions until the required balance is reached
198200
required_balance = await compute_required_balance(executions)
199-
logger.debug(f"Required balance for Sender {sender} executions: {required_balance}, {executions}")
201+
logger.debug(
202+
f"Required balance for Sender {execution_address} executions: {required_balance}, {executions}"
203+
)
200204
# Stop executions until the required balance is reached
201205
while executions and balance < (required_balance + settings.PAYMENT_BUFFER):
202206
last_execution = executions.pop(-1)
203207
logger.debug(f"Stopping {last_execution} due to insufficient balance")
204208
await pool.stop_vm(last_execution.vm_hash)
205209
required_balance = await compute_required_balance(executions)
210+
206211
community_wallet = await get_community_wallet_address()
207212
if not community_wallet:
208213
logger.error("Monitor payment ERROR: No community wallet set. Cannot check community payment")
209214

215+
# Check if the credit balance held in the wallet is sufficient credit tier resources (Not do it yet)
216+
for execution_address, chains in pool.get_executions_by_address(payment_type=PaymentType.credit).items():
217+
for chain, executions in chains.items():
218+
executions = [execution for execution in executions]
219+
if not executions:
220+
continue
221+
balance = await fetch_credit_balance_of_address(execution_address)
222+
223+
# Stop executions until the required credits are reached
224+
required_credits = await compute_required_credit_balance(executions)
225+
logger.debug(
226+
f"Required credit balance for Address {execution_address} executions: {required_credits}, {executions}"
227+
)
228+
# Stop executions until the required credits are reached
229+
while executions and balance < (required_credits + settings.PAYMENT_BUFFER):
230+
last_execution = executions.pop(-1)
231+
logger.debug(f"Stopping {last_execution} due to insufficient credit balance")
232+
await pool.stop_vm(last_execution.vm_hash)
233+
required_credits = await compute_required_credit_balance(executions)
234+
210235
# Check if the balance held in the wallet is sufficient stream tier resources
211-
for sender, chains in pool.get_executions_by_sender(payment_type=PaymentType.superfluid).items():
236+
for execution_address, chains in pool.get_executions_by_address(payment_type=PaymentType.superfluid).items():
212237
for chain, executions in chains.items():
213238
try:
214-
stream = await get_stream(sender=sender, receiver=settings.PAYMENT_RECEIVER_ADDRESS, chain=chain)
239+
stream = await get_stream(
240+
sender=execution_address, receiver=settings.PAYMENT_RECEIVER_ADDRESS, chain=chain
241+
)
215242

216243
logger.debug(
217-
f"Stream flow from {sender} to {settings.PAYMENT_RECEIVER_ADDRESS} = {stream} {chain.value}"
244+
f"Stream flow from {execution_address} to {settings.PAYMENT_RECEIVER_ADDRESS} = {stream} {chain.value}"
218245
)
219246
except ValueError as error:
220-
logger.error(f"Error found getting stream for chain {chain} and sender {sender}: {error}")
247+
logger.error(f"Error found getting stream for chain {chain} and sender {execution_address}: {error}")
221248
continue
222249
try:
223-
community_stream = await get_stream(sender=sender, receiver=community_wallet, chain=chain)
224-
logger.debug(f"Stream flow from {sender} to {community_wallet} (community) : {stream} {chain}")
250+
community_stream = await get_stream(sender=execution_address, receiver=community_wallet, chain=chain)
251+
logger.debug(
252+
f"Stream flow from {execution_address} to {community_wallet} (community) : {stream} {chain}"
253+
)
225254

226255
except ValueError as error:
227-
logger.error(f"Error found getting stream for chain {chain} and sender {sender}: {error}")
256+
logger.error(f"Error found getting stream for chain {chain} and sender {execution_address}: {error}")
228257
continue
229258

230259
while executions:
@@ -249,7 +278,7 @@ async def check_payment(pool: VmPool):
249278
)
250279
required_community_stream = format_cost(required_stream * COMMUNITY_STREAM_RATIO)
251280
logger.debug(
252-
f"Stream for senders {sender} {len(executions)} executions. CRN : {stream} / {required_crn_stream}."
281+
f"Stream for senders {execution_address} {len(executions)} executions. CRN : {stream} / {required_crn_stream}."
253282
f"Community: {community_stream} / {required_community_stream}"
254283
)
255284
# Can pay all executions
@@ -259,7 +288,7 @@ async def check_payment(pool: VmPool):
259288
break
260289
# Stop executions until the required stream is reached
261290
last_execution = executions.pop(-1)
262-
logger.info(f"Stopping {last_execution} of {sender} due to insufficient stream")
291+
logger.info(f"Stopping {last_execution} of {execution_address} due to insufficient stream")
263292
await pool.stop_vm(last_execution.vm_hash)
264293

265294

src/aleph/vm/pool.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,9 @@ def get_available_gpus(self) -> list[GpuDevice]:
379379
available_gpus.append(gpu)
380380
return available_gpus
381381

382-
def get_executions_by_sender(self, payment_type: PaymentType) -> dict[str, dict[str, list[VmExecution]]]:
382+
def get_executions_by_address(self, payment_type: PaymentType) -> dict[str, dict[str, list[VmExecution]]]:
383383
"""Return all executions of the given type, grouped by sender and by chain."""
384-
executions_by_sender: dict[str, dict[str, list[VmExecution]]] = {}
384+
executions_by_address: dict[str, dict[str, list[VmExecution]]] = {}
385385
for vm_hash, execution in self.executions.items():
386386
if execution.vm_hash in (settings.CHECK_FASTAPI_VM_ID, settings.LEGACY_CHECK_FASTAPI_VM_ID):
387387
# Ignore Diagnostic VM execution
@@ -399,11 +399,11 @@ def get_executions_by_sender(self, payment_type: PaymentType) -> dict[str, dict[
399399
else Payment(chain=Chain.ETH, type=PaymentType.hold)
400400
)
401401
if execution_payment.type == payment_type:
402-
sender = execution.message.address
402+
address = execution.message.address
403403
chain = execution_payment.chain
404-
executions_by_sender.setdefault(sender, {})
405-
executions_by_sender[sender].setdefault(chain, []).append(execution)
406-
return executions_by_sender
404+
executions_by_address.setdefault(address, {})
405+
executions_by_address[address].setdefault(chain, []).append(execution)
406+
return executions_by_address
407407

408408
def get_valid_reservation(self, resource) -> Reservation | None:
409409
if resource in self.reservations and self.reservations[resource].is_expired():

tests/supervisor/test_checkpayment.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def compute_required_flow(executions):
8282

8383
pool.executions = {hash: execution}
8484

85-
executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid)
85+
executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid)
8686
assert len(executions_by_sender) == 1
8787
assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}}
8888

@@ -136,7 +136,7 @@ async def compute_required_flow(executions):
136136

137137
pool.executions = {hash: execution}
138138

139-
executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid)
139+
executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid)
140140
assert len(executions_by_sender) == 1
141141
assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}}
142142

@@ -173,7 +173,7 @@ async def test_not_enough_flow(mocker, fake_instance_content):
173173

174174
pool.executions = {hash: execution}
175175

176-
executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid)
176+
executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid)
177177
assert len(executions_by_sender) == 1
178178
assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}}
179179

@@ -217,7 +217,7 @@ async def get_stream(sender, receiver, chain):
217217

218218
pool.executions = {hash: execution}
219219

220-
executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid)
220+
executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid)
221221
assert len(executions_by_sender) == 1
222222
assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}}
223223

0 commit comments

Comments
 (0)