Skip to content

Commit

Permalink
feat: update resolvers to try multiple endpoints (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrriehl authored Sep 19, 2023
1 parent 3bef47f commit dbf6de4
Show file tree
Hide file tree
Showing 10 changed files with 517 additions and 92 deletions.
2 changes: 2 additions & 0 deletions python/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ disable=missing-module-docstring,
too-many-instance-attributes,
duplicate-code,
too-many-arguments,
too-many-locals,
too-many-return-statements,
logging-fstring-interpolation,
broad-exception-caught

17 changes: 16 additions & 1 deletion python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ cosmpy = "^0.9.1"
websockets = "^10.4"

[tool.poetry.dev-dependencies]
aioresponses = "^0.7.4"
black = "^23.1.0"
pytest = "^7.1.3"
pylint = "^2.15.3"
Expand Down
12 changes: 9 additions & 3 deletions python/src/uagents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class Agent(Sink):
_unsigned_message_handlers (Dict[str, MessageCallback]): Handlers for
unsigned messages.
_models (Dict[str, Type[Model]]): Dictionary mapping supported message digests to messages.
_replies (Dict[str, Set[Type[Model]]]): Dictionary of allowed reply digests for each type
_replies (Dict[str, Dict[str, Type[Model]]]): Dictionary of allowed replies for each type
of incoming message.
_queries (Dict[str, asyncio.Future]): Dictionary mapping query senders to their response
Futures.
Expand Down Expand Up @@ -149,6 +149,7 @@ def __init__(
agentverse: Optional[Union[str, Dict[str, str]]] = None,
mailbox: Optional[Union[str, Dict[str, str]]] = None,
resolve: Optional[Resolver] = None,
max_resolver_endpoints: Optional[int] = None,
version: Optional[str] = None,
):
"""
Expand All @@ -162,12 +163,17 @@ def __init__(
agentverse (Optional[Union[str, Dict[str, str]]]): The agentverse configuration.
mailbox (Optional[Union[str, Dict[str, str]]]): The mailbox configuration.
resolve (Optional[Resolver]): The resolver to use for agent communication.
max_resolver_endpoints (Optional[int]): The maximum number of endpoints to resolve.
version (Optional[str]): The version of the agent.
"""
self._name = name
self._port = port if port is not None else 8000
self._background_tasks: Set[asyncio.Task] = set()
self._resolver = resolve if resolve is not None else GlobalResolver()
self._resolver = (
resolve
if resolve is not None
else GlobalResolver(max_endpoints=max_resolver_endpoints)
)
self._loop = asyncio.get_event_loop_policy().get_event_loop()

self._initialize_wallet_and_identity(seed, name)
Expand Down Expand Up @@ -212,7 +218,7 @@ def __init__(
self._signed_message_handlers: Dict[str, MessageCallback] = {}
self._unsigned_message_handlers: Dict[str, MessageCallback] = {}
self._models: Dict[str, Type[Model]] = {}
self._replies: Dict[str, Set[Type[Model]]] = {}
self._replies: Dict[str, Dict[str, Type[Model]]] = {}
self._queries: Dict[str, asyncio.Future] = {}
self._dispatcher = dispatcher
self._message_queue = asyncio.Queue()
Expand Down
1 change: 1 addition & 0 deletions python/src/uagents/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class AgentNetwork(Enum):
MAILBOX_POLL_INTERVAL_SECONDS = 1.0

DEFAULT_ENVELOPE_TIMEOUT_SECONDS = 30
DEFAULT_MAX_ENDPOINTS = 10
DEFAULT_SEARCH_LIMIT = 100


Expand Down
150 changes: 116 additions & 34 deletions python/src/uagents/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import uuid
from dataclasses import dataclass
from enum import Enum
from time import time
from typing import (
Dict,
Expand Down Expand Up @@ -44,6 +45,14 @@
EventCallback = Callable[["Context"], Awaitable[None]]


class DeliveryStatus(str, Enum):
"""Delivery status of a message."""

SENT = "sent"
DELIVERED = "delivered"
FAILED = "failed"


@dataclass
class MsgDigest:
"""
Expand All @@ -58,6 +67,24 @@ class MsgDigest:
schema_digest: str


@dataclass
class MsgStatus:
"""
Represents the status of a sent message.
Attributes:
status (str): The delivery status of the message {'sent', 'delivered', 'failed'}.
detail (str): The details of the message delivery.
destination (str): The destination address of the message.
endpoint (str): The endpoint the message was sent to.
"""

status: DeliveryStatus
detail: str
destination: str
endpoint: str


ERROR_MESSAGE_DIGEST = Model.build_schema_digest(ErrorMessage)


Expand All @@ -76,7 +103,7 @@ class Context:
_queries (Dict[str, asyncio.Future]): Dictionary mapping query senders to their
response Futures.
_session (Optional[uuid.UUID]): The session UUID.
_replies (Optional[Dict[str, Set[Type[Model]]]]): Dictionary of allowed reply digests
_replies (Optional[Dict[str, Dict[str, Type[Model]]]]): Dictionary of allowed reply digests
for each type of incoming message.
_interval_messages (Optional[Set[str]]): Set of message digests that may be sent by
interval tasks.
Expand Down Expand Up @@ -115,7 +142,7 @@ def __init__(
ledger: LedgerClient,
queries: Dict[str, asyncio.Future],
session: Optional[uuid.UUID] = None,
replies: Optional[Dict[str, Set[Type[Model]]]] = None,
replies: Optional[Dict[str, Dict[str, Type[Model]]]] = None,
interval_messages: Optional[Set[str]] = None,
message_received: Optional[MsgDigest] = None,
protocols: Optional[Dict[str, Protocol]] = None,
Expand All @@ -135,7 +162,8 @@ def __init__(
queries (Dict[str, asyncio.Future]): Dictionary mapping query senders to their response
Futures.
session (Optional[uuid.UUID]): The optional session UUID.
replies (Optional[Dict[str, Set[Type[Model]]]]): Optional dictionary of reply models.
replies (Optional[Dict[str, Dict[str, Type[Model]]]]): Dictionary of allowed replies
for each type of incoming message.
interval_messages (Optional[Set[str]]): The optional set of interval messages.
message_received (Optional[MsgDigest]): The optional message digest received.
protocols (Optional[Dict[str, Protocol]]): The optional dictionary of protocols.
Expand Down Expand Up @@ -262,17 +290,20 @@ async def send(
destination: str,
message: Model,
timeout: Optional[int] = DEFAULT_ENVELOPE_TIMEOUT_SECONDS,
):
) -> MsgStatus:
"""
Send a message to the specified destination.
Args:
destination (str): The destination address to send the message to.
message (Model): The message to be sent.
timeout (Optional[int]): The optional timeout for sending the message, in seconds.
Returns:
MsgStatus: The delivery status of the message.
"""
schema_digest = Model.build_schema_digest(message)
await self.send_raw(
return await self.send_raw(
destination,
message.json(),
schema_digest,
Expand All @@ -286,7 +317,7 @@ async def experimental_broadcast(
message: Model,
limit: Optional[int] = DEFAULT_SEARCH_LIMIT,
timeout: Optional[int] = DEFAULT_ENVELOPE_TIMEOUT_SECONDS,
):
) -> List[MsgStatus]:
"""Broadcast a message to agents with a specific protocol.
This asynchronous method broadcasts a given message to agents associated
Expand All @@ -300,7 +331,7 @@ async def experimental_broadcast(
timeout (int, optional): The timeout for sending each message.
Returns:
None
List[MsgStatus]: A list of message delivery statuses.
"""
agents = self.get_agents_by_protocol(destination_protocol, limit=limit)
if not agents:
Expand All @@ -320,6 +351,7 @@ async def experimental_broadcast(
]
)
self.logger.debug(f"Sent {len(futures)} messages")
return futures

async def send_raw(
self,
Expand All @@ -328,16 +360,19 @@ async def send_raw(
schema_digest: str,
message_type: Optional[Type[Model]] = None,
timeout: Optional[int] = DEFAULT_ENVELOPE_TIMEOUT_SECONDS,
):
) -> MsgStatus:
"""
Send a raw message to the specified destination.
Args:
destination (str): The destination address to send the message to.
destination (str): The destination name or address to send the message to.
json_message (JsonStr): The JSON-encoded message to be sent.
schema_digest (str): The schema digest of the message.
message_type (Optional[Type[Model]]): The optional type of the message being sent.
timeout (Optional[int]): The optional timeout for sending the message, in seconds.
Returns:
MsgStatus: The delivery status of the message.
"""
# Check if this message is a reply
if (
Expand All @@ -353,36 +388,64 @@ async def send_raw(
f"Outgoing message {message_type or ''} "
f"is not a valid reply to {received.message}"
)
return

return MsgStatus(
status=DeliveryStatus.FAILED,
detail="Invalid reply",
destination=destination,
endpoint="",
)
# Check if this message is a valid interval message
if self._message_received is None and self._interval_messages:
if schema_digest not in self._interval_messages:
self._logger.exception(
f"Outgoing message {message_type} is not a valid interval message"
)
return
return MsgStatus(
status=DeliveryStatus.FAILED,
detail="Invalid interval message",
destination=destination,
endpoint="",
)

# Handle local dispatch of messages
if dispatcher.contains(destination):
await dispatcher.dispatch(
self.address, destination, schema_digest, json_message, self._session
self.address,
destination,
schema_digest,
json_message,
self._session,
)
return MsgStatus(
status=DeliveryStatus.DELIVERED,
detail="Message dispatched locally",
destination=destination,
endpoint="",
)
return

# Handle queries waiting for a response
if destination in self._queries:
self._queries[destination].set_result((json_message, schema_digest))
del self._queries[destination]
return
return MsgStatus(
status=DeliveryStatus.DELIVERED,
detail="Sync message resolved",
destination=destination,
endpoint="",
)

# Resolve the endpoint
destination_address, endpoint = await self._resolver.resolve(destination)
if endpoint is None:
# Resolve the destination address and endpoint ('destination' can be a name or address)
destination_address, endpoints = await self._resolver.resolve(destination)
if len(endpoints) == 0:
self._logger.exception(
f"Unable to resolve destination endpoint for address {destination}"
)
return
return MsgStatus(
status=DeliveryStatus.FAILED,
detail="Unable to resolve destination endpoint",
destination=destination,
endpoint="",
)

# Calculate when the envelope expires
expires = int(time()) + timeout
Expand All @@ -400,19 +463,38 @@ async def send_raw(
env.encode_payload(json_message)
env.sign(self._identity)

try:
async with aiohttp.ClientSession() as session:
async with session.post(
endpoint,
headers={"content-type": "application/json"},
data=env.json(),
) as resp:
success = resp.status == 200
if not success:
self._logger.exception(
f"Unable to send envelope to {destination_address} @ {endpoint}"
for endpoint in endpoints:
try:
async with aiohttp.ClientSession() as session:
async with session.post(
endpoint,
headers={"content-type": "application/json"},
data=env.json(),
) as resp:
success = resp.status == 200
if success:
return MsgStatus(
status=DeliveryStatus.DELIVERED,
detail="Message successfully delivered via HTTP",
destination=destination,
endpoint=endpoint,
)
self._logger.warning(
f"Failed to send message to {destination_address} @ {endpoint}: "
+ (await resp.text())
)
except aiohttp.ClientConnectorError as ex:
self._logger.exception(f"Failed to connect to {endpoint}: {ex}")
except Exception as ex:
self._logger.exception(f"Failed to send message to {destination}: {ex}")
except aiohttp.ClientConnectorError as ex:
self._logger.warning(f"Failed to connect to {endpoint}: {ex}")
except Exception as ex:
self._logger.warning(
f"Failed to send message to {destination} @ {endpoint}: {ex}"
)

self._logger.exception(f"Failed to deliver message to {destination}")

return MsgStatus(
status=DeliveryStatus.FAILED,
detail="Message delivery failed",
destination=destination,
endpoint="",
)
3 changes: 3 additions & 0 deletions python/src/uagents/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def register(self, address: str, sink: Sink):
def unregister(self, address: str, sink: Sink):
destinations = self._sinks.get(address, set())
destinations.discard(sink)
if len(destinations) == 0:
del self._sinks[address]
return
self._sinks[address] = destinations

def contains(self, address: str) -> bool:
Expand Down
Loading

0 comments on commit dbf6de4

Please sign in to comment.