Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 36 additions & 110 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from datetime import datetime, UTC
import json
import logging
import os
from pathlib import Path
from typing import Annotated, Any
from typing import Annotated, Any, cast

from llama_stack_client import APIConnectionError
from llama_stack_client import AsyncLlamaStackClient # type: ignore
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
from llama_stack_client.types import UserMessage, Shield # type: ignore
from llama_stack_client.types.agents.turn import Turn
from llama_stack_client.types.agents.turn_create_params import (
ToolgroupAgentToolGroupWithArgs,
Toolgroup,
Expand All @@ -35,7 +35,8 @@
validate_conversation_ownership,
)
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.suid import get_suid
from utils.transcripts import store_transcript
from utils.types import TurnSummary

logger = logging.getLogger("app.endpoints.handlers")
router = APIRouter(tags=["query"])
Expand Down Expand Up @@ -203,7 +204,7 @@ async def query_endpoint_handler(
user_conversation=user_conversation, query_request=query_request
),
)
response, conversation_id = await retrieve_response(
summary, conversation_id = await retrieve_response(
client,
llama_stack_model_id,
query_request,
Expand All @@ -224,7 +225,7 @@ async def query_endpoint_handler(
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
query=query_request.query,
query_request=query_request,
response=response,
summary=summary,
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
attachments=query_request.attachments or [],
Expand All @@ -237,7 +238,10 @@ async def query_endpoint_handler(
provider_id=provider_id,
)

return QueryResponse(conversation_id=conversation_id, response=response)
return QueryResponse(
conversation_id=conversation_id,
response=summary.llm_response,
)

# connection to Llama Stack server
except APIConnectionError as e:
Expand Down Expand Up @@ -381,7 +385,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[str, str]:
) -> tuple[TurnSummary, str]:
"""
Retrieve response from LLMs and agents.

Expand All @@ -404,7 +408,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing.

Returns:
tuple[str, str]: A tuple containing the LLM or agent's response content
tuple[TurnSummary, str]: A tuple containing a summary of the LLM or agent's response content
and the conversation ID.
"""
available_input_shields = [
Expand Down Expand Up @@ -484,27 +488,35 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
stream=False,
toolgroups=toolgroups,
)
response = cast(Turn, response)

summary = TurnSummary(
llm_response=(
interleaved_content_as_str(response.output_message.content)
if (
getattr(response, "output_message", None) is not None
and getattr(response.output_message, "content", None) is not None
)
else ""
),
tool_calls=[],
)

# Check for validation errors in the response
steps = getattr(response, "steps", [])
steps = response.steps or []
for step in steps:
if step.step_type == "shield_call" and step.violation:
# Metric for LLM validation errors
metrics.llm_calls_validation_errors_total.inc()
break

output_message = getattr(response, "output_message", None)
if output_message is not None:
content = getattr(output_message, "content", None)
if content is not None:
return str(content), conversation_id

# fallback
logger.warning(
"Response lacks output_message.content (conversation_id=%s)",
conversation_id,
)
return "", conversation_id
if step.step_type == "tool_execution":
summary.append_tool_calls_from_llama(step)

if not summary.llm_response:
logger.warning(
"Response lacks output_message.content (conversation_id=%s)",
conversation_id,
)
return summary, conversation_id


def validate_attachments_metadata(attachments: list[Attachment]) -> None:
Expand Down Expand Up @@ -539,92 +551,6 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None:
)


def construct_transcripts_path(user_id: str, conversation_id: str) -> Path:
"""
Construct path to transcripts.

Constructs a sanitized filesystem path for storing transcripts
based on the user ID and conversation ID.

Parameters:
user_id (str): The user identifier, which will be normalized and sanitized.
conversation_id (str): The conversation identifier, which will be normalized and sanitized.

Returns:
Path: The constructed path for storing transcripts for the specified user and conversation.
"""
# these two normalizations are required by Snyk as it detects
# this Path sanitization pattern
uid = os.path.normpath("/" + user_id).lstrip("/")
cid = os.path.normpath("/" + conversation_id).lstrip("/")
file_path = (
configuration.user_data_collection_configuration.transcripts_storage or ""
)
return Path(file_path, uid, cid)


def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-arguments
user_id: str,
conversation_id: str,
model_id: str,
provider_id: str | None,
query_is_valid: bool,
query: str,
query_request: QueryRequest,
response: str,
rag_chunks: list[str],
truncated: bool,
attachments: list[Attachment],
) -> None:
"""
Store transcript in the local filesystem.

Constructs a sanitized filesystem path for storing transcripts
based on the user ID and conversation ID.

Returns:
Path: The constructed path for storing transcripts for the specified user and conversation.

Args:
user_id: The user ID (UUID).
conversation_id: The conversation ID (UUID).
query_is_valid: The result of the query validation.
query: The query (without attachments).
query_request: The request containing a query.
response: The response to store.
rag_chunks: The list of `RagChunk` objects.
truncated: The flag indicating if the history was truncated.
attachments: The list of `Attachment` objects.
"""
transcripts_path = construct_transcripts_path(user_id, conversation_id)
transcripts_path.mkdir(parents=True, exist_ok=True)

