Skip to content

Commit 4bb351d

Browse files
committed
Query endpoints compatibility with OLS
1 parent 780bac2 commit 4bb351d

File tree

13 files changed

+254
-114
lines changed

13 files changed

+254
-114
lines changed

src/app/endpoints/query.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,13 @@
88
from typing import Annotated, Any, Optional, cast
99

1010
from fastapi import APIRouter, Depends, HTTPException, Request
11-
from litellm.exceptions import RateLimitError
1211
from llama_stack_client import (
1312
APIConnectionError,
1413
AsyncLlamaStackClient, # type: ignore
1514
)
1615
from llama_stack_client.types import Shield, UserMessage # type: ignore
1716
from llama_stack_client.types.alpha.agents.turn import Turn
1817
from llama_stack_client.types.alpha.agents.turn_create_params import (
19-
Document,
2018
Toolgroup,
2119
ToolgroupAgentToolGroupWithArgs,
2220
)
@@ -42,10 +40,10 @@
4240
InternalServerErrorResponse,
4341
NotFoundResponse,
4442
QueryResponse,
43+
PromptTooLongResponse,
4544
QuotaExceededResponse,
4645
ReferencedDocument,
4746
ServiceUnavailableResponse,
48-
ToolCall,
4947
UnauthorizedResponse,
5048
UnprocessableEntityResponse,
5149
)
@@ -84,6 +82,7 @@
8482
404: NotFoundResponse.openapi_response(
8583
examples=["model", "conversation", "provider"]
8684
),
85+
413: PromptTooLongResponse.openapi_response(),
8786
422: UnprocessableEntityResponse.openapi_response(),
8887
429: QuotaExceededResponse.openapi_response(),
8988
500: InternalServerErrorResponse.openapi_response(examples=["configuration"]),
@@ -379,20 +378,6 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
379378

380379
# Convert tool calls to response format
381380
logger.info("Processing tool calls...")
382-
tool_calls = [
383-
ToolCall(
384-
tool_name=tc.name,
385-
arguments=(
386-
tc.args if isinstance(tc.args, dict) else {"query": str(tc.args)}
387-
),
388-
result=(
389-
{"response": tc.response}
390-
if tc.response and tc.name != constants.DEFAULT_RAG_TOOL
391-
else None
392-
),
393-
)
394-
for tc in summary.tool_calls
395-
]
396381

397382
logger.info("Using referenced documents from response...")
398383

@@ -403,7 +388,8 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
403388
conversation_id=conversation_id,
404389
response=summary.llm_response,
405390
rag_chunks=summary.rag_chunks if summary.rag_chunks else [],
406-
tool_calls=tool_calls if tool_calls else None,
391+
tool_calls=summary.tool_calls if summary.tool_calls else None,
392+
tool_results=summary.tool_results if summary.tool_results else None,
407393
referenced_documents=referenced_documents,
408394
truncated=False, # TODO: implement truncation detection
409395
input_tokens=token_usage.input_tokens,
@@ -427,7 +413,7 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
427413
logger.exception("Error persisting conversation details: %s", e)
428414
response = InternalServerErrorResponse.database_error()
429415
raise HTTPException(**response.model_dump()) from e
430-
except RateLimitError as e:
416+
except Exception as e:
431417
used_model = getattr(e, "model", "")
432418
response = QuotaExceededResponse.model(used_model)
433419
raise HTTPException(**response.model_dump()) from e
@@ -743,14 +729,14 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
743729
toolgroups = None
744730

745731
# TODO: LCORE-881 - Remove if Llama Stack starts to support these mime types
746-
documents: list[Document] = [
747-
(
748-
{"content": doc["content"], "mime_type": "text/plain"}
749-
if doc["mime_type"].lower() in ("application/json", "application/xml")
750-
else doc
751-
)
752-
for doc in query_request.get_documents()
753-
]
732+
# documents: list[Document] = [
733+
# (
734+
# {"content": doc["content"], "mime_type": "text/plain"}
735+
# if doc["mime_type"].lower() in ("application/json", "application/xml")
736+
# else doc
737+
# )
738+
# for doc in query_request.get_documents()
739+
# ]
754740

755741
response = await agent.create_turn(
756742
messages=[UserMessage(role="user", content=query_request.query).model_dump()],
@@ -771,6 +757,8 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
771757
else ""
772758
),
773759
tool_calls=[],
760+
tool_results=[],
761+
rag_chunks=[],
774762
)
775763

776764
referenced_documents = parse_referenced_documents(response)

