From dbf6de4580e4c9cfc32735565f04eb6eb9ec8709 Mon Sep 17 00:00:00 2001 From: James Riehl <33920192+jrriehl@users.noreply.github.com> Date: Tue, 19 Sep 2023 08:56:54 +0100 Subject: [PATCH] feat: update resolvers to try multiple endpoints (#150) --- python/.pylintrc | 2 + python/poetry.lock | 17 ++- python/pyproject.toml | 1 + python/src/uagents/agent.py | 12 +- python/src/uagents/config.py | 1 + python/src/uagents/context.py | 150 ++++++++++++++----- python/src/uagents/dispatch.py | 3 + python/src/uagents/query.py | 44 +++--- python/src/uagents/resolver.py | 116 ++++++++++----- python/tests/test_context.py | 263 +++++++++++++++++++++++++++++++++ 10 files changed, 517 insertions(+), 92 deletions(-) create mode 100644 python/tests/test_context.py diff --git a/python/.pylintrc b/python/.pylintrc index bd151bd4..cbbc7976 100644 --- a/python/.pylintrc +++ b/python/.pylintrc @@ -8,6 +8,8 @@ disable=missing-module-docstring, too-many-instance-attributes, duplicate-code, too-many-arguments, + too-many-locals, + too-many-return-statements, logging-fstring-interpolation, broad-exception-caught diff --git a/python/poetry.lock b/python/poetry.lock index adc6480b..2200e6be 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -109,6 +109,21 @@ yarl = ">=1.0,<2.0" [package.extras] speedups = ["Brotli", "aiodns", "cchardet"] +[[package]] +name = "aioresponses" +version = "0.7.4" +description = "Mock out requests made by ClientSession from aiohttp package" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "aioresponses-0.7.4-py2.py3-none-any.whl", hash = "sha256:1160486b5ea96fcae6170cf2bdef029b9d3a283b7dbeabb3d7f1182769bfb6b7"}, + {file = "aioresponses-0.7.4.tar.gz", hash = "sha256:9b8c108b36354c04633bad0ea752b55d956a7602fe3e3234b939fc44af96f1d8"}, +] + +[package.dependencies] +aiohttp = ">=2.0.0,<4.0.0" + [[package]] name = "aiosignal" version = "1.3.1" @@ -2179,4 +2194,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.12" -content-hash = "68dcedf263dd15c5fdc1d83112a4ff106aeb22d44e9afdee512ed4058418c38c" +content-hash = "98c7ca5a73a64ae7a3bc6ed181aa12ab7b984a514af22073ec9ea1ecf7fd6a17" diff --git a/python/pyproject.toml b/python/pyproject.toml index 349cfa91..53677eaa 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -19,6 +19,7 @@ cosmpy = "^0.9.1" websockets = "^10.4" [tool.poetry.dev-dependencies] +aioresponses = "^0.7.4" black = "^23.1.0" pytest = "^7.1.3" pylint = "^2.15.3" diff --git a/python/src/uagents/agent.py b/python/src/uagents/agent.py index 7ab179c3..5afa6bc9 100644 --- a/python/src/uagents/agent.py +++ b/python/src/uagents/agent.py @@ -111,7 +111,7 @@ class Agent(Sink): _unsigned_message_handlers (Dict[str, MessageCallback]): Handlers for unsigned messages. _models (Dict[str, Type[Model]]): Dictionary mapping supported message digests to messages. - _replies (Dict[str, Set[Type[Model]]]): Dictionary of allowed reply digests for each type + _replies (Dict[str, Dict[str, Type[Model]]]): Dictionary of allowed replies for each type of incoming message. _queries (Dict[str, asyncio.Future]): Dictionary mapping query senders to their response Futures. @@ -149,6 +149,7 @@ def __init__( agentverse: Optional[Union[str, Dict[str, str]]] = None, mailbox: Optional[Union[str, Dict[str, str]]] = None, resolve: Optional[Resolver] = None, + max_resolver_endpoints: Optional[int] = None, version: Optional[str] = None, ): """ @@ -162,12 +163,17 @@ def __init__( agentverse (Optional[Union[str, Dict[str, str]]]): The agentverse configuration. mailbox (Optional[Union[str, Dict[str, str]]]): The mailbox configuration. resolve (Optional[Resolver]): The resolver to use for agent communication. + max_resolver_endpoints (Optional[int]): The maximum number of endpoints to resolve. version (Optional[str]): The version of the agent. """ self._name = name self._port = port if port is not None else 8000 self._background_tasks: Set[asyncio.Task] = set() - self._resolver = resolve if resolve is not None else GlobalResolver() + self._resolver = ( + resolve + if resolve is not None + else GlobalResolver(max_endpoints=max_resolver_endpoints) + ) self._loop = asyncio.get_event_loop_policy().get_event_loop() self._initialize_wallet_and_identity(seed, name) @@ -212,7 +218,7 @@ def __init__( self._signed_message_handlers: Dict[str, MessageCallback] = {} self._unsigned_message_handlers: Dict[str, MessageCallback] = {} self._models: Dict[str, Type[Model]] = {} - self._replies: Dict[str, Set[Type[Model]]] = {} + self._replies: Dict[str, Dict[str, Type[Model]]] = {} self._queries: Dict[str, asyncio.Future] = {} self._dispatcher = dispatcher self._message_queue = asyncio.Queue() diff --git a/python/src/uagents/config.py b/python/src/uagents/config.py index 6e8da0c2..f22fbfdc 100644 --- a/python/src/uagents/config.py +++ b/python/src/uagents/config.py @@ -32,6 +32,7 @@ class AgentNetwork(Enum): MAILBOX_POLL_INTERVAL_SECONDS = 1.0 DEFAULT_ENVELOPE_TIMEOUT_SECONDS = 30 +DEFAULT_MAX_ENDPOINTS = 10 DEFAULT_SEARCH_LIMIT = 100 diff --git a/python/src/uagents/context.py b/python/src/uagents/context.py index 5acfa796..d10203e3 100644 --- a/python/src/uagents/context.py +++ b/python/src/uagents/context.py @@ -6,6 +6,7 @@ import logging import uuid from dataclasses import dataclass +from enum import Enum from time import time from typing import ( Dict, @@ -44,6 +45,14 @@ EventCallback = Callable[["Context"], Awaitable[None]] +class DeliveryStatus(str, Enum): + """Delivery status of a message.""" + + SENT = "sent" + DELIVERED = "delivered" + FAILED = "failed" + + @dataclass class MsgDigest: """ @@ -58,6 +67,24 @@ class MsgDigest: schema_digest: str +@dataclass +class MsgStatus: + """ + Represents the status of a sent message. + + Attributes: + status (str): The delivery status of the message {'sent', 'delivered', 'failed'}. + detail (str): The details of the message delivery. + destination (str): The destination address of the message. + endpoint (str): The endpoint the message was sent to. + """ + + status: DeliveryStatus + detail: str + destination: str + endpoint: str + + ERROR_MESSAGE_DIGEST = Model.build_schema_digest(ErrorMessage) @@ -76,7 +103,7 @@ class Context: _queries (Dict[str, asyncio.Future]): Dictionary mapping query senders to their response Futures. _session (Optional[uuid.UUID]): The session UUID. - _replies (Optional[Dict[str, Set[Type[Model]]]]): Dictionary of allowed reply digests + _replies (Optional[Dict[str, Dict[str, Type[Model]]]]): Dictionary of allowed reply digests for each type of incoming message. _interval_messages (Optional[Set[str]]): Set of message digests that may be sent by interval tasks. @@ -115,7 +142,7 @@ def __init__( ledger: LedgerClient, queries: Dict[str, asyncio.Future], session: Optional[uuid.UUID] = None, - replies: Optional[Dict[str, Set[Type[Model]]]] = None, + replies: Optional[Dict[str, Dict[str, Type[Model]]]] = None, interval_messages: Optional[Set[str]] = None, message_received: Optional[MsgDigest] = None, protocols: Optional[Dict[str, Protocol]] = None, @@ -135,7 +162,8 @@ def __init__( queries (Dict[str, asyncio.Future]): Dictionary mapping query senders to their response Futures. session (Optional[uuid.UUID]): The optional session UUID. - replies (Optional[Dict[str, Set[Type[Model]]]]): Optional dictionary of reply models. + replies (Optional[Dict[str, Dict[str, Type[Model]]]]): Dictionary of allowed replies + for each type of incoming message. interval_messages (Optional[Set[str]]): The optional set of interval messages. message_received (Optional[MsgDigest]): The optional message digest received. protocols (Optional[Dict[str, Protocol]]): The optional dictionary of protocols. @@ -262,7 +290,7 @@ async def send( destination: str, message: Model, timeout: Optional[int] = DEFAULT_ENVELOPE_TIMEOUT_SECONDS, - ): + ) -> MsgStatus: """ Send a message to the specified destination. @@ -270,9 +298,12 @@ async def send( destination (str): The destination address to send the message to. message (Model): The message to be sent. timeout (Optional[int]): The optional timeout for sending the message, in seconds. + + Returns: + MsgStatus: The delivery status of the message. """ schema_digest = Model.build_schema_digest(message) - await self.send_raw( + return await self.send_raw( destination, message.json(), schema_digest, @@ -286,7 +317,7 @@ async def experimental_broadcast( message: Model, limit: Optional[int] = DEFAULT_SEARCH_LIMIT, timeout: Optional[int] = DEFAULT_ENVELOPE_TIMEOUT_SECONDS, - ): + ) -> List[MsgStatus]: """Broadcast a message to agents with a specific protocol. This asynchronous method broadcasts a given message to agents associated @@ -300,7 +331,7 @@ async def experimental_broadcast( timeout (int, optional): The timeout for sending each message. Returns: - None + List[MsgStatus]: A list of message delivery statuses. """ agents = self.get_agents_by_protocol(destination_protocol, limit=limit) if not agents: @@ -320,6 +351,7 @@ async def experimental_broadcast( ] ) self.logger.debug(f"Sent {len(futures)} messages") + return futures async def send_raw( self, @@ -328,16 +360,19 @@ async def send_raw( schema_digest: str, message_type: Optional[Type[Model]] = None, timeout: Optional[int] = DEFAULT_ENVELOPE_TIMEOUT_SECONDS, - ): + ) -> MsgStatus: """ Send a raw message to the specified destination. Args: - destination (str): The destination address to send the message to. + destination (str): The destination name or address to send the message to. json_message (JsonStr): The JSON-encoded message to be sent. schema_digest (str): The schema digest of the message. message_type (Optional[Type[Model]]): The optional type of the message being sent. timeout (Optional[int]): The optional timeout for sending the message, in seconds. + + Returns: + MsgStatus: The delivery status of the message. """ # Check if this message is a reply if ( @@ -353,36 +388,64 @@ async def send_raw( f"Outgoing message {message_type or ''} " f"is not a valid reply to {received.message}" ) - return - + return MsgStatus( + status=DeliveryStatus.FAILED, + detail="Invalid reply", + destination=destination, + endpoint="", + ) # Check if this message is a valid interval message if self._message_received is None and self._interval_messages: if schema_digest not in self._interval_messages: self._logger.exception( f"Outgoing message {message_type} is not a valid interval message" ) - return + return MsgStatus( + status=DeliveryStatus.FAILED, + detail="Invalid interval message", + destination=destination, + endpoint="", + ) # Handle local dispatch of messages if dispatcher.contains(destination): await dispatcher.dispatch( - self.address, destination, schema_digest, json_message, self._session + self.address, + destination, + schema_digest, + json_message, + self._session, + ) + return MsgStatus( + status=DeliveryStatus.DELIVERED, + detail="Message dispatched locally", + destination=destination, + endpoint="", ) - return # Handle queries waiting for a response if destination in self._queries: self._queries[destination].set_result((json_message, schema_digest)) del self._queries[destination] - return + return MsgStatus( + status=DeliveryStatus.DELIVERED, + detail="Sync message resolved", + destination=destination, + endpoint="", + ) - # Resolve the endpoint - destination_address, endpoint = await self._resolver.resolve(destination) - if endpoint is None: + # Resolve the destination address and endpoint ('destination' can be a name or address) + destination_address, endpoints = await self._resolver.resolve(destination) + if len(endpoints) == 0: self._logger.exception( f"Unable to resolve destination endpoint for address {destination}" ) - return + return MsgStatus( + status=DeliveryStatus.FAILED, + detail="Unable to resolve destination endpoint", + destination=destination, + endpoint="", + ) # Calculate when the envelope expires expires = int(time()) + timeout @@ -400,19 +463,38 @@ async def send_raw( env.encode_payload(json_message) env.sign(self._identity) - try: - async with aiohttp.ClientSession() as session: - async with session.post( - endpoint, - headers={"content-type": "application/json"}, - data=env.json(), - ) as resp: - success = resp.status == 200 - if not success: - self._logger.exception( - f"Unable to send envelope to {destination_address} @ {endpoint}" + for endpoint in endpoints: + try: + async with aiohttp.ClientSession() as session: + async with session.post( + endpoint, + headers={"content-type": "application/json"}, + data=env.json(), + ) as resp: + success = resp.status == 200 + if success: + return MsgStatus( + status=DeliveryStatus.DELIVERED, + detail="Message successfully delivered via HTTP", + destination=destination, + endpoint=endpoint, + ) + self._logger.warning( + f"Failed to send message to {destination_address} @ {endpoint}: " + + (await resp.text()) ) - except aiohttp.ClientConnectorError as ex: - self._logger.exception(f"Failed to connect to {endpoint}: {ex}") - except Exception as ex: - self._logger.exception(f"Failed to send message to {destination}: {ex}") + except aiohttp.ClientConnectorError as ex: + self._logger.warning(f"Failed to connect to {endpoint}: {ex}") + except Exception as ex: + self._logger.warning( + f"Failed to send message to {destination} @ {endpoint}: {ex}" + ) + + self._logger.exception(f"Failed to deliver message to {destination}") + + return MsgStatus( + status=DeliveryStatus.FAILED, + detail="Message delivery failed", + destination=destination, + endpoint="", + ) diff --git a/python/src/uagents/dispatch.py b/python/src/uagents/dispatch.py index 5aedfb3d..80079678 100644 --- a/python/src/uagents/dispatch.py +++ b/python/src/uagents/dispatch.py @@ -33,6 +33,9 @@ def register(self, address: str, sink: Sink): def unregister(self, address: str, sink: Sink): destinations = self._sinks.get(address, set()) destinations.discard(sink) + if len(destinations) == 0: + del self._sinks[address] + return self._sinks[address] = destinations def contains(self, address: str) -> bool: diff --git a/python/src/uagents/query.py b/python/src/uagents/query.py index bbacb6f6..f5bf6db5 100644 --- a/python/src/uagents/query.py +++ b/python/src/uagents/query.py @@ -44,8 +44,8 @@ async def query( schema_digest = Model.build_schema_digest(message) # resolve the endpoint - destination_address, endpoint = await resolver.resolve(destination) - if endpoint is None: + destination_address, endpoints = await resolver.resolve(destination) + if len(endpoints) == 0: LOGGER.exception( f"Unable to resolve destination endpoint for address {destination}" ) @@ -65,22 +65,30 @@ async def query( ) env.encode_payload(json_message) - async with aiohttp.ClientSession() as session: - async with session.post( - endpoint, - headers={ - "content-type": "application/json", - "x-uagents-connection": "sync", - }, - data=env.json(), - timeout=timeout, - ) as resp: - success = resp.status == 200 - - if success: - return Envelope.parse_obj(await resp.json()) - - LOGGER.exception(f"Unable to query {destination} @ {endpoint}") + for endpoint in endpoints: + try: + async with aiohttp.ClientSession() as session: + async with session.post( + endpoints[0], + headers={ + "content-type": "application/json", + "x-uagents-connection": "sync", + }, + data=env.json(), + timeout=timeout, + ) as resp: + success = resp.status == 200 + + if success: + return Envelope.parse_obj(await resp.json()) + except aiohttp.ClientConnectorError as ex: + LOGGER.warning(f"Failed to connect to {endpoint}: {ex}") + except Exception as ex: + LOGGER.warning( + f"Failed to send sync message to {destination} @ {endpoint}: {ex}" + ) + + LOGGER.exception(f"Failed to send sync message to {destination}") def enclose_response(message: Model, sender: str, session: str) -> str: diff --git a/python/src/uagents/resolver.py b/python/src/uagents/resolver.py index ec3b2b0c..63a35db8 100644 --- a/python/src/uagents/resolver.py +++ b/python/src/uagents/resolver.py @@ -1,9 +1,10 @@ """Endpoint Resolver.""" from abc import ABC, abstractmethod -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple import random +from uagents.config import DEFAULT_MAX_ENDPOINTS from uagents.network import get_almanac_contract, get_name_service_contract @@ -34,7 +35,7 @@ def get_agent_address(name: str) -> str: name (str): The name to query. Returns: - str: The associated agent address. + Optional[str]: The associated agent address if found. """ query_msg = {"domain_record": {"domain": f"{name}"}} result = get_name_service_contract().query(query_msg) @@ -42,8 +43,7 @@ def get_agent_address(name: str) -> str: registered_address = result["record"]["records"][0]["agent_address"]["records"] if len(registered_address) > 0: return registered_address[0]["address"] - return 0 - return 1 + return None def is_agent_address(address): @@ -68,53 +68,67 @@ def is_agent_address(address): class Resolver(ABC): @abstractmethod # pylint: disable=unnecessary-pass - async def resolve(self, destination: str) -> Optional[str]: + async def resolve(self, destination: str) -> Tuple[Optional[str], List[str]]: """ - Resolve the destination to an endpoint. + Resolve the destination to an address and endpoint. Args: - destination (str): The destination to resolve. + destination (str): The destination name or address to resolve. Returns: - Optional[str]: The resolved endpoint or None. + Tuple[Optional[str], List[str]]: The address (if available) and resolved endpoints. """ pass class GlobalResolver(Resolver): - async def resolve(self, destination: str) -> Optional[str]: + def __init__(self, max_endpoints: Optional[int] = None): """ - Resolve the destination using a combination of Almanac and NameService resolvers. + Initialize the GlobalResolver. Args: - destination (str): The destination to resolve. - - Returns: - Optional[str]: The resolved endpoint or None. - """ - almanac_resolver = AlmanacResolver() - name_service_resolver = NameServiceResolver() - address = ( - destination - if is_agent_address(destination) - else await name_service_resolver.resolve(destination) + max_endpoints (Optional[int]): The maximum number of endpoints to return. + """ + self._max_endpoints = max_endpoints or DEFAULT_MAX_ENDPOINTS + self._almanc_resolver = AlmanacResolver(max_endpoints=self._max_endpoints) + self._name_service_resolver = NameServiceResolver( + max_endpoints=self._max_endpoints ) - if is_agent_address(address): - return await almanac_resolver.resolve(address) - return None, None + async def resolve(self, destination: str) -> Tuple[Optional[str], List[str]]: + """ + Resolve the destination using the appropriate resolver. + + Args: + destination (str): The destination name or address to resolve. + + Returns: + Tuple[Optional[str], List[str]]: The address (if available) and resolved endpoints. + """ + if is_agent_address(destination): + return await self._almanc_resolver.resolve(destination) + return await self._name_service_resolver.resolve(destination) class AlmanacResolver(Resolver): - async def resolve(self, destination: str) -> Optional[str]: + def __init__(self, max_endpoints: Optional[int] = None): + """ + Initialize the AlmanacResolver. + + Args: + max_endpoints (Optional[int]): The maximum number of endpoints to return. + """ + self._max_endpoints = max_endpoints or DEFAULT_MAX_ENDPOINTS + + async def resolve(self, destination: str) -> Tuple[Optional[str], List[str]]: """ Resolve the destination using the Almanac contract. Args: - destination (str): The destination to resolve. + destination (str): The destination address to resolve. Returns: - Optional[str]: The resolved endpoint or None. + Tuple[str, List[str]]: The address and resolved endpoints. """ result = query_record(destination, "service") if result is not None: @@ -126,34 +140,55 @@ async def resolve(self, destination: str) -> Optional[str]: if len(endpoint_list) > 0: endpoints = [val.get("url") for val in endpoint_list] weights = [val.get("weight") for val in endpoint_list] - return destination, random.choices(endpoints, weights=weights)[0] + return destination, random.choices( + endpoints, + weights=weights, + k=min(self._max_endpoints, len(endpoints)), + ) - return None, None + return None, [] class NameServiceResolver(Resolver): - async def resolve(self, destination: str) -> Optional[str]: + def __init__(self, max_endpoints: Optional[int] = None): + """ + Initialize the NameServiceResolver. + + Args: + max_endpoints (Optional[int]): The maximum number of endpoints to return. + """ + self._max_endpoints = max_endpoints or DEFAULT_MAX_ENDPOINTS + self._almanac_resolver = AlmanacResolver(max_endpoints=self._max_endpoints) + + async def resolve(self, destination: str) -> Tuple[Optional[str], List[str]]: """ Resolve the destination using the NameService contract. Args: - destination (str): The destination to resolve. + destination (str): The destination name to resolve. Returns: - Optional[str]: The resolved endpoint or None. + Tuple[Optional[str], List[str]]: The address (if available) and resolved endpoints. """ - return get_agent_address(destination) + address = get_agent_address(destination) + if address is not None: + return await self._almanac_resolver.resolve(address) + return None, [] class RulesBasedResolver(Resolver): - def __init__(self, rules: Dict[str, str]): + def __init__( + self, rules: Dict[str, str], max_endpoints: Optional[int] = None + ) -> Tuple[Optional[str], List[str]]: """ Initialize the RulesBasedResolver with the provided rules. Args: rules (Dict[str, str]): A dictionary of rules mapping destinations to endpoints. + max_endpoints (Optional[int]): The maximum number of endpoints to return. """ self._rules = rules + self._max_endpoints = max_endpoints or DEFAULT_MAX_ENDPOINTS async def resolve(self, destination: str) -> Optional[str]: """ @@ -163,6 +198,15 @@ async def resolve(self, destination: str) -> Optional[str]: destination (str): The destination to resolve. Returns: - Optional[str]: The resolved endpoint or None. + Tuple[str, List[str]]: The address and resolved endpoints. """ - return self._rules.get(destination) + endpoints = self._rules.get(destination) + if isinstance(endpoints, str): + endpoints = [endpoints] + elif endpoints is None: + endpoints = [] + if len(endpoints) > self._max_endpoints: + endpoints = random.choices( + endpoints, k=min(self._max_endpoints, len(endpoints)) + ) + return destination, endpoints diff --git a/python/tests/test_context.py b/python/tests/test_context.py new file mode 100644 index 00000000..c57895a6 --- /dev/null +++ b/python/tests/test_context.py @@ -0,0 +1,263 @@ +# pylint: disable=protected-access +import asyncio +import unittest + +from aioresponses import aioresponses +from uagents import Agent +from uagents.context import ( + DeliveryStatus, + MsgDigest, + MsgStatus, + Identity, + Model, +) +from uagents.dispatch import dispatcher +from uagents.resolver import RulesBasedResolver + + +class Incoming(Model): + text: str + + +class Message(Model): + message: str + + +endpoints = ["http://localhost:8000"] +clyde = Agent(name="clyde", seed="clyde recovery phrase", endpoint=endpoints) +dispatcher.unregister(clyde.address, clyde) +resolver = RulesBasedResolver( + rules={ + clyde.address: endpoints, + } +) +alice = Agent(name="alice", seed="alice recovery phrase", resolve=resolver) +bob = Agent(name="bob", seed="bob recovery phrase") +incoming = Incoming(text="hello") +incoming_digest = Model.build_schema_digest(incoming) +msg = Message(message="hey") +msg_digest = Model.build_schema_digest(msg) + + +class TestContextSendMethods(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.agent = alice + self.context = self.agent._ctx + + async def test_send_local_dispatch(self): + result = await self.context.send(bob.address, msg) + exp_msg_status = MsgStatus( + status=DeliveryStatus.DELIVERED, + detail="Message dispatched locally", + destination=bob.address, + endpoint="", + ) + + self.assertEqual(result, exp_msg_status) + + async def test_send_local_dispatch_valid_reply(self): + self.context._message_received = MsgDigest( + message=incoming, schema_digest=incoming_digest + ) + self.context._replies[incoming_digest] = {msg_digest: Message} + result = await self.context.send(bob.address, msg) + exp_msg_status = MsgStatus( + status=DeliveryStatus.DELIVERED, + detail="Message dispatched locally", + destination=bob.address, + endpoint="", + ) + + self.assertEqual(result, exp_msg_status) + self.context._message_received = None + self.context._replies = {} + + async def test_send_local_dispatch_invalid_reply(self): + self.context._message_received = MsgDigest( + message=incoming, schema_digest=incoming_digest + ) + self.context._replies[incoming_digest] = {msg_digest: Message} + result = await self.context.send(bob.address, incoming) + exp_msg_status = MsgStatus( + status=DeliveryStatus.FAILED, + detail="Invalid reply", + destination=bob.address, + endpoint="", + ) + + self.assertEqual(result, exp_msg_status) + self.context._message_received = None + self.context._replies = {} + + async def test_send_local_dispatch_valid_interval_msg(self): + self.context._interval_messages = {msg_digest} + result = await self.context.send(bob.address, msg) + exp_msg_status = MsgStatus( + status=DeliveryStatus.DELIVERED, + detail="Message dispatched locally", + destination=bob.address, + endpoint="", + ) + + self.assertEqual(result, exp_msg_status) + self.context._interval_messages = set() + + async def test_send_local_dispatch_invalid_interval_msg(self): + self.context._interval_messages = {msg_digest} + result = await self.context.send(bob.address, incoming) + exp_msg_status = MsgStatus( + status=DeliveryStatus.FAILED, + detail="Invalid interval message", + destination=bob.address, + endpoint="", + ) + + self.assertEqual(result, exp_msg_status) + + async def test_send_resolve_sync_query(self): + future = asyncio.Future() + self.context._queries[clyde.address] = future + result = await self.context.send(clyde.address, msg) + exp_msg_status = MsgStatus( + status=DeliveryStatus.DELIVERED, + detail="Sync message resolved", + destination=clyde.address, + endpoint="", + ) + + self.assertEqual(future.result(), (msg.json(), msg_digest)) + self.assertEqual(result, exp_msg_status) + self.assertEqual( + len(self.context._queries), 0, "Query not removed from context" + ) + + async def test_send_external_dispatch_resolve_failure(self): + destination = Identity.generate().address + result = await self.context.send(destination, msg) + exp_msg_status = MsgStatus( + status=DeliveryStatus.FAILED, + detail="Unable to resolve destination endpoint", + destination=destination, + endpoint="", + ) + + self.assertEqual(result, exp_msg_status) + + @aioresponses() + async def test_send_external_dispatch_success(self, mocked_responses): + # Mock the HTTP POST request with a status code and response content + mocked_responses.post(endpoints[0], status=200) + + # Perform the actual operation + result = await self.context.send(clyde.address, msg) + + # Define the expected message status + exp_msg_status = MsgStatus( + status=DeliveryStatus.DELIVERED, + detail="Message successfully delivered via HTTP", + destination=clyde.address, + endpoint=endpoints[0], + ) + + # Assertions + self.assertEqual(result, exp_msg_status) + + @aioresponses() + async def test_send_external_dispatch_failure(self, mocked_responses): + # Mock the HTTP POST request with a status code and response content + mocked_responses.post(endpoints[0], status=404) + + # Perform the actual operation + result = await self.context.send(clyde.address, msg) + + # Define the expected message status + exp_msg_status = MsgStatus( + status=DeliveryStatus.FAILED, + detail="Message delivery failed", + destination=clyde.address, + endpoint="", + ) + + # Assertions + self.assertEqual(result, exp_msg_status) + + @aioresponses() + async def test_send_external_dispatch_multiple_endpoints_first_success( + self, mocked_responses + ): + endpoints.append("http://localhost:8001") + + # Mock the HTTP POST request with a status code and response content + mocked_responses.post(endpoints[0], status=200) + mocked_responses.post(endpoints[1], status=404) + + # Perform the actual operation + result = await self.context.send(clyde.address, msg) + + # Define the expected message status + exp_msg_status = MsgStatus( + status=DeliveryStatus.DELIVERED, + detail="Message successfully delivered via HTTP", + destination=clyde.address, + endpoint=endpoints[0], + ) + + # Assertions + self.assertEqual(result, exp_msg_status) + + # Ensure that only one request was sent + mocked_responses.assert_called_once() + + endpoints.pop() + + @aioresponses() + async def test_send_external_dispatch_multiple_endpoints_second_success( + self, mocked_responses + ): + endpoints.append("http://localhost:8001") + + # Mock the HTTP POST request with a status code and response content + mocked_responses.post(endpoints[0], status=404) + mocked_responses.post(endpoints[1], status=200) + + # Perform the actual operation + result = await self.context.send(clyde.address, msg) + + # Define the expected message status + exp_msg_status = MsgStatus( + status=DeliveryStatus.DELIVERED, + detail="Message successfully delivered via HTTP", + destination=clyde.address, + endpoint=endpoints[1], + ) + + # Assertions + self.assertEqual(result, exp_msg_status) + + endpoints.pop() + + @aioresponses() + async def test_send_external_dispatch_multiple_endpoints_failure( + self, mocked_responses + ): + endpoints.append("http://localhost:8001") + + # Mock the HTTP POST request with a status code and response content + mocked_responses.post(endpoints[0], status=404) + mocked_responses.post(endpoints[1], status=404) + + # Perform the actual operation + result = await self.context.send(clyde.address, msg) + + # Define the expected message status + exp_msg_status = MsgStatus( + status=DeliveryStatus.FAILED, + detail="Message delivery failed", + destination=clyde.address, + endpoint="", + ) + + # Assertions + self.assertEqual(result, exp_msg_status) + + endpoints.pop()