Skip to content

Commit a41a85b

Browse files
feat: Summarize messages recursively
1 parent e77381e commit a41a85b

File tree

8 files changed

+155
-162
lines changed

8 files changed

+155
-162
lines changed

agents-api/agents_api/activities/summarization.py

+67-18
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
#!/usr/bin/env python3
22

3+
import asyncio
34
from uuid import UUID
45
from typing import Callable
56
from textwrap import dedent
67
from temporalio import activity
78
from litellm import acompletion
9+
from agents_api.models.entry.add_entries import add_entries_query
810
from agents_api.models.entry.entries_summarization import (
911
get_toplevel_entries_query,
1012
entries_summarization_query,
1113
)
1214
from agents_api.common.protocol.entries import Entry
13-
from ..env import summarization_model_name
15+
from agents_api.rec_sum.entities import get_entities
16+
from agents_api.rec_sum.summarize import summarize_messages
17+
from agents_api.rec_sum.trim import trim_messages
1418

1519

1620
example_previous_memory = """
@@ -148,31 +152,76 @@ async def run_prompt(
148152
return parser(content.strip() if content is not None else "")
149153

150154

155+
# @activity.defn
156+
# async def summarization(session_id: str) -> None:
157+
# session_id = UUID(session_id)
158+
# entries = [
159+
# Entry(**row)
160+
# for _, row in get_toplevel_entries_query(session_id=session_id).iterrows()
161+
# ]
162+
163+
# assert len(entries) > 0, "no need to summarize on empty entries list"
164+
165+
# response = await run_prompt(
166+
# dialog=entries, previous_memories=[], model=summarization_model_name
167+
# )
168+
169+
# new_entry = Entry(
170+
# session_id=session_id,
171+
# source="summarizer",
172+
# role="system",
173+
# name="information",
174+
# content=response,
175+
# timestamp=entries[-1].timestamp + 0.01,
176+
# )
177+
178+
# entries_summarization_query(
179+
# session_id=session_id,
180+
# new_entry=new_entry,
181+
# old_entry_ids=[e.id for e in entries],
182+
# )
183+
184+
151185
@activity.defn
152186
async def summarization(session_id: str) -> None:
153187
session_id = UUID(session_id)
154188
entries = [
155-
Entry(**row)
156-
for _, row in get_toplevel_entries_query(session_id=session_id).iterrows()
189+
row for _, row in get_toplevel_entries_query(session_id=session_id).iterrows()
157190
]
158191

159192
assert len(entries) > 0, "no need to summarize on empty entries list"
160193

161-
response = await run_prompt(
162-
dialog=entries, previous_memories=[], model=summarization_model_name
194+
trimmed_messages, entities = await asyncio.gather(
195+
trim_messages(entries),
196+
get_entities(entries),
163197
)
164-
165-
new_entry = Entry(
166-
session_id=session_id,
167-
source="summarizer",
168-
role="system",
169-
name="information",
170-
content=response,
171-
timestamp=entries[-1].timestamp + 0.01,
198+
summarized = await summarize_messages(trimmed_messages)
199+
200+
ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2
201+
202+
add_entries_query(
203+
Entry(
204+
session_id=session_id,
205+
source="summarizer",
206+
role="system",
207+
name="entities",
208+
content=entities["content"],
209+
timestamp=entries[0]["timestamp"] + ts_delta,
210+
)
172211
)
173212

174-
entries_summarization_query(
175-
session_id=session_id,
176-
new_entry=new_entry,
177-
old_entry_ids=[e.id for e in entries],
178-
)
213+
for msg in summarized:
214+
new_entry = Entry(
215+
session_id=session_id,
216+
source="summarizer",
217+
role="system",
218+
name="information",
219+
content=msg["content"],
220+
timestamp=entries[-1]["timestamp"] + 0.01,
221+
)
222+
223+
entries_summarization_query(
224+
session_id=session_id,
225+
new_entry=new_entry,
226+
old_entry_ids=[entries[idx]["entry_id"] for idx in msg["summarizes"]],
227+
)
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
from .entities import get_entities
2-
from .summarize import summarize_messages
3-
from .trim import trim_messages

agents-api/agents_api/rec_sum/data.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,21 @@
55
module_directory = Path(__file__).parent
66

77

8-
9-
with open(f"{module_directory}/entities_example_chat.json", 'r') as _f:
8+
with open(f"{module_directory}/entities_example_chat.json", "r") as _f:
109
entities_example_chat = json.load(_f)
11-
1210

1311

14-
with open(f"{module_directory}/trim_example_chat.json", 'r') as _f:
12+
with open(f"{module_directory}/trim_example_chat.json", "r") as _f:
1513
trim_example_chat = json.load(_f)
16-
1714

1815

19-
with open(f"{module_directory}/trim_example_result.json", 'r') as _f:
16+
with open(f"{module_directory}/trim_example_result.json", "r") as _f:
2017
trim_example_result = json.load(_f)
21-
2218

2319

24-
with open(f"{module_directory}/summarize_example_chat.json", 'r') as _f:
20+
with open(f"{module_directory}/summarize_example_chat.json", "r") as _f:
2521
summarize_example_chat = json.load(_f)
26-
2722

2823

29-
with open(f"{module_directory}/summarize_example_result.json", 'r') as _f:
24+
with open(f"{module_directory}/summarize_example_result.json", "r") as _f:
3025
summarize_example_result = json.load(_f)
31-
32-

agents-api/agents_api/rec_sum/entities.py

+10-30
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22

3-
from tenacity import retry, stop_after_attempt, wait_fixed
3+
from tenacity import retry, stop_after_attempt
44

55
from .data import entities_example_chat
66
from .generate import generate
@@ -41,38 +41,18 @@
4141
- See the example to get a better idea of the task."""
4242

