Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 55 additions & 13 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import logging
import os
from pathlib import Path
from typing import Annotated, Any
from typing import Annotated, Any, cast


from llama_stack_client import APIConnectionError
from llama_stack_client import AsyncLlamaStackClient # type: ignore
Expand All @@ -25,7 +26,12 @@
from app.database import get_session
import metrics
from models.database.conversations import UserConversation
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
from models.responses import (
QueryResponse,
UnauthorizedResponse,
ForbiddenResponse,
ReferencedDocument,
)
from models.requests import QueryRequest, Attachment
import constants
from utils.endpoints import (
Expand All @@ -36,15 +42,28 @@
)
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.suid import get_suid
from utils.metadata import (
extract_referenced_documents_from_steps,
)

logger = logging.getLogger("app.endpoints.handlers")
router = APIRouter(tags=["query"])
auth_dependency = get_auth_dependency()


query_response: dict[int | str, dict[str, Any]] = {
200: {
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "LLM answer",
"referenced_documents": [
{
"doc_url": (
"https://docs.openshift.com/container-platform/"
"4.15/operators/olm/index.html"
),
"doc_title": "Operator Lifecycle Manager (OLM)",
}
],
},
400: {
"description": "Missing or invalid credentials provided by client",
Expand All @@ -54,7 +73,7 @@
"description": "User is not authorized",
"model": ForbiddenResponse,
},
503: {
500: {
"detail": {
"response": "Unable to connect to Llama Stack",
"cause": "Connection error.",
Expand Down Expand Up @@ -203,7 +222,7 @@ async def query_endpoint_handler(
user_conversation=user_conversation, query_request=query_request
),
)
response, conversation_id = await retrieve_response(
response, conversation_id, referenced_documents = await retrieve_response(
client,
llama_stack_model_id,
query_request,
Expand Down Expand Up @@ -237,7 +256,11 @@ async def query_endpoint_handler(
provider_id=provider_id,
)

return QueryResponse(conversation_id=conversation_id, response=response)
return QueryResponse(
conversation_id=conversation_id,
response=response,
referenced_documents=referenced_documents,
)

# connection to Llama Stack server
except APIConnectionError as e:
Expand Down Expand Up @@ -375,13 +398,13 @@ def is_input_shield(shield: Shield) -> bool:
return _is_inout_shield(shield) or not is_output_shield(shield)


async def retrieve_response( # pylint: disable=too-many-locals
async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches
client: AsyncLlamaStackClient,
model_id: str,
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[str, str]:
) -> tuple[str, str, list[ReferencedDocument]]:
"""
Retrieve response from LLMs and agents.

Expand All @@ -404,8 +427,9 @@ async def retrieve_response( # pylint: disable=too-many-locals
mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing.

Returns:
tuple[str, str]: A tuple containing the LLM or agent's response content
and the conversation ID.
tuple[str, str, list[ReferencedDocument]]: A tuple containing the response
content, the conversation ID, and the list of referenced documents parsed
from tool execution steps.
"""
available_input_shields = [
shield.identifier
Expand Down Expand Up @@ -485,15 +509,33 @@ async def retrieve_response( # pylint: disable=too-many-locals
toolgroups=toolgroups,
)

# Check for validation errors in the response
# Check for validation errors and extract referenced documents
steps = getattr(response, "steps", [])
for step in steps:
if step.step_type == "shield_call" and step.violation:
if getattr(step, "step_type", "") == "shield_call" and getattr(
step, "violation", False
):
# Metric for LLM validation errors
metrics.llm_calls_validation_errors_total.inc()
break

return str(response.output_message.content), conversation_id # type: ignore[union-attr]
# Extract referenced documents from tool execution steps
referenced_documents = extract_referenced_documents_from_steps(steps)

# When stream=False, response should have output_message attribute
response_obj = cast(Any, response)

# Safely guard access to output_message and content
output_message = getattr(response_obj, "output_message", None)
if output_message and getattr(output_message, "content", None) is not None:
response_text = str(output_message.content)
else:
response_text = ""

return (
response_text,
conversation_id,
referenced_documents,
)


def validate_attachments_metadata(attachments: list[Attachment]) -> None:
Expand Down
65 changes: 40 additions & 25 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Handler for REST API call to provide answer to streaming query."""

import ast
import json
import re
import logging
from typing import Annotated, Any, AsyncIterator, Iterator

import pydantic

from llama_stack_client import APIConnectionError
from llama_stack_client import AsyncLlamaStackClient # type: ignore
from llama_stack_client.types import UserMessage # type: ignore
Expand All @@ -24,8 +24,10 @@
import metrics
from models.requests import QueryRequest
from models.database.conversations import UserConversation
from models.responses import ReferencedDocument
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.metadata import parse_knowledge_search_metadata

from app.endpoints.query import (
get_rag_toolgroups,
Expand All @@ -45,9 +47,6 @@
auth_dependency = get_auth_dependency()


METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n")


def format_stream_data(d: dict) -> str:
"""
Format a dictionary as a Server-Sent Events (SSE) data string.
Expand Down Expand Up @@ -102,20 +101,36 @@ def stream_end_event(metadata_map: dict) -> str:
str: A Server-Sent Events (SSE) formatted string
representing the end of the data stream.
"""
# Create ReferencedDocument objects and convert them to serializable dict format
referenced_documents = []
for v in filter(
lambda v: ("docs_url" in v) and ("title" in v),
metadata_map.values(),
):
try:
doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"])
referenced_documents.append(
{
"doc_url": str(
doc.doc_url
), # Convert AnyUrl to string for JSON serialization
"doc_title": doc.doc_title,
}
)
except (pydantic.ValidationError, ValueError) as e:
logger.warning(
"Skipping invalid referenced document with docs_url='%s', title='%s': %s",
v.get("docs_url", "<missing>"),
v.get("title", "<missing>"),
str(e),
)
continue

return format_stream_data(
{
"event": "end",
"data": {
"referenced_documents": [
{
"doc_url": v["docs_url"],
"doc_title": v["title"],
}
for v in filter(
lambda v: ("docs_url" in v) and ("title" in v),
metadata_map.values(),
)
],
"referenced_documents": referenced_documents,
"truncated": None, # TODO(jboos): implement truncated
"input_tokens": 0, # TODO(jboos): implement input tokens
"output_tokens": 0, # TODO(jboos): implement output tokens
Expand Down Expand Up @@ -435,16 +450,16 @@ def _handle_tool_execution_event(
newline_pos = summary.find("\n")
if newline_pos > 0:
summary = summary[:newline_pos]
for match in METADATA_PATTERN.findall(text_content_item.text):
try:
meta = ast.literal_eval(match)
if "document_id" in meta:
metadata_map[meta["document_id"]] = meta
except Exception: # pylint: disable=broad-except
logger.debug(
"An exception was thrown in processing %s",
match,
)
try:
parsed_metadata = parse_knowledge_search_metadata(
text_content_item.text, strict=False
)
metadata_map.update(parsed_metadata)
except ValueError as e:
logger.exception(
"Error processing metadata from text; position=%s",
getattr(e, "position", "unknown"),
)

yield format_stream_data(
{
Expand Down
40 changes: 37 additions & 3 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, AnyUrl


class ModelsResponse(BaseModel):
Expand Down Expand Up @@ -36,21 +36,30 @@ class ModelsResponse(BaseModel):

# TODO(lucasagomes): a lot of fields to add to QueryResponse. For now
# we are keeping it simple. The missing fields are:
# - referenced_documents: The optional URLs and titles for the documents used
# to generate the response.
# - truncated: Set to True if conversation history was truncated to be within context window.
# - input_tokens: Number of tokens sent to LLM
# - output_tokens: Number of tokens received from LLM
# - available_quotas: Quota available as measured by all configured quota limiters
# - tool_calls: List of tool requests.
# - tool_results: List of tool results.
# See LLMResponse in ols-service for more details.


class ReferencedDocument(BaseModel):
"""Model representing a document referenced in generating a response."""

doc_url: AnyUrl = Field(description="URL of the referenced document")
doc_title: str = Field(description="Title of the referenced document")


class QueryResponse(BaseModel):
"""Model representing LLM response to a query.

Attributes:
conversation_id: The optional conversation ID (UUID).
response: The response.
referenced_documents: The optional URLs and titles for the documents used
to generate the response.
"""

conversation_id: Optional[str] = Field(
Expand All @@ -66,13 +75,38 @@ class QueryResponse(BaseModel):
],
)

referenced_documents: list[ReferencedDocument] = Field(
default_factory=list,
description="List of documents referenced in generating the response",
examples=[
[
{
"doc_url": (
"https://docs.openshift.com/container-platform/"
"4.15/operators/olm/index.html"
),
"doc_title": "Operator Lifecycle Manager (OLM)",
}
]
],
)

# provides examples for /docs endpoint
model_config = {
"json_schema_extra": {
"examples": [
{
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "Operator Lifecycle Manager (OLM) helps users install...",
"referenced_documents": [
{
"doc_url": (
"https://docs.openshift.com/container-platform/"
"4.15/operators/olm/index.html"
),
"doc_title": "Operator Lifecycle Manager (OLM)",
}
],
}
]
}
Expand Down
Loading