Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 50 additions & 22 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,52 @@ def get_session(self, service_session_id: str, *, session_id: str | None = None)
"""
return AgentSession(session_id=session_id, service_session_id=service_session_id)

async def _run_before_providers(
self,
*,
session: AgentSession | None,
input_messages: list[Message] | None = None,
options: dict[str, Any] | None = None,
) -> SessionContext:
"""Run before_run on all context providers and return the session context.

Creates a SessionContext and invokes ``before_run`` on each provider in
forward order. ``BaseHistoryProvider`` instances with
``load_messages=False`` are skipped.

Keyword Args:
session: The conversation session (None for stateless invocation).
input_messages: The normalized input messages.
options: Runtime options dict.

Returns:
The SessionContext with provider context populated.
"""
provider_session = session
if provider_session is None and self.context_providers:
provider_session = AgentSession()

session_context = SessionContext(
session_id=provider_session.session_id if provider_session else None,
service_session_id=provider_session.service_session_id if provider_session else None,
input_messages=input_messages or [],
options=options or {},
)

for provider in self.context_providers:
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
continue
if provider_session is None:
raise RuntimeError("Provider session must be available when context providers are configured.")
await provider.before_run(
agent=self, # type: ignore[arg-type]
session=provider_session,
context=session_context,
state=provider_session.state.setdefault(provider.source_id, {}),
)

return session_context

async def _run_after_providers(
self,
*,
Expand Down Expand Up @@ -1273,30 +1319,12 @@ async def _prepare_session_and_messages(
else:
chat_options = {}

provider_session = session
if provider_session is None and self.context_providers:
provider_session = AgentSession()

session_context = SessionContext(
session_id=provider_session.session_id if provider_session else None,
service_session_id=provider_session.service_session_id if provider_session else None,
input_messages=input_messages or [],
options=options or {},
session_context = await self._run_before_providers(
session=session,
input_messages=input_messages,
options=options,
)

# Run before_run providers (forward order, skip BaseHistoryProvider with load_messages=False)
for provider in self.context_providers:
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
continue
if provider_session is None:
raise RuntimeError("Provider session must be available when context providers are configured.")
await provider.before_run(
agent=self, # type: ignore[arg-type]
session=provider_session,
context=session_context,
state=provider_session.state.setdefault(provider.source_id, {}),
)

# Merge provider-contributed tools into chat_options
if session_context.tools:
if chat_options.get("tools") is not None:
Expand Down
9 changes: 3 additions & 6 deletions python/packages/core/agent_framework/_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def split_before_memory(conversation):
# Fallback: split at last user message
return EvalItem._split_last_turn_static(conversation)


item.split_messages(split=split_before_memory)
"""

Expand Down Expand Up @@ -468,10 +469,7 @@ def raise_for_status(self, msg: str | None = None) -> None:
"""
if not self.all_passed:
errored = (self.result_counts or {}).get("errored", 0)
detail = msg or (
f"Eval run {self.run_id} {self.status}: "
f"{self.passed} passed, {self.failed} failed."
)
detail = msg or (f"Eval run {self.run_id} {self.status}: {self.passed} passed, {self.failed} failed.")
if errored:
detail += f" {errored} errored."
if self.report_url:
Expand Down Expand Up @@ -1188,8 +1186,7 @@ def _coerce_result(value: Any, check_name: str) -> CheckResult:
score = float(d["score"])
except (TypeError, ValueError) as exc:
raise TypeError(
f"Function evaluator '{check_name}' returned dict with non-numeric 'score' value:"
f" {d['score']!r}"
f"Function evaluator '{check_name}' returned dict with non-numeric 'score' value: {d['score']!r}"
) from exc
# Honour an explicit 'passed' override; otherwise threshold-based.
passed = bool(d["passed"]) if "passed" in d else score >= float(d.get("threshold", 0.5))
Expand Down
76 changes: 76 additions & 0 deletions python/packages/core/tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2085,3 +2085,79 @@ async def test_as_tool_raises_on_user_input_request(client: SupportsChatGetRespo
assert len(exc_info.value.contents) == 1
assert exc_info.value.contents[0].type == "oauth_consent_request"
assert exc_info.value.contents[0].consent_link == "https://login.microsoftonline.com/consent"


async def test_base_agent_run_before_providers_creates_session_context(
client: SupportsChatGetResponse,
) -> None:
"""Test that BaseAgent._run_before_providers creates a SessionContext and calls providers."""
mock_provider = MockContextProvider(messages=[Message(role="system", text="Injected context")])
agent = Agent(client=client, context_providers=[mock_provider])
session = agent.create_session()

session_context = await agent._run_before_providers( # type: ignore[reportPrivateUsage]
session=session,
input_messages=[Message(role="user", text="Hello")],
options={"temperature": 0.5},
)

assert mock_provider.before_run_called
assert session_context.session_id == session.session_id
messages = session_context.get_messages(include_input=True)
assert len(messages) == 2
assert messages[0].text == "Injected context"
assert messages[1].text == "Hello"
assert session_context.options.get("temperature") == 0.5


async def test_base_agent_run_before_providers_creates_session_when_none(
client: SupportsChatGetResponse,
) -> None:
"""Test that _run_before_providers creates a session when None is passed with providers."""
mock_provider = MockContextProvider()
agent = Agent(client=client, context_providers=[mock_provider])

session_context = await agent._run_before_providers( # type: ignore[reportPrivateUsage]
session=None,
input_messages=[Message(role="user", text="Hello")],
)

assert mock_provider.before_run_called
assert session_context.session_id is not None


async def test_base_agent_run_before_providers_skips_history_provider_load_false(
client: SupportsChatGetResponse,
) -> None:
"""Test that _run_before_providers skips BaseHistoryProvider with load_messages=False."""
from agent_framework import BaseHistoryProvider

class StubHistoryProvider(BaseHistoryProvider):
def __init__(self, *, load_messages: bool = True) -> None:
super().__init__(source_id=f"stub-{load_messages}", load_messages=load_messages)
self.before_run_called = False

async def before_run(self, *, agent: Any, session: Any, context: Any, state: Any) -> None:
self.before_run_called = True

async def after_run(self, *, agent: Any, session: Any, context: Any, state: Any) -> None:
pass

async def get_messages(self, session_id: Any, **kwargs: Any) -> list[Message]:
return []

async def save_messages(self, session_id: Any, messages: Any, **kwargs: Any) -> None:
pass

skipped = StubHistoryProvider(load_messages=False)
active = StubHistoryProvider(load_messages=True)
agent = Agent(client=client, context_providers=[skipped, active])
session = agent.create_session()

await agent._run_before_providers( # type: ignore[reportPrivateUsage]
session=session,
input_messages=[Message(role="user", text="Hello")],
)

assert not skipped.before_run_called
assert active.before_run_called
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,22 @@ def run(
AgentException: If the request fails.
"""
if stream:
ctx_holder: dict[str, Any] = {}

