diff --git a/examples/memory/advanced_sqlite_session_example.py b/examples/memory/advanced_sqlite_session_example.py index fe9d3aab4..7c2ce4793 100644 --- a/examples/memory/advanced_sqlite_session_example.py +++ b/examples/memory/advanced_sqlite_session_example.py @@ -132,7 +132,7 @@ async def main(): # Show current conversation print("Current conversation:") current_items = await session.get_items() - for i, item in enumerate(current_items, 1): + for i, item in enumerate(current_items, 1): # type: ignore[assignment] role = str(item.get("role", item.get("type", "unknown"))) if item.get("type") == "function_call": content = f"{item.get('name', 'unknown')}({item.get('arguments', '{}')})" @@ -151,8 +151,8 @@ async def main(): # Show available turns for branching print("\nAvailable turns for branching:") turns = await session.get_conversation_turns() - for turn in turns: - print(f" Turn {turn['turn']}: {turn['content']}") + for turn in turns: # type: ignore[assignment] + print(f" Turn {turn['turn']}: {turn['content']}") # type: ignore[index] # Create a branch from turn 2 print("\nCreating new branch from turn 2...") @@ -163,7 +163,7 @@ async def main(): branch_items = await session.get_items() print(f"Items copied to new branch: {len(branch_items)}") print("New branch contains:") - for i, item in enumerate(branch_items, 1): + for i, item in enumerate(branch_items, 1): # type: ignore[assignment] role = str(item.get("role", item.get("type", "unknown"))) if item.get("type") == "function_call": content = f"{item.get('name', 'unknown')}({item.get('arguments', '{}')})" @@ -198,7 +198,7 @@ async def main(): print("\n=== New Conversation Branch ===") new_conversation = await session.get_items() print("New conversation with branch:") - for i, item in enumerate(new_conversation, 1): + for i, item in enumerate(new_conversation, 1): # type: ignore[assignment] role = str(item.get("role", item.get("type", "unknown"))) if item.get("type") == "function_call": content = f"{item.get('name', 'unknown')}({item.get('arguments', '{}')})" @@ -224,8 +224,8 @@ async def main(): # Show conversation turns in current branch print("\nConversation turns in current branch:") current_turns = await session.get_conversation_turns() - for turn in current_turns: - print(f" Turn {turn['turn']}: {turn['content']}") + for turn in current_turns: # type: ignore[assignment] + print(f" Turn {turn['turn']}: {turn['content']}") # type: ignore[index] print("\n=== Branch Switching Demo ===") print("We can switch back to the main branch...") diff --git a/examples/memory/dapr_session_example.py b/examples/memory/dapr_session_example.py index e0cc34d63..3a5a777a4 100644 --- a/examples/memory/dapr_session_example.py +++ b/examples/memory/dapr_session_example.py @@ -417,8 +417,8 @@ async def demonstrate_multi_store(): r_items = await redis_session.get_items() p_items = await pg_session.get_items() - r_example = r_items[-1]["content"] if r_items else "empty" - p_example = p_items[-1]["content"] if p_items else "empty" + r_example = r_items[-1]["content"] if r_items else "empty" # type: ignore[typeddict-item] + p_example = p_items[-1]["content"] if p_items else "empty" # type: ignore[typeddict-item] print(f"{redis_store}: {len(r_items)} items; example: {r_example}") print(f"{pg_store}: {len(p_items)} items; example: {p_example}") diff --git a/src/agents/extensions/memory/__init__.py b/src/agents/extensions/memory/__init__.py index 1f3f1e379..68e21a05f 100644 --- a/src/agents/extensions/memory/__init__.py +++ b/src/agents/extensions/memory/__init__.py @@ -8,7 +8,18 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .advanced_sqlite_session import AdvancedSQLiteSession + from .dapr_session import ( + DAPR_CONSISTENCY_EVENTUAL, + DAPR_CONSISTENCY_STRONG, + DaprSession, + ) + from .encrypt_session import EncryptedSession + from .redis_session import RedisSession + from .sqlalchemy_session import SQLAlchemySession __all__: list[str] = [ "AdvancedSQLiteSession", diff --git a/tests/extensions/memory/test_dapr_session.py b/tests/extensions/memory/test_dapr_session.py index fef3df782..26e8743b2 100644 --- a/tests/extensions/memory/test_dapr_session.py +++ b/tests/extensions/memory/test_dapr_session.py @@ -182,7 +182,7 @@ async def _create_test_session( session = DaprSession( session_id=session_id, state_store_name="statestore", - dapr_client=fake_dapr_client, + dapr_client=fake_dapr_client, # type: ignore[arg-type] ) # Clean up any existing data @@ -260,12 +260,12 @@ async def test_session_isolation(fake_dapr_client: FakeDaprClient): session1 = DaprSession( session_id="session_1", state_store_name="statestore", - dapr_client=fake_dapr_client, + dapr_client=fake_dapr_client, # type: ignore[arg-type] ) session2 = DaprSession( session_id="session_2", state_store_name="statestore", - dapr_client=fake_dapr_client, + dapr_client=fake_dapr_client, # type: ignore[arg-type] ) try: @@ -386,7 +386,7 @@ async def test_pop_from_empty_session(fake_dapr_client: FakeDaprClient): session = DaprSession( session_id="empty_session", state_store_name="statestore", - dapr_client=fake_dapr_client, + dapr_client=fake_dapr_client, # type: ignore[arg-type] ) try: await session.clear_session() @@ -540,7 +540,7 @@ async def test_dapr_connectivity(fake_dapr_client: FakeDaprClient): session = DaprSession( session_id="connectivity_test", state_store_name="statestore", - dapr_client=fake_dapr_client, + dapr_client=fake_dapr_client, # type: ignore[arg-type] ) try: # Test ping @@ -555,7 +555,7 @@ async def test_ttl_functionality(fake_dapr_client: FakeDaprClient): session = DaprSession( session_id="ttl_test", state_store_name="statestore", - dapr_client=fake_dapr_client, + dapr_client=fake_dapr_client, # type: ignore[arg-type] ttl=3600, # 1 hour TTL ) @@ -586,7 +586,7 @@ async def test_consistency_levels(fake_dapr_client: FakeDaprClient): session_eventual = DaprSession( session_id="eventual_test", state_store_name="statestore", - dapr_client=fake_dapr_client, + dapr_client=fake_dapr_client, # type: ignore[arg-type] consistency=DAPR_CONSISTENCY_EVENTUAL, ) @@ -594,7 +594,7 @@ async def test_consistency_levels(fake_dapr_client: FakeDaprClient): session_strong = DaprSession( session_id="strong_test", state_store_name="statestore", - dapr_client=fake_dapr_client, + dapr_client=fake_dapr_client, # type: ignore[arg-type] consistency=DAPR_CONSISTENCY_STRONG, ) @@ -621,7 +621,7 @@ async def test_external_client_not_closed(fake_dapr_client: FakeDaprClient): session = DaprSession( session_id="external_client_test", state_store_name="statestore", - dapr_client=fake_dapr_client, + dapr_client=fake_dapr_client, # type: ignore[arg-type] ) try: @@ -650,7 +650,7 @@ async def test_internal_client_ownership(fake_dapr_client: FakeDaprClient): session = DaprSession( session_id="internal_client_test", state_store_name="statestore", - dapr_client=fake_dapr_client, + dapr_client=fake_dapr_client, # type: ignore[arg-type] ) session._owns_client = True # Simulate ownership @@ -732,7 +732,7 @@ async def test_close_method_coverage(fake_dapr_client: FakeDaprClient): session1 = DaprSession( session_id="close_test_1", state_store_name="statestore", - dapr_client=fake_dapr_client, + dapr_client=fake_dapr_client, # type: ignore[arg-type] ) # Verify _owns_client is False for external client @@ -749,7 +749,7 @@ async def test_close_method_coverage(fake_dapr_client: FakeDaprClient): session2 = DaprSession( session_id="close_test_2", state_store_name="statestore", - dapr_client=fake_dapr_client2, + dapr_client=fake_dapr_client2, # type: ignore[arg-type] ) session2._owns_client = True # Simulate ownership @@ -788,8 +788,8 @@ async def test_already_deserialized_messages(fake_dapr_client: FakeDaprClient): # Should handle both string and dict messages items = await session.get_items() assert len(items) == 2 - assert items[0]["content"] == "First message" - assert items[1]["content"] == "Second message" + assert items[0]["content"] == "First message" # type: ignore[typeddict-item] + assert items[1]["content"] == "Second message" # type: ignore[typeddict-item] await session.close() @@ -800,7 +800,7 @@ async def test_context_manager(fake_dapr_client: FakeDaprClient): async with DaprSession( "test_cm_session", state_store_name="statestore", - dapr_client=fake_dapr_client, + dapr_client=fake_dapr_client, # type: ignore[arg-type] ) as session: # Verify we got the session object back assert session.session_id == "test_cm_session" @@ -809,7 +809,7 @@ async def test_context_manager(fake_dapr_client: FakeDaprClient): await session.add_items([{"role": "user", "content": "Test message"}]) items = await session.get_items() assert len(items) == 1 - assert items[0]["content"] == "Test message" + assert items[0]["content"] == "Test message" # type: ignore[typeddict-item] # After exiting context manager, close should have been called # Verify we can still check the state (fake client doesn't truly disconnect) @@ -819,7 +819,7 @@ async def test_context_manager(fake_dapr_client: FakeDaprClient): owned_session = DaprSession( "test_cm_owned", state_store_name="statestore", - dapr_client=fake_dapr_client, + dapr_client=fake_dapr_client, # type: ignore[arg-type] ) # Manually set ownership to simulate from_address behavior owned_session._owns_client = True