Skip to content

Commit 906e60a

Browse files
committed
Add tool calls to stored transcripts
Also moved all of the transcript handling to its own module as it grew a bit with this.
1 parent 601339b commit 906e60a

File tree

7 files changed

+413
-245
lines changed

7 files changed

+413
-245
lines changed

src/app/endpoints/query.py

Lines changed: 36 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from datetime import datetime, UTC
44
import json
55
import logging
6-
import os
7-
from pathlib import Path
8-
from typing import Annotated, Any
6+
from typing import Annotated, Any, cast
97

108
from llama_stack_client import APIConnectionError
119
from llama_stack_client import AsyncLlamaStackClient # type: ignore
10+
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
1211
from llama_stack_client.types import UserMessage, Shield # type: ignore
12+
from llama_stack_client.types.agents.turn import Turn
1313
from llama_stack_client.types.agents.turn_create_params import (
1414
ToolgroupAgentToolGroupWithArgs,
1515
Toolgroup,
@@ -35,7 +35,8 @@
3535
validate_conversation_ownership,
3636
)
3737
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
38-
from utils.suid import get_suid
38+
from utils.transcripts import store_transcript
39+
from utils.types import TurnSummary
3940

4041
logger = logging.getLogger("app.endpoints.handlers")
4142
router = APIRouter(tags=["query"])
@@ -203,7 +204,7 @@ async def query_endpoint_handler(
203204
user_conversation=user_conversation, query_request=query_request
204205
),
205206
)
206-
response, conversation_id = await retrieve_response(
207+
summary, conversation_id = await retrieve_response(
207208
client,
208209
llama_stack_model_id,
209210
query_request,
@@ -224,7 +225,7 @@ async def query_endpoint_handler(
224225
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
225226
query=query_request.query,
226227
query_request=query_request,
227-
response=response,
228+
summary=summary,
228229
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
229230
truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
230231
attachments=query_request.attachments or [],
@@ -237,7 +238,10 @@ async def query_endpoint_handler(
237238
provider_id=provider_id,
238239
)
239240

240-
return QueryResponse(conversation_id=conversation_id, response=response)
241+
return QueryResponse(
242+
conversation_id=conversation_id,
243+
response=summary.llm_response,
244+
)
241245

242246
# connection to Llama Stack server
243247
except APIConnectionError as e:
@@ -381,7 +385,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
381385
query_request: QueryRequest,
382386
token: str,
383387
mcp_headers: dict[str, dict[str, str]] | None = None,
384-
) -> tuple[str, str]:
388+
) -> tuple[TurnSummary, str]:
385389
"""
386390
Retrieve response from LLMs and agents.
387391
@@ -404,7 +408,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
404408
mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing.
405409
406410
Returns:
407-
tuple[str, str]: A tuple containing the LLM or agent's response content
411+
tuple[TurnSummary, str]: A tuple containing a summary of the LLM or agent's response content
408412
and the conversation ID.
409413
"""
410414
available_input_shields = [
@@ -484,27 +488,35 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
484488
stream=False,
485489
toolgroups=toolgroups,
486490
)
491+
response = cast(Turn, response)
492+
493+
summary = TurnSummary(
494+
llm_response=(
495+
interleaved_content_as_str(response.output_message.content)
496+
if (
497+
getattr(response, "output_message", None) is not None
498+
and getattr(response.output_message, "content", None) is not None
499+
)
500+
else ""
501+
),
502+
tool_calls=[],
503+
)
487504

488505
# Check for validation errors in the response
489-
steps = getattr(response, "steps", [])
506+
steps = response.steps or []
490507
for step in steps:
491508
if step.step_type == "shield_call" and step.violation:
492509
# Metric for LLM validation errors
493510
metrics.llm_calls_validation_errors_total.inc()
494-
break
495-
496-
output_message = getattr(response, "output_message", None)
497-
if output_message is not None:
498-
content = getattr(output_message, "content", None)
499-
if content is not None:
500-
return str(content), conversation_id
501-
502-
# fallback
503-
logger.warning(
504-
"Response lacks output_message.content (conversation_id=%s)",
505-
conversation_id,
506-
)
507-
return "", conversation_id
511+
if step.step_type == "tool_execution":
512+
summary.append_tool_calls_from_llama(step)
513+
514+
if not summary.llm_response:
515+
logger.warning(
516+
"Response lacks output_message.content (conversation_id=%s)",
517+
conversation_id,
518+
)
519+
return summary, conversation_id
508520

509521

510522
def validate_attachments_metadata(attachments: list[Attachment]) -> None:
@@ -539,92 +551,6 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None:
539551
)
540552

