diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index d981080696..7f0eed7203 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -87,6 +87,15 @@ MemoryStore, MemoryTopicRecord, ) +from ._harness._todo import ( + DEFAULT_TODO_SOURCE_ID, + TodoFileStore, + TodoInput, + TodoItem, + TodoProvider, + TodoSessionStore, + TodoStore, +) from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPWebsocketTool from ._middleware import ( AgentContext, @@ -270,6 +279,7 @@ "COMPACTION_STATE_KEY", "DEFAULT_MAX_ITERATIONS", "DEFAULT_MEMORY_SOURCE_ID", + "DEFAULT_TODO_SOURCE_ID", "EXCLUDED_KEY", "EXCLUDE_REASON_KEY", "GROUP_ANNOTATION_KEY", @@ -410,6 +420,12 @@ "SwitchCaseEdgeGroupCase", "SwitchCaseEdgeGroupDefault", "TextSpanRegion", + "TodoFileStore", + "TodoInput", + "TodoItem", + "TodoProvider", + "TodoSessionStore", + "TodoStore", "TokenBudgetComposedStrategy", "TokenizerProtocol", "ToolMode", diff --git a/python/packages/core/agent_framework/_harness/_todo.py b/python/packages/core/agent_framework/_harness/_todo.py new file mode 100644 index 0000000000..c0c9c89791 --- /dev/null +++ b/python/packages/core/agent_framework/_harness/_todo.py @@ -0,0 +1,549 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import asyncio +import json +import os +import weakref +from abc import ABC, abstractmethod +from base64 import urlsafe_b64encode +from collections.abc import Mapping, MutableMapping +from pathlib import Path +from typing import Any, ClassVar, cast + +from .._feature_stage import ExperimentalFeature, experimental +from .._serialization import SerializationMixin +from .._sessions import AgentSession, ContextProvider, SessionContext +from .._tools import tool +from .._types import Message + +DEFAULT_TODO_SOURCE_ID = "todo" +DEFAULT_TODO_INSTRUCTIONS = ( + "## Todo Items\n\n" + "You have access to a todo list for tracking work items.\n" + "While planning, make sure that you break down complex tasks into manageable todo items " + "and add them to the list.\n" + "Ask questions from the user where clarification is needed to create effective todos.\n" + "If the user provides feedback on your plan, adjust your todos accordingly by adding new items " + "or removing irrelevant ones.\n" + "During execution, use the todo list to keep track of what needs to be done, " + "mark items as complete when finished, and remove any items that are no longer needed.\n" + "When a user changes the topic or changes their mind, ensure that you update the todo list accordingly " + "by removing irrelevant items or adding new ones as needed.\n\n" + "Use these tools to manage your tasks:\n" + "- Use add_todos to break down complex work into trackable items (supports adding one or many at once).\n" + "- Use complete_todos to mark items as done when finished (supports one or many at once).\n" + "- Use get_remaining_todos to check what work is still pending.\n" + "- Use get_all_todos to review the full list including completed items.\n" + "- Use remove_todos to remove items that are no longer needed (supports one or many at once)." +) + + +@experimental(feature_id=ExperimentalFeature.HARNESS) +class TodoItem(SerializationMixin): + """Represent one todo item tracked for the current session.""" + + id: int + title: str + description: str | None + is_complete: bool + __slots__ = ("description", "id", "is_complete", "title") + + def __init__(self, id: int, title: str, description: str | None = None, is_complete: bool = False) -> None: + """Initialize one todo item.""" + self.id = id + self.title = title + self.description = description + self.is_complete = is_complete + + def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: + """Serialize the todo item for persistence.""" + del exclude + payload = { + "id": self.id, + "title": self.title, + "description": self.description, + "is_complete": self.is_complete, + } + return {key: value for key, value in payload.items() if value is not None or not exclude_none} + + @classmethod + def from_dict( + cls, raw_item: MutableMapping[str, Any], /, *, dependencies: MutableMapping[str, Any] | None = None + ) -> TodoItem: + """Parse one todo item loaded from storage.""" + del dependencies + item_id = raw_item.get("id") + title = raw_item.get("title") + description = raw_item.get("description") + is_complete = raw_item.get("is_complete", False) + if not isinstance(item_id, int): + raise ValueError("Todo item id must be an integer.") + if not isinstance(title, str) or not title.strip(): + raise ValueError("Todo item title must be a non-empty string.") + if description is not None and not isinstance(description, str): + raise ValueError("Todo item description must be a string or null.") + if not isinstance(is_complete, bool): + raise ValueError("Todo item is_complete must be a boolean.") + return cls(id=item_id, title=title, description=description, is_complete=is_complete) + + def __eq__(self, other: object) -> bool: + """Return whether two todo items have the same values.""" + return isinstance(other, TodoItem) and self.to_dict() == other.to_dict() + + def __repr__(self) -> str: + """Return a helpful debug representation.""" + return ( + "TodoItem(" + f"id={self.id!r}, title={self.title!r}, description={self.description!r}, is_complete={self.is_complete!r})" + ) + + +@experimental(feature_id=ExperimentalFeature.HARNESS) +class TodoInput(SerializationMixin): + """Describe one todo item to create.""" + + title: str + description: str | None + __slots__ = ("description", "title") + + def __init__(self, title: str, description: str | None = None) -> None: + """Initialize one todo input.""" + normalized_title = title.strip() + if not normalized_title: + raise ValueError("Todo input title must be a non-empty string.") + if description is not None and not isinstance(description, str): + raise ValueError("Todo input description must be a string or null.") + self.title = normalized_title + self.description = description + + def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: + """Serialize the todo input.""" + del exclude + payload = {"title": self.title, "description": self.description} + return {key: value for key, value in payload.items() if value is not None or not exclude_none} + + @classmethod + def from_dict( + cls, raw_todo: MutableMapping[str, Any], /, *, dependencies: MutableMapping[str, Any] | None = None + ) -> TodoInput: + """Parse one todo input loaded from tool arguments.""" + del dependencies + title = raw_todo.get("title") + description = raw_todo.get("description") + if not isinstance(title, str): + raise ValueError("Todo input title must be a string.") + return cls(title=title, description=description) + + +def _parse_todo_items(items_payload: list[Any], *, source_description: str) -> list[TodoItem]: + """Parse persisted todo item payloads with clear corruption errors.""" + items: list[TodoItem] = [] + for index, item in enumerate(items_payload): + if not isinstance(item, Mapping): + raise ValueError( + f"Todo item at index {index} in {source_description} must be a mapping; got {type(item).__name__}." + ) + items.append(TodoItem.from_dict(dict(cast(Mapping[str, Any], item)))) + return items + + +def _coerce_todo_input(todo: TodoInput | dict[str, Any] | Any) -> TodoInput: + """Normalize tool-provided todo input into a TodoInput model.""" + if isinstance(todo, TodoInput): + return todo + if isinstance(todo, MutableMapping): + return TodoInput.from_dict(cast(MutableMapping[str, Any], todo)) + raise ValueError("Todo input must be a TodoInput instance or JSON object.") + + +def _safe_next_id(items: list[TodoItem], next_id: int) -> int: + """Clamp ``next_id`` so it cannot collide with any persisted item id.""" + return max(next_id, max((item.id for item in items), default=0) + 1) + + +@experimental(feature_id=ExperimentalFeature.HARNESS) +class TodoStore(ABC): + """Abstract backing store for session todo items.""" + + @abstractmethod + async def load_state(self, session: AgentSession, *, source_id: str) -> tuple[list[TodoItem], int]: + """Load persisted todo items and the next available ID.""" + + @abstractmethod + async def save_state(self, session: AgentSession, items: list[TodoItem], *, next_id: int, source_id: str) -> None: + """Persist todo items and the next available ID.""" + + async def load_items(self, session: AgentSession, *, source_id: str) -> list[TodoItem]: + """Load todo items for one session.""" + items, _ = await self.load_state(session, source_id=source_id) + return items + + +@experimental(feature_id=ExperimentalFeature.HARNESS) +class TodoSessionStore(TodoStore): + """Store todo state inside ``AgentSession.state``.""" + + async def load_state(self, session: AgentSession, *, source_id: str) -> tuple[list[TodoItem], int]: + """Load todo state from session state.""" + provider_state_value = session.state.get(source_id) + if provider_state_value is None: + provider_state: dict[str, Any] = {} + session.state[source_id] = provider_state + elif isinstance(provider_state_value, dict): + provider_state = cast(dict[str, Any], provider_state_value) + else: + raise ValueError( + f"Session state for source_id {source_id!r} must be a dict; got {type(provider_state_value).__name__}." + ) + + raw_items = provider_state.get("items", []) + if not isinstance(raw_items, list): + raise ValueError( + f"Session state for source_id {source_id!r} has a non-list 'items' field; " + f"got {type(raw_items).__name__}." + ) + raw_next_id = provider_state.get("next_id", 1) + if not isinstance(raw_next_id, int): + raise ValueError( + f"Session state for source_id {source_id!r} has a non-integer 'next_id' field; " + f"got {type(raw_next_id).__name__}." + ) + items_payload: list[Any] = cast(Any, raw_items) + items = _parse_todo_items(items_payload, source_description="session todo state") + return items, _safe_next_id(items, raw_next_id) + + async def save_state(self, session: AgentSession, items: list[TodoItem], *, next_id: int, source_id: str) -> None: + """Persist todo state back into session state.""" + provider_state_value = session.state.get(source_id) + provider_state = cast(dict[str, Any], provider_state_value) if isinstance(provider_state_value, dict) else {} + if not isinstance(provider_state_value, dict): + session.state[source_id] = provider_state + provider_state["items"] = [item.to_dict(exclude_none=False) for item in items] + provider_state["next_id"] = _safe_next_id(items, next_id) + + +@experimental(feature_id=ExperimentalFeature.HARNESS) +class TodoFileStore(TodoStore): + """Store todo state in one JSON file per session and source ID.""" + + def __init__( + self, + base_path: str | Path, + *, + kind: str = "todos", + owner_prefix: str = "", + owner_state_key: str | None = None, + state_filename: str = "todos.json", + ) -> None: + """Initialize the file-backed todo store. + + Args: + base_path: Root storage directory. + + Keyword Args: + kind: Storage bucket name under each owner directory. + owner_prefix: Optional prefix applied to the resolved owner ID. + owner_state_key: Session-state key holding the logical owner ID. + state_filename: File name used for the persisted todo state. + """ + self.base_path = Path(base_path) + self.kind = kind + self.owner_prefix = owner_prefix + self.owner_state_key = owner_state_key + self.state_filename = state_filename + self._base_root = self.base_path.resolve() + + _ENCODED_SEGMENT_PREFIX: ClassVar[str] = "~todo-" + _WINDOWS_RESERVED_FILE_STEMS: ClassVar[frozenset[str]] = frozenset({ + "CON", + "PRN", + "AUX", + "NUL", + "COM1", + "COM2", + "COM3", + "COM4", + "COM5", + "COM6", + "COM7", + "COM8", + "COM9", + "LPT1", + "LPT2", + "LPT3", + "LPT4", + "LPT5", + "LPT6", + "LPT7", + "LPT8", + "LPT9", + }) + + def _get_state_path(self, session: AgentSession, *, source_id: str) -> Path: + """Return the JSON file path for one session and source ID.""" + session_directory = self.base_path + if self.owner_state_key is not None: + owner_value = session.state.get(self.owner_state_key) + if owner_value is None: + raise RuntimeError( + f"TodoFileStore requires session.state[{self.owner_state_key!r}] to be set for file-backed storage." + ) + owner_segment = self._path_segment(owner_value, label="owner") + session_directory = session_directory / f"{self.owner_prefix}{owner_segment}" / self.kind + session_directory = session_directory / self._path_segment( + session.session_id, label="session_id", reject_path_separators=True + ) + state_path = (session_directory / self._state_filename(source_id)).resolve() + if not state_path.is_relative_to(self._base_root): + raise ValueError(f"Todo file path escaped base directory for session_id {session.session_id!r}.") + return state_path + + @classmethod + def _path_segment(cls, value: object, *, label: str, reject_path_separators: bool = False) -> str: + """Return a filesystem-safe path segment for user-controlled state values.""" + raw_value = str(value) + if reject_path_separators and ("/" in raw_value or "\\" in raw_value): + raise ValueError(f"TodoFileStore {label} must not contain path separators: {raw_value!r}") + if cls._is_literal_path_segment_safe(raw_value): + return raw_value + encoded_value = urlsafe_b64encode(raw_value.encode("utf-8")).decode("ascii").rstrip("=") + return f"{cls._ENCODED_SEGMENT_PREFIX}{encoded_value or label}" + + @classmethod + def _is_literal_path_segment_safe(cls, value: str) -> bool: + """Return whether a value can be used directly as one path segment.""" + if ( + not value + or value.startswith(".") + or value.endswith((" ", ".")) + or value.upper() in cls._WINDOWS_RESERVED_FILE_STEMS + ): + return False + if any(ord(character) < 32 for character in value): + return False + return all(character.isalnum() or character in "._-" for character in value) + + def _state_filename(self, source_id: str) -> str: + """Return a source-specific JSON state filename.""" + state_path = Path(self.state_filename) + source_segment = self._path_segment(source_id, label="source_id") + if state_path.suffix: + return f"{state_path.stem}.{source_segment}{state_path.suffix}" + return f"{state_path.name}.{source_segment}.json" + + async def load_state(self, session: AgentSession, *, source_id: str) -> tuple[list[TodoItem], int]: + """Load todo state from disk.""" + state_path = self._get_state_path(session, source_id=source_id) + return await asyncio.to_thread(self._load_state_sync, state_path) + + @staticmethod + def _load_state_sync(state_path: Path) -> tuple[list[TodoItem], int]: + """Synchronous helper that performs the disk I/O for ``load_state``.""" + if not state_path.exists(): + return [], 1 + payload = cast(dict[str, Any], json.loads(state_path.read_text(encoding="utf-8"))) + if not isinstance(payload, dict): + raise ValueError(f"Todo file {state_path} must contain a JSON object.") + raw_items = payload.get("items", []) + raw_next_id = payload.get("next_id", 1) + if not isinstance(raw_items, list): + raise ValueError(f"Todo file {state_path} has a non-list 'items' field.") + if not isinstance(raw_next_id, int): + raise ValueError(f"Todo file {state_path} has a non-integer 'next_id' field.") + items_payload: list[Any] = cast(Any, raw_items) + items = _parse_todo_items(items_payload, source_description=f"todo file {state_path}") + return items, _safe_next_id(items, raw_next_id) + + async def save_state(self, session: AgentSession, items: list[TodoItem], *, next_id: int, source_id: str) -> None: + """Persist todo state to disk.""" + state_path = self._get_state_path(session, source_id=source_id) + payload = ( + json.dumps({ + "items": [item.to_dict(exclude_none=False) for item in items], + "next_id": _safe_next_id(items, next_id), + }) + + "\n" + ) + await asyncio.to_thread(self._save_state_sync, state_path, payload) + + @staticmethod + def _save_state_sync(state_path: Path, payload: str) -> None: + """Synchronous helper that atomically writes the JSON state file.""" + state_path.parent.mkdir(parents=True, exist_ok=True) + # Write to a sibling temp file then atomically replace, so a crash mid-write cannot leave + # a truncated state file that breaks every subsequent tool call. + temp_path = state_path.with_name(f"{state_path.name}.tmp.{os.getpid()}") + try: + temp_path.write_text(payload, encoding="utf-8") + os.replace(temp_path, state_path) + finally: + if temp_path.exists(): + temp_path.unlink(missing_ok=True) + + +@experimental(feature_id=ExperimentalFeature.HARNESS) +class TodoProvider(ContextProvider): + """Provide todo management tools and instructions to an agent. + + The ``TodoProvider`` enables agents to create, complete, remove, and query todo items as part of their planning + and execution workflow. Todo state is stored in the configured ``TodoStore`` and persists across agent invocations + within the same session. By default, state is stored in ``AgentSession.state`` with ``TodoSessionStore``; callers + can provide ``TodoFileStore`` or another store implementation for file-backed or custom persistence. + + This provider exposes the following tools to the agent: + - ``add_todos``: Add one or more todo items, each with a title and optional description. + - ``complete_todos``: Mark one or more todo items as complete by their IDs. + - ``remove_todos``: Remove one or more todo items by their IDs. + - ``get_remaining_todos``: Retrieve only incomplete todo items. + - ``get_all_todos``: Retrieve all todo items, complete and incomplete. + """ + + def __init__( + self, + source_id: str = DEFAULT_TODO_SOURCE_ID, + *, + instructions: str | None = None, + store: TodoStore | None = None, + ) -> None: + """Initialize the todo provider. + + Args: + source_id: Unique source ID for the provider. + + Keyword Args: + instructions: Optional instruction override. + store: Optional todo store override. + """ + super().__init__(source_id) + self.instructions = instructions or DEFAULT_TODO_INSTRUCTIONS + self.store = store or TodoSessionStore() + # WeakKeyDictionary so per-session locks are evicted automatically when the session is GC'd + # rather than accumulating in long-running services that create many sessions. + self._mutation_locks: weakref.WeakKeyDictionary[AgentSession, asyncio.Lock] = weakref.WeakKeyDictionary() + + def _mutation_lock(self, session: AgentSession) -> asyncio.Lock: + """Return the per-session lock for read-modify-write todo operations.""" + lock = self._mutation_locks.get(session) + if lock is None: + lock = asyncio.Lock() + self._mutation_locks[session] = lock + return lock + + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Inject todo tools and instructions before the model runs.""" + del agent, state + + @tool(name="add_todos", approval_mode="never_require") + async def add_todos(todos: list[dict[str, Any]]) -> str: + """Add one or more todo items for the current session.""" + if not todos: + raise ValueError("todos must contain at least one item.") + + async with self._mutation_lock(session): + existing_items, next_id = await self.store.load_state(session, source_id=self.source_id) + created_items: list[TodoItem] = [] + for raw_todo in todos: + todo = _coerce_todo_input(raw_todo) + created_item = TodoItem( + id=next_id, + title=todo.title, + description=todo.description.strip() if todo.description is not None else None, + ) + existing_items.append(created_item) + created_items.append(created_item) + next_id += 1 + + await self.store.save_state(session, existing_items, next_id=next_id, source_id=self.source_id) + return json.dumps([item.to_dict(exclude_none=False) for item in created_items]) + + @tool(name="complete_todos", approval_mode="never_require") + async def complete_todos(ids: list[int]) -> str: + """Mark one or more todo items as complete by ID.""" + if not ids: + raise ValueError("ids must contain at least one todo ID.") + + async with self._mutation_lock(session): + items, next_id = await self.store.load_state(session, source_id=self.source_id) + id_set = set(ids) + completed_count = 0 + updated_items: list[TodoItem] = [] + for item in items: + if not item.is_complete and item.id in id_set: + updated_items.append( + TodoItem( + id=item.id, + title=item.title, + description=item.description, + is_complete=True, + ) + ) + completed_count += 1 + else: + updated_items.append(item) + + if completed_count: + await self.store.save_state(session, updated_items, next_id=next_id, source_id=self.source_id) + return json.dumps({"completed": completed_count}) + + @tool(name="remove_todos", approval_mode="never_require") + async def remove_todos(ids: list[int]) -> str: + """Remove one or more todo items by ID.""" + if not ids: + raise ValueError("ids must contain at least one todo ID.") + + async with self._mutation_lock(session): + items, next_id = await self.store.load_state(session, source_id=self.source_id) + remaining_items = [item for item in items if item.id not in set(ids)] + removed_count = len(items) - len(remaining_items) + if removed_count: + await self.store.save_state(session, remaining_items, next_id=next_id, source_id=self.source_id) + return json.dumps({"removed": removed_count}) + + @tool(name="get_remaining_todos", approval_mode="never_require") + async def get_remaining_todos() -> str: + """Retrieve only incomplete todo items for the current session.""" + items = [ + item for item in await self.store.load_items(session, source_id=self.source_id) if not item.is_complete + ] + return json.dumps([item.to_dict(exclude_none=False) for item in items]) + + @tool(name="get_all_todos", approval_mode="never_require") + async def get_all_todos() -> str: + """Retrieve all todo items for the current session.""" + items = await self.store.load_items(session, source_id=self.source_id) + return json.dumps([item.to_dict(exclude_none=False) for item in items]) + + context.extend_instructions(self.source_id, [self.instructions]) + context.extend_tools( + self.source_id, + [add_todos, complete_todos, remove_todos, get_remaining_todos, get_all_todos], + ) + current_items = await self.store.load_items(session, source_id=self.source_id) + context.extend_messages( + self.source_id, + [ + Message( + role="user", + contents=[ + "### Current todo list\n" + + ( + "\n".join( + f"- {item.id} [{'done' if item.is_complete else 'open'}] {item.title}" + + (f": {item.description}" if item.description else "") + for item in current_items + ) + or "- none yet" + ) + ], + ) + ], + ) diff --git a/python/packages/core/tests/core/test_harness_todo.py b/python/packages/core/tests/core/test_harness_todo.py new file mode 100644 index 0000000000..068c83142d --- /dev/null +++ b/python/packages/core/tests/core/test_harness_todo.py @@ -0,0 +1,377 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import asyncio +import json +import os +from pathlib import Path + +import pytest + +from agent_framework import ( + Agent, + AgentSession, + ExperimentalFeature, + Message, + SupportsChatGetResponse, + TodoFileStore, + TodoInput, + TodoItem, + TodoProvider, + TodoSessionStore, + TodoStore, +) + + +def _tool_by_name(tools: list[object], name: str) -> object: + """Return the tool with the requested name from a prepared tool list.""" + for tool in tools: + if getattr(tool, "name", None) == name: + return tool + raise AssertionError(f"Tool {name!r} was not found.") + + +def test_todo_item_round_trips_with_value_equality() -> None: + """Todo items should support value equality and JSON serialization.""" + raw_item = { + "id": 1, + "title": "Write tests", + "description": "Cover the harness", + "is_complete": False, + } + + item = TodoItem.from_dict(raw_item) + + assert item == TodoItem(**raw_item) + assert item.to_dict() == raw_item + assert json.loads(item.to_json()) == raw_item + assert "TodoItem(" in repr(item) + + +def test_todo_input_round_trips_and_validates() -> None: + """Todo input should trim titles and reject invalid payloads.""" + todo_input = TodoInput.from_dict({"title": " Write tests ", "description": "Cover the harness"}) + + assert todo_input.title == "Write tests" + assert todo_input.to_dict() == {"title": "Write tests", "description": "Cover the harness"} + assert json.loads(todo_input.to_json()) == {"title": "Write tests", "description": "Cover the harness"} + + with pytest.raises(ValueError, match="non-empty string"): + TodoInput(title=" ") + + with pytest.raises(ValueError, match="description must be a string or null"): + TodoInput.from_dict({"title": "Write tests", "description": 123}) + + +async def test_todo_session_store_initializes_and_round_trips_state() -> None: + """Session-backed todo storage should initialize and persist todo state.""" + session = AgentSession(session_id="session-1") + store = TodoSessionStore() + + items, next_id = await store.load_state(session, source_id="todo") + assert items == [] + assert next_id == 1 + assert session.state["todo"] == {} + + todo_item = TodoItem(id=1, title="Ship feature", description="Use session storage") + await store.save_state(session, [todo_item], next_id=2, source_id="todo") + + loaded_items, loaded_next_id = await store.load_state(session, source_id="todo") + assert loaded_items == [todo_item] + assert loaded_next_id == 2 + assert await store.load_items(session, source_id="todo") == [todo_item] + + +async def test_todo_file_store_round_trips_state(tmp_path: Path) -> None: + """Todo file storage should persist one JSON state file per owner and session.""" + session = AgentSession(session_id="session-1") + session.state["owner_id"] = "alice" + store = TodoFileStore( + tmp_path, + kind="todos", + owner_prefix="user_", + owner_state_key="owner_id", + ) + + await store.save_state( + session, + [TodoItem(id=1, title="Ship feature", description="Use file storage")], + next_id=2, + source_id="todo", + ) + + items, next_id = await store.load_state(session, source_id="todo") + assert items == [TodoItem(id=1, title="Ship feature", description="Use file storage", is_complete=False)] + assert next_id == 2 + + state_path = tmp_path / "user_alice" / "todos" / "session-1" / "todos.todo.json" + assert state_path.exists() + assert json.loads(state_path.read_text(encoding="utf-8")) == { + "items": [{"id": 1, "title": "Ship feature", "description": "Use file storage", "is_complete": False}], + "next_id": 2, + } + + with pytest.raises(RuntimeError, match="owner_id"): + await store.load_state(AgentSession(session_id="missing-owner"), source_id="todo") + + +async def test_todo_file_store_load_does_not_create_directories(tmp_path: Path) -> None: + """Loading from a never-written session must not create empty directories on disk.""" + session = AgentSession(session_id="session-1") + store = TodoFileStore(tmp_path) + + items, next_id = await store.load_state(session, source_id="todo") + assert items == [] + assert next_id == 1 + assert list(tmp_path.iterdir()) == [] # noqa: ASYNC240 + + +async def test_todo_file_store_writes_state_atomically(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """A crash between writing the temp file and renaming must not corrupt existing state.""" + session = AgentSession(session_id="session-1") + store = TodoFileStore(tmp_path) + + await store.save_state(session, [TodoItem(id=1, title="Initial")], next_id=2, source_id="todo") + state_path = tmp_path / "session-1" / "todos.todo.json" + original_contents = state_path.read_text(encoding="utf-8") + + def _boom(*args: object, **kwargs: object) -> None: + raise OSError("disk full") + + monkeypatch.setattr(os, "replace", _boom) + + with pytest.raises(OSError, match="disk full"): + await store.save_state(session, [TodoItem(id=2, title="Replacement")], next_id=3, source_id="todo") + + # Original file is untouched, no temp leftovers. + assert state_path.read_text(encoding="utf-8") == original_contents + assert sorted(p.name for p in state_path.parent.iterdir()) == [state_path.name] + + +async def test_todo_session_store_rejects_non_mapping_items() -> None: + """Session-backed todo storage should report malformed item entries clearly.""" + session = AgentSession(session_id="session-1") + session.state["todo"] = {"items": [{"id": 1, "title": "Good"}, "bad"], "next_id": 2} + store = TodoSessionStore() + + with pytest.raises(ValueError, match="index 1.*str"): + await store.load_state(session, source_id="todo") + + +async def test_todo_session_store_rejects_malformed_state_types() -> None: + """Session-backed todo storage should raise for malformed top-level state, mirroring TodoFileStore.""" + session = AgentSession(session_id="session-1") + session.state["todo"] = "not a dict" + store = TodoSessionStore() + + with pytest.raises(ValueError, match="must be a dict"): + await store.load_state(session, source_id="todo") + + session.state["todo"] = {"items": "not a list", "next_id": 1} + with pytest.raises(ValueError, match="non-list 'items'"): + await store.load_state(session, source_id="todo") + + session.state["todo"] = {"items": [], "next_id": "1"} + with pytest.raises(ValueError, match="non-integer 'next_id'"): + await store.load_state(session, source_id="todo") + + +async def test_todo_stores_clamp_next_id_to_avoid_collisions(tmp_path: Path) -> None: + """Both stores should clamp ``next_id`` to ``max(item.id) + 1`` to prevent ID collisions.""" + session_a = AgentSession(session_id="session-a") + session_a.state["todo"] = {"items": [{"id": 5, "title": "Seeded"}], "next_id": 1} + + session_store = TodoSessionStore() + items, next_id = await session_store.load_state(session_a, source_id="todo") + assert next_id == 6 # clamped over the stored next_id of 1 + assert items == [TodoItem(id=5, title="Seeded")] + + session_b = AgentSession(session_id="session-b") + file_store = TodoFileStore(tmp_path) + state_path = tmp_path / "session-b" / "todos.todo.json" + state_path.parent.mkdir(parents=True) + state_path.write_text(json.dumps({"items": [{"id": 7, "title": "Seeded"}], "next_id": 1}) + "\n", encoding="utf-8") + items, next_id = await file_store.load_state(session_b, source_id="todo") + assert next_id == 8 + assert items == [TodoItem(id=7, title="Seeded")] + + +async def test_todo_provider_evicts_locks_when_session_is_garbage_collected() -> None: + """The provider should not retain mutation locks for sessions that have been GC'd.""" + import gc + + provider = TodoProvider() + session = AgentSession(session_id="session-1") + provider._mutation_lock(session) # type: ignore[reportPrivateUsage] + assert len(provider._mutation_locks) == 1 # type: ignore[reportPrivateUsage] + + del session + gc.collect() + assert len(provider._mutation_locks) == 0 # type: ignore[reportPrivateUsage] + + +async def test_todo_file_store_rejects_session_path_traversal(tmp_path: Path) -> None: + """File-backed todo storage should not write outside its base path for malicious session IDs.""" + session = AgentSession(session_id="../escape") + store = TodoFileStore(tmp_path) + + with pytest.raises(ValueError, match="session_id.*path separators"): + await store.save_state(session, [TodoItem(id=1, title="Escape")], next_id=2, source_id="todo") + + assert list(tmp_path.rglob("*")) == [] # noqa: ASYNC240 + + +async def test_todo_file_store_namespaces_state_by_source_id(tmp_path: Path) -> None: + """File-backed todo storage should isolate providers that share a session.""" + session = AgentSession(session_id="session-1") + store = TodoFileStore(tmp_path) + + await store.save_state(session, [TodoItem(id=1, title="First source")], next_id=2, source_id="first") + await store.save_state(session, [TodoItem(id=1, title="Second source")], next_id=2, source_id="second") + + first_items, _ = await store.load_state(session, source_id="first") + second_items, _ = await store.load_state(session, source_id="second") + + assert first_items == [TodoItem(id=1, title="First source")] + assert second_items == [TodoItem(id=1, title="Second source")] + assert (tmp_path / "session-1" / "todos.first.json").exists() + assert (tmp_path / "session-1" / "todos.second.json").exists() + + +async def test_todo_provider_runs_with_file_store(tmp_path: Path, chat_client_base: SupportsChatGetResponse) -> None: + """The provider should drive the full add/list flow when backed by ``TodoFileStore``.""" + session = AgentSession(session_id="session-1") + provider = TodoProvider(store=TodoFileStore(tmp_path)) + agent = Agent(client=chat_client_base, context_providers=[provider]) + + _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session=session, + input_messages=[Message(role="user", contents=["Track this work"])], + ) + tools = options["tools"] + assert isinstance(tools, list) + + add_todos = _tool_by_name(tools, "add_todos") + get_all_todos = _tool_by_name(tools, "get_all_todos") + + await add_todos.invoke(arguments={"todos": [{"title": "Persist me"}]}) + state_path = tmp_path / "session-1" / "todos.todo.json" + assert state_path.exists() + persisted = json.loads(state_path.read_text(encoding="utf-8")) + assert persisted["items"] == [{"id": 1, "title": "Persist me", "description": None, "is_complete": False}] + assert persisted["next_id"] == 2 + + get_all_result = await get_all_todos.invoke() + assert json.loads(get_all_result[0].text) == [ + {"id": 1, "title": "Persist me", "description": None, "is_complete": False} + ] + + +async def test_todo_provider_tools_manage_session_state( + chat_client_base: SupportsChatGetResponse, +) -> None: + """Todo provider tools should add, complete, remove, and list session-backed todos.""" + session = AgentSession(session_id="session-1") + provider = TodoProvider() + agent = Agent(client=chat_client_base, context_providers=[provider]) + + _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session=session, + input_messages=[Message(role="user", contents=["Track this work"])], + ) + tools = options["tools"] + assert isinstance(tools, list) + + add_todos = _tool_by_name(tools, "add_todos") + complete_todos = _tool_by_name(tools, "complete_todos") + remove_todos = _tool_by_name(tools, "remove_todos") + get_remaining_todos = _tool_by_name(tools, "get_remaining_todos") + get_all_todos = _tool_by_name(tools, "get_all_todos") + + add_result = await add_todos.invoke( + arguments={ + "todos": [ + {"title": " Write tests ", "description": " Cover stores "}, + {"title": "Ship feature"}, + ] + } + ) + assert json.loads(add_result[0].text) == [ + {"id": 1, "title": "Write tests", "description": "Cover stores", "is_complete": False}, + {"id": 2, "title": "Ship feature", "description": None, "is_complete": False}, + ] + + complete_result = await complete_todos.invoke(arguments={"ids": [1]}) + assert json.loads(complete_result[0].text) == {"completed": 1} + + remaining_result = await get_remaining_todos.invoke() + assert json.loads(remaining_result[0].text) == [ + {"id": 2, "title": "Ship feature", "description": None, "is_complete": False} + ] + + remove_result = await remove_todos.invoke(arguments={"ids": [2]}) + assert json.loads(remove_result[0].text) == {"removed": 1} + + get_all_result = await get_all_todos.invoke() + assert json.loads(get_all_result[0].text) == [ + {"id": 1, "title": "Write tests", "description": "Cover stores", "is_complete": True} + ] + + +async def test_todo_provider_serializes_concurrent_mutations( + chat_client_base: SupportsChatGetResponse, +) -> None: + """Concurrent todo mutations should not duplicate IDs or lose updates.""" + session = AgentSession(session_id="session-1") + provider = TodoProvider() + agent = Agent(client=chat_client_base, context_providers=[provider]) + + _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session=session, + input_messages=[Message(role="user", contents=["Track this work"])], + ) + tools = options["tools"] + assert isinstance(tools, list) + + add_todos = _tool_by_name(tools, "add_todos") + complete_todos = _tool_by_name(tools, "complete_todos") + get_all_todos = _tool_by_name(tools, "get_all_todos") + + await add_todos.invoke(arguments={"todos": [{"title": f"Existing {index}"} for index in range(1, 6)]}) + + await asyncio.gather( + add_todos.invoke(arguments={"todos": [{"title": "Add A1"}, {"title": "Add A2"}]}), + add_todos.invoke(arguments={"todos": [{"title": "Add B1"}, {"title": "Add B2"}]}), + complete_todos.invoke(arguments={"ids": [1, 2, 3, 4, 5]}), + ) + + get_all_result = await get_all_todos.invoke() + payload = json.loads(get_all_result[0].text) + ids = [item["id"] for item in payload] + + assert sorted(ids) == list(range(1, 10)) + assert len(ids) == len(set(ids)) + assert {item["title"] for item in payload} == { + "Existing 1", + "Existing 2", + "Existing 3", + "Existing 4", + "Existing 5", + "Add A1", + "Add A2", + "Add B1", + "Add B2", + } + assert {item["id"] for item in payload if item["is_complete"]} == {1, 2, 3, 4, 5} + + +def test_todo_harness_classes_are_marked_experimental() -> None: + """Todo harness public classes should expose HARNESS experimental metadata.""" + assert TodoStore.__feature_id__ == ExperimentalFeature.HARNESS.value + assert TodoItem.__feature_id__ == ExperimentalFeature.HARNESS.value + assert TodoInput.__feature_id__ == ExperimentalFeature.HARNESS.value + assert TodoSessionStore.__feature_id__ == ExperimentalFeature.HARNESS.value + assert TodoFileStore.__feature_id__ == ExperimentalFeature.HARNESS.value + assert TodoProvider.__feature_id__ == ExperimentalFeature.HARNESS.value + assert ".. warning:: Experimental" in TodoProvider.__doc__