Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,19 @@ async def _run_loop(
# Execute the event loop cycle with retry logic for context limits
events = self._execute_event_loop_cycle(invocation_state)
async for event in events:
# Signal from the model provider that the message sent by the user should be redacted,
# likely due to a guardrail.
if (
event.get("callback")
and event["callback"].get("event")
and event["callback"]["event"].get("redactContent")
and event["callback"]["event"]["redactContent"].get("redactUserContentMessage")
):
self.messages[-1]["content"] = [
{"text": event["callback"]["event"]["redactContent"]["redactUserContentMessage"]}
]
if self._session_manager:
self._session_manager.redact_latest_message(self.messages[-1], self)
yield event

finally:
Expand Down
16 changes: 4 additions & 12 deletions src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,17 +221,13 @@ def handle_message_stop(event: MessageStopEvent) -> StopReason:
return event["stopReason"]


def handle_redact_content(event: RedactContentEvent, messages: Messages, state: dict[str, Any]) -> None:
def handle_redact_content(event: RedactContentEvent, state: dict[str, Any]) -> None:
"""Handles redacting content from the input or output.

Args:
event: Redact Content Event.
messages: Agent messages.
state: The current state of message processing.
"""
if event.get("redactUserContentMessage") is not None:
messages[-1]["content"] = [{"text": event["redactUserContentMessage"]}] # type: ignore

if event.get("redactAssistantContentMessage") is not None:
state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}]

Expand All @@ -251,15 +247,11 @@ def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]:
return usage, metrics


async def process_stream(
chunks: AsyncIterable[StreamEvent],
messages: Messages,
) -> AsyncGenerator[dict[str, Any], None]:
async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[dict[str, Any], None]:
"""Processes the response stream from the API, constructing the final message and extracting usage metrics.

Args:
chunks: The chunks of the response stream from the model.
messages: The agents messages.

Returns:
The reason for stopping, the constructed message, and the usage metrics.
Expand Down Expand Up @@ -295,7 +287,7 @@ async def process_stream(
elif "metadata" in chunk:
usage, metrics = extract_usage_metrics(chunk["metadata"])
elif "redactContent" in chunk:
handle_redact_content(chunk["redactContent"], messages, state)
handle_redact_content(chunk["redactContent"], state)

yield {"stop": (stop_reason, state["message"], usage, metrics)}

Expand Down Expand Up @@ -323,5 +315,5 @@ async def stream_messages(

chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt)

async for event in process_stream(chunks, messages):
async for event in process_stream(chunks):
yield event
2 changes: 1 addition & 1 deletion src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ async def structured_output(
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.stream(messages=prompt, tool_specs=[tool_spec], **kwargs)
async for event in process_stream(response, prompt):
async for event in process_stream(response):
yield event

stop_reason, messages, _, _ = event["stop"]
Expand Down
2 changes: 1 addition & 1 deletion src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ async def structured_output(
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.stream(messages=prompt, tool_specs=[tool_spec], **kwargs)
async for event in streaming.process_stream(response, prompt):
async for event in streaming.process_stream(response):
yield event

stop_reason, messages, _, _ = event["stop"]
Expand Down
20 changes: 18 additions & 2 deletions src/strands/session/repository_session_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Repository session manager implementation."""

import logging
from typing import Optional

