Skip to content

Commit f9ad22d

Browse files
committed
Add tool calls to stored transcripts
Also moved all of the transcript handling to its own module as it grew a bit with this.
1 parent 2cc494c commit f9ad22d

File tree

7 files changed

+397
-211
lines changed

7 files changed

+397
-211
lines changed

src/app/endpoints/query.py

Lines changed: 22 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from datetime import datetime, UTC
44
import json
55
import logging
6-
import os
7-
from pathlib import Path
8-
from typing import Annotated, Any
6+
from typing import Annotated, Any, cast
97

108
from llama_stack_client import APIConnectionError
119
from llama_stack_client import AsyncLlamaStackClient # type: ignore
10+
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
1211
from llama_stack_client.types import UserMessage, Shield # type: ignore
12+
from llama_stack_client.types.agents.turn import Turn
1313
from llama_stack_client.types.agents.turn_create_params import (
1414
ToolgroupAgentToolGroupWithArgs,
1515
Toolgroup,
@@ -35,7 +35,8 @@
3535
validate_conversation_ownership,
3636
)
3737
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
38-
from utils.suid import get_suid
38+
from utils.transcripts import store_transcript
39+
from utils.types import TurnSummary
3940

4041
logger = logging.getLogger("app.endpoints.handlers")
4142
router = APIRouter(tags=["query"])
@@ -189,7 +190,7 @@ async def query_endpoint_handler(
189190
user_conversation=user_conversation, query_request=query_request
190191
),
191192
)
192-
response, conversation_id = await retrieve_response(
193+
summary, conversation_id = await retrieve_response(
193194
client,
194195
llama_stack_model_id,
195196
query_request,
@@ -210,7 +211,7 @@ async def query_endpoint_handler(
210211
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
211212
query=query_request.query,
212213
query_request=query_request,
213-
response=response,
214+
summary=summary,
214215
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
215216
truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
216217
attachments=query_request.attachments or [],
@@ -223,7 +224,10 @@ async def query_endpoint_handler(
223224
provider_id=provider_id,
224225
)
225226

226-
return QueryResponse(conversation_id=conversation_id, response=response)
227+
return QueryResponse(
228+
conversation_id=conversation_id,
229+
response=summary.llm_response,
230+
)
227231

228232
# connection to Llama Stack server
229233
except APIConnectionError as e:
@@ -322,7 +326,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
322326
query_request: QueryRequest,
323327
token: str,
324328
mcp_headers: dict[str, dict[str, str]] | None = None,
325-
) -> tuple[str, str]:
329+
) -> tuple[TurnSummary, str]:
326330
"""Retrieve response from LLMs and agents."""
327331
available_input_shields = [
328332
shield.identifier
@@ -401,16 +405,23 @@ async def retrieve_response( # pylint: disable=too-many-locals
401405
stream=False,
402406
toolgroups=toolgroups,
403407
)
408+
response = cast(Turn, response)
409+
410+
summary = TurnSummary(
411+
llm_response=interleaved_content_as_str(response.output_message.content),
412+
tool_calls=[],
413+
)
404414

405415
# Check for validation errors in the response
406-
steps = getattr(response, "steps", [])
416+
steps = response.steps or []
407417
for step in steps:
408418
if step.step_type == "shield_call" and step.violation:
409419
# Metric for LLM validation errors
410420
metrics.llm_calls_validation_errors_total.inc()
411-
break
421+
if step.step_type == "tool_execution":
422+
summary.append_tool_calls_from_llama(step)
412423

413-
return str(response.output_message.content), conversation_id # type: ignore[union-attr]
424+
return summary, conversation_id
414425

415426

416427
def validate_attachments_metadata(attachments: list[Attachment]) -> None:
@@ -443,73 +454,6 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None:
443454
)
444455

445456

446-
def construct_transcripts_path(user_id: str, conversation_id: str) -> Path:
447-
"""Construct path to transcripts."""
448-
# these two normalizations are required by Snyk as it detects
449-
# this Path sanitization pattern
450-
uid = os.path.normpath("/" + user_id).lstrip("/")
451-
cid = os.path.normpath("/" + conversation_id).lstrip("/")
452-
file_path = (
453-
configuration.user_data_collection_configuration.transcripts_storage or ""
454-
)
455-
return Path(file_path, uid, cid)
456-
457-
458-
def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-arguments
459-
user_id: str,
460-
conversation_id: str,
461-
model_id: str,
462-
provider_id: str | None,
463-
query_is_valid: bool,
464-
query: str,
465-
query_request: QueryRequest,
466-
response: str,
467-
rag_chunks: list[str],
468-
truncated: bool,
469-
attachments: list[Attachment],
470-
) -> None:
471-
"""Store transcript in the local filesystem.
472-
473-
Args:
474-
user_id: The user ID (UUID).
475-
conversation_id: The conversation ID (UUID).
476-
query_is_valid: The result of the query validation.
477-
query: The query (without attachments).
478-
query_request: The request containing a query.
479-
response: The response to store.
480-
rag_chunks: The list of `RagChunk` objects.
481-
truncated: The flag indicating if the history was truncated.
482-
attachments: The list of `Attachment` objects.
483-
"""
484-
transcripts_path = construct_transcripts_path(user_id, conversation_id)
485-
transcripts_path.mkdir(parents=True, exist_ok=True)
486-
487-
data_to_store = {
488-
"metadata": {
489-
"provider": provider_id,
490-
"model": model_id,
491-
"query_provider": query_request.provider,
492-
"query_model": query_request.model,
493-
"user_id": user_id,
494-
"conversation_id": conversation_id,
495-
"timestamp": datetime.now(UTC).isoformat(),
496-
},
497-
"redacted_query": query,
498-
"query_is_valid": query_is_valid,
499-
"llm_response": response,
500-
"rag_chunks": rag_chunks,
501-
"truncated": truncated,
502-
"attachments": [attachment.model_dump() for attachment in attachments],
503-
}
504-
505-
# stores feedback in a file under unique uuid
506-
transcript_file_path = transcripts_path / f"{get_suid()}.json"
507-
with open(transcript_file_path, "w", encoding="utf-8") as transcript_file:
508-
json.dump(data_to_store, transcript_file)
509-
510-
logger.info("Transcript successfully stored at: %s", transcript_file_path)
511-
512-
513457
def get_rag_toolgroups(
514458
vector_db_ids: list[str],
515459
) -> list[Toolgroup] | None:

src/app/endpoints/streaming_query.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
import json
55
import re
66
import logging
7-
from typing import Annotated, Any, AsyncIterator, Iterator
7+
from typing import Annotated, Any, AsyncIterator, Iterator, cast
88

99
from llama_stack_client import APIConnectionError
1010
from llama_stack_client import AsyncLlamaStackClient # type: ignore
1111
from llama_stack_client.types import UserMessage # type: ignore
1212

1313
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
14+
from llama_stack_client.types.agents.agent_turn_response_stream_chunk import (
15+
AgentTurnResponseStreamChunk,
16+
)
1417
from llama_stack_client.types.shared import ToolCall
1518
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
1619

@@ -26,13 +29,14 @@
2629
from models.database.conversations import UserConversation
2730
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
2831
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
32+
from utils.transcripts import store_transcript
33+
from utils.types import TurnSummary
2934

3035
from app.endpoints.query import (
3136
get_rag_toolgroups,
3237
is_input_shield,
3338
is_output_shield,
3439
is_transcripts_enabled,
35-
store_transcript,
3640
select_model_and_provider_id,
3741
validate_attachments_metadata,
3842
validate_conversation_ownership,
@@ -436,23 +440,29 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
436440
)
437441
metadata_map: dict[str, dict[str, Any]] = {}
438442

439-
async def response_generator(turn_response: Any) -> AsyncIterator[str]:
443+
async def response_generator(
444+
turn_response: AsyncIterator[AgentTurnResponseStreamChunk],
445+
) -> AsyncIterator[str]:
440446
"""Generate SSE formatted streaming response."""
441447
chunk_id = 0
442-
complete_response = "No response from the model"
448+
summary = TurnSummary(
449+
llm_response="No response from the model", tool_calls=[]
450+
)
443451

444452
# Send start event
445453
yield stream_start_event(conversation_id)
446454

447455
async for chunk in turn_response:
456+
p = chunk.event.payload
457+
if p.event_type == "turn_complete":
458+
summary.llm_response = interleaved_content_as_str(
459+
p.turn.output_message.content
460+
)
461+
elif p.event_type == "step_complete":
462+
if p.step_details.step_type == "tool_execution":
463+
summary.append_tool_calls_from_llama(p.step_details)
464+
448465
for event in stream_build_event(chunk, chunk_id, metadata_map):
449-
if (
450-
json.loads(event.replace("data: ", ""))["event"]
451-
== "turn_complete"
452-
):
453-
complete_response = json.loads(event.replace("data: ", ""))[
454-
"data"
455-
]["token"]
456466
chunk_id += 1
457467
yield event
458468

