Skip to content

Commit 7e8243a

Browse files
authored
feat: redact content from a message in a session (#446)
1 parent 1969142 commit 7e8243a

File tree

12 files changed

+210
-47
lines changed

12 files changed

+210
-47
lines changed

src/strands/agent/agent.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,19 @@ async def _run_loop(
543543
# Execute the event loop cycle with retry logic for context limits
544544
events = self._execute_event_loop_cycle(invocation_state)
545545
async for event in events:
546+
# Signal from the model provider that the message sent by the user should be redacted,
547+
# likely due to a guardrail.
548+
if (
549+
event.get("callback")
550+
and event["callback"].get("event")
551+
and event["callback"]["event"].get("redactContent")
552+
and event["callback"]["event"]["redactContent"].get("redactUserContentMessage")
553+
):
554+
self.messages[-1]["content"] = [
555+
{"text": event["callback"]["event"]["redactContent"]["redactUserContentMessage"]}
556+
]
557+
if self._session_manager:
558+
self._session_manager.redact_latest_message(self.messages[-1], self)
546559
yield event
547560

548561
finally:

src/strands/event_loop/streaming.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -221,17 +221,13 @@ def handle_message_stop(event: MessageStopEvent) -> StopReason:
221221
return event["stopReason"]
222222

223223

224-
def handle_redact_content(event: RedactContentEvent, messages: Messages, state: dict[str, Any]) -> None:
224+
def handle_redact_content(event: RedactContentEvent, state: dict[str, Any]) -> None:
225225
"""Handles redacting content from the input or output.
226226
227227
Args:
228228
event: Redact Content Event.
229-
messages: Agent messages.
230229
state: The current state of message processing.
231230
"""
232-
if event.get("redactUserContentMessage") is not None:
233-
messages[-1]["content"] = [{"text": event["redactUserContentMessage"]}] # type: ignore
234-
235231
if event.get("redactAssistantContentMessage") is not None:
236232
state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}]
237233

@@ -251,15 +247,11 @@ def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]:
251247
return usage, metrics
252248

253249

254-
async def process_stream(
255-
chunks: AsyncIterable[StreamEvent],
256-
messages: Messages,
257-
) -> AsyncGenerator[dict[str, Any], None]:
250+
async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[dict[str, Any], None]:
258251
"""Processes the response stream from the API, constructing the final message and extracting usage metrics.
259252
260253
Args:
261254
chunks: The chunks of the response stream from the model.
262-
messages: The agents messages.
263255
264256
Returns:
265257
The reason for stopping, the constructed message, and the usage metrics.
@@ -295,7 +287,7 @@ async def process_stream(
295287
elif "metadata" in chunk:
296288
usage, metrics = extract_usage_metrics(chunk["metadata"])
297289
elif "redactContent" in chunk:
298-
handle_redact_content(chunk["redactContent"], messages, state)
290+
handle_redact_content(chunk["redactContent"], state)
299291

300292
yield {"stop": (stop_reason, state["message"], usage, metrics)}
301293

@@ -323,5 +315,5 @@ async def stream_messages(
323315

324316
chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt)
325317

326-
async for event in process_stream(chunks, messages):
318+
async for event in process_stream(chunks):
327319
yield event

src/strands/models/anthropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ async def structured_output(
407407
tool_spec = convert_pydantic_to_tool_spec(output_model)
408408

409409
response = self.stream(messages=prompt, tool_specs=[tool_spec], **kwargs)
410-
async for event in process_stream(response, prompt):
410+
async for event in process_stream(response):
411411
yield event
412412

413413
stop_reason, messages, _, _ = event["stop"]

src/strands/models/bedrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ async def structured_output(
577577
tool_spec = convert_pydantic_to_tool_spec(output_model)
578578

579579
response = self.stream(messages=prompt, tool_specs=[tool_spec], **kwargs)
580-
async for event in streaming.process_stream(response, prompt):
580+
async for event in streaming.process_stream(response):
581581
yield event
582582

583583
stop_reason, messages, _, _ = event["stop"]

src/strands/session/repository_session_manager.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Repository session manager implementation."""
22

33
import logging
4+
from typing import Optional
45