from ..agent.agent import Agent
from ..agent.state import AgentState
Expand Down Expand Up @@ -50,15 +51,30 @@ def __init__(
# Keep track of the initialized agent id's so that two agents in a session cannot share an id
self._initialized_agent_ids: set[str] = set()

# Keep track of the latest message stored in the session in case we need to redact its content.
self._latest_message: Optional[SessionMessage] = None

def append_message(self, message: Message, agent: Agent) -> None:
"""Append a message to the agent's session.

Args:
message: Message to add to the agent in the session
agent: Agent to append the message to
"""
session_message = SessionMessage.from_message(message)
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
self._latest_message = SessionMessage.from_message(message)
self.session_repository.create_message(self.session_id, agent.agent_id, self._latest_message)

def redact_latest_message(self, redact_message: Message, agent: Agent) -> None:
"""Redact the latest message appended to the session.

Args:
redact_message: New message to use that contains the redact content
agent: Agent to apply the message redaction to
"""
if self._latest_message is None:
raise SessionException("No message to redact.")
self._latest_message.redact_message = redact_message
return self.session_repository.update_message(self.session_id, agent.agent_id, self._latest_message)

def sync_agent(self, agent: Agent) -> None:
"""Serialize and update the agent into the session repository.
Expand Down
9 changes: 9 additions & 0 deletions src/strands/session/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent))
registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent))

@abstractmethod
def redact_latest_message(self, redact_message: Message, agent: "Agent") -> None:
"""Redact the message most recently appended to the agent in the session.

Args:
redact_message: New message to use that contains the redact content
agent: Agent to apply the message redaction to
"""

@abstractmethod
def append_message(self, message: Message, agent: "Agent") -> None:
"""Append a message to the agent's session.
Expand Down
23 changes: 19 additions & 4 deletions src/strands/types/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, cast
from typing import Any, Dict, Optional, cast
from uuid import uuid4

from ..agent.agent import Agent
Expand Down Expand Up @@ -54,9 +54,18 @@ def decode_bytes_values(obj: Any) -> Any:

@dataclass
class SessionMessage:
"""Message within a SessionAgent."""
"""Message within a SessionAgent.

Attributes:
message: Message content
redact_message: If the original message is redacted, this is the new content to use
message_id: Unique id for a message
created_at: ISO format timestamp for when this message was created
updated_at: ISO format timestamp for when this message was last updated
"""

message: Message
redact_message: Optional[Message] = None
message_id: str = field(default_factory=lambda: str(uuid4()))
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
Expand All @@ -73,8 +82,14 @@ def from_message(cls, message: Message) -> "SessionMessage":
)

def to_message(self) -> Message:
"""Convert SessionMessage back to a Message, decoding any bytes values."""
return cast(Message, decode_bytes_values(self.message))
"""Convert SessionMessage back to a Message, decoding any bytes values.

If the message was redacted, return the redact content instead.
"""
if self.redact_message is not None:
return cast(Message, decode_bytes_values(self.redact_message))
else:
return cast(Message, decode_bytes_values(self.message))

@classmethod
def from_dict(cls, env: dict[str, Any]) -> "SessionMessage":
Expand Down
52 changes: 33 additions & 19 deletions tests/fixtures/mocked_model_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypeVar
from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypedDict, TypeVar, Union

from pydantic import BaseModel

Expand All @@ -12,6 +12,11 @@
T = TypeVar("T", bound=BaseModel)


class RedactionMessage(TypedDict):
redactedUserContent: str
redactedAssistantContent: str


class MockedModelProvider(Model):
"""A mock implementation of the Model interface for testing purposes.

Expand All @@ -20,7 +25,7 @@ class MockedModelProvider(Model):
to stream mock responses as events.
"""

def __init__(self, agent_responses: Messages):
def __init__(self, agent_responses: list[Union[Message, RedactionMessage]]):
self.agent_responses = agent_responses
self.index = 0

