|
1 | 1 | #!/usr/bin/env python3
|
2 | 2 |
|
| 3 | +import asyncio |
3 | 4 | from uuid import UUID
|
4 | 5 | from typing import Callable
|
5 | 6 | from textwrap import dedent
|
6 | 7 | from temporalio import activity
|
7 | 8 | from litellm import acompletion
|
| 9 | +from agents_api.models.entry.add_entries import add_entries_query |
8 | 10 | from agents_api.models.entry.entries_summarization import (
|
9 | 11 | get_toplevel_entries_query,
|
10 | 12 | entries_summarization_query,
|
11 | 13 | )
|
12 | 14 | from agents_api.common.protocol.entries import Entry
|
13 |
| -from ..env import summarization_model_name |
| 15 | +from agents_api.rec_sum.entities import get_entities |
| 16 | +from agents_api.rec_sum.summarize import summarize_messages |
| 17 | +from agents_api.rec_sum.trim import trim_messages |
14 | 18 |
|
15 | 19 |
|
16 | 20 | example_previous_memory = """
|
@@ -148,31 +152,76 @@ async def run_prompt(
|
148 | 152 | return parser(content.strip() if content is not None else "")
|
149 | 153 |
|
150 | 154 |
|
| 155 | +# @activity.defn |
| 156 | +# async def summarization(session_id: str) -> None: |
| 157 | +# session_id = UUID(session_id) |
| 158 | +# entries = [ |
| 159 | +# Entry(**row) |
| 160 | +# for _, row in get_toplevel_entries_query(session_id=session_id).iterrows() |
| 161 | +# ] |
| 162 | + |
| 163 | +# assert len(entries) > 0, "no need to summarize on empty entries list" |
| 164 | + |
| 165 | +# response = await run_prompt( |
| 166 | +# dialog=entries, previous_memories=[], model=summarization_model_name |
| 167 | +# ) |
| 168 | + |
| 169 | +# new_entry = Entry( |
| 170 | +# session_id=session_id, |
| 171 | +# source="summarizer", |
| 172 | +# role="system", |
| 173 | +# name="information", |
| 174 | +# content=response, |
| 175 | +# timestamp=entries[-1].timestamp + 0.01, |
| 176 | +# ) |
| 177 | + |
| 178 | +# entries_summarization_query( |
| 179 | +# session_id=session_id, |
| 180 | +# new_entry=new_entry, |
| 181 | +# old_entry_ids=[e.id for e in entries], |
| 182 | +# ) |
| 183 | + |
| 184 | + |
151 | 185 | @activity.defn
|
152 | 186 | async def summarization(session_id: str) -> None:
|
153 | 187 | session_id = UUID(session_id)
|
154 | 188 | entries = [
|
155 |
| - Entry(**row) |
156 |
| - for _, row in get_toplevel_entries_query(session_id=session_id).iterrows() |
| 189 | + row for _, row in get_toplevel_entries_query(session_id=session_id).iterrows() |
157 | 190 | ]
|
158 | 191 |
|
159 | 192 | assert len(entries) > 0, "no need to summarize on empty entries list"
|
160 | 193 |
|
161 |
| - response = await run_prompt( |
162 |
| - dialog=entries, previous_memories=[], model=summarization_model_name |
| 194 | + trimmed_messages, entities = await asyncio.gather( |
| 195 | + trim_messages(entries), |
| 196 | + get_entities(entries), |
163 | 197 | )
|
164 |
| - |
165 |
| - new_entry = Entry( |
166 |
| - session_id=session_id, |
167 |
| - source="summarizer", |
168 |
| - role="system", |
169 |
| - name="information", |
170 |
| - content=response, |
171 |
| - timestamp=entries[-1].timestamp + 0.01, |
| 198 | + summarized = await summarize_messages(trimmed_messages) |
| 199 | + |
| 200 | + ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2 |
| 201 | + |
| 202 | + add_entries_query( |
| 203 | + Entry( |
| 204 | + session_id=session_id, |
| 205 | + source="summarizer", |
| 206 | + role="system", |
| 207 | + name="entities", |
| 208 | + content=entities["content"], |
| 209 | + timestamp=entries[0]["timestamp"] + ts_delta, |
| 210 | + ) |
172 | 211 | )
|
173 | 212 |
|
174 |
| - entries_summarization_query( |
175 |
| - session_id=session_id, |
176 |
| - new_entry=new_entry, |
177 |
| - old_entry_ids=[e.id for e in entries], |
178 |
| - ) |
| 213 | + for msg in summarized: |
| 214 | + new_entry = Entry( |
| 215 | + session_id=session_id, |
| 216 | + source="summarizer", |
| 217 | + role="system", |
| 218 | + name="information", |
| 219 | + content=msg["content"], |
| 220 | + timestamp=entries[-1]["timestamp"] + 0.01, |
| 221 | + ) |
| 222 | + |
| 223 | + entries_summarization_query( |
| 224 | + session_id=session_id, |
| 225 | + new_entry=new_entry, |
| 226 | + old_entry_ids=[entries[idx]["entry_id"] for idx in msg["summarizes"]], |
| 227 | + ) |
0 commit comments