Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 122 additions & 11 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Handler for REST API call to provide answer to query."""

import ast
from datetime import datetime, UTC
import json
import logging
import os
from pathlib import Path
from typing import Annotated, Any
import re
from typing import Annotated, Any, cast

from llama_stack_client import APIConnectionError
from llama_stack_client import AsyncLlamaStackClient # type: ignore
Expand All @@ -25,7 +27,12 @@
from app.database import get_session
import metrics
from models.database.conversations import UserConversation
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
from models.responses import (
QueryResponse,
UnauthorizedResponse,
ForbiddenResponse,
ReferencedDocument,
)
from models.requests import QueryRequest, Attachment
import constants
from utils.endpoints import (
Expand All @@ -41,10 +48,92 @@
router = APIRouter(tags=["query"])
auth_dependency = get_auth_dependency()

METADATA_PATTERN = re.compile(r"^\s*Metadata:\s*(\{.*?\})\s*$", re.MULTILINE)


def _process_knowledge_search_content(
tool_response: Any, metadata_map: dict[str, dict[str, Any]]
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metadata_map seems to be the return value, not a real parameter. Please refactor to return new metadata_map

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed with 06cba91

"""Process knowledge search tool response content for metadata."""
# Guard against missing tool_response or content
if not tool_response:
return

content = getattr(tool_response, "content", None)
if not content:
return

# Ensure content is iterable
try:
iter(content)
except TypeError:
return

for text_content_item in content:
# Skip items that lack a non-empty "text" attribute
text = getattr(text_content_item, "text", None)
if not text:
continue

for match in METADATA_PATTERN.findall(text):
try:
meta = ast.literal_eval(match)
# Verify the result is a dict before accessing keys
if isinstance(meta, dict) and "document_id" in meta:
metadata_map[meta["document_id"]] = meta
except (SyntaxError, ValueError): # only expected from literal_eval
logger.exception(
"An exception was thrown in processing %s",
match,
)


def extract_referenced_documents_from_steps(steps: list) -> list[ReferencedDocument]:
"""Extract referenced documents from tool execution steps.

Args:
steps: List of response steps from the agent