data_to_store = {
"metadata": {
"provider": provider_id,
"model": model_id,
"query_provider": query_request.provider,
"query_model": query_request.model,
"user_id": user_id,
"conversation_id": conversation_id,
"timestamp": datetime.now(UTC).isoformat(),
},
"redacted_query": query,
"query_is_valid": query_is_valid,
"llm_response": response,
"rag_chunks": rag_chunks,
"truncated": truncated,
"attachments": [attachment.model_dump() for attachment in attachments],
}

# stores feedback in a file under unique uuid
transcript_file_path = transcripts_path / f"{get_suid()}.json"
with open(transcript_file_path, "w", encoding="utf-8") as transcript_file:
json.dump(data_to_store, transcript_file)

logger.info("Transcript successfully stored at: %s", transcript_file_path)


def get_rag_toolgroups(
vector_db_ids: list[str],
) -> list[Toolgroup] | None:
Expand Down
37 changes: 24 additions & 13 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
import json
import re
import logging
from typing import Annotated, Any, AsyncIterator, Iterator
from typing import Annotated, Any, AsyncIterator, Iterator, cast

from llama_stack_client import APIConnectionError
from llama_stack_client import AsyncLlamaStackClient # type: ignore
from llama_stack_client.types import UserMessage # type: ignore

from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
from llama_stack_client.types.agents.agent_turn_response_stream_chunk import (
AgentTurnResponseStreamChunk,
)
from llama_stack_client.types.shared import ToolCall
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem

Expand All @@ -26,13 +29,14 @@
from models.database.conversations import UserConversation
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.transcripts import store_transcript
from utils.types import TurnSummary

from app.endpoints.query import (
get_rag_toolgroups,
is_input_shield,
is_output_shield,
is_transcripts_enabled,
store_transcript,
select_model_and_provider_id,
validate_attachments_metadata,
validate_conversation_ownership,
Expand Down Expand Up @@ -574,7 +578,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
)
metadata_map: dict[str, dict[str, Any]] = {}

async def response_generator(turn_response: Any) -> AsyncIterator[str]:
async def response_generator(
turn_response: AsyncIterator[AgentTurnResponseStreamChunk],
) -> AsyncIterator[str]:
"""
Generate SSE formatted streaming response.

Expand All @@ -587,20 +593,24 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
complete response for transcript storage if enabled.
"""
chunk_id = 0
complete_response = "No response from the model"
summary = TurnSummary(
llm_response="No response from the model", tool_calls=[]
)

# Send start event
yield stream_start_event(conversation_id)

async for chunk in turn_response:
p = chunk.event.payload
if p.event_type == "turn_complete":
summary.llm_response = interleaved_content_as_str(
p.turn.output_message.content
)
elif p.event_type == "step_complete":
if p.step_details.step_type == "tool_execution":
summary.append_tool_calls_from_llama(p.step_details)

for event in stream_build_event(chunk, chunk_id, metadata_map):
if (
json.loads(event.replace("data: ", ""))["event"]
== "turn_complete"
):
complete_response = json.loads(event.replace("data: ", ""))[
"data"
]["token"]
chunk_id += 1
yield event

Expand All @@ -617,7 +627,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
query=query_request.query,
query_request=query_request,
response=complete_response,
summary=summary,
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
truncated=False, # TODO(lucasagomes): implement truncation as part
# of quota work
Expand Down Expand Up @@ -655,7 +665,7 @@ async def retrieve_response(
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[Any, str]:
) -> tuple[AsyncIterator[AgentTurnResponseStreamChunk], str]:
"""
Retrieve response from LLMs and agents.

Expand Down Expand Up @@ -758,5 +768,6 @@ async def retrieve_response(
stream=True,
toolgroups=toolgroups,
)
response = cast(AsyncIterator[AgentTurnResponseStreamChunk], response)

return response, conversation_id
Loading