4343

44-
make_entities_prompt = lambda session, user="a user", assistant="gpt-4-turbo", **_: [f"""\
45-
You are given a session history of a chat between {user or "a user"} and {assistant or "gpt-4-turbo"}. The session is formatted in the ChatML JSON format (from OpenAI).
46-
47-
{entities_instructions}
48-
49-
<ct:example-session>
50-
{json.dumps(entities_example_chat, indent=2)}
51-
</ct:example-session>
52-
53-
<ct:example-plan>
54-
{entities_example_plan}
55-
</ct:example-plan>
56-
57-
<ct:example-entities>
58-
{entities_example_result}
59-
</ct:example-entities>""",
60-
61-
f"""\
62-
Begin! Write the entities as a Markdown formatted list. First write your plan inside <ct:plan></ct:plan> and then the extracted entities between <ct:entities></ct:entities>.
63-
64-
<ct:session>
65-
{json.dumps(session, indent=2)}
66-
67-
</ct:session>"""]
68-
44+
def make_entities_prompt(session, user="a user", assistant="gpt-4-turbo", **_):
45+
return [
46+
f"You are given a session history of a chat between {user or 'a user'} and {assistant or 'gpt-4-turbo'}. The session is formatted in the ChatML JSON format (from OpenAI).\n\n{entities_instructions}\n\n<ct:example-session>\n{json.dumps(entities_example_chat, indent=2)}\n</ct:example-session>\n\n<ct:example-plan>\n{entities_example_plan}\n</ct:example-plan>\n\n<ct:example-entities>\n{entities_example_result}\n</ct:example-entities>",
47+
f"Begin! Write the entities as a Markdown formatted list. First write your plan inside <ct:plan></ct:plan> and then the extracted entities between <ct:entities></ct:entities>.\n\n<ct:session>\n{json.dumps(session, indent=2)}\n\n</ct:session>",
48+
]
6949

7050