async def _after_run_hook(response: AgentResponse) -> None:
session_context = ctx_holder.get("session_context")
sess = ctx_holder.get("session")
if session_context is not None and sess is not None:
session_context._response = response
await self._run_after_providers(session=sess, context=session_context)
Comment thread
giles17 marked this conversation as resolved.
Outdated

def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse:
return AgentResponse.from_updates(updates)

return ResponseStream(
self._stream_updates(messages=messages, session=session, options=options),
self._stream_updates(messages=messages, session=session, options=options, _ctx_holder=ctx_holder),
finalizer=_finalize,
result_hooks=[_after_run_hook],
)
return self._run_impl(messages=messages, session=session, options=options)

Expand All @@ -377,11 +386,19 @@ async def _run_impl(
session = self.create_session()

opts: dict[str, Any] = dict(options) if options else {}
timeout = opts.pop("timeout", None) or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS
timeout = opts.get("timeout") or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS

Comment thread
giles17 marked this conversation as resolved.
copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts)
input_messages = normalize_messages(messages)
prompt = "\n".join([message.text for message in input_messages])

session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts)

Comment thread
giles17 marked this conversation as resolved.
# Build the prompt from the full set of messages in the session context,
# so that any context/history provider-injected messages are included.
context_messages = session_context.get_messages(include_input=True)
prompt = "\n".join([message.text for message in context_messages])
if session_context.instructions:
prompt = "\n".join(session_context.instructions) + "\n" + prompt
message_options = cast(MessageOptions, {"prompt": prompt})

try:
Expand All @@ -408,14 +425,18 @@ async def _run_impl(
)
response_id = message_id

return AgentResponse(messages=response_messages, response_id=response_id)
response = AgentResponse(messages=response_messages, response_id=response_id)
session_context._response = response # type: ignore[assignment]
await self._run_after_providers(session=session, context=session_context)
return response

async def _stream_updates(
self,
messages: AgentRunInputs | None = None,
*,
session: AgentSession | None = None,
options: OptionsT | None = None,
_ctx_holder: dict[str, Any] | None = None,
) -> AsyncIterable[AgentResponseUpdate]:
"""Internal method to stream updates from GitHub Copilot.

Expand All @@ -425,6 +446,9 @@ async def _stream_updates(
Keyword Args:
session: The conversation session associated with the message(s).
options: Runtime options (model, timeout, etc.).
_ctx_holder: Internal dict populated with session_context and session
so that the caller (via a ResponseStream result_hook) can run
after_run providers without duplicating the updates buffer.

Yields:
AgentResponseUpdate items.
Expand All @@ -442,7 +466,18 @@ async def _stream_updates(

copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts)
input_messages = normalize_messages(messages)
prompt = "\n".join([message.text for message in input_messages])

session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts)

if _ctx_holder is not None:
_ctx_holder["session_context"] = session_context
_ctx_holder["session"] = session

# Build the prompt from the full session context so provider-injected messages are included.
context_messages = session_context.get_messages(include_input=True)
prompt = "\n".join([message.text for message in context_messages])
if session_context.instructions:
prompt = "\n".join(session_context.instructions) + "\n" + prompt
message_options = cast(MessageOptions, {"prompt": prompt})

queue: asyncio.Queue[AgentResponseUpdate | Exception | None] = asyncio.Queue()
Expand Down
Loading
Loading