Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 90 additions & 6 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Handler for REST API call to provide answer to query."""

import ast
from datetime import datetime, UTC
import json
import logging
import os
from pathlib import Path
from typing import Annotated, Any
import re
from typing import Annotated, Any, cast

from llama_stack_client import APIConnectionError
from llama_stack_client import AsyncLlamaStackClient # type: ignore
Expand Down Expand Up @@ -41,10 +43,79 @@
router = APIRouter(tags=["query"])
auth_dependency = get_auth_dependency()

METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n")


def _process_knowledge_search_content(
tool_response: Any, metadata_map: dict[str, dict[str, Any]]
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metadata_map seems to be the return value, not a real parameter. Please refactor to return new metadata_map

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed with 06cba91

"""Process knowledge search tool response content for metadata."""
for text_content_item in tool_response.content:
if not hasattr(text_content_item, "text"):
continue

for match in METADATA_PATTERN.findall(text_content_item.text):
try:
meta = ast.literal_eval(match)
if "document_id" in meta:
metadata_map[meta["document_id"]] = meta
except Exception: # pylint: disable=broad-except
logger.debug(
"An exception was thrown in processing %s",
match,
)


def extract_referenced_documents_from_steps(steps: list) -> list[dict[str, str]]:
"""Extract referenced documents from tool execution steps.

Args:
steps: List of response steps from the agent

Returns:
List of referenced documents with doc_url and doc_title
"""
metadata_map: dict[str, dict[str, Any]] = {}

for step in steps:
if step.step_type != "tool_execution" or not hasattr(step, "tool_responses"):
continue

for tool_response in step.tool_responses:
if (
tool_response.tool_name != "knowledge_search"
or not tool_response.content
):
continue

_process_knowledge_search_content(tool_response, metadata_map)

# Extract referenced documents from metadata
return [
{
"doc_url": v["docs_url"],
"doc_title": v["title"],
}
for v in filter(
lambda v: ("docs_url" in v) and ("title" in v),
metadata_map.values(),
)
]


query_response: dict[int | str, dict[str, Any]] = {
200: {
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "LLM answer",
"referenced_documents": [
{
"doc_url": (
"https://docs.openshift.com/container-platform/"
"4.15/operators/olm/index.html"
),
"doc_title": "Operator Lifecycle Manager (OLM)",
}
],
},
400: {
"description": "Missing or invalid credentials provided by client",
Expand Down Expand Up @@ -189,7 +260,7 @@ async def query_endpoint_handler(
user_conversation=user_conversation, query_request=query_request
),
)
response, conversation_id = await retrieve_response(
response, conversation_id, referenced_documents = await retrieve_response(
client,
llama_stack_model_id,
query_request,
Expand Down Expand Up @@ -223,7 +294,11 @@ async def query_endpoint_handler(
provider_id=provider_id,
)

return QueryResponse(conversation_id=conversation_id, response=response)
return QueryResponse(
conversation_id=conversation_id,
response=response,
referenced_documents=referenced_documents,
)

# connection to Llama Stack server
except APIConnectionError as e:
Expand Down Expand Up @@ -322,7 +397,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[str, str]:
) -> tuple[str, str, list[dict[str, str]]]:
"""Retrieve response from LLMs and agents."""
available_input_shields = [
shield.identifier
Expand Down Expand Up @@ -402,15 +477,24 @@ async def retrieve_response( # pylint: disable=too-many-locals
toolgroups=toolgroups,
)

# Check for validation errors in the response
# Check for validation errors and extract referenced documents
steps = getattr(response, "steps", [])
for step in steps:
if step.step_type == "shield_call" and step.violation:
# Metric for LLM validation errors
metrics.llm_calls_validation_errors_total.inc()
break

return str(response.output_message.content), conversation_id # type: ignore[union-attr]
# Extract referenced documents from tool execution steps
referenced_documents = extract_referenced_documents_from_steps(steps)

# When stream=False, response should have output_message attribute
response_obj = cast(Any, response)
return (
str(response_obj.output_message.content),
conversation_id,
referenced_documents,
)


def validate_attachments_metadata(attachments: list[Attachment]) -> None:
Expand Down
29 changes: 27 additions & 2 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class ModelsResponse(BaseModel):

# TODO(lucasagomes): a lot of fields to add to QueryResponse. For now
# we are keeping it simple. The missing fields are:
# - referenced_documents: The optional URLs and titles for the documents used
# to generate the response.
# - truncated: Set to True if conversation history was truncated to be within context window.
# - input_tokens: Number of tokens sent to LLM
# - output_tokens: Number of tokens received from LLM
Expand All @@ -51,6 +49,8 @@ class QueryResponse(BaseModel):
Attributes:
conversation_id: The optional conversation ID (UUID).
response: The response.
referenced_documents: The optional URLs and titles for the documents used
to generate the response.
"""

conversation_id: Optional[str] = Field(
Expand All @@ -66,13 +66,38 @@ class QueryResponse(BaseModel):
],
)

referenced_documents: list[dict[str, str]] = Field(
default_factory=list,
description="List of documents referenced in generating the response",
examples=[
[
{
"doc_url": (
"https://docs.openshift.com/container-platform/"
"4.15/operators/olm/index.html"
),
"doc_title": "Operator Lifecycle Manager (OLM)",
}
]
],
)

# provides examples for /docs endpoint
model_config = {
"json_schema_extra": {
"examples": [
{
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "Operator Lifecycle Manager (OLM) helps users install...",
"referenced_documents": [
{
"doc_url": (
"https://docs.openshift.com/container-platform/"
"4.15/operators/olm/index.html"
),
"doc_title": "Operator Lifecycle Manager (OLM)",
}
],
}
]
}
Expand Down
Loading