7151
@retry(stop=stop_after_attempt(2))
7252
async def get_entities(
7353
chat_session,
74-
model="gpt-4-turbo",
75-
stop=["</ct:entities"],
54+
model="gpt-4-turbo",
55+
stop=["</ct:entities"],
7656
temperature=0.7,
7757
**kwargs,
7858
):
@@ -84,7 +64,7 @@ async def get_entities(
8464
and chat_session[0].get("name") != "entities"
8565
):
8666
chat_session = chat_session[1:]
87-
67+
8868
names = get_names_from_session(chat_session)
8969
system_prompt, user_message = make_entities_prompt(chat_session, **names)
9070
messages = [chatml.system(system_prompt), chatml.user(user_message)]
@@ -100,5 +80,5 @@ async def get_entities(
10080
result["content"] = result["content"].split("<ct:entities>")[-1].strip()
10181
result["role"] = "system"
10282
result["name"] = "entities"
103-
83+
10484
return chatml.make(**result)

agents-api/agents_api/rec_sum/generate.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,15 @@
77

88
@retry(wait=wait_fixed(2), stop=stop_after_attempt(5))
99
async def generate(
10-
messages: list[dict],
11-
client: AsyncClient=client,
12-
model: str="gpt-4-turbo",
10+
messages: list[dict],
11+
client: AsyncClient = client,
12+
model: str = "gpt-4-turbo",
1313
**kwargs
1414
) -> dict:
1515
result = await client.chat.completions.create(
1616
model=model, messages=messages, **kwargs
1717
)
18-
18+
1919
result = result.choices[0].message.__dict__
2020

2121
return result
22-
+14-34
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22

3-
from tenacity import retry, stop_after_attempt, wait_fixed
3+
from tenacity import retry, stop_after_attempt
44

55
from .data import summarize_example_chat, summarize_example_result
66
from .generate import generate
@@ -24,7 +24,6 @@
2424
- We can safely summarize message 34's essay into just the salient points only."""
2525

2626

27-
2827
summarize_instructions = """\
2928
Your goal is to compactify the history by coalescing redundant information in messages into their summary in order to reduce its size and save costs.
3029
@@ -36,53 +35,32 @@
3635
- VERY IMPORTANT: Add the indices of messages that are being summarized so that those messages can then be removed from the session otherwise, there'll be no way to identify which messages to remove. See example for more details."""
3736

3837

39-
40-
make_summarize_prompt = lambda session, user="a user", assistant="gpt-4-turbo", **_: [f"""\
41-
You are given a session history of a chat between {user or "a user"} and {assistant or "gpt-4-turbo"}. The session is formatted in the ChatML JSON format (from OpenAI).
42-
43-
{summarize_instructions}
44-
45-
<ct:example-session>
46-
{json.dumps(add_indices(summarize_example_chat), indent=2)}
47-
</ct:example-session>
48-
49-
<ct:example-plan>
50-
{summarize_example_plan}
51-
</ct:example-plan>
52-
53-
<ct:example-summarized-messages>
54-
{json.dumps(summarize_example_result, indent=2)}
55-
</ct:example-summarized-messages>""",
56-
57-
f"""\
58-
Begin! Write the summarized messages as a json list just like the example above. First write your plan inside <ct:plan></ct:plan> and then your answer between <ct:summarized-messages></ct:summarized-messages>. Don't forget to add the indices of the messages being summarized alongside each summary.
59-
60-
<ct:session>
61-
{json.dumps(add_indices(session), indent=2)}
62-
63-
</ct:session>"""]
64-
38+
def make_summarize_prompt(session, user="a user", assistant="gpt-4-turbo", **_):
39+
return [
40+
f"You are given a session history of a chat between {user or 'a user'} and {assistant or 'gpt-4-turbo'}. The session is formatted in the ChatML JSON format (from OpenAI).\n\n{summarize_instructions}\n\n<ct:example-session>\n{json.dumps(add_indices(summarize_example_chat), indent=2)}\n</ct:example-session>\n\n<ct:example-plan>\n{summarize_example_plan}\n</ct:example-plan>\n\n<ct:example-summarized-messages>\n{json.dumps(summarize_example_result, indent=2)}\n</ct:example-summarized-messages>",
41+
f"Begin! Write the summarized messages as a json list just like the example above. First write your plan inside <ct:plan></ct:plan> and then your answer between <ct:summarized-messages></ct:summarized-messages>. Don't forget to add the indices of the messages being summarized alongside each summary.\n\n<ct:session>\n{json.dumps(add_indices(session), indent=2)}\n\n</ct:session>",
42+
]
6543

6644

6745
@retry(stop=stop_after_attempt(2))
6846
async def summarize_messages(
6947
chat_session,
70-
model="gpt-4-turbo",
71-
stop=["</ct:summarized"],
48+
model="gpt-4-turbo",
49+
stop=["</ct:summarized"],
7250
temperature=0.8,
7351
**kwargs,
7452
):
7553
assert len(chat_session) > 2, "Session is too short"
7654

7755
offset = 0
78-
56+
7957
# Remove the system prompt if present
8058
if (
8159
chat_session[0]["role"] == "system"
8260
and chat_session[0].get("name") != "entities"
8361
):
8462
chat_session = chat_session[1:]
85-
63+
8664
# The indices are not matched up correctly
8765
offset = 1
8866

@@ -98,7 +76,9 @@ async def summarize_messages(
9876
)
9977

10078
assert "<ct:summarized-messages>" in result["content"]
101-
summarized_messages = json.loads(result["content"].split("<ct:summarized-messages>")[-1].strip())
79+
summarized_messages = json.loads(
80+
result["content"].split("<ct:summarized-messages>")[-1].strip()
81+
)
10282

10383
assert all((msg.get("summarizes") is not None for msg in summarized_messages))
10484

@@ -107,5 +87,5 @@ async def summarize_messages(
10787
{**msg, "summarizes": [i + offset for i in msg["summarizes"]]}
10888
for msg in summarized_messages
10989
]
110-
90+
11191
return summarized_messages

0 commit comments

Comments
 (0)