33from datetime import datetime , UTC
44import json
55import logging
6- import os
7- from pathlib import Path
8- from typing import Annotated , Any
6+ from typing import Annotated , Any , cast
97
108from llama_stack_client import APIConnectionError
119from llama_stack_client import AsyncLlamaStackClient # type: ignore
10+ from llama_stack_client .lib .agents .event_logger import interleaved_content_as_str
1211from llama_stack_client .types import UserMessage , Shield # type: ignore
12+ from llama_stack_client .types .agents .turn import Turn
1313from llama_stack_client .types .agents .turn_create_params import (
1414 ToolgroupAgentToolGroupWithArgs ,
1515 Toolgroup ,
3535 validate_conversation_ownership ,
3636)
3737from utils .mcp_headers import mcp_headers_dependency , handle_mcp_headers_with_toolgroups
38- from utils .suid import get_suid
38+ from utils .transcripts import store_transcript
39+ from utils .types import TurnSummary
3940
4041logger = logging .getLogger ("app.endpoints.handlers" )
4142router = APIRouter (tags = ["query" ])
@@ -203,7 +204,7 @@ async def query_endpoint_handler(
203204 user_conversation = user_conversation , query_request = query_request
204205 ),
205206 )
206- response , conversation_id = await retrieve_response (
207+ summary , conversation_id = await retrieve_response (
207208 client ,
208209 llama_stack_model_id ,
209210 query_request ,
@@ -224,7 +225,7 @@ async def query_endpoint_handler(
224225 query_is_valid = True , # TODO(lucasagomes): implement as part of query validation
225226 query = query_request .query ,
226227 query_request = query_request ,
227- response = response ,
228+ summary = summary ,
228229 rag_chunks = [], # TODO(lucasagomes): implement rag_chunks
229230 truncated = False , # TODO(lucasagomes): implement truncation as part of quota work
230231 attachments = query_request .attachments or [],
@@ -237,7 +238,10 @@ async def query_endpoint_handler(
237238 provider_id = provider_id ,
238239 )
239240
240- return QueryResponse (conversation_id = conversation_id , response = response )
241+ return QueryResponse (
242+ conversation_id = conversation_id ,
243+ response = summary .llm_response ,
244+ )
241245
242246 # connection to Llama Stack server
243247 except APIConnectionError as e :
@@ -381,7 +385,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
381385 query_request : QueryRequest ,
382386 token : str ,
383387 mcp_headers : dict [str , dict [str , str ]] | None = None ,
384- ) -> tuple [str , str ]:
388+ ) -> tuple [TurnSummary , str ]:
385389 """
386390 Retrieve response from LLMs and agents.
387391
@@ -404,7 +408,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
404408 mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing.
405409
406410 Returns:
407- tuple[str , str]: A tuple containing the LLM or agent's response content
411+ tuple[TurnSummary , str]: A tuple containing a summary of the LLM or agent's response content
408412 and the conversation ID.
409413 """
410414 available_input_shields = [
@@ -484,27 +488,35 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
484488 stream = False ,
485489 toolgroups = toolgroups ,
486490 )
491+ response = cast (Turn , response )
492+
493+ summary = TurnSummary (
494+ llm_response = (
495+ interleaved_content_as_str (response .output_message .content )
496+ if (
497+ getattr (response , "output_message" , None ) is not None
498+ and getattr (response .output_message , "content" , None ) is not None
499+ )
500+ else ""
501+ ),
502+ tool_calls = [],
503+ )
487504
488505 # Check for validation errors in the response
489- steps = getattr ( response , " steps" , [])
506+ steps = response . steps or []
490507 for step in steps :
491508 if step .step_type == "shield_call" and step .violation :
492509 # Metric for LLM validation errors
493510 metrics .llm_calls_validation_errors_total .inc ()
494- break
495-
496- output_message = getattr (response , "output_message" , None )
497- if output_message is not None :
498- content = getattr (output_message , "content" , None )
499- if content is not None :
500- return str (content ), conversation_id
501-
502- # fallback
503- logger .warning (
504- "Response lacks output_message.content (conversation_id=%s)" ,
505- conversation_id ,
506- )
507- return "" , conversation_id
511+ if step .step_type == "tool_execution" :
512+ summary .append_tool_calls_from_llama (step )
513+
514+ if not summary .llm_response :
515+ logger .warning (
516+ "Response lacks output_message.content (conversation_id=%s)" ,
517+ conversation_id ,
518+ )
519+ return summary , conversation_id
508520
509521
510522def validate_attachments_metadata (attachments : list [Attachment ]) -> None :
@@ -539,92 +551,6 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None:
539551 )
540552
541553
542- def construct_transcripts_path (user_id : str , conversation_id : str ) -> Path :
543- """
544- Construct path to transcripts.
545-
546- Constructs a sanitized filesystem path for storing transcripts
547- based on the user ID and conversation ID.
548-
549- Parameters:
550- user_id (str): The user identifier, which will be normalized and sanitized.
551- conversation_id (str): The conversation identifier, which will be normalized and sanitized.
552-
553- Returns:
554- Path: The constructed path for storing transcripts for the specified user and conversation.
555- """
556- # these two normalizations are required by Snyk as it detects
557- # this Path sanitization pattern
558- uid = os .path .normpath ("/" + user_id ).lstrip ("/" )
559- cid = os .path .normpath ("/" + conversation_id ).lstrip ("/" )
560- file_path = (
561- configuration .user_data_collection_configuration .transcripts_storage or ""
562- )
563- return Path (file_path , uid , cid )
564-
565-
566- def store_transcript ( # pylint: disable=too-many-arguments,too-many-positional-arguments
567- user_id : str ,
568- conversation_id : str ,
569- model_id : str ,
570- provider_id : str | None ,
571- query_is_valid : bool ,
572- query : str ,
573- query_request : QueryRequest ,
574- response : str ,
575- rag_chunks : list [str ],
576- truncated : bool ,
577- attachments : list [Attachment ],
578- ) -> None :
579- """
580- Store transcript in the local filesystem.
581-
582- Constructs a sanitized filesystem path for storing transcripts
583- based on the user ID and conversation ID.
584-
585- Returns:
586- Path: The constructed path for storing transcripts for the specified user and conversation.
587-
588- Args:
589- user_id: The user ID (UUID).
590- conversation_id: The conversation ID (UUID).
591- query_is_valid: The result of the query validation.
592- query: The query (without attachments).
593- query_request: The request containing a query.
594- response: The response to store.
595- rag_chunks: The list of `RagChunk` objects.
596- truncated: The flag indicating if the history was truncated.
597- attachments: The list of `Attachment` objects.
598- """
599- transcripts_path = construct_transcripts_path (user_id , conversation_id )
600- transcripts_path .mkdir (parents = True , exist_ok = True )
601-
602- data_to_store = {
603- "metadata" : {
604- "provider" : provider_id ,
605- "model" : model_id ,
606- "query_provider" : query_request .provider ,
607- "query_model" : query_request .model ,
608- "user_id" : user_id ,
609- "conversation_id" : conversation_id ,
610- "timestamp" : datetime .now (UTC ).isoformat (),
611- },
612- "redacted_query" : query ,
613- "query_is_valid" : query_is_valid ,
614- "llm_response" : response ,
615- "rag_chunks" : rag_chunks ,
616- "truncated" : truncated ,
617- "attachments" : [attachment .model_dump () for attachment in attachments ],
618- }
619-
620- # stores feedback in a file under unique uuid
621- transcript_file_path = transcripts_path / f"{ get_suid ()} .json"
622- with open (transcript_file_path , "w" , encoding = "utf-8" ) as transcript_file :
623- json .dump (data_to_store , transcript_file )
624-
625- logger .info ("Transcript successfully stored at: %s" , transcript_file_path )
626-
627-
628554def get_rag_toolgroups (
629555 vector_db_ids : list [str ],
630556) -> list [Toolgroup ] | None :
0 commit comments