56
from ..agent.agent import Agent
67
from ..agent.state import AgentState
@@ -50,15 +51,30 @@ def __init__(
5051
# Keep track of the initialized agent id's so that two agents in a session cannot share an id
5152
self._initialized_agent_ids: set[str] = set()
5253

54+
# Keep track of the latest message stored in the session in case we need to redact its content.
55+
self._latest_message: Optional[SessionMessage] = None
56+
5357
def append_message(self, message: Message, agent: Agent) -> None:
5458
"""Append a message to the agent's session.
5559
5660
Args:
5761
message: Message to add to the agent in the session
5862
agent: Agent to append the message to
5963
"""
60-
session_message = SessionMessage.from_message(message)
61-
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
64+
self._latest_message = SessionMessage.from_message(message)
65+
self.session_repository.create_message(self.session_id, agent.agent_id, self._latest_message)
66+
67+
def redact_latest_message(self, redact_message: Message, agent: Agent) -> None:
68+
"""Redact the latest message appended to the session.
69+
70+
Args:
71+
redact_message: New message to use that contains the redact content
72+
agent: Agent to apply the message redaction to
73+
"""
74+
if self._latest_message is None:
75+
raise SessionException("No message to redact.")
76+
self._latest_message.redact_message = redact_message
77+
return self.session_repository.update_message(self.session_id, agent.agent_id, self._latest_message)
6278

6379
def sync_agent(self, agent: Agent) -> None:
6480
"""Serialize and update the agent into the session repository.

src/strands/session/session_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
2626
registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent))
2727
registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent))
2828

29+
@abstractmethod
30+
def redact_latest_message(self, redact_message: Message, agent: "Agent") -> None:
31+
"""Redact the message most recently appended to the agent in the session.
32+
33+
Args:
34+
redact_message: New message to use that contains the redact content
35+
agent: Agent to apply the message redaction to
36+
"""
37+
2938
@abstractmethod
3039
def append_message(self, message: Message, agent: "Agent") -> None:
3140
"""Append a message to the agent's session.

src/strands/types/session.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass, field
66
from datetime import datetime, timezone
77
from enum import Enum
8-
from typing import Any, Dict, cast
8+
from typing import Any, Dict, Optional, cast
99
from uuid import uuid4
1010

1111
from ..agent.agent import Agent
@@ -54,9 +54,18 @@ def decode_bytes_values(obj: Any) -> Any:
5454

5555
@dataclass
5656
class SessionMessage:
57-
"""Message within a SessionAgent."""
57+
"""Message within a SessionAgent.
58+
59+
Attributes:
60+
message: Message content
61+
redact_message: If the original message is redacted, this is the new content to use
62+
message_id: Unique id for a message
63+
created_at: ISO format timestamp for when this message was created
64+
updated_at: ISO format timestamp for when this message was last updated
65+
"""
5866

5967
message: Message
68+
redact_message: Optional[Message] = None
6069
message_id: str = field(default_factory=lambda: str(uuid4()))
6170
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
6271
updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
@@ -73,8 +82,14 @@ def from_message(cls, message: Message) -> "SessionMessage":
7382
)
7483

7584
def to_message(self) -> Message:
76-
"""Convert SessionMessage back to a Message, decoding any bytes values."""
77-
return cast(Message, decode_bytes_values(self.message))
85+
"""Convert SessionMessage back to a Message, decoding any bytes values.
86+
87+
If the message was redacted, return the redact content instead.
88+
"""
89+
if self.redact_message is not None:
90+
return cast(Message, decode_bytes_values(self.redact_message))
91+
else:
92+
return cast(Message, decode_bytes_values(self.message))
7893

7994
@classmethod
8095
def from_dict(cls, env: dict[str, Any]) -> "SessionMessage":

tests/fixtures/mocked_model_provider.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypeVar
2+
from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypedDict, TypeVar, Union
33

44
from pydantic import BaseModel
55

@@ -12,6 +12,11 @@
1212
T = TypeVar("T", bound=BaseModel)
1313

1414

15+
class RedactionMessage(TypedDict):
16+
redactedUserContent: str
17+
redactedAssistantContent: str
18+
19+
1520
class MockedModelProvider(Model):
1621
"""A mock implementation of the Model interface for testing purposes.
1722
@@ -20,7 +25,7 @@ class MockedModelProvider(Model):
2025
to stream mock responses as events.
2126
"""
2227

23-
def __init__(self, agent_responses: Messages):
28+
def __init__(self, agent_responses: list[Union[Message, RedactionMessage]]):
2429
self.agent_responses = agent_responses
2530
self.index = 0
2631