src/app/endpoints/query_v2.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
# pylint: disable=too-many-locals,too-many-branches,too-many-nested-blocks
2+
13
"""Handler for REST API call to provide answer to query using Response API."""
24

3-
import json
45
import logging
56
from typing import Annotated, Any, cast
67

@@ -24,6 +25,7 @@
2425
from models.requests import QueryRequest
2526
from models.responses import (
2627
ForbiddenResponse,
28+
PromptTooLongResponse,
2729
InternalServerErrorResponse,
2830
NotFoundResponse,
2931
QueryResponse,
@@ -59,6 +61,7 @@
5961
404: NotFoundResponse.openapi_response(
6062
examples=["conversation", "model", "provider"]
6163
),
64+
413: PromptTooLongResponse.openapi_response(),
6265
422: UnprocessableEntityResponse.openapi_response(),
6366
429: QuotaExceededResponse.openapi_response(),
6467
500: InternalServerErrorResponse.openapi_response(examples=["configuration"]),
@@ -96,7 +99,7 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-
9699
id=str(call_id),
97100
name=getattr(output_item, "name", "function_call"),
98101
args=args,
99-
response=None,
102+
type="tool_call",
100103
)
101104

102105
if item_type == "file_search_call":
@@ -105,36 +108,38 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-
105108
"status": getattr(output_item, "status", None),
106109
}
107110
results = getattr(output_item, "results", None)
108-
response_payload: Any | None = None
111+
# response_payload: Any | None = None
109112
if results is not None:
110113
# Store only the essential result metadata to avoid large payloads
111-
response_payload = {
112-
"results": [
113-
{
114-
"file_id": (
115-
getattr(result, "file_id", None)
116-
if not isinstance(result, dict)
117-
else result.get("file_id")
118-
),
119-
"filename": (
120-
getattr(result, "filename", None)
121-
if not isinstance(result, dict)
122-
else result.get("filename")
123-
),
124-
"score": (
125-
getattr(result, "score", None)
126-
if not isinstance(result, dict)
127-
else result.get("score")
128-
),
129-
}
130-
for result in results
131-
]
132-
}
114+
# response_payload = {
115+
# "results": [
116+
# {
117+
# "file_id": (
118+
# getattr(result, "file_id", None)
119+
# if not isinstance(result, dict)
120+
# else result.get("file_id")
121+
# ),
122+
# "filename": (
123+
# getattr(result, "filename", None)
124+
# if not isinstance(result, dict)
125+
# else result.get("filename")
126+
# ),
127+
# "score": (
128+
# getattr(result, "score", None)
129+
# if not isinstance(result, dict)
130+
# else result.get("score")
131+
# ),
132+
# }
133+
# for result in results
134+
# ]
135+
# }
136+
... # Handle response_payload
133137
return ToolCallSummary(
134138
id=str(getattr(output_item, "id")),
135139
name=DEFAULT_RAG_TOOL,
136140
args=args,
137-
response=json.dumps(response_payload) if response_payload else None,
141+
# response=json.dumps(response_payload) if response_payload else None,
142+
type="tool_call",
138143
)
139144

140145
if item_type == "web_search_call":
@@ -143,7 +148,7 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-
143148
id=str(getattr(output_item, "id")),
144149
name="web_search",
145150
args=args,
146-
response=None,
151+
type="tool_call",
147152
)
148153

149154
if item_type == "mcp_call":
@@ -160,7 +165,8 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-
160165
id=str(getattr(output_item, "id")),
161166
name=getattr(output_item, "name", "mcp_call"),
162167
args=args,
163-
response=getattr(output_item, "output", None),
168+
# response=getattr(output_item, "output", None),
169+
type="tool_call",
164170
)
165171

166172
if item_type == "mcp_list_tools":
@@ -178,7 +184,8 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-
178184
id=str(getattr(output_item, "id")),
179185
name="mcp_list_tools",
180186
args=args,
181-
response=None,
187+
# response=None,
188+
type="tool_call",
182189
)
183190

184191
if item_type == "mcp_approval_request":
@@ -191,7 +198,8 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-
191198
id=str(getattr(output_item, "id")),
192199
name=getattr(output_item, "name", "mcp_approval_request"),
193200
args=args,
194-
response=None,
201+
# response=None,
202+
type="tool_call",
195203
)
196204

197205
return None
@@ -400,6 +408,8 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
400408
summary = TurnSummary(
401409
llm_response=llm_response,
402410
tool_calls=tool_calls,
411+
tool_results=[],
412+
rag_chunks=[],
403413
)
404414

405415
# Extract referenced documents and token usage from Responses API response