541553

542-
def construct_transcripts_path(user_id: str, conversation_id: str) -> Path:
543-
"""
544-
Construct path to transcripts.
545-
546-
Constructs a sanitized filesystem path for storing transcripts
547-
based on the user ID and conversation ID.
548-
549-
Parameters:
550-
user_id (str): The user identifier, which will be normalized and sanitized.
551-
conversation_id (str): The conversation identifier, which will be normalized and sanitized.
552-
553-
Returns:
554-
Path: The constructed path for storing transcripts for the specified user and conversation.
555-
"""
556-
# these two normalizations are required by Snyk as it detects
557-
# this Path sanitization pattern
558-
uid = os.path.normpath("/" + user_id).lstrip("/")
559-
cid = os.path.normpath("/" + conversation_id).lstrip("/")
560-
file_path = (
561-
configuration.user_data_collection_configuration.transcripts_storage or ""
562-
)
563-
return Path(file_path, uid, cid)
564-
565-
566-
def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-arguments
567-
user_id: str,
568-
conversation_id: str,
569-
model_id: str,
570-
provider_id: str | None,
571-
query_is_valid: bool,
572-
query: str,
573-
query_request: QueryRequest,
574-
response: str,
575-
rag_chunks: list[str],
576-
truncated: bool,
577-
attachments: list[Attachment],
578-
) -> None:
579-
"""
580-
Store transcript in the local filesystem.
581-
582-
Constructs a sanitized filesystem path for storing transcripts
583-
based on the user ID and conversation ID.
584-
585-
Returns:
586-
Path: The constructed path for storing transcripts for the specified user and conversation.
587-
588-
Args:
589-
user_id: The user ID (UUID).
590-
conversation_id: The conversation ID (UUID).
591-
query_is_valid: The result of the query validation.
592-
query: The query (without attachments).
593-
query_request: The request containing a query.
594-
response: The response to store.
595-
rag_chunks: The list of `RagChunk` objects.
596-
truncated: The flag indicating if the history was truncated.
597-
attachments: The list of `Attachment` objects.
598-
"""
599-
transcripts_path = construct_transcripts_path(user_id, conversation_id)
600-
transcripts_path.mkdir(parents=True, exist_ok=True)
601-
602-
data_to_store = {
603-
"metadata": {
604-
"provider": provider_id,
605-
"model": model_id,
606-
"query_provider": query_request.provider,
607-
"query_model": query_request.model,
608-
"user_id": user_id,
609-
"conversation_id": conversation_id,
610-
"timestamp": datetime.now(UTC).isoformat(),
611-
},
612-
"redacted_query": query,
613-
"query_is_valid": query_is_valid,
614-
"llm_response": response,
615-
"rag_chunks": rag_chunks,
616-
"truncated": truncated,
617-
"attachments": [attachment.model_dump() for attachment in attachments],
618-
}
619-
620-
# stores feedback in a file under unique uuid
621-
transcript_file_path = transcripts_path / f"{get_suid()}.json"
622-
with open(transcript_file_path, "w", encoding="utf-8") as transcript_file:
623-
json.dump(data_to_store, transcript_file)
624-
625-
logger.info("Transcript successfully stored at: %s", transcript_file_path)
626-
627-
628554
def get_rag_toolgroups(
629555
vector_db_ids: list[str],
630556
) -> list[Toolgroup] | None:

src/app/endpoints/streaming_query.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
import json
55
import re
66
import logging
7-
from typing import Annotated, Any, AsyncIterator, Iterator
7+
from typing import Annotated, Any, AsyncIterator, Iterator, cast
88

99
from llama_stack_client import APIConnectionError
1010
from llama_stack_client import AsyncLlamaStackClient # type: ignore
1111
from llama_stack_client.types import UserMessage # type: ignore
1212

1313
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
14+
from llama_stack_client.types.agents.agent_turn_response_stream_chunk import (
15+
AgentTurnResponseStreamChunk,
16+
)
1417
from llama_stack_client.types.shared import ToolCall
1518
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
1619

@@ -26,13 +29,14 @@
2629
from models.database.conversations import UserConversation
2730
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
2831
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
32+
from utils.transcripts import store_transcript
33+
from utils.types import TurnSummary
2934

3035
from app.endpoints.query import (
3136
get_rag_toolgroups,
3237
is_input_shield,
3338
is_output_shield,
3439
is_transcripts_enabled,
35-
store_transcript,
3640
select_model_and_provider_id,
3741
validate_attachments_metadata,
3842
validate_conversation_ownership,
@@ -574,7 +578,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
574578
)
575579
metadata_map: dict[str, dict[str, Any]] = {}
576580

577-
async def response_generator(turn_response: Any) -> AsyncIterator[str]:
581+
async def response_generator(
582+
turn_response: AsyncIterator[AgentTurnResponseStreamChunk],
583+
) -> AsyncIterator[str]:
578584
"""
579585
Generate SSE formatted streaming response.
580586
@@ -587,20 +593,24 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
587593
complete response for transcript storage if enabled.
588594
"""
589595
chunk_id = 0
590-
complete_response = "No response from the model"
596+
summary = TurnSummary(
597+
llm_response="No response from the model", tool_calls=[]
598+
)
591599

592600
# Send start event
593601
yield stream_start_event(conversation_id)
594602

595603
async for chunk in turn_response:
604+
p = chunk.event.payload
605+
if p.event_type == "turn_complete":
606+
summary.llm_response = interleaved_content_as_str(
607+
p.turn.output_message.content
608+
)
609+
elif p.event_type == "step_complete":
610+
if p.step_details.step_type == "tool_execution":
611+
summary.append_tool_calls_from_llama(p.step_details)
612+
596613
for event in stream_build_event(chunk, chunk_id, metadata_map):
597-
if (
598-
json.loads(event.replace("data: ", ""))["event"]
599-
== "turn_complete"
600-
):
601-
complete_response = json.loads(event.replace("data: ", ""))[
602-
"data"
603-
]["token"]
604614
chunk_id += 1
605615
yield event
606616

@@ -617,7 +627,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
617627
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
618628
query=query_request.query,
619629
query_request=query_request,
620-
response=complete_response,
630+
summary=summary,
621631
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
622632
truncated=False, # TODO(lucasagomes): implement truncation as part
623633
# of quota work
@@ -655,7 +665,7 @@ async def retrieve_response(
655665
query_request: QueryRequest,
656666
token: str,
657667
mcp_headers: dict[str, dict[str, str]] | None = None,
658-
) -> tuple[Any, str]:
668+
) -> tuple[AsyncIterator[AgentTurnResponseStreamChunk], str]:
659669
"""
660670
Retrieve response from LLMs and agents.
661671
@@ -758,5 +768,6 @@ async def retrieve_response(
758768
stream=True,
759769
toolgroups=toolgroups,
760770
)
771+
response = cast(AsyncIterator[AgentTurnResponseStreamChunk], response)
761772

762773
return response, conversation_id

0 commit comments

Comments
 (0)