@@ -54,27 +59,36 @@ async def stream(
5459

5560
self.index += 1
5661

57-
def map_agent_message_to_events(self, agent_message: Message) -> Iterable[dict[str, Any]]:
62+
def map_agent_message_to_events(self, agent_message: Union[Message, RedactionMessage]) -> Iterable[dict[str, Any]]:
5863
stop_reason: StopReason = "end_turn"
5964
yield {"messageStart": {"role": "assistant"}}
60-
for content in agent_message["content"]:
61-
if "text" in content:
62-
yield {"contentBlockStart": {"start": {}}}
63-
yield {"contentBlockDelta": {"delta": {"text": content["text"]}}}
64-
yield {"contentBlockStop": {}}
65-
if "toolUse" in content:
66-
stop_reason = "tool_use"
67-
yield {
68-
"contentBlockStart": {
69-
"start": {
70-
"toolUse": {
71-
"name": content["toolUse"]["name"],
72-
"toolUseId": content["toolUse"]["toolUseId"],
65+
if agent_message.get("redactedAssistantContent"):
66+
yield {"redactContent": {"redactUserContentMessage": agent_message["redactedUserContent"]}}
67+
yield {"contentBlockStart": {"start": {}}}
68+
yield {"contentBlockDelta": {"delta": {"text": agent_message["redactedAssistantContent"]}}}
69+
yield {"contentBlockStop": {}}
70+
stop_reason = "guardrail_intervened"
71+
else:
72+
for content in agent_message["content"]:
73+
if "text" in content:
74+
yield {"contentBlockStart": {"start": {}}}
75+
yield {"contentBlockDelta": {"delta": {"text": content["text"]}}}
76+
yield {"contentBlockStop": {}}
77+
if "toolUse" in content:
78+
stop_reason = "tool_use"
79+
yield {
80+
"contentBlockStart": {
81+
"start": {
82+
"toolUse": {
83+
"name": content["toolUse"]["name"],
84+
"toolUseId": content["toolUse"]["toolUseId"],
85+
}
7386
}
7487
}
7588
}
76-
}
77-
yield {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(content["toolUse"]["input"])}}}}
78-
yield {"contentBlockStop": {}}
89+
yield {
90+
"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(content["toolUse"]["input"])}}}
91+
}
92+
yield {"contentBlockStop": {}}
7993

8094
yield {"messageStop": {"stopReason": stop_reason}}

tests/strands/agent/test_agent.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import textwrap
66
import unittest.mock
7+
from uuid import uuid4
78

89
import pytest
910
from pydantic import BaseModel
@@ -1425,3 +1426,61 @@ def test_agent_restored_from_session_management():
14251426
agent = Agent(session_manager=session_manager)
14261427

14271428
assert agent.state.get("foo") == "bar"
1429+
1430+
1431+
def test_agent_redacts_input_on_triggered_guardrail():
1432+
mocked_model = MockedModelProvider(
1433+
[{"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"}]
1434+
)
1435+
1436+
agent = Agent(
1437+
model=mocked_model,
1438+
system_prompt="You are a helpful assistant.",
1439+
callback_handler=None,
1440+
)
1441+
1442+
response1 = agent("CACTUS")
1443+
1444+
assert response1.stop_reason == "guardrail_intervened"
1445+
assert agent.messages[0]["content"][0]["text"] == "BLOCKED!"
1446+
1447+
1448+
def test_agent_restored_from_session_management_with_redacted_input():
1449+
mocked_model = MockedModelProvider(
1450+
[{"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"}]
1451+
)
1452+
1453+
test_session_id = str(uuid4())
1454+
mocked_session_repository = MockedSessionRepository()
1455+
session_manager = RepositorySessionManager(session_id=test_session_id, session_repository=mocked_session_repository)
1456+
1457+
agent = Agent(
1458+
model=mocked_model,
1459+
system_prompt="You are a helpful assistant.",
1460+
callback_handler=None,
1461+
session_manager=session_manager,
1462+
)
1463+
1464+
assert mocked_session_repository.read_agent(test_session_id, agent.agent_id) is not None
1465+
1466+
response1 = agent("CACTUS")
1467+
1468+
assert response1.stop_reason == "guardrail_intervened"
1469+
assert agent.messages[0]["content"][0]["text"] == "BLOCKED!"
1470+
user_input_session_message = mocked_session_repository.list_messages(test_session_id, agent.agent_id)[0]
1471+
# Assert persisted message is equal to the redacted message in the agent
1472+
assert user_input_session_message.to_message() == agent.messages[0]
1473+
1474+
# Restore an agent from the session, confirm input is still redacted
1475+
session_manager_2 = RepositorySessionManager(
1476+
session_id=test_session_id, session_repository=mocked_session_repository
1477+
)
1478+
agent_2 = Agent(
1479+
model=mocked_model,
1480+
system_prompt="You are a helpful assistant.",
1481+
callback_handler=None,
1482+
session_manager=session_manager_2,
1483+
)
1484+
1485+
# Assert that the restored agent redacted message is equal to the original agent
1486+
assert agent.messages[0] == agent_2.messages[0]

tests/strands/event_loop/test_streaming.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,8 +528,7 @@ def test_extract_usage_metrics():
528528
)
529529
@pytest.mark.asyncio
530530
async def test_process_stream(response, exp_events, agenerator, alist):
531-
messages = [{"role": "user", "content": [{"text": "Some input!"}]}]
532-
stream = strands.event_loop.streaming.process_stream(agenerator(response), messages)
531+
stream = strands.event_loop.streaming.process_stream(agenerator(response))
533532

534533
tru_events = await alist(stream)
535534
assert tru_events == exp_events

0 commit comments

Comments
 (0)