Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): add is active almanac update #564

Merged
merged 10 commits into from
Oct 22, 2024
15 changes: 12 additions & 3 deletions python/src/uagents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@
from uagents.protocol import Protocol
from uagents.registration import (
AgentRegistrationPolicy,
AgentStatusUpdate,
DefaultRegistrationPolicy,
update_agent_status,
)
from uagents.resolver import GlobalResolver, Resolver
from uagents.storage import KeyValueStore, get_or_create_private_keys
Expand Down Expand Up @@ -346,10 +348,10 @@ def __init__(
else:
self._mailbox_client = None

almanac_api_url = f"{self._agentverse['http_prefix']}://{self._agentverse['base_url']}/v1/almanac"
self._almanac_api_url = f"{self._agentverse['http_prefix']}://{self._agentverse['base_url']}/v1/almanac"
self._resolver = resolve or GlobalResolver(
max_endpoints=max_resolver_endpoints,
almanac_api_url=almanac_api_url,
almanac_api_url=self._almanac_api_url,
)

self._ledger = get_ledger(test)
Expand Down Expand Up @@ -378,7 +380,7 @@ def __init__(
self._almanac_contract,
self._test,
logger=self._logger,
almanac_api=almanac_api_url,
almanac_api=self._almanac_api_url,
)
self._metadata = self._initialize_metadata(metadata)

Expand Down Expand Up @@ -1095,6 +1097,13 @@ async def _shutdown(self):
Perform shutdown actions.

"""
try:
status = AgentStatusUpdate(agent_address=self.address, is_active=False)
status.sign(self._identity)
await update_agent_status(status, self._almanac_api_url)
except Exception as ex:
self._logger.exception(f"Failed to update agent registration status: {ex}")

for handler in self._on_shutdown:
try:
ctx = self._build_context()
Expand Down
140 changes: 83 additions & 57 deletions python/src/uagents/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import hashlib
import json
import logging
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union

import aiohttp
from cosmpy.aerial.client import LedgerClient
from cosmpy.aerial.wallet import LocalWallet
from cosmpy.crypto.address import Address
from pydantic import BaseModel
from pydantic import BaseModel, field_serializer

from uagents.config import (
ALMANAC_API_MAX_RETRIES,
Expand All @@ -23,6 +24,51 @@
from uagents.types import AgentEndpoint


class VerifiableModel(BaseModel):
agent_address: str
signature: Optional[str] = None
timestamp: Optional[int] = None

def sign(self, identity: Identity):
self.timestamp = int(time.time())
digest = self._build_digest()
self.signature = identity.sign_digest(digest)

def verify(self) -> bool:
return self.signature is not None and Identity.verify_digest(
self.agent_address, self._build_digest(), self.signature
)

def _build_digest(self) -> bytes:
sha256 = hashlib.sha256()
sha256.update(
json.dumps(
self.model_dump(exclude={"signature"}),
sort_keys=True,
separators=(",", ":"),
).encode("utf-8")
)
return sha256.digest()


class AgentRegistrationAttestation(VerifiableModel):
protocols: List[str]
endpoints: List[AgentEndpoint]
metadata: Optional[Dict[str, Union[str, Dict[str, str]]]] = None

@field_serializer("protocols")
Archento marked this conversation as resolved.
Show resolved Hide resolved
def sort_protocols(self, val: List[str]) -> List[str]:
return sorted(val)

@field_serializer("endpoints")
def sort_endpoints(self, val: List[AgentEndpoint]) -> List[AgentEndpoint]:
return sorted(val, key=lambda x: x.url)


class AgentStatusUpdate(VerifiableModel):
is_active: bool


def generate_backoff_time(retry: int) -> float:
"""
Generate a backoff time starting from 0.128 seconds and limited to ~131 seconds
Expand All @@ -49,6 +95,29 @@ def coerce_metadata_to_str(
return out


async def use_almanac_api(
Archento marked this conversation as resolved.
Show resolved Hide resolved
url: str, data: BaseModel, raise_from: bool = True, retries: int = 3
) -> bool:
async with aiohttp.ClientSession() as session:
for retry in range(retries):
try:
async with session.post(
url,
headers={"content-type": "application/json"},
data=data.model_dump_json(),
timeout=aiohttp.ClientTimeout(total=ALMANAC_API_TIMEOUT_SECONDS),
) as resp:
resp.raise_for_status()
return True
except (aiohttp.ClientError, asyncio.exceptions.TimeoutError) as e:
if retry == retries - 1:
if raise_from:
raise e
return False
await asyncio.sleep(generate_backoff_time(retry))
return False


class AgentRegistrationPolicy(ABC):
@abstractmethod
# pylint: disable=unnecessary-pass
Expand All @@ -62,43 +131,6 @@ async def register(
pass


class AgentRegistrationAttestation(BaseModel):
agent_address: str
protocols: List[str]
endpoints: List[AgentEndpoint]
metadata: Optional[Dict[str, Union[str, Dict[str, str]]]] = None
signature: Optional[str] = None

def sign(self, identity: Identity):
digest = self._build_digest()
self.signature = identity.sign_digest(digest)

def verify(self) -> bool:
if self.signature is None:
raise ValueError("Attestation signature is missing")
return Identity.verify_digest(
self.agent_address, self._build_digest(), self.signature
)

def _build_digest(self) -> bytes:
normalised_attestation = AgentRegistrationAttestation(
agent_address=self.agent_address,
protocols=sorted(self.protocols),
endpoints=sorted(self.endpoints, key=lambda x: x.url),
metadata=self.metadata,
)

sha256 = hashlib.sha256()
sha256.update(
json.dumps(
normalised_attestation.model_dump(exclude={"signature"}),
sort_keys=True,
separators=(",", ":"),
).encode("utf-8")
)
return sha256.digest()


class AlmanacApiRegistrationPolicy(AgentRegistrationPolicy):
def __init__(
self,
Expand Down Expand Up @@ -137,25 +169,11 @@ async def register(
# sign the attestation
attestation.sign(self._identity)

# submit the attestation to the API
async with aiohttp.ClientSession() as session: # noqa: SIM117
for retry in range(self._max_retries):
try:
async with session.post(
f"{self._almanac_api}/agents",
headers={"content-type": "application/json"},
data=attestation.model_dump_json(),
timeout=aiohttp.ClientTimeout(
total=ALMANAC_API_TIMEOUT_SECONDS
),
) as resp:
resp.raise_for_status()
self._logger.info("Registration on Almanac API successful")
return
except (aiohttp.ClientError, asyncio.exceptions.TimeoutError) as e:
if retry == self._max_retries - 1:
raise e
await asyncio.sleep(generate_backoff_time(retry))
success = await use_almanac_api(
f"{self._almanac_api}/agents", attestation, retries=self._max_retries
)
if success:
self._logger.info("Registration on Almanac API successful")


class LedgerBasedRegistrationPolicy(AgentRegistrationPolicy):
Expand Down Expand Up @@ -293,3 +311,11 @@ async def register(
except Exception as e:
self._logger.error(f"Failed to register on Almanac contract: {e}")
raise


async def update_agent_status(status: AgentStatusUpdate, almanac_api: str):
Archento marked this conversation as resolved.
Show resolved Hide resolved
await use_almanac_api(
f"{almanac_api}/agents/{status.agent_address}/status",
status,
raise_from=False,
)
18 changes: 16 additions & 2 deletions python/tests/test_registration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import unittest

import pytest
Expand All @@ -21,12 +22,14 @@

def test_attestation_signature():
identity = Identity.generate()
ts = int(time.time())

# create a dummy attestation
attestation = AgentRegistrationAttestation(
agent_address=identity.address,
protocols=TEST_PROTOCOLS,
endpoints=TEST_ENDPOINTS,
timestamp=ts,
)

# sign the attestation with the identity
Expand All @@ -39,6 +42,7 @@ def test_attestation_signature():

def test_attestation_signature_with_metadata():
identity = Identity.generate()
ts = int(time.time())

# create a dummy attestation
attestation = AgentRegistrationAttestation(
Expand All @@ -48,6 +52,7 @@ def test_attestation_signature_with_metadata():
metadata=coerce_metadata_to_str(
{"foo": "bar", "baz": 3.17, "qux": {"a": "b", "c": 4, "d": 5.6}}
),
timestamp=ts,
)

# sign the attestation with the identity
Expand All @@ -60,12 +65,14 @@ def test_attestation_signature_with_metadata():

def test_recovery_of_attestation():
identity = Identity.generate()
ts = int(time.time())

# create an attestation
original_attestation = AgentRegistrationAttestation(
agent_address=identity.address,
protocols=TEST_PROTOCOLS,
endpoints=TEST_ENDPOINTS,
timestamp=ts,
)
original_attestation.sign(identity)

Expand All @@ -75,27 +82,34 @@ def test_recovery_of_attestation():
protocols=TEST_PROTOCOLS,
endpoints=TEST_ENDPOINTS,
signature=original_attestation.signature,
timestamp=ts,
)
assert recovered.verify()


def test_order_of_protocols_or_endpoints_does_not_matter():
identity = Identity.generate()
ts = int(time.time())

reversed_protocols = list(reversed(TEST_PROTOCOLS))
reversed_endpoints = list(reversed(TEST_ENDPOINTS))

# create an attestation
original_attestation = AgentRegistrationAttestation(
agent_address=identity.address,
protocols=TEST_PROTOCOLS,
endpoints=TEST_ENDPOINTS,
timestamp=ts,
)
original_attestation.sign(identity)

# recover the attestation
recovered = AgentRegistrationAttestation(
agent_address=original_attestation.agent_address,
protocols=TEST_PROTOCOLS,
endpoints=TEST_ENDPOINTS,
protocols=reversed_protocols,
endpoints=reversed_endpoints,
signature=original_attestation.signature,
timestamp=ts,
)
assert recovered.verify()

Expand Down
Loading