Returns:
List of referenced documents with doc_url and doc_title
"""
metadata_map: dict[str, dict[str, Any]] = {}

for step in steps:
if getattr(step, "step_type", "") != "tool_execution" or not hasattr(
step, "tool_responses"
):
continue

for tool_response in getattr(step, "tool_responses", []) or []:
if getattr(
tool_response, "tool_name", ""
) != "knowledge_search" or not getattr(tool_response, "content", []):
continue

_process_knowledge_search_content(tool_response, metadata_map)

# Extract referenced documents from metadata
return [
ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"])
for v in metadata_map.values()
if "docs_url" in v and "title" in v
]


query_response: dict[int | str, dict[str, Any]] = {
200: {
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "LLM answer",
"referenced_documents": [
{
"doc_url": (
"https://docs.openshift.com/container-platform/"
"4.15/operators/olm/index.html"
),
"doc_title": "Operator Lifecycle Manager (OLM)",
}
],
},
400: {
"description": "Missing or invalid credentials provided by client",
Expand All @@ -54,7 +143,7 @@
"description": "User is not authorized",
"model": ForbiddenResponse,
},
503: {
500: {
"detail": {
"response": "Unable to connect to Llama Stack",
"cause": "Connection error.",
Expand Down Expand Up @@ -189,7 +278,7 @@ async def query_endpoint_handler(
user_conversation=user_conversation, query_request=query_request
),
)
response, conversation_id = await retrieve_response(
response, conversation_id, referenced_documents = await retrieve_response(
client,
llama_stack_model_id,
query_request,
Expand Down Expand Up @@ -223,7 +312,11 @@ async def query_endpoint_handler(
provider_id=provider_id,
)

return QueryResponse(conversation_id=conversation_id, response=response)
return QueryResponse(
conversation_id=conversation_id,
response=response,
referenced_documents=referenced_documents,
)

# connection to Llama Stack server
except APIConnectionError as e:
Expand Down Expand Up @@ -316,13 +409,13 @@ def is_input_shield(shield: Shield) -> bool:
return _is_inout_shield(shield) or not is_output_shield(shield)


async def retrieve_response( # pylint: disable=too-many-locals
async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches
client: AsyncLlamaStackClient,
model_id: str,
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[str, str]:
) -> tuple[str, str, list[ReferencedDocument]]:
"""Retrieve response from LLMs and agents."""
available_input_shields = [
shield.identifier
Expand Down Expand Up @@ -402,15 +495,33 @@ async def retrieve_response( # pylint: disable=too-many-locals
toolgroups=toolgroups,
)

# Check for validation errors in the response
# Check for validation errors and extract referenced documents
steps = getattr(response, "steps", [])
for step in steps:
if step.step_type == "shield_call" and step.violation:
if getattr(step, "step_type", "") == "shield_call" and getattr(
step, "violation", False
):
# Metric for LLM validation errors
metrics.llm_calls_validation_errors_total.inc()
break

return str(response.output_message.content), conversation_id # type: ignore[union-attr]
# Extract referenced documents from tool execution steps
referenced_documents = extract_referenced_documents_from_steps(steps)

# When stream=False, response should have output_message attribute
response_obj = cast(Any, response)

# Safely guard access to output_message and content
output_message = getattr(response_obj, "output_message", None)
if output_message and getattr(output_message, "content", None) is not None:
content_str = str(output_message.content)
else:
content_str = ""

return (
content_str,
conversation_id,
referenced_documents,
)


def validate_attachments_metadata(attachments: list[Attachment]) -> None:
Expand Down
38 changes: 25 additions & 13 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import metrics
from models.requests import QueryRequest
from models.database.conversations import UserConversation
from models.responses import ReferencedDocument
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups

Expand Down Expand Up @@ -72,20 +73,27 @@ def stream_start_event(conversation_id: str) -> str:

def stream_end_event(metadata_map: dict) -> str:
"""Yield the end of the data stream."""
# Create ReferencedDocument objects and convert them to serializable dict format
referenced_documents = []
for v in filter(
lambda v: ("docs_url" in v) and ("title" in v),
metadata_map.values(),
):
doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"])
referenced_documents.append(
{
"doc_url": str(
doc.doc_url
), # Convert AnyUrl to string for JSON serialization
"doc_title": doc.doc_title,
}
)

return format_stream_data(
{
"event": "end",
"data": {
"referenced_documents": [
{
"doc_url": v["docs_url"],
"doc_title": v["title"],
}
for v in filter(
lambda v: ("docs_url" in v) and ("title" in v),
metadata_map.values(),
)
],
"referenced_documents": referenced_documents,
"truncated": None, # TODO(jboos): implement truncated
"input_tokens": 0, # TODO(jboos): implement input tokens
"output_tokens": 0, # TODO(jboos): implement output tokens
Expand Down Expand Up @@ -330,10 +338,14 @@ def _handle_tool_execution_event(
for match in METADATA_PATTERN.findall(text_content_item.text):
try:
meta = ast.literal_eval(match)
if "document_id" in meta:
# Verify the result is a dict before accessing keys
if isinstance(meta, dict) and "document_id" in meta:
metadata_map[meta["document_id"]] = meta
except Exception: # pylint: disable=broad-except
logger.debug(
except (
SyntaxError,
ValueError,
): # only expected from literal_eval
logger.exception(
"An exception was thrown in processing %s",
match,
)
Expand Down
40 changes: 37 additions & 3 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, AnyUrl


class ModelsResponse(BaseModel):
Expand Down Expand Up @@ -36,21 +36,30 @@ class ModelsResponse(BaseModel):

# TODO(lucasagomes): a lot of fields to add to QueryResponse. For now
# we are keeping it simple. The missing fields are:
# - referenced_documents: The optional URLs and titles for the documents used
# to generate the response.
# - truncated: Set to True if conversation history was truncated to be within context window.
# - input_tokens: Number of tokens sent to LLM
# - output_tokens: Number of tokens received from LLM
# - available_quotas: Quota available as measured by all configured quota limiters
# - tool_calls: List of tool requests.
# - tool_results: List of tool results.
# See LLMResponse in ols-service for more details.


class ReferencedDocument(BaseModel):
"""Model representing a document referenced in generating a response."""

doc_url: AnyUrl = Field(description="URL of the referenced document")
doc_title: str = Field(description="Title of the referenced document")


class QueryResponse(BaseModel):
"""Model representing LLM response to a query.

Attributes:
conversation_id: The optional conversation ID (UUID).
response: The response.
referenced_documents: The optional URLs and titles for the documents used
to generate the response.
"""

conversation_id: Optional[str] = Field(
Expand All @@ -66,13 +75,38 @@ class QueryResponse(BaseModel):
],
)

referenced_documents: list[ReferencedDocument] = Field(
default_factory=list,
description="List of documents referenced in generating the response",
examples=[
[
{
"doc_url": (
"https://docs.openshift.com/container-platform/"
"4.15/operators/olm/index.html"
),
"doc_title": "Operator Lifecycle Manager (OLM)",
}
]
],
)

# provides examples for /docs endpoint
model_config = {
"json_schema_extra": {
"examples": [
{
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "Operator Lifecycle Manager (OLM) helps users install...",
"referenced_documents": [
{
"doc_url": (
"https://docs.openshift.com/container-platform/"
"4.15/operators/olm/index.html"
),
"doc_title": "Operator Lifecycle Manager (OLM)",
}
],
}
]
}
Expand Down
Loading