|
7 | 7 | from pathlib import Path |
8 | 8 | from typing import Annotated, Any, cast |
9 | 9 |
|
10 | | -import pydantic |
11 | 10 |
|
12 | 11 | from llama_stack_client import APIConnectionError |
13 | 12 | from llama_stack_client import AsyncLlamaStackClient # type: ignore |
|
43 | 42 | ) |
44 | 43 | from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups |
45 | 44 | from utils.suid import get_suid |
46 | | -from utils.metadata import parse_knowledge_search_metadata |
| 45 | +from utils.metadata import ( |
| 46 | + extract_referenced_documents_from_steps, |
| 47 | +) |
47 | 48 |
|
48 | 49 | logger = logging.getLogger("app.endpoints.handlers") |
49 | 50 | router = APIRouter(tags=["query"]) |
50 | 51 | auth_dependency = get_auth_dependency() |
51 | 52 |
|
52 | 53 |
|
53 | | -def _process_knowledge_search_content(tool_response: Any) -> dict[str, dict[str, Any]]: |
54 | | - """Process knowledge search tool response content for metadata. |
55 | | -
|
56 | | - Args: |
57 | | - tool_response: Tool response object containing content to parse |
58 | | -
|
59 | | - Returns: |
60 | | - Dictionary mapping document_id to metadata dict |
61 | | - """ |
62 | | - metadata_map: dict[str, dict[str, Any]] = {} |
63 | | - |
64 | | - # Guard against missing tool_response or content |
65 | | - if not tool_response: |
66 | | - return metadata_map |
67 | | - |
68 | | - content = getattr(tool_response, "content", None) |
69 | | - if not content: |
70 | | - return metadata_map |
71 | | - |
72 | | - # Ensure content is iterable |
73 | | - try: |
74 | | - iter(content) |
75 | | - except TypeError: |
76 | | - return metadata_map |
77 | | - |
78 | | - for text_content_item in content: |
79 | | - # Skip items that lack a non-empty "text" attribute |
80 | | - text = getattr(text_content_item, "text", None) |
81 | | - if not text: |
82 | | - continue |
83 | | - |
84 | | - try: |
85 | | - parsed_metadata = parse_knowledge_search_metadata(text) |
86 | | - metadata_map.update(parsed_metadata) |
87 | | - except ValueError: |
88 | | - logger.exception( |
89 | | - "An exception was thrown in processing metadata from text: %s", |
90 | | - text[:200] + "..." if len(text) > 200 else text, |
91 | | - ) |
92 | | - |
93 | | - return metadata_map |
94 | | - |
95 | | - |
96 | | -def extract_referenced_documents_from_steps( |
97 | | - steps: list[Any], |
98 | | -) -> list[ReferencedDocument]: |
99 | | - """Extract referenced documents from tool execution steps. |
100 | | -
|
101 | | - Args: |
102 | | - steps: List of response steps from the agent |
103 | | -
|
104 | | - Returns: |
105 | | - List of referenced documents with doc_url and doc_title |
106 | | - """ |
107 | | - metadata_map: dict[str, dict[str, Any]] = {} |
108 | | - |
109 | | - for step in steps: |
110 | | - if getattr(step, "step_type", "") != "tool_execution" or not hasattr( |
111 | | - step, "tool_responses" |
112 | | - ): |
113 | | - continue |
114 | | - |
115 | | - for tool_response in getattr(step, "tool_responses", []) or []: |
116 | | - if getattr( |
117 | | - tool_response, "tool_name", "" |
118 | | - ) != "knowledge_search" or not getattr(tool_response, "content", []): |
119 | | - continue |
120 | | - |
121 | | - response_metadata = _process_knowledge_search_content(tool_response) |
122 | | - metadata_map.update(response_metadata) |
123 | | - |
124 | | - # Extract referenced documents from metadata with error handling |
125 | | - referenced_documents = [] |
126 | | - for v in metadata_map.values(): |
127 | | - if "docs_url" in v and "title" in v: |
128 | | - try: |
129 | | - doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"]) |
130 | | - referenced_documents.append(doc) |
131 | | - except (pydantic.ValidationError, ValueError) as e: |
132 | | - logger.warning( |
133 | | - "Skipping invalid referenced document with docs_url='%s', title='%s': %s", |
134 | | - v.get("docs_url", "<missing>"), |
135 | | - v.get("title", "<missing>"), |
136 | | - str(e), |
137 | | - ) |
138 | | - continue |
139 | | - |
140 | | - return referenced_documents |
141 | | - |
142 | | - |
143 | 54 | query_response: dict[int | str, dict[str, Any]] = { |
144 | 55 | 200: { |
145 | 56 | "conversation_id": "123e4567-e89b-12d3-a456-426614174000", |
@@ -516,8 +427,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche |
516 | 427 | mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing. |
517 | 428 |
|
518 | 429 | Returns: |
519 | | - tuple[str, str]: A tuple containing the LLM or agent's response content |
520 | | - and the conversation ID. |
| 430 | + tuple[str, str, list[ReferencedDocument]]: A tuple containing the response |
| 431 | + content, the conversation ID, and the list of referenced documents parsed |
| 432 | + from tool execution steps. |
521 | 433 | """ |
522 | 434 | available_input_shields = [ |
523 | 435 | shield.identifier |
@@ -615,12 +527,12 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche |
615 | 527 | # Safely guard access to output_message and content |
616 | 528 | output_message = getattr(response_obj, "output_message", None) |
617 | 529 | if output_message and getattr(output_message, "content", None) is not None: |
618 | | - content_str = str(output_message.content) |
| 530 | + response_text = str(output_message.content) |
619 | 531 | else: |
620 | | - content_str = "" |
| 532 | + response_text = "" |
621 | 533 |
|
622 | 534 | return ( |
623 | | - content_str, |
| 535 | + response_text, |
624 | 536 | conversation_id, |
625 | 537 | referenced_documents, |
626 | 538 | ) |
|
0 commit comments