33from datetime import datetime , UTC
44import json
55import logging
6- import os
7- from pathlib import Path
8- from typing import Annotated , Any
6+ from typing import Annotated , Any , cast
97
108from llama_stack_client import APIConnectionError
119from llama_stack_client import AsyncLlamaStackClient # type: ignore
10+ from llama_stack_client .lib .agents .event_logger import interleaved_content_as_str
1211from llama_stack_client .types import UserMessage , Shield # type: ignore
12+ from llama_stack_client .types .agents .turn import Turn
1313from llama_stack_client .types .agents .turn_create_params import (
1414 ToolgroupAgentToolGroupWithArgs ,
1515 Toolgroup ,
3535 validate_conversation_ownership ,
3636)
3737from utils .mcp_headers import mcp_headers_dependency , handle_mcp_headers_with_toolgroups
38- from utils .suid import get_suid
38+ from utils .transcripts import store_transcript
39+ from utils .types import TurnSummary
3940
4041logger = logging .getLogger ("app.endpoints.handlers" )
4142router = APIRouter (tags = ["query" ])
@@ -189,7 +190,7 @@ async def query_endpoint_handler(
189190 user_conversation = user_conversation , query_request = query_request
190191 ),
191192 )
192- response , conversation_id = await retrieve_response (
193+ summary , conversation_id = await retrieve_response (
193194 client ,
194195 llama_stack_model_id ,
195196 query_request ,
@@ -210,7 +211,7 @@ async def query_endpoint_handler(
210211 query_is_valid = True , # TODO(lucasagomes): implement as part of query validation
211212 query = query_request .query ,
212213 query_request = query_request ,
213- response = response ,
214+ summary = summary ,
214215 rag_chunks = [], # TODO(lucasagomes): implement rag_chunks
215216 truncated = False , # TODO(lucasagomes): implement truncation as part of quota work
216217 attachments = query_request .attachments or [],
@@ -223,7 +224,10 @@ async def query_endpoint_handler(
223224 provider_id = provider_id ,
224225 )
225226
226- return QueryResponse (conversation_id = conversation_id , response = response )
227+ return QueryResponse (
228+ conversation_id = conversation_id ,
229+ response = summary .llm_response ,
230+ )
227231
228232 # connection to Llama Stack server
229233 except APIConnectionError as e :
@@ -322,7 +326,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
322326 query_request : QueryRequest ,
323327 token : str ,
324328 mcp_headers : dict [str , dict [str , str ]] | None = None ,
325- ) -> tuple [str , str ]:
329+ ) -> tuple [TurnSummary , str ]:
326330 """Retrieve response from LLMs and agents."""
327331 available_input_shields = [
328332 shield .identifier
@@ -401,16 +405,23 @@ async def retrieve_response( # pylint: disable=too-many-locals
401405 stream = False ,
402406 toolgroups = toolgroups ,
403407 )
408+ response = cast (Turn , response )
409+
410+ summary = TurnSummary (
411+ llm_response = interleaved_content_as_str (response .output_message .content ),
412+ tool_calls = [],
413+ )
404414
405415 # Check for validation errors in the response
406- steps = getattr ( response , " steps" , [])
416+ steps = response . steps or []
407417 for step in steps :
408418 if step .step_type == "shield_call" and step .violation :
409419 # Metric for LLM validation errors
410420 metrics .llm_calls_validation_errors_total .inc ()
411- break
421+ if step .step_type == "tool_execution" :
422+ summary .append_tool_calls_from_llama (step )
412423
413- return str ( response . output_message . content ) , conversation_id # type: ignore[union-attr]
424+ return summary , conversation_id
414425
415426
416427def validate_attachments_metadata (attachments : list [Attachment ]) -> None :
@@ -443,73 +454,6 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None:
443454 )
444455
445456
446- def construct_transcripts_path (user_id : str , conversation_id : str ) -> Path :
447- """Construct path to transcripts."""
448- # these two normalizations are required by Snyk as it detects
449- # this Path sanitization pattern
450- uid = os .path .normpath ("/" + user_id ).lstrip ("/" )
451- cid = os .path .normpath ("/" + conversation_id ).lstrip ("/" )
452- file_path = (
453- configuration .user_data_collection_configuration .transcripts_storage or ""
454- )
455- return Path (file_path , uid , cid )
456-
457-
458- def store_transcript ( # pylint: disable=too-many-arguments,too-many-positional-arguments
459- user_id : str ,
460- conversation_id : str ,
461- model_id : str ,
462- provider_id : str | None ,
463- query_is_valid : bool ,
464- query : str ,
465- query_request : QueryRequest ,
466- response : str ,
467- rag_chunks : list [str ],
468- truncated : bool ,
469- attachments : list [Attachment ],
470- ) -> None :
471- """Store transcript in the local filesystem.
472-
473- Args:
474- user_id: The user ID (UUID).
475- conversation_id: The conversation ID (UUID).
476- query_is_valid: The result of the query validation.
477- query: The query (without attachments).
478- query_request: The request containing a query.
479- response: The response to store.
480- rag_chunks: The list of `RagChunk` objects.
481- truncated: The flag indicating if the history was truncated.
482- attachments: The list of `Attachment` objects.
483- """
484- transcripts_path = construct_transcripts_path (user_id , conversation_id )
485- transcripts_path .mkdir (parents = True , exist_ok = True )
486-
487- data_to_store = {
488- "metadata" : {
489- "provider" : provider_id ,
490- "model" : model_id ,
491- "query_provider" : query_request .provider ,
492- "query_model" : query_request .model ,
493- "user_id" : user_id ,
494- "conversation_id" : conversation_id ,
495- "timestamp" : datetime .now (UTC ).isoformat (),
496- },
497- "redacted_query" : query ,
498- "query_is_valid" : query_is_valid ,
499- "llm_response" : response ,
500- "rag_chunks" : rag_chunks ,
501- "truncated" : truncated ,
502- "attachments" : [attachment .model_dump () for attachment in attachments ],
503- }
504-
505- # stores feedback in a file under unique uuid
506- transcript_file_path = transcripts_path / f"{ get_suid ()} .json"
507- with open (transcript_file_path , "w" , encoding = "utf-8" ) as transcript_file :
508- json .dump (data_to_store , transcript_file )
509-
510- logger .info ("Transcript successfully stored at: %s" , transcript_file_path )
511-
512-
513457def get_rag_toolgroups (
514458 vector_db_ids : list [str ],
515459) -> list [Toolgroup ] | None :
0 commit comments