Expand Down Expand Up @@ -54,27 +59,36 @@ async def stream(

self.index += 1

def map_agent_message_to_events(self, agent_message: Message) -> Iterable[dict[str, Any]]:
def map_agent_message_to_events(self, agent_message: Union[Message, RedactionMessage]) -> Iterable[dict[str, Any]]:
stop_reason: StopReason = "end_turn"
yield {"messageStart": {"role": "assistant"}}
for content in agent_message["content"]:
if "text" in content:
yield {"contentBlockStart": {"start": {}}}
yield {"contentBlockDelta": {"delta": {"text": content["text"]}}}
yield {"contentBlockStop": {}}
if "toolUse" in content:
stop_reason = "tool_use"
yield {
"contentBlockStart": {
"start": {
"toolUse": {
"name": content["toolUse"]["name"],
"toolUseId": content["toolUse"]["toolUseId"],
if agent_message.get("redactedAssistantContent"):
yield {"redactContent": {"redactUserContentMessage": agent_message["redactedUserContent"]}}
yield {"contentBlockStart": {"start": {}}}
yield {"contentBlockDelta": {"delta": {"text": agent_message["redactedAssistantContent"]}}}
yield {"contentBlockStop": {}}
stop_reason = "guardrail_intervened"
else:
for content in agent_message["content"]:
if "text" in content:
yield {"contentBlockStart": {"start": {}}}
yield {"contentBlockDelta": {"delta": {"text": content["text"]}}}
yield {"contentBlockStop": {}}
if "toolUse" in content:
stop_reason = "tool_use"
yield {
"contentBlockStart": {
"start": {
"toolUse": {
"name": content["toolUse"]["name"],
"toolUseId": content["toolUse"]["toolUseId"],
}
}
}
}
}
yield {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(content["toolUse"]["input"])}}}}
yield {"contentBlockStop": {}}
yield {
"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(content["toolUse"]["input"])}}}
}
yield {"contentBlockStop": {}}

yield {"messageStop": {"stopReason": stop_reason}}
59 changes: 59 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import textwrap
import unittest.mock
from uuid import uuid4

import pytest
from pydantic import BaseModel
Expand Down Expand Up @@ -1425,3 +1426,61 @@ def test_agent_restored_from_session_management():
agent = Agent(session_manager=session_manager)

assert agent.state.get("foo") == "bar"


def test_agent_redacts_input_on_triggered_guardrail():
mocked_model = MockedModelProvider(
[{"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"}]
)

agent = Agent(
model=mocked_model,
system_prompt="You are a helpful assistant.",
callback_handler=None,
)

response1 = agent("CACTUS")

assert response1.stop_reason == "guardrail_intervened"
assert agent.messages[0]["content"][0]["text"] == "BLOCKED!"


def test_agent_restored_from_session_management_with_redacted_input():
mocked_model = MockedModelProvider(
[{"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"}]
)

test_session_id = str(uuid4())
mocked_session_repository = MockedSessionRepository()
session_manager = RepositorySessionManager(session_id=test_session_id, session_repository=mocked_session_repository)

agent = Agent(
model=mocked_model,
system_prompt="You are a helpful assistant.",
callback_handler=None,
session_manager=session_manager,
)

assert mocked_session_repository.read_agent(test_session_id, agent.agent_id) is not None

response1 = agent("CACTUS")

assert response1.stop_reason == "guardrail_intervened"
assert agent.messages[0]["content"][0]["text"] == "BLOCKED!"
user_input_session_message = mocked_session_repository.list_messages(test_session_id, agent.agent_id)[0]
# Assert persisted message is equal to the redacted message in the agent
assert user_input_session_message.to_message() == agent.messages[0]

# Restore an agent from the session, confirm input is still redacted
session_manager_2 = RepositorySessionManager(
session_id=test_session_id, session_repository=mocked_session_repository
)
agent_2 = Agent(
model=mocked_model,
system_prompt="You are a helpful assistant.",
callback_handler=None,
session_manager=session_manager_2,
)

# Assert that the restored agent redacted message is equal to the original agent
assert agent.messages[0] == agent_2.messages[0]
3 changes: 1 addition & 2 deletions tests/strands/event_loop/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,7 @@ def test_extract_usage_metrics():
)
@pytest.mark.asyncio
async def test_process_stream(response, exp_events, agenerator, alist):
messages = [{"role": "user", "content": [{"text": "Some input!"}]}]
stream = strands.event_loop.streaming.process_stream(agenerator(response), messages)
stream = strands.event_loop.streaming.process_stream(agenerator(response))

tru_events = await alist(stream)
assert tru_events == exp_events
Expand Down
Loading
Loading