diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index 4b6d9cc19b..7dae6ddcd2 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -35,9 +35,9 @@ AgentResponseUpdate, AgentSession, BaseAgent, - BaseHistoryProvider, Content, ContinuationToken, + HistoryProvider, Message, ResponseStream, SessionContext, @@ -353,7 +353,7 @@ async def _map_a2a_stream( # Run before_run providers (forward order) for provider in self.context_providers: - if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: + if isinstance(provider, HistoryProvider) and not provider.load_messages: continue if session is None: raise RuntimeError("Provider session must be available when context providers are configured.") diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index 58a82f6d8c..136164c79a 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -24,8 +24,8 @@ AgentResponse, AgentResponseUpdate, AgentSession, - BaseContextProvider, Content, + ContextProvider, Message, SessionContext, ) @@ -869,7 +869,7 @@ async def test_poll_task_completed(a2a_agent: A2AAgent, mock_a2a_client: MockA2A # region Context Provider Tests -class TrackingContextProvider(BaseContextProvider): +class TrackingContextProvider(ContextProvider): """A context provider that records when before_run and after_run are called.""" def __init__(self) -> None: diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py index adda863c5a..c10e7cb9b8 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. -"""New-pattern Azure AI Search context provider using BaseContextProvider. +"""New-pattern Azure AI Search context provider using ContextProvider. This module provides ``AzureAISearchContextProvider``, built on the new -:class:`BaseContextProvider` hooks pattern. +:class:`ContextProvider` hooks pattern. """ from __future__ import annotations @@ -17,8 +17,8 @@ AGENT_FRAMEWORK_USER_AGENT, AgentSession, Annotation, - BaseContextProvider, Content, + ContextProvider, Message, SecretString, SessionContext, @@ -154,8 +154,8 @@ class AzureAISearchSettings(TypedDict, total=False): api_key: SecretString | None -class AzureAISearchContextProvider(BaseContextProvider): - """Azure AI Search context provider using the new BaseContextProvider hooks pattern. +class AzureAISearchContextProvider(ContextProvider): + """Azure AI Search context provider using the new ContextProvider hooks pattern. Retrieves relevant context from Azure AI Search using semantic or agentic search modes. diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py index 6a61350a9c..d13f285249 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py @@ -11,7 +11,7 @@ from typing import Any, ClassVar, TypedDict from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message -from agent_framework._sessions import BaseHistoryProvider +from agent_framework._sessions import HistoryProvider from agent_framework._settings import SecretString, load_settings from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential @@ -32,8 +32,8 @@ class AzureCosmosHistorySettings(TypedDict, total=False): key: SecretString | None -class CosmosHistoryProvider(BaseHistoryProvider): - """Azure Cosmos DB-backed history provider using BaseHistoryProvider hooks.""" +class CosmosHistoryProvider(HistoryProvider): + """Azure Cosmos DB-backed history provider using HistoryProvider hooks.""" DEFAULT_SOURCE_ID: ClassVar[str] = "azure_cosmos_history" _BATCH_OPERATION_LIMIT: ClassVar[int] = 100 diff --git a/python/packages/claude/agent_framework_claude/_agent.py b/python/packages/claude/agent_framework_claude/_agent.py index dd30a3b2d2..fcc7151342 100644 --- a/python/packages/claude/agent_framework_claude/_agent.py +++ b/python/packages/claude/agent_framework_claude/_agent.py @@ -16,8 +16,8 @@ AgentRunInputs, AgentSession, BaseAgent, - BaseContextProvider, Content, + ContextProvider, FunctionTool, Message, ResponseStream, @@ -223,7 +223,7 @@ def __init__( id: str | None = None, name: str | None = None, description: str | None = None, - context_providers: Sequence[BaseContextProvider] | None = None, + context_providers: Sequence[ContextProvider] | None = None, middleware: Sequence[AgentMiddlewareTypes] | None = None, tools: ToolTypes | Callable[..., Any] | str | Sequence[ToolTypes | Callable[..., Any] | str] | None = None, default_options: OptionsT | MutableMapping[str, Any] | None = None, diff --git a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py index fc2a35c72b..56a9c89081 100644 --- a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py +++ b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py @@ -11,8 +11,8 @@ AgentResponseUpdate, AgentSession, BaseAgent, - BaseContextProvider, Content, + ContextProvider, Message, ResponseStream, normalize_messages, @@ -60,7 +60,7 @@ def __init__( id: str | None = None, name: str | None = None, description: str | None = None, - context_providers: Sequence[BaseContextProvider] | None = None, + context_providers: Sequence[ContextProvider] | None = None, middleware: list[AgentMiddlewareTypes] | None = None, environment_id: str | None = None, agent_identifier: str | None = None, diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index 4a9317aa54..29fe978aef 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -61,8 +61,8 @@ agent_framework/ - **`AgentSession`** - Manages conversation state and session metadata - **`SessionContext`** - Context object for session-scoped data during agent runs -- **`BaseContextProvider`** - Base class for context providers (RAG, memory systems) -- **`BaseHistoryProvider`** - Base class for conversation history storage +- **`ContextProvider`** - Base class for context providers (RAG, memory systems) +- **`HistoryProvider`** - Base class for conversation history storage ### Skills (`_skills.py`) @@ -70,7 +70,7 @@ agent_framework/ - **`SkillResource`** - Named supplementary content attached to a skill; holds either static `content` or a dynamic `function` (sync or async). Exactly one must be provided. - **`SkillScript`** - An executable script attached to a skill; holds either an inline `function` (code-defined, runs in-process) or a `path` to a file on disk (file-based, delegated to a runner). Exactly one must be provided. - **`SkillScriptRunner`** - Protocol for file-based script execution. Any callable matching `(skill, script, args) -> Any` satisfies it. Code-defined scripts do not use a runner. -- **`SkillsProvider`** - Context provider (extends `BaseContextProvider`) that discovers file-based skills from `SKILL.md` files and/or accepts code-defined `Skill` instances. Follows progressive disclosure: advertise → load → read resources / run scripts. +- **`SkillsProvider`** - Context provider (extends `ContextProvider`) that discovers file-based skills from `SKILL.md` files and/or accepts code-defined `Skill` instances. Follows progressive disclosure: advertise → load → read resources / run scripts. ### Workflows (`_workflows/`) diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index d30d5159ea..7c39e1ffea 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -102,8 +102,10 @@ ) from ._sessions import ( AgentSession, - BaseContextProvider, - BaseHistoryProvider, + BaseContextProvider, # type: ignore[reportDeprecated] + BaseHistoryProvider, # type: ignore[reportDeprecated] + ContextProvider, + HistoryProvider, InMemoryHistoryProvider, SessionContext, register_state_type, @@ -296,6 +298,7 @@ "CompactionProvider", "CompactionStrategy", "Content", + "ContextProvider", "ContinuationToken", "ConversationSplit", "ConversationSplitter", @@ -331,6 +334,7 @@ "FunctionTool", "GeneratedEmbeddings", "GraphConnectivityError", + "HistoryProvider", "InMemoryCheckpointStorage", "InMemoryHistoryProvider", "InProcRunnerContext", diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 1868742111..b857a6377f 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -29,14 +29,16 @@ from ._clients import BaseChatClient, SupportsChatGetResponse from ._docstrings import apply_layered_docstring from ._mcp import LOG_LEVEL_MAPPING, MCPTool -from ._middleware import AgentMiddlewareLayer, FunctionInvocationContext, MiddlewareTypes +from ._middleware import AgentMiddlewareLayer, FunctionInvocationContext, MiddlewareTypes, categorize_middleware from ._serialization import SerializationMixin from ._sessions import ( AgentSession, - BaseContextProvider, - BaseHistoryProvider, + ContextProvider, + HistoryProvider, InMemoryHistoryProvider, + PerServiceCallHistoryPersistingMiddleware, SessionContext, + is_local_history_conversation_id, ) from ._tools import FunctionInvocationLayer, FunctionTool, ToolTypes, normalize_tools from ._types import ( @@ -50,7 +52,7 @@ map_chat_to_agent_update, normalize_messages, ) -from .exceptions import AgentInvalidResponseException, UserInputRequiredException +from .exceptions import AgentInvalidRequestException, AgentInvalidResponseException, UserInputRequiredException from .observability import AgentTelemetryLayer if sys.version_info >= (3, 13): @@ -166,6 +168,7 @@ class _RunContext(TypedDict): input_messages: Sequence[Message] session_messages: Sequence[Message] agent_name: str + suppress_response_id: bool chat_options: MutableMapping[str, Any] compaction_strategy: CompactionStrategy | None tokenizer: TokenizerProtocol | None @@ -366,6 +369,7 @@ async def _stream(): """ DEFAULT_EXCLUDE: ClassVar[set[str]] = {"additional_properties"} + require_per_service_call_history_persistence: bool = False def __init__( self, @@ -373,7 +377,7 @@ def __init__( id: str | None = None, name: str | None = None, description: str | None = None, - context_providers: Sequence[BaseContextProvider] | None = None, + context_providers: Sequence[ContextProvider] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, additional_properties: MutableMapping[str, Any] | None = None, ) -> None: @@ -393,7 +397,7 @@ def __init__( self.id = id self.name = name self.description = description - self.context_providers: list[BaseContextProvider] = list(context_providers or []) + self.context_providers: list[ContextProvider] = list(context_providers or []) self.middleware: list[MiddlewareTypes] | None = ( cast(list[MiddlewareTypes], middleware) if middleware is not None else None ) @@ -455,7 +459,12 @@ async def _run_after_providers( if provider_session is None and self.context_providers: provider_session = AgentSession() + per_service_call_history_required = self.require_per_service_call_history_persistence and any( + isinstance(provider, HistoryProvider) for provider in self.context_providers + ) for provider in reversed(self.context_providers): + if per_service_call_history_required and isinstance(provider, HistoryProvider): + continue if provider_session is None: raise RuntimeError("Provider session must be available when context providers are configured.") await provider.after_run( @@ -656,8 +665,9 @@ def __init__( description: str | None = None, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, default_options: OptionsCoT | None = None, - context_providers: Sequence[BaseContextProvider] | None = None, + context_providers: Sequence[ContextProvider] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, + require_per_service_call_history_persistence: bool = False, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, additional_properties: MutableMapping[str, Any] | None = None, @@ -675,6 +685,11 @@ def __init__( description: A brief description of the agent's purpose. context_providers: Context providers to include during agent invocation. middleware: List of middleware to intercept agent and function invocations. + require_per_service_call_history_persistence: When True, history providers are invoked + around each model call instead of once per ``run()`` when the service + is not already storing history. If service-side storage is active for + the run, the agent skips local history providers and relies on the + service-managed conversation instead. default_options: A TypedDict containing chat options. When using a typed agent like ``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for provider-specific options including temperature, max_tokens, model_id, @@ -706,6 +721,7 @@ def __init__( ) self.client = client self.compaction_strategy = compaction_strategy + self.require_per_service_call_history_persistence = require_per_service_call_history_persistence self.tokenizer = tokenizer # Get tools from options or named parameter (named param takes precedence) @@ -764,6 +780,35 @@ async def __aenter__(self) -> Self: await self._async_exit_stack.enter_async_context(context_manager) return self + def _get_history_providers(self) -> list[HistoryProvider]: + return [provider for provider in self.context_providers if isinstance(provider, HistoryProvider)] + + def _resolve_per_service_call_history_providers( + self, + *, + session: AgentSession | None, + options: Mapping[str, Any] | None, + service_stores_history: bool, + ) -> list[HistoryProvider]: + history_providers = self._get_history_providers() + if not self.require_per_service_call_history_persistence or not history_providers: + return [] + + conversation_id = ( + session.service_session_id + if session and session.service_session_id + else cast(str | None, (options or {}).get("conversation_id") or self.default_options.get("conversation_id")) + ) + if service_stores_history: + return [] + + if conversation_id is not None: + raise AgentInvalidRequestException( + "require_per_service_call_history_persistence cannot be used " + "with an existing service-managed conversation." + ) + return history_providers + async def __aexit__( self, exc_type: type[BaseException] | None, @@ -885,155 +930,189 @@ def run( When stream=True: A ResponseStream of AgentResponseUpdate items with ``get_final_response()`` for the final AgentResponse. """ + + async def _prepare_run_context() -> _RunContext: + return await self._prepare_run_context( + messages=messages, + session=session, + tools=tools, + options=options, + compaction_strategy=compaction_strategy, + tokenizer=tokenizer, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + ) + if not stream: async def _run_non_streaming() -> AgentResponse[Any]: - ctx = await self._prepare_run_context( - messages=messages, - session=session, - tools=tools, - options=options, - compaction_strategy=compaction_strategy, - tokenizer=tokenizer, - function_invocation_kwargs=function_invocation_kwargs, - client_kwargs=client_kwargs, - ) - response = cast( - ChatResponse[Any], - await self.client.get_response( # type: ignore - messages=ctx["session_messages"], - stream=False, - options=ctx["chat_options"], # type: ignore[reportArgumentType] - compaction_strategy=ctx["compaction_strategy"], - tokenizer=ctx["tokenizer"], - function_invocation_kwargs=ctx["function_invocation_kwargs"], - client_kwargs=ctx["client_kwargs"], - ), - ) + ctx = await _prepare_run_context() + response = await self._call_chat_client(ctx, stream=False) + return await self._parse_non_streaming_response(ctx, response) - if not response: - raise AgentInvalidResponseException("Chat client did not return a response.") + return _run_non_streaming() - await self._finalize_response( - response=response, - agent_name=ctx["agent_name"], - session=ctx["session"], - session_context=ctx["session_context"], - ) - response_format = ctx["chat_options"].get("response_format") - if not ( - response_format is not None - and isinstance(response_format, type) - and issubclass(response_format, BaseModel) - ): - response_format = None - - return AgentResponse( - messages=response.messages, - response_id=response.response_id, - created_at=response.created_at, - usage_details=response.usage_details, - value=response.value, - response_format=response_format, - continuation_token=response.continuation_token, - raw_representation=response, - additional_properties=response.additional_properties, - ) + async def _run_streaming() -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: + ctx = await _prepare_run_context() + stream_response = self._call_chat_client(ctx, stream=True) + return self._parse_streaming_response(ctx, stream_response) - return _run_non_streaming() + return cast( + ResponseStream[AgentResponseUpdate, AgentResponse[Any]], + cast(Any, ResponseStream).from_awaitable(_run_streaming()), + ) - # Use a holder to capture the context created during stream initialization - ctx_holder: dict[str, _RunContext | None] = {"ctx": None} + @overload + def _call_chat_client( + self, + context: _RunContext, + *, + stream: Literal[False], + ) -> Awaitable[ChatResponse[Any]]: ... - async def _post_hook(response: AgentResponse) -> None: - ctx = ctx_holder["ctx"] - if ctx is None: - return # No context available (shouldn't happen in normal flow) + @overload + def _call_chat_client( + self, + context: _RunContext, + *, + stream: Literal[True], + ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... + + def _call_chat_client( + self, + context: _RunContext, + *, + stream: bool, + ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: + """Invoke the downstream chat client for a prepared run context.""" + if stream: + return self.client.get_response( # type: ignore[call-overload, no-any-return] + messages=context["session_messages"], + stream=True, + options=context["chat_options"], # type: ignore[reportArgumentType] + compaction_strategy=context["compaction_strategy"], + tokenizer=context["tokenizer"], + function_invocation_kwargs=context["function_invocation_kwargs"], + client_kwargs=context["client_kwargs"], + ) + return self.client.get_response( # type: ignore[call-overload, no-any-return] + messages=context["session_messages"], + stream=False, + options=context["chat_options"], # type: ignore[reportArgumentType] + compaction_strategy=context["compaction_strategy"], + tokenizer=context["tokenizer"], + function_invocation_kwargs=context["function_invocation_kwargs"], + client_kwargs=context["client_kwargs"], + ) + + async def _parse_non_streaming_response( + self, + context: _RunContext, + response: ChatResponse[Any], + ) -> AgentResponse[Any]: + """Finalize a non-streaming chat response into an AgentResponse.""" + if not response: + raise AgentInvalidResponseException("Chat client did not return a response.") + + await self._finalize_response( + response=response, + agent_name=context["agent_name"], + session=context["session"], + session_context=context["session_context"], + suppress_response_id=context["suppress_response_id"], + ) + + response_format = context["chat_options"].get("response_format") + if not ( + response_format is not None and isinstance(response_format, type) and issubclass(response_format, BaseModel) + ): + response_format = None + + return AgentResponse( + messages=response.messages, + response_id=None if context["suppress_response_id"] else response.response_id, + created_at=response.created_at, + usage_details=response.usage_details, + value=response.value, + response_format=response_format, + continuation_token=response.continuation_token, + raw_representation=response, + additional_properties=response.additional_properties, + ) + + def _parse_streaming_response( + self, + context: _RunContext, + stream_response: ResponseStream[ChatResponseUpdate, ChatResponse[Any]], + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: + """Finalize a streaming chat response into an agent response stream.""" + + async def _post_hook(response: AgentResponse) -> None: # Update thread with conversation_id derived from streaming raw updates. # Using response_id here can break function-call continuation for APIs # where response IDs are not valid conversation handles. conversation_id = self._extract_conversation_id_from_streaming_response(response) - # Ensure author names are set for all messages + for message in response.messages: if message.author_name is None: - message.author_name = ctx["agent_name"] - - # Propagate conversation_id back to session from streaming updates. - # For Responses-style APIs this can rotate every turn (response_id-based continuation), - # so refresh when a newer value is returned. - sess = ctx["session"] - if sess and conversation_id and sess.service_session_id != conversation_id: - sess.service_session_id = conversation_id - - # Run after_run providers (reverse order) - session_context = ctx["session_context"] + message.author_name = context["agent_name"] + + session = context["session"] + if ( + session + and conversation_id + and not is_local_history_conversation_id(conversation_id) + and session.service_session_id != conversation_id + ): + session.service_session_id = conversation_id + + suppress_response_id = context["suppress_response_id"] + session_context = context["session_context"] session_context._response = AgentResponse( # type: ignore[assignment] messages=response.messages, - response_id=response.response_id, + response_id=None if suppress_response_id else response.response_id, ) - await self._run_after_providers(session=ctx["session"], context=session_context) + await self._run_after_providers(session=session, context=session_context) - async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: - ctx_holder["ctx"] = await self._prepare_run_context( - messages=messages, - session=session, - tools=tools, - options=options, - compaction_strategy=compaction_strategy, - tokenizer=tokenizer, - function_invocation_kwargs=function_invocation_kwargs, - client_kwargs=client_kwargs, - ) - ctx: _RunContext = ctx_holder["ctx"] # type: ignore[assignment] # Safe: we just assigned it - return self.client.get_response( # type: ignore[call-overload, no-any-return] - messages=ctx["session_messages"], - stream=True, - options=ctx["chat_options"], # type: ignore[reportArgumentType] - compaction_strategy=ctx["compaction_strategy"], - tokenizer=ctx["tokenizer"], - function_invocation_kwargs=ctx["function_invocation_kwargs"], - client_kwargs=ctx["client_kwargs"], - ) - - def _propagate_conversation_id( - update: AgentResponseUpdate, - ) -> AgentResponseUpdate: - """Eagerly propagate conversation_id to session as updates arrive. - - This ensures session.service_session_id is set even when the user - only iterates the stream without calling get_final_response(). - """ + def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpdate: + """Eagerly propagate conversation_id to session as updates arrive.""" + session = context["session"] if session is None: return update raw = update.raw_representation - conv_id = getattr(raw, "conversation_id", None) if raw else None - if isinstance(conv_id, str) and conv_id and session.service_session_id != conv_id: - session.service_session_id = conv_id + conversation_id = getattr(raw, "conversation_id", None) if raw else None + if ( + isinstance(conversation_id, str) + and conversation_id + and not is_local_history_conversation_id(conversation_id) + and session.service_session_id != conversation_id + ): + session.service_session_id = conversation_id + return update + + def _suppress_response_id(update: AgentResponseUpdate) -> AgentResponseUpdate: + """Hide raw service response ids when local per-service-call persistence owns continuation.""" + update.response_id = None return update def _finalizer(updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]: - ctx = ctx_holder["ctx"] - rf = ( - ctx.get("chat_options", {}).get("response_format") - if ctx - else (options.get("response_format") if options else None) # type: ignore[union-attr] - ) - return self._finalize_response_updates(updates, response_format=rf) - - return ( - ResponseStream - .from_awaitable(_get_stream()) # type: ignore[reportUnknownMemberType] - .map( - transform=partial( - map_chat_to_agent_update, - agent_name=self.name, - ), - finalizer=_finalizer, + return self._finalize_response_updates( + updates, + response_format=context["chat_options"].get("response_format"), ) - .with_transform_hook(_propagate_conversation_id) - .with_result_hook(_post_hook) + + stream = stream_response.map( + transform=partial( + map_chat_to_agent_update, + agent_name=self.name, + ), + finalizer=_finalizer, ) + if context["suppress_response_id"]: + stream = stream.with_transform_hook(_suppress_response_id) + + return stream.with_transform_hook(_propagate_conversation_id).with_result_hook(_post_hook) def _finalize_response_updates( self, @@ -1111,6 +1190,12 @@ async def _prepare_run_context( if active_session is None and self.context_providers: active_session = AgentSession() + per_service_call_history_providers = self._resolve_per_service_call_history_providers( + session=active_session, + options=opts, + service_stores_history=bool(store_), + ) + session_context, chat_options = await self._prepare_session_and_messages( session=active_session, input_messages=input_messages, @@ -1191,6 +1276,43 @@ async def _prepare_run_context( effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} if active_session is not None: effective_client_kwargs["session"] = active_session + if per_service_call_history_providers and active_session is not None: + per_service_call_history_middleware = PerServiceCallHistoryPersistingMiddleware( + agent=self, + session=active_session, + providers=per_service_call_history_providers, + ) + existing_middleware = effective_client_kwargs.get("middleware") + if isinstance(existing_middleware, Sequence) and not isinstance(existing_middleware, (str, bytes)): + effective_client_kwargs["middleware"] = [per_service_call_history_middleware, *existing_middleware] + elif existing_middleware is not None: + effective_client_kwargs["middleware"] = [ + per_service_call_history_middleware, + cast(MiddlewareTypes, existing_middleware), + ] + else: + effective_client_kwargs["middleware"] = [per_service_call_history_middleware] + provider_middleware = session_context.get_middleware() + if provider_middleware: + middleware_list = categorize_middleware(provider_middleware) + provider_function_chat_middleware = [ + *middleware_list["function"], + *middleware_list["chat"], + ] + if provider_function_chat_middleware: + existing_middleware = effective_client_kwargs.get("middleware") + if isinstance(existing_middleware, Sequence) and not isinstance(existing_middleware, (str, bytes)): + effective_client_kwargs["middleware"] = [ + *existing_middleware, + *provider_function_chat_middleware, + ] + elif existing_middleware is not None: + effective_client_kwargs["middleware"] = [ + cast(MiddlewareTypes, existing_middleware), + *provider_function_chat_middleware, + ] + else: + effective_client_kwargs["middleware"] = provider_function_chat_middleware return { "session": active_session, @@ -1198,6 +1320,7 @@ async def _prepare_run_context( "input_messages": input_messages, "session_messages": session_messages, "agent_name": agent_name, + "suppress_response_id": bool(per_service_call_history_providers), "chat_options": co, "compaction_strategy": compaction_strategy or self.compaction_strategy, "tokenizer": tokenizer or self.tokenizer, @@ -1211,6 +1334,7 @@ async def _finalize_response( agent_name: str, session: AgentSession | None, session_context: SessionContext, + suppress_response_id: bool = False, ) -> None: """Finalize response by setting author names and running after_run providers. @@ -1219,6 +1343,7 @@ async def _finalize_response( agent_name: The name of the agent to set as author. session: The conversation session. session_context: The invocation context. + suppress_response_id: When True, omit the raw service response ID from the public response. """ # Ensure that the author name is set for each message in the response. for message in response.messages: @@ -1228,13 +1353,18 @@ async def _finalize_response( # Propagate conversation_id back to session (e.g. thread ID from Assistants API). # For Responses-style APIs this can rotate every turn (response_id-based continuation), # so refresh when a newer value is returned. - if session and response.conversation_id and session.service_session_id != response.conversation_id: + if ( + session + and response.conversation_id + and not is_local_history_conversation_id(response.conversation_id) + and session.service_session_id != response.conversation_id + ): session.service_session_id = response.conversation_id # Set the response on the context for after_run providers session_context._response = AgentResponse( # type: ignore[assignment] messages=response.messages, - response_id=response.response_id, + response_id=None if suppress_response_id else response.response_id, ) # Run after_run providers (reverse order) @@ -1284,9 +1414,15 @@ async def _prepare_session_and_messages( options=options or {}, ) - # Run before_run providers (forward order, skip BaseHistoryProvider with load_messages=False) + per_service_call_history_required = self.require_per_service_call_history_persistence and bool( + self._get_history_providers() + ) + + # Run before_run providers (forward order, skip HistoryProvider when per-service-call persistence owns history) for provider in self.context_providers: - if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: + if per_service_call_history_required and isinstance(provider, HistoryProvider): + continue + if isinstance(provider, HistoryProvider) and not provider.load_messages: continue if provider_session is None: raise RuntimeError("Provider session must be available when context providers are configured.") @@ -1551,8 +1687,9 @@ def __init__( description: str | None = None, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, default_options: OptionsCoT | None = None, - context_providers: Sequence[BaseContextProvider] | None = None, + context_providers: Sequence[ContextProvider] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, + require_per_service_call_history_persistence: bool = False, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, additional_properties: MutableMapping[str, Any] | None = None, @@ -1568,6 +1705,7 @@ def __init__( default_options=default_options, context_providers=context_providers, middleware=middleware, + require_per_service_call_history_persistence=require_per_service_call_history_persistence, compaction_strategy=compaction_strategy, tokenizer=tokenizer, additional_properties=additional_properties, diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index d1394a6733..41bcf25883 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -572,6 +572,7 @@ def as_agent( default_options: OptionsCoT | Mapping[str, Any] | None = None, context_providers: Sequence[Any] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, + require_per_service_call_history_persistence: bool = False, function_invocation_configuration: FunctionInvocationConfiguration | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, @@ -596,6 +597,10 @@ def as_agent( and dict literals are accepted without specialized option typing. context_providers: Context providers to include during agent invocation. middleware: List of middleware to intercept agent and function invocations. + require_per_service_call_history_persistence: Whether to require per-service-call + chat history persistence. When enabled, history providers are invoked around + each model call instead of once per ``run()`` when the service is not already + storing history. function_invocation_configuration: Optional function invocation configuration override. compaction_strategy: Optional agent-level compaction override. When omitted, client-level compaction defaults remain in effect for each call. @@ -636,6 +641,7 @@ def as_agent( "default_options": cast(Any, default_options), "context_providers": context_providers, "middleware": middleware, + "require_per_service_call_history_persistence": require_per_service_call_history_persistence, "compaction_strategy": compaction_strategy, "tokenizer": tokenizer, "additional_properties": dict(additional_properties) if additional_properties is not None else None, diff --git a/python/packages/core/agent_framework/_compaction.py b/python/packages/core/agent_framework/_compaction.py index 8a15a6438c..06879e3a16 100644 --- a/python/packages/core/agent_framework/_compaction.py +++ b/python/packages/core/agent_framework/_compaction.py @@ -15,7 +15,7 @@ runtime_checkable, ) -from ._sessions import BaseContextProvider +from ._sessions import ContextProvider from ._types import ChatResponse, Content, Message if TYPE_CHECKING: @@ -1152,7 +1152,7 @@ async def apply_compaction( COMPACTION_STATE_KEY: Final[str] = "_compaction_messages" -class CompactionProvider(BaseContextProvider): +class CompactionProvider(ContextProvider): """Context provider that compacts messages before and after agent runs. This provider accepts two separate strategies: diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index 84656824aa..11aa2419db 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -4,8 +4,8 @@ This module provides the core types for the context provider pipeline: - SessionContext: Per-invocation state passed through providers -- BaseContextProvider: Base class for context providers (renamed to ContextProvider in PR2) -- BaseHistoryProvider: Base class for history storage providers (renamed to HistoryProvider in PR2) +- ContextProvider: Base class for context providers +- HistoryProvider: Base class for history storage providers - AgentSession: Lightweight session state container - InMemoryHistoryProvider: Built-in in-memory history provider """ @@ -13,21 +13,42 @@ from __future__ import annotations import copy +import sys import uuid from abc import abstractmethod -from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, ClassVar, cast +from collections.abc import Awaitable, Callable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, cast -from ._types import AgentResponse, Message +if sys.version_info >= (3, 13): + from warnings import deprecated # type: ignore # pragma: no cover +else: + from typing_extensions import deprecated # type: ignore # pragma: no cover + +from ._middleware import ChatContext, ChatMiddleware +from ._types import AgentResponse, ChatResponse, Message, ResponseStream +from .exceptions import ChatClientInvalidResponseException if TYPE_CHECKING: from ._agents import SupportsAgentRun + from ._middleware import MiddlewareTypes # Registry of known types for state deserialization _STATE_TYPE_REGISTRY: dict[str, type] = {} +def _is_middleware_sequence( + middleware: MiddlewareTypes | Sequence[MiddlewareTypes], +) -> TypeGuard[Sequence[MiddlewareTypes]]: + return isinstance(middleware, Sequence) and not isinstance(middleware, (str, bytes)) + + +def _is_single_middleware( + middleware: MiddlewareTypes | Sequence[MiddlewareTypes], +) -> TypeGuard[MiddlewareTypes]: + return not _is_middleware_sequence(middleware) + + def register_state_type(cls: type) -> None: """Register a type for automatic deserialization in session state. @@ -131,6 +152,8 @@ class SessionContext: Maintains insertion order (provider execution order). instructions: Additional instructions added by providers. tools: Additional tools added by providers. + middleware: Dict mapping source_id -> chat/function middleware added by that provider. + Maintains insertion order (provider execution order). response: After invocation, contains the full AgentResponse, should not be changed. options: Options passed to agent.run() - read-only, for reflection only. metadata: Shared metadata dictionary for cross-provider communication. @@ -145,6 +168,7 @@ def __init__( context_messages: dict[str, list[Message]] | None = None, instructions: list[str] | None = None, tools: list[Any] | None = None, + middleware: dict[str, list[MiddlewareTypes]] | None = None, options: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, ): @@ -157,6 +181,7 @@ def __init__( context_messages: Pre-populated context messages by source. instructions: Pre-populated instructions. tools: Pre-populated tools. + middleware: Pre-populated chat/function middleware by source. options: Options from agent.run() - read-only for providers. metadata: Shared metadata for cross-provider communication. """ @@ -166,6 +191,10 @@ def __init__( self.context_messages: dict[str, list[Message]] = context_messages or {} self.instructions: list[str] = instructions or [] self.tools: list[Any] = tools or [] + self.middleware: dict[str, list[MiddlewareTypes]] = {} + if middleware: + for source_id, provider_middleware in middleware.items(): + self.extend_middleware(source_id, provider_middleware) self._response: AgentResponse | None = None self.options: dict[str, Any] = options or {} self.metadata: dict[str, Any] = metadata or {} @@ -236,6 +265,40 @@ def extend_tools(self, source_id: str, tools: Sequence[Any]) -> None: additional_properties["context_source"] = source_id self.tools.extend(tools) + def extend_middleware( + self, + source_id: str, + middleware: MiddlewareTypes | Sequence[MiddlewareTypes], + ) -> None: + """Add middleware to be applied for this invocation. + + Args: + source_id: The provider source_id adding this middleware. + middleware: A single chat/function middleware object/callable or sequence of middleware. + """ + from ._middleware import categorize_middleware + from .exceptions import MiddlewareException + + if _is_middleware_sequence(middleware): + middleware_items = list(middleware) + elif _is_single_middleware(middleware): + middleware_items = [middleware] + else: + raise TypeError("middleware must be a middleware object or a sequence of middleware objects.") + middleware_list = categorize_middleware(middleware_items) + if middleware_list["agent"]: + raise MiddlewareException("Context providers may only add chat or function middleware.") + if source_id not in self.middleware: + self.middleware[source_id] = [] + self.middleware[source_id].extend(middleware_items) + + def get_middleware(self) -> list[MiddlewareTypes]: + """Get provider-added chat/function middleware in provider execution order.""" + result: list[MiddlewareTypes] = [] + for middleware_items in self.middleware.values(): + result.extend(middleware_items) + return result + def get_messages( self, *, @@ -272,17 +335,12 @@ def get_messages( return result -class BaseContextProvider: - """Base class for context providers (hooks pattern). +class ContextProvider: + """Base class for context providers. Context providers participate in the context engineering pipeline, adding context before model invocation and processing responses after. - Note: - This class uses a temporary name prefixed with ``_`` to avoid collision - with the existing ``ContextProvider`` in ``_memory.py``. It will be - renamed to ``ContextProvider`` in PR2 when the old class is removed. - Attributes: source_id: Unique identifier for this provider instance (required). Used for message/tool attribution so other providers can filter. @@ -312,7 +370,7 @@ async def before_run( Args: agent: The agent running this invocation. session: The current session. - context: The invocation context - add messages/instructions/tools here. + context: The invocation context - add messages/instructions/tools/chat/function middleware here. state: The provider-scoped mutable state dict for this provider. Full cross-provider state remains available at ``session.state``. """ @@ -339,7 +397,7 @@ async def after_run( """ -class BaseHistoryProvider(BaseContextProvider): +class HistoryProvider(ContextProvider): """Base class for conversation history storage providers. A single class configurable for different use cases: @@ -347,10 +405,6 @@ class BaseHistoryProvider(BaseContextProvider): - Audit/logging storage (stores only, doesn't load) - Evaluation storage (stores only for later analysis) - Note: - This class uses a temporary name prefixed with ``_`` to avoid collision - with existing types. It will be renamed to ``HistoryProvider`` in PR2. - Subclasses only need to implement ``get_messages()`` and ``save_messages()``. The default ``before_run``/``after_run`` handle loading and storing based on configuration flags. Override them for custom behavior. @@ -467,6 +521,207 @@ async def after_run( await self.save_messages(context.session_id, messages_to_store, state=state) +LOCAL_HISTORY_CONVERSATION_ID = "agent_framework_local_history_persistence" + + +def is_local_history_conversation_id(conversation_id: str | None) -> bool: + """Return whether a conversation id is the local history-persistence sentinel.""" + return conversation_id == LOCAL_HISTORY_CONVERSATION_ID + + +def _response_contains_follow_up_request(response: ChatResponse) -> bool: + """Return whether a response requires another model call in the current run.""" + return any( + item.type in {"function_call", "function_approval_request"} + for message in response.messages + for item in message.contents + ) + + +def _split_service_call_messages(messages: Sequence[Message]) -> tuple[list[Message], dict[str, list[Message]]]: + """Split service-call messages into input messages and attributed context messages.""" + input_messages: list[Message] = [] + context_messages: dict[str, list[Message]] = {} + for message in messages: + attribution = message.additional_properties.get("_attribution") + if isinstance(attribution, Mapping): + attribution_mapping = cast(Mapping[str, Any], attribution) + source_id = attribution_mapping.get("source_id") + if isinstance(source_id, str): + context_messages.setdefault(source_id, []).append(message) + continue + input_messages.append(message) + return input_messages, context_messages + + +class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware): + """Persist local chat history after each service call when history is framework-managed. + + This middleware runs around each model call when + ``require_per_service_call_history_persistence`` is enabled. It loads history providers + before the model call, persists them after the model call, and uses a local + sentinel conversation id so the function loop follows the existing + service-managed branch without forwarding that sentinel to the leaf client. + """ + + def __init__( + self, + *, + agent: SupportsAgentRun, + session: AgentSession, + providers: Sequence[HistoryProvider], + ) -> None: + """Initialize the middleware. + + Args: + agent: The agent that owns the history providers. + session: The active session for the current run. + providers: The history providers participating in per-service-call persistence. + """ + self._agent = agent + self._session = session + self._providers = list(providers) + + async def _prepare_service_call_context(self, messages: Sequence[Message]) -> SessionContext: + """Create a per-call SessionContext and load history providers into it.""" + input_messages, context_messages = _split_service_call_messages(messages) + service_call_context = SessionContext( + session_id=self._session.session_id, + service_session_id=None, + input_messages=list(input_messages), + ) + for source_id, source_messages in context_messages.items(): + service_call_context.extend_messages(source_id, source_messages) + for provider in self._providers: + if not provider.load_messages: + continue + await provider.before_run( + agent=self._agent, + session=self._session, + context=service_call_context, + state=self._session.state.setdefault(provider.source_id, {}), + ) + return service_call_context + + async def _persist_service_call_response( + self, + *, + service_call_context: SessionContext, + response: ChatResponse, + ) -> None: + """Persist a single model-call response through the configured history providers.""" + service_call_context._response = AgentResponse( # type: ignore[assignment] + messages=response.messages, + response_id=None, + ) + for provider in reversed(self._providers): + await provider.after_run( + agent=self._agent, + session=self._session, + context=service_call_context, + state=self._session.state.setdefault(provider.source_id, {}), + ) + + def _strip_local_conversation_id(self, context: ChatContext) -> None: + """Remove the local sentinel before the leaf chat client is invoked.""" + if is_local_history_conversation_id(cast(str | None, context.kwargs.get("conversation_id"))): + context.kwargs.pop("conversation_id", None) + + if context.options is None: + return + + mutable_options = dict(context.options) + if is_local_history_conversation_id(cast(str | None, mutable_options.get("conversation_id"))): + mutable_options.pop("conversation_id", None) + context.options = mutable_options + + async def _finalize_response( + self, + *, + service_call_context: SessionContext, + response: ChatResponse, + ) -> ChatResponse: + """Persist a model response and apply the local follow-up sentinel when needed.""" + if response.conversation_id is not None and not is_local_history_conversation_id(response.conversation_id): + raise ChatClientInvalidResponseException( + "require_per_service_call_history_persistence cannot be used " + "when the chat client returns a real conversation_id." + ) + + await self._persist_service_call_response( + service_call_context=service_call_context, + response=response, + ) + if _response_contains_follow_up_request(response): + response.mark_internal_conversation_id() + response.conversation_id = LOCAL_HISTORY_CONVERSATION_ID + return response + + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + """Load and persist history providers around a single model call. + + Args: + context: The chat invocation context for the current model call. + call_next: The next middleware or the leaf chat client. + + Raises: + ChatClientInvalidResponseException: If the leaf client returns a real + service-managed conversation id while local per-service-call persistence is enabled. + ValueError: If the downstream middleware contract returns the wrong + result type for streaming or non-streaming execution. + """ + service_call_context = await self._prepare_service_call_context(context.messages) + context.messages = service_call_context.get_messages(include_input=True) + self._strip_local_conversation_id(context) + + await call_next() + + if context.result is None: + return + + if context.stream: + if not isinstance(context.result, ResponseStream): + raise ValueError("Streaming chat middleware requires a ResponseStream result.") + context.result = context.result.with_result_hook( + lambda response: self._finalize_response( + service_call_context=service_call_context, + response=response, + ) + ) + return + + if isinstance(context.result, ResponseStream): + raise ValueError("Non-streaming chat middleware requires a ChatResponse result.") + context.result = await self._finalize_response( + service_call_context=service_call_context, + response=context.result, + ) + + +@deprecated( + "BaseContextProvider is deprecated. Use ContextProvider instead.", + category=DeprecationWarning, +) +class BaseContextProvider(ContextProvider): + """Deprecated alias for :class:`ContextProvider`. + + .. deprecated:: + BaseContextProvider is deprecated. Use :class:`ContextProvider` instead. + """ + + +@deprecated( + "BaseHistoryProvider is deprecated. Use HistoryProvider instead.", + category=DeprecationWarning, +) +class BaseHistoryProvider(HistoryProvider): + """Deprecated alias for :class:`HistoryProvider`. + + .. deprecated:: + BaseHistoryProvider is deprecated. Use :class:`HistoryProvider` instead. + """ + + class AgentSession: """A conversation session with an agent. @@ -535,7 +790,7 @@ def from_dict(cls, data: dict[str, Any]) -> AgentSession: return session -class InMemoryHistoryProvider(BaseHistoryProvider): +class InMemoryHistoryProvider(HistoryProvider): """Built-in history provider that stores messages in session.state. Messages are stored in ``state["messages"]`` as a list of diff --git a/python/packages/core/agent_framework/_skills.py b/python/packages/core/agent_framework/_skills.py index dc111750af..5c99dbaa60 100644 --- a/python/packages/core/agent_framework/_skills.py +++ b/python/packages/core/agent_framework/_skills.py @@ -36,7 +36,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Final, Protocol, runtime_checkable from ._feature_stage import ExperimentalFeature, experimental -from ._sessions import BaseContextProvider +from ._sessions import ContextProvider from ._tools import FunctionTool if TYPE_CHECKING: @@ -519,7 +519,7 @@ def __call__(self, skill: Skill, script: SkillScript, args: dict[str, Any] | Non @experimental(feature_id=ExperimentalFeature.SKILLS) -class SkillsProvider(BaseContextProvider): +class SkillsProvider(ContextProvider): """Context provider that advertises skills and exposes skill tools. Supports both **file-based** skills (discovered from ``SKILL.md`` files) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 521a0c4d96..043187caa4 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -1688,6 +1688,34 @@ def _update_conversation_id( options["conversation_id"] = conversation_id +def _update_continuation_state( + kwargs: dict[str, Any], + response: ChatResponse[Any], + *, + session: AgentSession | None, + options: dict[str, Any] | None = None, +) -> None: + """Update in-flight and persisted continuation state from a response.""" + conversation_id = response.conversation_id + if conversation_id is None: + return + + _update_conversation_id(kwargs, conversation_id, options) + if ( + session is not None + and not response.has_internal_conversation_id() + and session.service_session_id != conversation_id + ): + session.service_session_id = conversation_id + + +def _clear_internal_conversation_id(response: ChatResponse[Any]) -> ChatResponse[Any]: + if response.has_internal_conversation_id(): + response.conversation_id = None + response.clear_internal_conversation_id() + return response + + def _extract_tools( options: dict[str, Any] | None, ) -> ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None: @@ -2206,9 +2234,14 @@ async def _get_response() -> ChatResponse[Any]: ), ) aggregated_usage = add_usage_details(aggregated_usage, response.usage_details) + _update_continuation_state( + filtered_kwargs, + response, + session=invocation_session, + options=mutable_options, + ) if response.conversation_id is not None: - _update_conversation_id(filtered_kwargs, response.conversation_id, mutable_options) prepped_messages = [] result = await _process_function_requests( @@ -2223,7 +2256,7 @@ async def _get_response() -> ChatResponse[Any]: ) if result.get("action") == "return": response.usage_details = aggregated_usage - return response + return _clear_internal_conversation_id(response) total_function_calls += result.get("function_call_count", 0) if result.get("action") == "stop": # Error threshold reached: force a final non-tool turn so @@ -2279,11 +2312,17 @@ async def _get_response() -> ChatResponse[Any]: ), ) aggregated_usage = add_usage_details(aggregated_usage, response.usage_details) + _update_continuation_state( + filtered_kwargs, + response, + session=invocation_session, + options=mutable_options, + ) response.usage_details = aggregated_usage if fcc_messages: for msg in reversed(fcc_messages): response.messages.insert(0, msg) - return response + return _clear_internal_conversation_id(response) return _get_response() @@ -2343,6 +2382,12 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: # Get the finalized response from the inner stream # This triggers the inner stream's finalizer and result hooks response = await inner_stream.get_final_response() + _update_continuation_state( + filtered_kwargs, + response, + session=invocation_session, + options=mutable_options, + ) if not any( item.type in ("function_call", "function_approval_request") @@ -2352,7 +2397,6 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: return if response.conversation_id is not None: - _update_conversation_id(filtered_kwargs, response.conversation_id, mutable_options) prepped_messages = [] result = await _process_function_requests( @@ -2430,7 +2474,13 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: async for update in final_inner_stream: yield update # Finalize the inner stream to trigger its hooks - await final_inner_stream.get_final_response() + final_response = await final_inner_stream.get_final_response() + _update_continuation_state( + filtered_kwargs, + final_response, + session=invocation_session, + options=mutable_options, + ) def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse[Any]: # Note: stream_result_hooks are already run via inner stream's get_final_response() diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index 6d6bc58068..bd468b8450 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -2001,6 +2001,7 @@ class ChatResponse(SerializationMixin, Generic[ResponseModelT]): """ DEFAULT_EXCLUDE: ClassVar[set[str]] = {"raw_representation", "additional_properties"} + _INTERNAL_CONVERSATION_ID_KEY: ClassVar[str] = "_agent_framework_internal_conversation_id" def __init__( self, @@ -2069,6 +2070,18 @@ def __init__( self.continuation_token = continuation_token self.raw_representation: Any | list[Any] | None = raw_representation + def mark_internal_conversation_id(self) -> None: + """Mark the current conversation_id as internal control-flow state.""" + self.additional_properties[self._INTERNAL_CONVERSATION_ID_KEY] = True + + def clear_internal_conversation_id(self) -> None: + """Remove the internal conversation-id marker.""" + self.additional_properties.pop(self._INTERNAL_CONVERSATION_ID_KEY, None) + + def has_internal_conversation_id(self) -> bool: + """Return whether conversation_id is internal control-flow state.""" + return bool(self.additional_properties.get(self._INTERNAL_CONVERSATION_ID_KEY, False)) + @property def model_id(self) -> str | None: """Deprecated alias for :attr:`model`.""" diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index bf615814b3..53df314a24 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -14,8 +14,8 @@ from .._agents import BaseAgent from .._sessions import ( AgentSession, - BaseContextProvider, - BaseHistoryProvider, + ContextProvider, + HistoryProvider, InMemoryHistoryProvider, SessionContext, ) @@ -86,7 +86,7 @@ def __init__( id: str | None = None, name: str | None = None, description: str | None = None, - context_providers: Sequence[BaseContextProvider] | None = None, + context_providers: Sequence[ContextProvider] | None = None, **kwargs: Any, ) -> None: """Initialize the WorkflowAgent. @@ -249,7 +249,7 @@ async def _run_impl( options={}, ) for provider in self.context_providers: - if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: + if isinstance(provider, HistoryProvider) and not provider.load_messages: continue if provider_session is None: raise RuntimeError("Provider session must be available when context providers are configured.") @@ -314,7 +314,7 @@ async def _run_stream_impl( options={}, ) for provider in self.context_providers: - if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: + if isinstance(provider, HistoryProvider) and not provider.load_messages: continue if provider_session is None: raise RuntimeError("Provider session must be available when context providers are configured.") diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 236daa29a0..cdc6179602 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1502,98 +1502,25 @@ def __init__( self.token_usage_histogram = _get_token_usage_histogram() self.duration_histogram = _get_duration_histogram() - @overload - def run( + def _trace_agent_invocation( self, - messages: AgentRunInputs | None = None, *, - stream: Literal[False] = ..., - session: AgentSession | None = None, - middleware: Sequence[MiddlewareTypes] | None = None, - tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, - options: ChatOptions[ResponseModelBoundT], - compaction_strategy: CompactionStrategy | None = None, - tokenizer: TokenizerProtocol | None = None, - function_invocation_kwargs: Mapping[str, Any] | None = None, - client_kwargs: Mapping[str, Any] | None = None, - ) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ... - - @overload - def run( - self, - messages: AgentRunInputs | None = None, - *, - stream: Literal[False] = ..., - session: AgentSession | None = None, - middleware: Sequence[MiddlewareTypes] | None = None, - tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, - options: ChatOptions[None] | None = None, - compaction_strategy: CompactionStrategy | None = None, - tokenizer: TokenizerProtocol | None = None, - function_invocation_kwargs: Mapping[str, Any] | None = None, - client_kwargs: Mapping[str, Any] | None = None, - ) -> Awaitable[AgentResponse[Any]]: ... - - @overload - def run( - self, - messages: AgentRunInputs | None = None, - *, - stream: Literal[True], - session: AgentSession | None = None, - middleware: Sequence[MiddlewareTypes] | None = None, - tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, - options: ChatOptions[Any] | None = None, - compaction_strategy: CompactionStrategy | None = None, - tokenizer: TokenizerProtocol | None = None, - function_invocation_kwargs: Mapping[str, Any] | None = None, - client_kwargs: Mapping[str, Any] | None = None, - ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - - def run( - self, - messages: AgentRunInputs | None = None, - *, - stream: bool = False, - session: AgentSession | None = None, - middleware: Sequence[MiddlewareTypes] | None = None, - tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, - options: ChatOptions[Any] | None = None, - compaction_strategy: CompactionStrategy | None = None, - tokenizer: TokenizerProtocol | None = None, - function_invocation_kwargs: Mapping[str, Any] | None = None, - client_kwargs: Mapping[str, Any] | None = None, + messages: AgentRunInputs | None, + session: AgentSession | None, + merged_options: Mapping[str, Any], + client_kwargs: Mapping[str, Any] | None, + stream: bool, + execute: Callable[[], Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]], ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: - """Trace agent runs with OpenTelemetry spans and metrics.""" + """Trace an agent invocation while delegating execution to ``execute``.""" global OBSERVABILITY_SETTINGS - from ._types import ResponseStream, merge_chat_options + from ._types import ResponseStream - super_run = cast( - "Callable[..., Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]]", - super().run, # type: ignore[misc] - ) - provider_name = str(self.otel_provider_name) - super_run_kwargs: dict[str, Any] = { - "messages": messages, - "stream": stream, - "session": session, - "tools": tools, - "options": options, - "compaction_strategy": compaction_strategy, - "tokenizer": tokenizer, - "function_invocation_kwargs": function_invocation_kwargs, - "client_kwargs": client_kwargs, - } - if middleware is not None: - super_run_kwargs["middleware"] = middleware if not OBSERVABILITY_SETTINGS.ENABLED: - return super_run(**super_run_kwargs) # type: ignore[no-any-return] + return execute() - default_options = dict(getattr(self, "default_options", {})) + provider_name = str(self.otel_provider_name) merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} - merged_options: dict[str, Any] = merge_chat_options( - default_options, dict(options) if options is not None else {} - ) attributes = _get_span_attributes( operation_name=OtelAttr.AGENT_INVOKE_OPERATION, provider_name=provider_name, @@ -1601,7 +1528,7 @@ def run( agent_name=getattr(self, "name", None) or getattr(self, "id", "unknown"), agent_description=getattr(self, "description", None), thread_id=session.service_session_id if session else None, - all_options=merged_options, + all_options=dict(merged_options), **merged_client_kwargs, ) @@ -1613,7 +1540,7 @@ def run( if stream: try: - run_result: object = super_run(**super_run_kwargs) + run_result: object = execute() if isinstance(run_result, ResponseStream): result_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = run_result # pyright: ignore[reportUnknownVariableType] elif isinstance(run_result, Awaitable): @@ -1625,10 +1552,6 @@ def run( INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token) raise - # Create span directly without trace.use_span() context attachment. - # Streaming spans are closed asynchronously in cleanup hooks, which run - # in a different async context than creation — using use_span() would - # cause "Failed to detach context" errors from OpenTelemetry. operation = attributes.get(OtelAttr.OPERATION, "operation") span_name = attributes.get(OtelAttr.AGENT_NAME, "unknown") span = get_tracer().start_span(f"{operation} {span_name}") @@ -1638,7 +1561,7 @@ def run( span=span, provider_name=provider_name, messages=messages, - system_instructions=_get_instructions_from_options(merged_options), + system_instructions=_get_instructions_from_options(dict(merged_options)), ) span_state = {"closed": False} @@ -1687,15 +1610,13 @@ async def _finalize_stream() -> None: INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token) _close_span() - # Register a weak reference callback to close the span if stream is garbage collected - # without being consumed. This ensures spans don't leak if users don't consume streams. wrapped_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = result_stream.with_cleanup_hook( _record_duration ).with_cleanup_hook(_finalize_stream) weakref.finalize(wrapped_stream, _close_span) return wrapped_stream - async def _run() -> AgentResponse: + async def _run() -> AgentResponse[Any]: try: with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span: if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages: @@ -1703,11 +1624,11 @@ async def _run() -> AgentResponse: span=span, provider_name=provider_name, messages=messages, - system_instructions=_get_instructions_from_options(merged_options), + system_instructions=_get_instructions_from_options(dict(merged_options)), ) start_time_stamp = perf_counter() try: - response: AgentResponse[Any] = await super_run(**super_run_kwargs) + response: AgentResponse[Any] = await execute() except Exception as exception: capture_exception(span=span, exception=exception, timestamp=time_ns()) raise @@ -1736,6 +1657,103 @@ async def _run() -> AgentResponse: return _run() + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[False] = ..., + session: AgentSession | None = None, + middleware: Sequence[MiddlewareTypes] | None = None, + tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, + options: ChatOptions[ResponseModelBoundT], + compaction_strategy: CompactionStrategy | None = None, + tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + ) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ... + + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[False] = ..., + session: AgentSession | None = None, + middleware: Sequence[MiddlewareTypes] | None = None, + tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, + options: ChatOptions[None] | None = None, + compaction_strategy: CompactionStrategy | None = None, + tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + ) -> Awaitable[AgentResponse[Any]]: ... + + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[True], + session: AgentSession | None = None, + middleware: Sequence[MiddlewareTypes] | None = None, + tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, + options: ChatOptions[Any] | None = None, + compaction_strategy: CompactionStrategy | None = None, + tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + middleware: Sequence[MiddlewareTypes] | None = None, + tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, + options: ChatOptions[Any] | None = None, + compaction_strategy: CompactionStrategy | None = None, + tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: + """Trace agent runs with OpenTelemetry spans and metrics.""" + from ._types import merge_chat_options + + super_run = cast( + "Callable[..., Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]]", + super().run, # type: ignore[misc] + ) + super_run_kwargs: dict[str, Any] = { + "messages": messages, + "stream": stream, + "session": session, + "tools": tools, + "options": options, + "compaction_strategy": compaction_strategy, + "tokenizer": tokenizer, + "function_invocation_kwargs": function_invocation_kwargs, + "client_kwargs": client_kwargs, + } + if middleware is not None: + super_run_kwargs["middleware"] = middleware + + default_options = dict(getattr(self, "default_options", {})) + merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + merged_options: dict[str, Any] = merge_chat_options( + default_options, dict(options) if options is not None else {} + ) + return self._trace_agent_invocation( + messages=messages, + session=session, + merged_options=merged_options, + client_kwargs=merged_client_kwargs, + stream=stream, + execute=lambda: super_run(**super_run_kwargs), + ) + # region Otel Helpers diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 94253b3c34..015de2a2e3 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -3,8 +3,8 @@ import contextlib import inspect import json -from collections.abc import AsyncIterable, MutableSequence -from typing import Any +from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence, Sequence +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 @@ -18,22 +18,29 @@ AgentResponse, AgentResponseUpdate, AgentSession, - BaseContextProvider, + ChatContext, ChatOptions, ChatResponse, ChatResponseUpdate, Content, + ContextProvider, FunctionTool, + HistoryProvider, + InMemoryHistoryProvider, Message, + ResponseStream, + SessionContext, SlidingWindowStrategy, SupportsAgentRun, SupportsChatGetResponse, TruncationStrategy, + chat_middleware, tool, ) from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name from agent_framework._mcp import MCPTool, _build_prefixed_mcp_name, _normalize_mcp_name from agent_framework._middleware import FunctionInvocationContext +from agent_framework.exceptions import AgentInvalidRequestException, ChatClientInvalidResponseException class _FixedTokenizer: @@ -68,6 +75,49 @@ def get_mcp_client(self) -> contextlib.AbstractAsyncContextManager[Any]: raise NotImplementedError +class _RecordingHistoryProvider(HistoryProvider): + def __init__(self, source_id: str = "recording_history") -> None: + super().__init__(source_id=source_id) + + async def get_messages( + self, + session_id: str | None, + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> list[Message]: + if state is None: + return [] + state["get_call_count"] = state.get("get_call_count", 0) + 1 + return list(cast(list[Message], state.get("messages", []))) + + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + if state is None: + return + state["save_call_count"] = state.get("save_call_count", 0) + 1 + state.setdefault("messages", []).extend(messages) + + +class _ResponseIdRecordingHistoryProvider(_RecordingHistoryProvider): + async def after_run( + self, + *, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + state.setdefault("response_ids", []).append(context.response.response_id if context.response else None) + await super().after_run(agent=agent, session=session, context=context, state=state) + + def test_agent_session_type(agent_session: AgentSession) -> None: assert isinstance(agent_session, AgentSession) @@ -314,6 +364,413 @@ async def test_prepare_run_context_handles_function_kwargs( assert ctx["client_kwargs"]["session"] is session +async def test_chat_agent_persists_history_per_service_call( + chat_client_base: SupportsChatGetResponse, +) -> None: + provider = _RecordingHistoryProvider() + + @tool(name="lookup_weather", approval_mode="never_require") + def lookup_weather(location: str) -> str: + return f"Weather in {location}: sunny" + + session = AgentSession() + session.state[provider.source_id] = { + "messages": [ + Message(role="user", text="Earlier question"), + Message(role="assistant", text="Earlier answer"), + ] + } + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_1", + name="lookup_weather", + arguments='{"location": "Seattle"}', + ) + ], + ), + response_id="resp_call_1", + ), + ChatResponse(messages=Message(role="assistant", text="It is sunny in Seattle."), response_id="resp_call_2"), + ] + + agent = Agent( + client=chat_client_base, + tools=[lookup_weather], + context_providers=[provider], + require_per_service_call_history_persistence=True, + ) + + result = await agent.run("What's the weather in Seattle?", session=session) + + provider_state = session.state[provider.source_id] + stored_messages = cast(list[Message], provider_state["messages"]) + + assert result.text == "It is sunny in Seattle." + assert result.response_id is None + assert chat_client_base.call_count == 2 + assert provider_state["get_call_count"] == 2 + assert provider_state["save_call_count"] == 2 + assert stored_messages[-1].text == "It is sunny in Seattle." + assert session.service_session_id is None + + +async def test_chat_agent_persists_history_per_service_call_streaming( + chat_client_base: SupportsChatGetResponse, +) -> None: + provider = _RecordingHistoryProvider() + + @tool(name="lookup_weather", approval_mode="never_require") + def lookup_weather(location: str) -> str: + return f"Weather in {location}: sunny" + + session = AgentSession() + session.state[provider.source_id] = { + "messages": [ + Message(role="user", text="Earlier question"), + Message(role="assistant", text="Earlier answer"), + ] + } + chat_client_base.streaming_responses = [ + [ + ChatResponseUpdate( + contents=[ + Content.from_function_call( + call_id="call_1", + name="lookup_weather", + arguments='{"location": "Seattle"}', + ) + ], + role="assistant", + finish_reason="stop", + response_id="resp_call_1", + ) + ], + [ + ChatResponseUpdate( + contents=[Content.from_text("It is sunny in Seattle.")], + role="assistant", + finish_reason="stop", + response_id="resp_call_2", + ) + ], + ] + + agent = Agent( + client=chat_client_base, + tools=[lookup_weather], + context_providers=[provider], + require_per_service_call_history_persistence=True, + ) + + stream = agent.run("What's the weather in Seattle?", session=session, stream=True) + async for _ in stream: + pass + result = await stream.get_final_response() + + provider_state = session.state[provider.source_id] + stored_messages = cast(list[Message], provider_state["messages"]) + + assert result.text == "It is sunny in Seattle." + assert result.response_id is None + assert chat_client_base.call_count == 2 + assert provider_state["get_call_count"] == 2 + assert provider_state["save_call_count"] == 2 + assert stored_messages[-1].text == "It is sunny in Seattle." + assert session.service_session_id is None + + +async def test_streaming_per_service_call_persistence_hides_response_id_from_after_run( + chat_client_base: SupportsChatGetResponse, +) -> None: + provider = _ResponseIdRecordingHistoryProvider() + + @tool(name="lookup_weather", approval_mode="never_require") + def lookup_weather(location: str) -> str: + return f"Weather in {location}: sunny" + + session = AgentSession() + session.state[provider.source_id] = {"messages": []} + chat_client_base.streaming_responses = [ + [ + ChatResponseUpdate( + contents=[ + Content.from_function_call( + call_id="call_1", + name="lookup_weather", + arguments='{"location": "Seattle"}', + ) + ], + role="assistant", + finish_reason="stop", + response_id="resp_call_1", + ) + ], + [ + ChatResponseUpdate( + contents=[Content.from_text("It is sunny in Seattle.")], + role="assistant", + finish_reason="stop", + response_id="resp_call_2", + ) + ], + ] + + agent = Agent( + client=chat_client_base, + tools=[lookup_weather], + context_providers=[provider], + require_per_service_call_history_persistence=True, + ) + + stream = agent.run("What's the weather in Seattle?", session=session, stream=True) + async for _ in stream: + pass + result = await stream.get_final_response() + + provider_state = session.state[provider.source_id] + + assert result.response_id is None + assert provider_state["response_ids"] == [None, None] + + +async def test_per_service_call_persistence_uses_real_service_storage_when_client_stores_by_default( + chat_client_base: SupportsChatGetResponse, +) -> None: + provider = _RecordingHistoryProvider() + + @tool(name="lookup_weather", approval_mode="never_require") + def lookup_weather(location: str) -> str: + return f"Weather in {location}: sunny" + + chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined] + + session = AgentSession() + session.state[provider.source_id] = {"messages": []} + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_1", + name="lookup_weather", + arguments='{"location": "Seattle"}', + ) + ], + ), + conversation_id="resp_service_managed", + response_id="resp_call_1", + ), + ChatResponse( + messages=Message(role="assistant", text="It is sunny in Seattle."), + conversation_id="resp_service_managed", + response_id="resp_call_2", + ), + ] + + agent = Agent( + client=chat_client_base, + tools=[lookup_weather], + context_providers=[provider], + require_per_service_call_history_persistence=True, + ) + + result = await agent.run("What's the weather in Seattle?", session=session) + + provider_state = session.state[provider.source_id] + + assert result.text == "It is sunny in Seattle." + assert result.response_id == "resp_call_2" + assert chat_client_base.call_count == 2 + assert "get_call_count" not in provider_state + assert "save_call_count" not in provider_state + assert session.service_session_id == "resp_service_managed" + + +async def test_service_storage_updates_session_handle_per_service_call_before_non_streaming_failure( + chat_client_base: SupportsChatGetResponse, +) -> None: + provider = _RecordingHistoryProvider() + + @tool(name="lookup_weather", approval_mode="never_require") + def lookup_weather(location: str) -> str: + return f"Weather in {location}: sunny" + + chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined] + + session = AgentSession() + session.state[provider.source_id] = {"messages": []} + first_response = ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_1", + name="lookup_weather", + arguments='{"location": "Seattle"}', + ) + ], + ), + conversation_id="resp_call_1", + response_id="resp_call_1", + ) + mock_get_non_streaming_response = AsyncMock( + side_effect=[first_response, RuntimeError("service down")], + ) + + agent = Agent( + client=chat_client_base, + tools=[lookup_weather], + context_providers=[provider], + require_per_service_call_history_persistence=True, + ) + + with ( + patch.object(chat_client_base, "_get_non_streaming_response", new=mock_get_non_streaming_response), + pytest.raises(RuntimeError, match="service down"), + ): + await agent.run("What's the weather in Seattle?", session=session) + + assert mock_get_non_streaming_response.await_count == 2 + assert session.service_session_id == "resp_call_1" + + +async def test_service_storage_updates_session_handle_per_service_call_before_streaming_failure( + chat_client_base: SupportsChatGetResponse, +) -> None: + provider = _RecordingHistoryProvider() + + @tool(name="lookup_weather", approval_mode="never_require") + def lookup_weather(location: str) -> str: + return f"Weather in {location}: sunny" + + chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined] + + session = AgentSession() + session.state[provider.source_id] = {"messages": []} + + async def _first_stream_updates() -> AsyncIterable[ChatResponseUpdate]: + yield ChatResponseUpdate( + contents=[ + Content.from_function_call( + call_id="call_1", + name="lookup_weather", + arguments='{"location": "Seattle"}', + ) + ], + role="assistant", + finish_reason="stop", + ) + + def _finalize_first_stream(_updates: Sequence[ChatResponseUpdate]) -> ChatResponse[Any]: + return ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_1", + name="lookup_weather", + arguments='{"location": "Seattle"}', + ) + ], + ), + conversation_id="resp_call_1", + response_id="resp_call_1", + ) + + first_stream = ResponseStream(_first_stream_updates(), finalizer=_finalize_first_stream) + mock_get_streaming_response = MagicMock(side_effect=[first_stream, RuntimeError("service down")]) + + agent = Agent( + client=chat_client_base, + tools=[lookup_weather], + context_providers=[provider], + require_per_service_call_history_persistence=True, + ) + + with ( + patch.object(chat_client_base, "_get_streaming_response", new=mock_get_streaming_response), + pytest.raises(RuntimeError, match="service down"), + ): + stream = agent.run("What's the weather in Seattle?", session=session, stream=True) + async for _ in stream: + pass + + assert mock_get_streaming_response.call_count == 2 + assert session.service_session_id == "resp_call_1" + + +async def test_chat_agent_without_per_service_call_persistence_preserves_response_id( + chat_client_base: SupportsChatGetResponse, +) -> None: + chat_client_base.run_responses = [ + ChatResponse( + messages=Message(role="assistant", text="Hello"), + response_id="resp_call_1", + ) + ] + + agent = Agent( + client=chat_client_base, + context_providers=[InMemoryHistoryProvider()], + ) + + result = await agent.run("Hello", session=AgentSession(), options={"store": False}) + + assert result.response_id == "resp_call_1" + + +async def test_per_service_call_persistence_rejects_real_service_conversation_id( + chat_client_base: SupportsChatGetResponse, +) -> None: + provider = _RecordingHistoryProvider() + chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined] + session = AgentSession() + session.state[provider.source_id] = {"messages": []} + chat_client_base.run_responses = [ + ChatResponse( + messages=Message(role="assistant", text="Hello"), + conversation_id="resp_service_managed", + ) + ] + + agent = Agent( + client=chat_client_base, + context_providers=[provider], + require_per_service_call_history_persistence=True, + ) + + with pytest.raises( + ChatClientInvalidResponseException, + match="require_per_service_call_history_persistence cannot be used", + ): + await agent.run("Hello", session=session, options={"store": False}) + + +async def test_per_service_call_persistence_rejects_existing_conversation_id_when_service_not_storing_history( + chat_client_base: SupportsChatGetResponse, +) -> None: + provider = _RecordingHistoryProvider() + session = AgentSession() + session.state[provider.source_id] = {"messages": []} + + agent = Agent( + client=chat_client_base, + context_providers=[provider], + require_per_service_call_history_persistence=True, + ) + + with pytest.raises( + AgentInvalidRequestException, + match="require_per_service_call_history_persistence cannot be used", + ): + await agent.run("Hello", session=session, options={"store": False, "conversation_id": "existing_conversation"}) + + async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChatGetResponse) -> None: mock_response = ChatResponse( messages=[Message(role="assistant", contents=[Content.from_text("test response")])], @@ -586,7 +1043,7 @@ async def test_chat_client_agent_author_name_is_used_from_response( # Mock context provider for testing -class MockContextProvider(BaseContextProvider): +class MockContextProvider(ContextProvider): def __init__(self, messages: list[Message] | None = None) -> None: super().__init__(source_id="mock") self.context_messages = messages @@ -1723,7 +2180,7 @@ async def test_agent_create_session_with_context_providers( ): """Test that create_session works when context_providers are set on the agent.""" - class TestContextProvider(BaseContextProvider): + class TestContextProvider(ContextProvider): def __init__(self): super().__init__(source_id="test") @@ -1798,7 +2255,7 @@ def context_tool(text: str) -> str: """A tool provided by context.""" return text - class ToolContextProvider(BaseContextProvider): + class ToolContextProvider(ContextProvider): def __init__(self): super().__init__(source_id="tool-context") @@ -1827,7 +2284,7 @@ async def test_chat_agent_context_provider_adds_instructions_when_agent_has_none ): """Test that context provider instructions are used when agent has no default instructions.""" - class InstructionContextProvider(BaseContextProvider): + class InstructionContextProvider(ContextProvider): def __init__(self): super().__init__(source_id="instruction-context") @@ -1849,6 +2306,33 @@ async def before_run(self, *, agent, session, context, state): assert options.get("instructions") == "Context-provided instructions" +async def test_chat_agent_context_provider_adds_middleware_when_agent_has_none( + chat_client_base: SupportsChatGetResponse, +) -> None: + """Test that context provider middleware is collected during preparation.""" + + @chat_middleware + async def context_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + class MiddlewareContextProvider(ContextProvider): + def __init__(self) -> None: + super().__init__(source_id="middleware-context") + + async def before_run(self, *, agent, session, context, state) -> None: + context.extend_middleware("middleware-context", context_chat_middleware) + + agent = Agent(client=chat_client_base, context_providers=[MiddlewareContextProvider()]) + + session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session=None, + input_messages=[Message(role="user", text="Hello")], + ) + + assert session_context.middleware["middleware-context"] == [context_chat_middleware] + assert session_context.get_middleware() == [context_chat_middleware] + + # region STORES_BY_DEFAULT tests diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index 69d08482d3..508bdee075 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -15,6 +15,7 @@ ChatResponse, ChatResponseUpdate, Content, + ContextProvider, FunctionInvocationContext, FunctionMiddleware, FunctionTool, @@ -464,6 +465,31 @@ async def function_function_middleware( expected_order = ["class_agent_before", "function_agent_before", "function_agent_after", "class_agent_after"] assert execution_order == expected_order + async def test_provider_added_agent_middleware_is_rejected(self, chat_client_base: "MockBaseChatClient") -> None: + """Test provider-added agent middleware is rejected explicitly.""" + + @agent_middleware + async def provider_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + class ProviderMiddlewareContextProvider(ContextProvider): + def __init__(self) -> None: + super().__init__(source_id="provider-middleware") + + async def before_run(self, *, agent, session, context, state) -> None: + context.extend_middleware(self.source_id, provider_middleware) + + agent = Agent( + client=chat_client_base, + context_providers=[ProviderMiddlewareContextProvider()], + ) + + with pytest.raises( + MiddlewareException, + match="Context providers may only add chat or function middleware", + ): + await agent.run([Message(role="user", text="test message")]) + # region Tool Functions for Testing @@ -2066,6 +2092,121 @@ async def tracking_function_middleware( "agent_middleware_after", ] + async def test_provider_added_chat_and_function_middleware_are_forwarded( + self, chat_client_base: "MockBaseChatClient" + ) -> None: + """Test provider-added chat and function middleware forwarding and ordering.""" + execution_order: list[str] = [] + + @chat_middleware + async def constructor_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + execution_order.append("constructor_chat_before") + await call_next() + execution_order.append("constructor_chat_after") + + @chat_middleware + async def provider_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + execution_order.append("provider_chat_before") + await call_next() + execution_order.append("provider_chat_after") + + @chat_middleware + async def run_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + execution_order.append("run_chat_before") + await call_next() + execution_order.append("run_chat_after") + + @function_middleware + async def constructor_function_middleware( + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] + ) -> None: + execution_order.append("constructor_function_before") + await call_next() + execution_order.append("constructor_function_after") + + @function_middleware + async def provider_function_middleware( + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] + ) -> None: + execution_order.append("provider_function_before") + await call_next() + execution_order.append("provider_function_after") + + @function_middleware + async def run_function_middleware( + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] + ) -> None: + execution_order.append("run_function_before") + await call_next() + execution_order.append("run_function_after") + + class ProviderMiddlewareContextProvider(ContextProvider): + def __init__(self) -> None: + super().__init__(source_id="provider-middleware") + + async def before_run(self, *, agent, session, context, state) -> None: + context.extend_middleware( + self.source_id, + [ + provider_chat_middleware, + provider_function_middleware, + ], + ) + + chat_client_base.run_responses = [ + ChatResponse( + messages=[ + Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_provider", + name="sample_tool_function", + arguments='{"location": "Seattle"}', + ) + ], + ) + ] + ), + ChatResponse(messages=[Message(role="assistant", text="Final response")]), + ] + + agent = Agent( + client=chat_client_base, + middleware=[constructor_chat_middleware, constructor_function_middleware], + context_providers=[ProviderMiddlewareContextProvider()], + tools=[sample_tool_function], + ) + + response = await agent.run( + [Message(role="user", text="Get weather for Seattle")], + middleware=[run_chat_middleware, run_function_middleware], + ) + + assert response is not None + assert chat_client_base.call_count == 2 + assert response.messages[-1].text == "Final response" + assert execution_order == [ + "constructor_chat_before", + "run_chat_before", + "provider_chat_before", + "provider_chat_after", + "run_chat_after", + "constructor_chat_after", + "constructor_function_before", + "run_function_before", + "provider_function_before", + "provider_function_after", + "run_function_after", + "constructor_function_after", + "constructor_chat_before", + "run_chat_before", + "provider_chat_before", + "provider_chat_after", + "run_chat_after", + "constructor_chat_after", + ] + async def test_agent_middleware_can_access_and_override_options(self) -> None: """Test that agent middleware can access and override runtime options.""" captured_options: dict[str, Any] = {} diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index bd2cb8155e..3d4d75afe5 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -1,16 +1,26 @@ # Copyright (c) Microsoft. All rights reserved. import json -from collections.abc import Sequence +from collections.abc import Awaitable, Callable, Sequence -from agent_framework import Message -from agent_framework._sessions import ( +import pytest + +from agent_framework import ( + AgentContext, AgentSession, BaseContextProvider, BaseHistoryProvider, + ChatContext, + ContextProvider, + HistoryProvider, InMemoryHistoryProvider, + Message, SessionContext, + agent_middleware, + chat_middleware, ) +from agent_framework._sessions import LOCAL_HISTORY_CONVERSATION_ID, is_local_history_conversation_id +from agent_framework.exceptions import MiddlewareException # --------------------------------------------------------------------------- # SessionContext tests @@ -102,6 +112,50 @@ def test_extend_instructions_sequence(self) -> None: ctx.extend_instructions("sys", ["Be helpful", "Be concise"]) assert ctx.instructions == ["Be helpful", "Be concise"] + def test_extend_middleware_creates_key_and_appends(self) -> None: + ctx = SessionContext(input_messages=[]) + + @chat_middleware + async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + @chat_middleware + async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + ctx.extend_middleware("rag", first_middleware) + ctx.extend_middleware("rag", [second_middleware]) + + assert ctx.middleware["rag"] == [first_middleware, second_middleware] + assert ctx.get_middleware() == [first_middleware, second_middleware] + + def test_extend_middleware_preserves_source_order(self) -> None: + ctx = SessionContext(input_messages=[]) + + @chat_middleware + async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + @chat_middleware + async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + ctx.extend_middleware("a", first_middleware) + ctx.extend_middleware("b", second_middleware) + + assert list(ctx.middleware.keys()) == ["a", "b"] + assert ctx.get_middleware() == [first_middleware, second_middleware] + + def test_extend_middleware_rejects_agent_middleware(self) -> None: + ctx = SessionContext(input_messages=[]) + + @agent_middleware + async def provider_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + with pytest.raises(MiddlewareException, match="Context providers may only add chat or function middleware"): + ctx.extend_middleware("rag", provider_agent_middleware) + def test_get_messages_all(self) -> None: ctx = SessionContext(input_messages=[]) ctx.extend_messages("a", [Message(role="user", contents=["a"])]) @@ -154,37 +208,58 @@ def test_response_readonly(self) -> None: ctx._response = resp assert ctx.response is resp + def test_local_history_conversation_id_sentinel(self) -> None: + assert is_local_history_conversation_id(LOCAL_HISTORY_CONVERSATION_ID) is True + assert is_local_history_conversation_id("some_other_id") is False + # --------------------------------------------------------------------------- -# BaseContextProvider tests +# ContextProvider tests # --------------------------------------------------------------------------- -class TestContextProviderBase: +class TestContextProvider: def test_source_id_required(self) -> None: - provider = BaseContextProvider(source_id="test") + provider = ContextProvider(source_id="test") assert provider.source_id == "test" async def test_before_run_is_noop(self) -> None: - provider = BaseContextProvider(source_id="test") + provider = ContextProvider(source_id="test") session = AgentSession() ctx = SessionContext(input_messages=[]) # Should not raise await provider.before_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type] async def test_after_run_is_noop(self) -> None: - provider = BaseContextProvider(source_id="test") + provider = ContextProvider(source_id="test") session = AgentSession() ctx = SessionContext(input_messages=[]) await provider.after_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type] # --------------------------------------------------------------------------- -# BaseHistoryProvider tests +# Deprecated provider alias tests +# --------------------------------------------------------------------------- + + +class TestDeprecatedProviderAliases: + def test_base_context_provider_warns_and_is_compatible(self) -> None: + with pytest.warns(DeprecationWarning, match="BaseContextProvider is deprecated. Use ContextProvider instead."): + provider = BaseContextProvider(source_id="test") + + assert isinstance(provider, ContextProvider) + + def test_base_provider_aliases_preserve_subtyping(self) -> None: + assert issubclass(BaseContextProvider, ContextProvider) + assert issubclass(BaseHistoryProvider, HistoryProvider) + + +# --------------------------------------------------------------------------- +# HistoryProvider tests # --------------------------------------------------------------------------- -class ConcreteHistoryProvider(BaseHistoryProvider): +class ConcreteHistoryProvider(HistoryProvider): """Concrete test implementation.""" def __init__(self, source_id: str, stored_messages: list[Message] | None = None, **kwargs) -> None: diff --git a/python/packages/foundry/agent_framework_foundry/_agent.py b/python/packages/foundry/agent_framework_foundry/_agent.py index 6f548b4012..a499abf6f4 100644 --- a/python/packages/foundry/agent_framework_foundry/_agent.py +++ b/python/packages/foundry/agent_framework_foundry/_agent.py @@ -17,9 +17,9 @@ from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, AgentMiddlewareLayer, - BaseContextProvider, ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer, + ContextProvider, FunctionInvocationConfiguration, FunctionInvocationLayer, FunctionTool, @@ -50,8 +50,8 @@ if TYPE_CHECKING: from agent_framework import ( Agent, - BaseContextProvider, ChatAndFunctionMiddlewareTypes, + ContextProvider, MiddlewareTypes, ToolTypes, ) @@ -224,8 +224,9 @@ def as_agent( instructions: str | None = None, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None, - context_providers: Sequence[BaseContextProvider] | None = None, + context_providers: Sequence[ContextProvider] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, + require_per_service_call_history_persistence: bool = False, function_invocation_configuration: FunctionInvocationConfiguration | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, @@ -246,6 +247,7 @@ def as_agent( tools=function_tools, context_providers=context_providers, middleware=middleware, + require_per_service_call_history_persistence=require_per_service_call_history_persistence, client_type=cast(type[RawFoundryAgentChatClient], self.__class__), id=id, name=self.agent_name if name is None else name, @@ -468,7 +470,7 @@ def __init__( project_client: AIProjectClient | None = None, allow_preview: bool | None = None, tools: FunctionTool | Callable[..., Any] | Sequence[FunctionTool | Callable[..., Any]] | None = None, - context_providers: Sequence[BaseContextProvider] | None = None, + context_providers: Sequence[ContextProvider] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, client_type: type[RawFoundryAgentChatClient] | None = None, env_file_path: str | None = None, @@ -478,6 +480,7 @@ def __init__( description: str | None = None, instructions: str | None = None, default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None, + require_per_service_call_history_persistence: bool = False, function_invocation_configuration: FunctionInvocationConfiguration | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, @@ -507,6 +510,8 @@ def __init__( description: Optional local description for the local agent wrapper. instructions: Optional instructions for the local agent wrapper. default_options: Default chat options for the local agent wrapper. + require_per_service_call_history_persistence: Whether to require per-service-call + chat history persistence when using local history providers. function_invocation_configuration: Optional function invocation configuration override. compaction_strategy: Optional agent-level in-run compaction override. tokenizer: Optional agent-level tokenizer override. @@ -548,6 +553,7 @@ def __init__( default_options=cast(FoundryAgentOptionsT | None, default_options), context_providers=context_providers, middleware=middleware, + require_per_service_call_history_persistence=require_per_service_call_history_persistence, compaction_strategy=compaction_strategy, tokenizer=tokenizer, additional_properties=dict(additional_properties) if additional_properties is not None else None, @@ -661,7 +667,7 @@ def __init__( project_client: AIProjectClient | None = None, allow_preview: bool | None = None, tools: FunctionTool | Callable[..., Any] | Sequence[FunctionTool | Callable[..., Any]] | None = None, - context_providers: Sequence[BaseContextProvider] | None = None, + context_providers: Sequence[ContextProvider] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, client_type: type[RawFoundryAgentChatClient] | None = None, env_file_path: str | None = None, @@ -671,6 +677,7 @@ def __init__( description: str | None = None, instructions: str | None = None, default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None, + require_per_service_call_history_persistence: bool = False, function_invocation_configuration: FunctionInvocationConfiguration | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, @@ -696,6 +703,8 @@ def __init__( description: Optional local description for the local agent wrapper. instructions: Optional instructions for the local agent wrapper. default_options: Default chat options for the local agent wrapper. + require_per_service_call_history_persistence: Whether to require per-service-call + chat history persistence when using local history providers. function_invocation_configuration: Optional function invocation configuration override. compaction_strategy: Optional agent-level in-run compaction override. tokenizer: Optional agent-level tokenizer override. @@ -719,6 +728,7 @@ def __init__( description=description, instructions=instructions, default_options=default_options, + require_per_service_call_history_persistence=require_per_service_call_history_persistence, function_invocation_configuration=function_invocation_configuration, compaction_strategy=compaction_strategy, tokenizer=tokenizer, diff --git a/python/packages/foundry/agent_framework_foundry/_memory_provider.py b/python/packages/foundry/agent_framework_foundry/_memory_provider.py index 36d4a27a43..742b2e4753 100644 --- a/python/packages/foundry/agent_framework_foundry/_memory_provider.py +++ b/python/packages/foundry/agent_framework_foundry/_memory_provider.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. -"""Foundry Memory Context Provider using BaseContextProvider. +"""Foundry Memory Context Provider using ContextProvider. This module provides ``FoundryMemoryProvider``, built on -:class:`BaseContextProvider`. +:class:`ContextProvider`. """ from __future__ import annotations @@ -16,7 +16,7 @@ from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, AgentSession, - BaseContextProvider, + ContextProvider, Message, SessionContext, load_settings, @@ -46,8 +46,8 @@ class FoundryProjectSettings(TypedDict, total=False): project_endpoint: str | None -class FoundryMemoryProvider(BaseContextProvider): - """Foundry Memory context provider using the new BaseContextProvider hooks pattern. +class FoundryMemoryProvider(ContextProvider): + """Foundry Memory context provider using the new ContextProvider hooks pattern. Integrates Azure AI Foundry Memory Store for persistent semantic memory, searching and storing memories via the Azure AI Projects SDK. diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 69f3bf20d4..8bec66737a 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -15,8 +15,8 @@ AgentResponseUpdate, AgentSession, BaseAgent, - BaseContextProvider, Content, + ContextProvider, Message, ResponseStream, normalize_messages, @@ -178,7 +178,7 @@ def __init__( id: str | None = None, name: str | None = None, description: str | None = None, - context_providers: Sequence[BaseContextProvider] | None = None, + context_providers: Sequence[ContextProvider] | None = None, middleware: Sequence[AgentMiddlewareTypes] | None = None, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, default_options: OptionsT | None = None, diff --git a/python/packages/mem0/agent_framework_mem0/_context_provider.py b/python/packages/mem0/agent_framework_mem0/_context_provider.py index 36b878e411..6aef321b15 100644 --- a/python/packages/mem0/agent_framework_mem0/_context_provider.py +++ b/python/packages/mem0/agent_framework_mem0/_context_provider.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. -"""New-pattern Mem0 context provider using BaseContextProvider. +"""New-pattern Mem0 context provider using ContextProvider. This module provides ``Mem0ContextProvider``, built on the new -:class:`BaseContextProvider` hooks pattern. +:class:`ContextProvider` hooks pattern. """ from __future__ import annotations @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, ClassVar from agent_framework import Message -from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext +from agent_framework._sessions import AgentSession, ContextProvider, SessionContext from mem0 import AsyncMemory, AsyncMemoryClient if sys.version_info >= (3, 11): @@ -33,8 +33,8 @@ class _MemorySearchResponse_v1_1(TypedDict): _MemorySearchResponse_v2 = list[dict[str, Any]] -class Mem0ContextProvider(BaseContextProvider): - """Mem0 context provider using the new BaseContextProvider hooks pattern. +class Mem0ContextProvider(ContextProvider): + """Mem0 context provider using the new ContextProvider hooks pattern. Integrates Mem0 for persistent semantic memory, searching and storing memories via the Mem0 API. diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index 4352a8af47..baf91e0134 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -39,7 +39,7 @@ from typing import Any from agent_framework import Agent, SupportsAgentRun -from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware +from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware, MiddlewareTermination from agent_framework._sessions import AgentSession from agent_framework._tools import FunctionTool, tool from agent_framework._types import AgentResponse, Content, Message @@ -138,8 +138,6 @@ async def process( await call_next() return - from agent_framework._middleware import MiddlewareTermination - # Short-circuit execution and provide deterministic response payload for the tool call. # Parse the result using the default parser to ensure in a form that can be passed directly to LLM APIs. context.result = FunctionTool.parse_result({ @@ -375,6 +373,7 @@ def _clone_chat_agent(self, agent: Agent[Any]) -> Agent[Any]: description=agent.description, context_providers=agent.context_providers, middleware=agent.agent_middleware, + require_per_service_call_history_persistence=agent.require_per_service_call_history_persistence, default_options=cloned_options, # type: ignore[assignment] ) diff --git a/python/packages/orchestrations/tests/test_handoff.py b/python/packages/orchestrations/tests/test_handoff.py index 5c594ed537..b1524cce85 100644 --- a/python/packages/orchestrations/tests/test_handoff.py +++ b/python/packages/orchestrations/tests/test_handoff.py @@ -8,10 +8,11 @@ import pytest from agent_framework import ( Agent, - BaseContextProvider, ChatResponse, ChatResponseUpdate, Content, + ContextProvider, + InMemoryHistoryProvider, Message, ResponseStream, WorkflowEvent, @@ -695,6 +696,48 @@ def test_handoff_clone_disables_provider_side_storage() -> None: assert executor._agent.default_options.get("store") is False +async def test_handoff_clone_preserves_per_service_call_history_persistence() -> None: + """Handoff clones should keep per-service-call history persistence active for auto-handoff termination.""" + triage_history = InMemoryHistoryProvider() + triage = Agent( + id="triage", + name="triage", + client=MockChatClient(name="triage", handoff_to="specialist"), + context_providers=[triage_history], + require_per_service_call_history_persistence=True, + ) + specialist = Agent( + id="specialist", + name="specialist", + client=MockChatClient(name="specialist"), + default_options={"tool_choice": "none"}, + ) + + workflow = ( + HandoffBuilder(participants=[triage, specialist], termination_condition=lambda _: False) + .with_start_agent(triage) + .add_handoff(triage, [specialist]) + .add_handoff(specialist, [triage]) + .build() + ) + + await _drain(workflow.run("start", stream=True)) + + executor = workflow.executors[resolve_agent_id(triage)] + assert isinstance(executor, HandoffAgentExecutor) + assert executor._agent.require_per_service_call_history_persistence is True + + provider_state = executor._session.state[triage_history.source_id] + stored_messages = await triage_history.get_messages( + executor._session.session_id, + state=provider_state, + ) + + assert [message.role for message in stored_messages] == ["user", "assistant"] + assert any(content.type == "function_call" for content in stored_messages[-1].contents) + assert all(message.role != "tool" for message in stored_messages) + + async def test_handoff_clears_stale_service_session_id_before_run() -> None: """Stale service session IDs must be dropped before each handoff agent turn.""" triage = MockHandoffAgent(name="triage", handoff_to="specialist") @@ -997,7 +1040,7 @@ async def test_context_provider_preserved_during_handoff(): # Track whether context provider methods were called provider_calls: list[str] = [] - class TestContextProvider(BaseContextProvider): + class TestContextProvider(ContextProvider): """A test context provider that tracks its invocations.""" def __init__(self) -> None: diff --git a/python/packages/redis/agent_framework_redis/_context_provider.py b/python/packages/redis/agent_framework_redis/_context_provider.py index 98d5d9917f..3753ab7be1 100644 --- a/python/packages/redis/agent_framework_redis/_context_provider.py +++ b/python/packages/redis/agent_framework_redis/_context_provider.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. -"""New-pattern Redis context provider using BaseContextProvider. +"""New-pattern Redis context provider using ContextProvider. This module provides ``RedisContextProvider``, built on the new -:class:`BaseContextProvider` hooks pattern. +:class:`ContextProvider` hooks pattern. """ from __future__ import annotations @@ -16,7 +16,7 @@ import numpy as np from agent_framework import Message -from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext +from agent_framework._sessions import AgentSession, ContextProvider, SessionContext from agent_framework.exceptions import ( AgentException, IntegrationInvalidRequestException, @@ -41,8 +41,8 @@ from agent_framework._agents import SupportsAgentRun -class RedisContextProvider(BaseContextProvider): - """Redis context provider using the new BaseContextProvider hooks pattern. +class RedisContextProvider(ContextProvider): + """Redis context provider using the new ContextProvider hooks pattern. Stores context in Redis and retrieves scoped context via full-text or optional hybrid vector search. diff --git a/python/packages/redis/agent_framework_redis/_history_provider.py b/python/packages/redis/agent_framework_redis/_history_provider.py index be2db098b8..dbdc358a93 100644 --- a/python/packages/redis/agent_framework_redis/_history_provider.py +++ b/python/packages/redis/agent_framework_redis/_history_provider.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. -"""New-pattern Redis history provider using BaseHistoryProvider. +"""New-pattern Redis history provider using HistoryProvider. This module provides ``RedisHistoryProvider``, built on the new -:class:`BaseHistoryProvider` hooks pattern. +:class:`HistoryProvider` hooks pattern. """ from __future__ import annotations @@ -13,12 +13,12 @@ import redis.asyncio as redis from agent_framework import Message -from agent_framework._sessions import BaseHistoryProvider +from agent_framework._sessions import HistoryProvider from redis.credentials import CredentialProvider -class RedisHistoryProvider(BaseHistoryProvider): - """Redis-backed history provider using the new BaseHistoryProvider hooks pattern. +class RedisHistoryProvider(HistoryProvider): + """Redis-backed history provider using the new HistoryProvider hooks pattern. Stores conversation history in Redis Lists, with each session isolated by a unique Redis key. diff --git a/python/packages/redis/tests/test_providers.py b/python/packages/redis/tests/test_providers.py index dd0ff51cd8..54587a55e1 100644 --- a/python/packages/redis/tests/test_providers.py +++ b/python/packages/redis/tests/test_providers.py @@ -475,7 +475,7 @@ async def test_clear_calls_delete(self, mock_redis_client: MagicMock): class TestRedisHistoryProviderBeforeAfterRun: - """Test before_run/after_run integration via BaseHistoryProvider defaults.""" + """Test before_run/after_run integration via HistoryProvider defaults.""" async def test_before_run_loads_history(self, mock_redis_client: MagicMock): msg = Message(role="user", contents=["old msg"]) diff --git a/python/samples/01-get-started/04_memory.py b/python/samples/01-get-started/04_memory.py index 763a872ca7..7e0b1e2d5f 100644 --- a/python/samples/01-get-started/04_memory.py +++ b/python/samples/01-get-started/04_memory.py @@ -3,7 +3,7 @@ import asyncio from typing import Any -from agent_framework import Agent, AgentSession, BaseContextProvider, SessionContext +from agent_framework import Agent, AgentSession, ContextProvider, SessionContext from agent_framework.foundry import FoundryChatClient from azure.identity import AzureCliCredential @@ -17,7 +17,7 @@ # -class UserMemoryProvider(BaseContextProvider): +class UserMemoryProvider(ContextProvider): """A context provider that remembers user info in session state.""" DEFAULT_SOURCE_ID = "user_memory" diff --git a/python/samples/02-agents/chat_client/README.md b/python/samples/02-agents/chat_client/README.md index e037877291..80b1e0ea1a 100644 --- a/python/samples/02-agents/chat_client/README.md +++ b/python/samples/02-agents/chat_client/README.md @@ -9,6 +9,7 @@ This folder contains examples for direct chat client usage patterns. | [`built_in_chat_clients.py`](built_in_chat_clients.py) | Consolidated sample for built-in chat clients. Uses `get_client()` to create the selected client and pass it to `main()`. | | [`chat_response_cancellation.py`](chat_response_cancellation.py) | Demonstrates how to cancel chat responses during streaming, showing proper cancellation handling and cleanup. | | [`custom_chat_client.py`](custom_chat_client.py) | Demonstrates how to create custom chat clients by extending the `BaseChatClient` class. Shows a `EchoingChatClient` implementation and how to integrate it with `Agent` using the `as_agent()` method. | +| [`require_per_service_call_history_persistence.py`](require_per_service_call_history_persistence.py) | Compares two otherwise identical `FoundryChatClient` agents with `store=False`; the only difference is whether `require_per_service_call_history_persistence` is enabled, and only the run without it stores the synthesized tool result when middleware terminates the loop early. | ## Selecting a built-in client @@ -35,6 +36,15 @@ Example: uv run samples/02-agents/chat_client/built_in_chat_clients.py ``` +The `require_per_service_call_history_persistence.py` sample uses `FoundryChatClient`, so set the usual Foundry settings first and sign in with the Azure CLI: + +```bash +export FOUNDRY_PROJECT_ENDPOINT="https://.services.ai.azure.com/api/projects/" +export FOUNDRY_MODEL="" +az login +uv run samples/02-agents/chat_client/require_per_service_call_history_persistence.py +``` + ## Environment Variables Depending on the selected client, set the appropriate environment variables: diff --git a/python/samples/02-agents/chat_client/require_per_service_call_history_persistence.py b/python/samples/02-agents/chat_client/require_per_service_call_history_persistence.py new file mode 100644 index 0000000000..f3a9a9ddde --- /dev/null +++ b/python/samples/02-agents/chat_client/require_per_service_call_history_persistence.py @@ -0,0 +1,194 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +from typing import Annotated + +from agent_framework import ( + Agent, + FunctionInvocationContext, + FunctionMiddleware, + InMemoryHistoryProvider, + Message, + MiddlewareTermination, +) +from agent_framework.foundry import FoundryChatClient +from azure.identity import AzureCliCredential +from dotenv import load_dotenv +from pydantic import Field + +""" +Compare Foundry agents with and without per-service-call chat history persistence. + +This sample runs two otherwise identical Foundry agents with ``store=False`` so +history stays local for both runs. + +The sample adds a function middleware that raises ``MiddlewareTermination`` +immediately after the tool runs, so the request stops before a second model +call. + +That early termination is the important difference: + +- Without per-service-call chat history persistence, the synthesized tool result is + still written to local history. +- With ``require_per_service_call_history_persistence=True``, that synthesized tool result is + not written to local history. + +The per-service-call persistence case matches service-side storage behavior. When a terminated +request never sends the tool result back to the service, that result also never +becomes part of the service-managed history. +""" + +# Load environment variables from .env file +load_dotenv() + + +def lookup_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Return a deterministic weather result for the requested location.""" + return f"The weather in {location} is sunny." + + +class TerminateAfterToolMiddleware(FunctionMiddleware): + """Stop the tool loop after the first tool finishes.""" + + async def process( + self, + context: FunctionInvocationContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + """Run the tool, then terminate the loop with that tool result.""" + await call_next() + raise MiddlewareTermination(result=context.result) + + +def _describe_message(message: Message) -> str: + """Render one stored message in a compact, readable format.""" + parts: list[str] = [] + for content in message.contents: + if content.type == "text" and content.text: + parts.append(content.text) + elif content.type == "function_call": + parts.append(f"function_call -> {content.name}({content.arguments})") + elif content.type == "function_result": + parts.append(f"function_result -> {content.result}") + else: + parts.append(content.type) + + return f"{message.role}: {' | '.join(parts)}" + + +def _includes_tool_result(messages: list[Message]) -> bool: + """Return whether any stored message contains a tool result.""" + return any(content.type == "function_result" for message in messages for content in message.contents) + + +async def main() -> None: + """Run both comparison scenarios.""" + print("=== require_per_service_call_history_persistence when middleware terminates the tool loop ===\n") + + # 1. Create one Foundry chat client that both agents will share. + client = FoundryChatClient(credential=AzureCliCredential()) + query = "What is the weather in Seattle, and should I bring sunglasses?" + + # 2. Create and run the agent without per-service-call persistence. + agent_without_persistence = Agent( + client=client, + instructions=( + "You are a weather assistant. Call lookup_weather exactly once before answering " + "any weather question, then summarize the tool result in one short paragraph." + ), + tools=[lookup_weather], + context_providers=[InMemoryHistoryProvider()], + middleware=[TerminateAfterToolMiddleware()], + default_options={"tool_choice": "required", "store": False}, + ) + session_without_persistence = agent_without_persistence.create_session() + await agent_without_persistence.run( + query, + session=session_without_persistence, + ) + stored_messages_without_persistence = session_without_persistence.state[InMemoryHistoryProvider.DEFAULT_SOURCE_ID][ + "messages" + ] + + print("=== Without per-service-call persistence ===") + print("Loop terminated immediately after the tool finished.") + print(f"Stored synthesized tool result: {_includes_tool_result(stored_messages_without_persistence)}") + print("Stored history:") + for index, message in enumerate(stored_messages_without_persistence, start=1): + print(f" {index}. {_describe_message(message)}") + print() + + # 3. Create and run the agent with per-service-call persistence enabled. + agent_with_persistence = Agent( + client=client, + instructions=( + "You are a weather assistant. Call lookup_weather exactly once before answering " + "any weather question, then summarize the tool result in one short paragraph." + ), + tools=[lookup_weather], + context_providers=[InMemoryHistoryProvider()], + middleware=[TerminateAfterToolMiddleware()], + require_per_service_call_history_persistence=True, + default_options={"tool_choice": "required", "store": False}, + ) + session_with_persistence = agent_with_persistence.create_session() + await agent_with_persistence.run( + query, + session=session_with_persistence, + ) + stored_messages_with_persistence = session_with_persistence.state[InMemoryHistoryProvider.DEFAULT_SOURCE_ID][ + "messages" + ] + + print("=== With per-service-call persistence ===") + print("Loop terminated immediately after the tool finished.") + print(f"Stored synthesized tool result: {_includes_tool_result(stored_messages_with_persistence)}") + print("Stored history:") + for index, message in enumerate(stored_messages_with_persistence, start=1): + print(f" {index}. {_describe_message(message)}") + print() + + # 4. Summarize the effect of the flag. + print( + "Both runs used FoundryChatClient with store=False and terminated right after the tool. " + "Without per-service-call persistence, local history still stored the synthesized tool result. " + "With per-service-call persistence, local history stopped at the assistant function-call message instead, " + "which matches service-side storage because the terminated tool result is never sent back to the service." + ) + + +if __name__ == "__main__": + asyncio.run(main()) + + +""" +Sample output: +=== require_per_service_call_history_persistence when middleware terminates the tool loop === + +=== Without per-service-call persistence === +Loop terminated immediately after the tool finished. +Stored synthesized tool result: True +Stored history: + 1. user: What is the weather in Seattle, and should I bring sunglasses? + 2. assistant: function_call -> lookup_weather({"location":"Seattle"}) + 3. tool: function_result -> The weather in Seattle is sunny. + +=== With per-service-call persistence === +Loop terminated immediately after the tool finished. +Stored synthesized tool result: False +Stored history: + 1. user: What is the weather in Seattle, and should I bring sunglasses? + 2. assistant: function_call -> lookup_weather({"location":"Seattle"}) + +Both runs used FoundryChatClient with store=False and terminated right after +the tool. Without per-service-call persistence, local history still stored the +synthesized tool result. With per-service-call persistence, local history +stopped at the assistant function-call message instead, which matches +service-side storage because the terminated tool result is never sent back to +the service. +""" diff --git a/python/samples/02-agents/context_providers/simple_context_provider.py b/python/samples/02-agents/context_providers/simple_context_provider.py index 265a93a73e..5f2a0f409a 100644 --- a/python/samples/02-agents/context_providers/simple_context_provider.py +++ b/python/samples/02-agents/context_providers/simple_context_provider.py @@ -5,7 +5,7 @@ from contextlib import suppress from typing import Any -from agent_framework import Agent, AgentSession, BaseContextProvider, SessionContext, SupportsChatGetResponse +from agent_framework import Agent, AgentSession, ContextProvider, SessionContext, SupportsChatGetResponse from agent_framework.foundry import FoundryChatClient from azure.identity import AzureCliCredential from dotenv import load_dotenv @@ -20,7 +20,7 @@ class UserInfo(BaseModel): age: int | None = None -class UserInfoMemory(BaseContextProvider): +class UserInfoMemory(ContextProvider): DEFAULT_SOURCE_ID = "user_info_memory" def __init__(self, source_id: str = DEFAULT_SOURCE_ID, *, client: SupportsChatGetResponse, **kwargs: Any): diff --git a/python/samples/02-agents/conversations/custom_history_provider.py b/python/samples/02-agents/conversations/custom_history_provider.py index 59d63b70c0..674e056da2 100644 --- a/python/samples/02-agents/conversations/custom_history_provider.py +++ b/python/samples/02-agents/conversations/custom_history_provider.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from typing import Any -from agent_framework import Agent, AgentSession, BaseHistoryProvider, Message +from agent_framework import Agent, AgentSession, HistoryProvider, Message from agent_framework.openai import OpenAIChatClient from dotenv import load_dotenv @@ -20,7 +20,7 @@ """ -class CustomHistoryProvider(BaseHistoryProvider): +class CustomHistoryProvider(HistoryProvider): """Implementation of custom history provider. In real applications, this can be an implementation of relational database or vector store.""" diff --git a/python/samples/05-end-to-end/hosted_agents/README.md b/python/samples/05-end-to-end/hosted_agents/README.md index c343fa66b2..5ee8a7b1b2 100644 --- a/python/samples/05-end-to-end/hosted_agents/README.md +++ b/python/samples/05-end-to-end/hosted_agents/README.md @@ -7,7 +7,7 @@ These samples demonstrate how to build and host AI agents in Python using the [A | Sample | Description | | ----------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------- | | [`agent_with_hosted_mcp`](./agent_with_hosted_mcp/) | Hosted MCP tool that connects to Microsoft Learn via `https://learn.microsoft.com/api/mcp` | -| [`agent_with_text_search_rag`](./agent_with_text_search_rag/) | Retrieval-augmented generation using a custom `BaseContextProvider` with Contoso Outdoors sample data | +| [`agent_with_text_search_rag`](./agent_with_text_search_rag/) | Retrieval-augmented generation using a custom `ContextProvider` with Contoso Outdoors sample data | | [`agents_in_workflow`](./agents_in_workflow/) | Concurrent workflow that combines researcher, marketer, and legal specialist agents | | [`agent_with_local_tools`](./agent_with_local_tools/) | Local Python tool execution for Seattle hotel search | | [`writer_reviewer_agents_in_workflow`](./writer_reviewer_agents_in_workflow/) | Writer/Reviewer workflow using `FoundryChatClient` | diff --git a/python/samples/05-end-to-end/hosted_agents/agent_with_text_search_rag/main.py b/python/samples/05-end-to-end/hosted_agents/agent_with_text_search_rag/main.py index ef91d227f4..d7a8bfbf73 100644 --- a/python/samples/05-end-to-end/hosted_agents/agent_with_text_search_rag/main.py +++ b/python/samples/05-end-to-end/hosted_agents/agent_with_text_search_rag/main.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import Any -from agent_framework import Agent, AgentSession, BaseContextProvider, Message, SessionContext +from agent_framework import Agent, AgentSession, ContextProvider, Message, SessionContext from agent_framework.foundry import FoundryChatClient from azure.ai.agentserver.agentframework import from_agent_framework # pyright: ignore[reportUnknownVariableType] from azure.identity import DefaultAzureCredential @@ -28,7 +28,7 @@ class TextSearchResult: text: str -class TextSearchContextProvider(BaseContextProvider): +class TextSearchContextProvider(ContextProvider): """A simple context provider that simulates text search results based on keywords in the user's message.""" def __init__(self): diff --git a/python/samples/05-end-to-end/m365-agent/README.md b/python/samples/05-end-to-end/m365-agent/README.md index 6962a53229..ded1c63f09 100644 --- a/python/samples/05-end-to-end/m365-agent/README.md +++ b/python/samples/05-end-to-end/m365-agent/README.md @@ -7,7 +7,7 @@ This sample demonstrates a simple Weather Forecast Agent built with the Python M - Python 3.11+ - [uv](https://github.com/astral-sh/uv) for fast dependency management - [devtunnel](https://learn.microsoft.com/azure/developer/dev-tunnels/get-started?tabs=windows) -- [Microsoft 365 Agents Toolkit](https://github.com/OfficeDev/microsoft-365-agents-toolkit) for playground/testing +- `agentsplayground` for playground/testing - Access to OpenAI or Azure OpenAI with a model like `gpt-4o-mini` ## Configuration