@@ -469,7 +479,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
469479
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
470480
query=query_request.query,
471481
query_request=query_request,
472-
response=complete_response,
482+
summary=summary,
473483
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
474484
truncated=False, # TODO(lucasagomes): implement truncation as part
475485
# of quota work
@@ -507,7 +517,7 @@ async def retrieve_response(
507517
query_request: QueryRequest,
508518
token: str,
509519
mcp_headers: dict[str, dict[str, str]] | None = None,
510-
) -> tuple[Any, str]:
520+
) -> tuple[AsyncIterator[AgentTurnResponseStreamChunk], str]:
511521
"""Retrieve response from LLMs and agents."""
512522
available_input_shields = [
513523
shield.identifier
@@ -588,5 +598,6 @@ async def retrieve_response(
588598
stream=True,
589599
toolgroups=toolgroups,
590600
)
601+
response = cast(AsyncIterator[AgentTurnResponseStreamChunk], response)
591602

592603
return response, conversation_id

src/utils/transcripts.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Transcript handling.
2+
3+
Transcripts are a log of individual query/response pairs that get
4+
stored on disk for later analysis
5+
"""
6+
7+
from datetime import UTC, datetime
8+
import json
9+
import logging
10+
import os
11+
from pathlib import Path
12+
13+
from configuration import configuration
14+
from models.requests import Attachment, QueryRequest
15+
from utils.suid import get_suid
16+
from utils.types import TurnSummary
17+
18+
logger = logging.getLogger("utils.transcripts")
19+
20+
21+
def construct_transcripts_path(user_id: str, conversation_id: str) -> Path:
22+
"""Construct path to transcripts."""
23+
# these two normalizations are required by Snyk as it detects
24+
# this Path sanitization pattern
25+
uid = os.path.normpath("/" + user_id).lstrip("/")
26+
cid = os.path.normpath("/" + conversation_id).lstrip("/")
27+
file_path = (
28+
configuration.user_data_collection_configuration.transcripts_storage or ""
29+
)
30+
return Path(file_path, uid, cid)
31+
32+
33+
def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals
34+
user_id: str,
35+
conversation_id: str,
36+
model_id: str,
37+
provider_id: str | None,
38+
query_is_valid: bool,
39+
query: str,
40+
query_request: QueryRequest,
41+
summary: TurnSummary,
42+
rag_chunks: list[str],
43+
truncated: bool,
44+
attachments: list[Attachment],
45+
) -> None:
46+
"""Store transcript in the local filesystem.
47+
48+
Args:
49+
user_id: The user ID (UUID).
50+
conversation_id: The conversation ID (UUID).
51+
query_is_valid: The result of the query validation.
52+
query: The query (without attachments).
53+
query_request: The request containing a query.
54+
summary: Summary of the query/response turn.
55+
rag_chunks: The list of `RagChunk` objects.
56+
truncated: The flag indicating if the history was truncated.
57+
attachments: The list of `Attachment` objects.
58+
"""
59+
transcripts_path = construct_transcripts_path(user_id, conversation_id)
60+
transcripts_path.mkdir(parents=True, exist_ok=True)
61+
62+
data_to_store = {
63+
"metadata": {
64+
"provider": provider_id,
65+
"model": model_id,
66+
"query_provider": query_request.provider,
67+
"query_model": query_request.model,
68+
"user_id": user_id,
69+
"conversation_id": conversation_id,
70+
"timestamp": datetime.now(UTC).isoformat(),
71+
},
72+
"redacted_query": query,
73+
"query_is_valid": query_is_valid,
74+
"llm_response": summary.llm_response,
75+
"rag_chunks": rag_chunks,
76+
"truncated": truncated,
77+
"attachments": [attachment.model_dump() for attachment in attachments],
78+
"tool_calls": [tc.model_dump() for tc in summary.tool_calls],
79+
}
80+
81+
# stores feedback in a file under unique uuid
82+
transcript_file_path = transcripts_path / f"{get_suid()}.json"
83+
with open(transcript_file_path, "w", encoding="utf-8") as transcript_file:
84+
json.dump(data_to_store, transcript_file)
85+
86+
logger.info("Transcript successfully stored at: %s", transcript_file_path)

0 commit comments

Comments
 (0)