Skip to content

Commit

Permalink
Fix pyright
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Jan 6, 2025
1 parent a418be7 commit 8447511
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions python/packages/autogen-core/tests/test_intervention.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import Any

import pytest
from autogen_core import AgentId, DefaultInterventionHandler, DropMessage, MessageContext, SingleThreadedAgentRuntime
from autogen_core._default_subscription import DefaultSubscription
from autogen_core._default_topic import DefaultTopicId
from autogen_core._type_subscription import TypeSubscription
from autogen_core import (
AgentId,
DefaultInterventionHandler,
DefaultSubscription,
DefaultTopicId,
DropMessage,
MessageContext,
SingleThreadedAgentRuntime,
)
from autogen_core.exceptions import MessageDroppedException
from autogen_test_utils import LoopbackAgent, MessageType

Expand All @@ -13,20 +18,20 @@
async def test_intervention_count_messages() -> None:
class DebugInterventionHandler(DefaultInterventionHandler):
def __init__(self) -> None:
self._num_send_messages = 0
self._num_publish_messages = 0
self._num_response_messages = 0
self.num_send_messages = 0
self.num_publish_messages = 0
self.num_response_messages = 0

async def on_send(self, message: Any, *, message_context: MessageContext, recipient: AgentId) -> Any:
self._num_send_messages += 1
self.num_send_messages += 1
return message

async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:
self._num_publish_messages += 1
self.num_publish_messages += 1
return message

async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any:
self._num_response_messages += 1
self.num_response_messages += 1
return message

handler = DebugInterventionHandler()
Expand All @@ -39,8 +44,8 @@ async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId

await runtime.stop_when_idle()

assert handler._num_send_messages == 1
assert handler._num_response_messages == 1
assert handler.num_send_messages == 1
assert handler.num_response_messages == 1
loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
assert loopback_agent.num_calls == 1

Expand All @@ -51,7 +56,7 @@ async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId

await runtime.stop_when_idle()
assert loopback_agent.num_calls == 2
assert handler._num_publish_messages == 1
assert handler.num_publish_messages == 1


@pytest.mark.asyncio
Expand Down

0 comments on commit 8447511

Please sign in to comment.