src/app/endpoints/streaming_query.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from llama_stack_client.types.alpha.agents.agent_turn_response_stream_chunk import (
2121
AgentTurnResponseStreamChunk,
2222
)
23-
from llama_stack_client.types.alpha.agents.turn_create_params import Document
2423
from llama_stack_client.types.shared import ToolCall
2524
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
2625

@@ -51,6 +50,7 @@
5150
from models.responses import (
5251
ForbiddenResponse,
5352
InternalServerErrorResponse,
53+
PromptTooLongResponse,
5454
NotFoundResponse,
5555
QuotaExceededResponse,
5656
ServiceUnavailableResponse,
@@ -86,6 +86,7 @@
8686
404: NotFoundResponse.openapi_response(
8787
examples=["conversation", "model", "provider"]
8888
),
89+
413: PromptTooLongResponse.openapi_response(),
8990
422: UnprocessableEntityResponse.openapi_response(),
9091
429: QuotaExceededResponse.openapi_response(),
9192
500: InternalServerErrorResponse.openapi_response(examples=["configuration"]),
@@ -704,7 +705,10 @@ async def response_generator(
704705
complete response for transcript storage if enabled.
705706
"""
706707
chunk_id = 0
707-
summary = TurnSummary(llm_response="No response from the model", tool_calls=[])
708+
summary = TurnSummary(
709+
llm_response="No response from the model",
710+
tool_calls=[], tool_results=[], rag_chunks=[]
711+
)
708712

709713
# Determine media type for response formatting
710714
media_type = context.query_request.media_type or MEDIA_TYPE_JSON
@@ -1064,14 +1068,14 @@ async def retrieve_response(
10641068
toolgroups = None
10651069

10661070
# TODO: LCORE-881 - Remove if Llama Stack starts to support these mime types
1067-
documents: list[Document] = [
1068-
(
1069-
{"content": doc["content"], "mime_type": "text/plain"}
1070-
if doc["mime_type"].lower() in ("application/json", "application/xml")
1071-
else doc
1072-
)
1073-
for doc in query_request.get_documents()
1074-
]
1071+
# documents: list[Document] = [
1072+
# (
1073+
# {"content": doc["content"], "mime_type": "text/plain"}
1074+
# if doc["mime_type"].lower() in ("application/json", "application/xml")
1075+
# else doc
1076+
# )
1077+
# for doc in query_request.get_documents()
1078+
# ]
10751079

10761080
response = await agent.create_turn(
10771081
messages=[UserMessage(role="user", content=query_request.query).model_dump()],

src/app/endpoints/streaming_query_v2.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
ForbiddenResponse,
3939
InternalServerErrorResponse,
4040
NotFoundResponse,
41+
PromptTooLongResponse,
4142
QuotaExceededResponse,
4243
ServiceUnavailableResponse,
4344
StreamingQueryResponse,
@@ -70,6 +71,7 @@
7071
404: NotFoundResponse.openapi_response(
7172
examples=["conversation", "model", "provider"]
7273
),
74+
413: PromptTooLongResponse.openapi_response(),
7375
422: UnprocessableEntityResponse.openapi_response(),
7476
429: QuotaExceededResponse.openapi_response(),
7577
500: InternalServerErrorResponse.openapi_response(examples=["configuration"]),
@@ -108,7 +110,9 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
108110
complete response for transcript storage if enabled.
109111
"""
110112
chunk_id = 0
111-
summary = TurnSummary(llm_response="", tool_calls=[])
113+
summary = TurnSummary(
114+
llm_response="", tool_calls=[], tool_results=[], rag_chunks=[]
115+
)
112116

113117
# Determine media type for response formatting
114118
media_type = context.query_request.media_type or MEDIA_TYPE_JSON
@@ -216,8 +220,10 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
216220
ToolCallSummary(
217221
id=meta.get("call_id", item_id or "unknown"),
218222
name=meta.get("name", "tool_call"),
219-
args=arguments,
220-
response=None,
223+
args=(
224+
arguments if isinstance(arguments, dict) else {}
225+
), # Handle non-dict arguments
226+
type="tool_call",
221227
)
222228
)
223229

src/app/routers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def include_routers(app: FastAPI) -> None:
4747
app.include_router(conversations_v2.router, prefix="/v2")
4848

4949
# Note: query_v2, streaming_query_v2, and conversations_v3 are now exposed at /v1 above
50-
# The old query, streaming_query, and conversations modules are deprecated but kept for reference
50+
# The old query, streaming_query, and conversations modules are deprecated
5151

5252
# road-core does not version these endpoints
5353
app.include_router(health.router)

0 commit comments

Comments
 (0)