From 7a5a3643ade5770eb4ad64e70bb3feed7eb1d832 Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Wed, 8 Oct 2025 01:35:21 +0200 Subject: [PATCH 01/16] Add a tool to export search results --- orchestrator/api/api_v1/endpoints/search.py | 48 ++++ orchestrator/search/agent/agent.py | 4 +- orchestrator/search/agent/prompts.py | 15 ++ orchestrator/search/agent/state.py | 2 + orchestrator/search/agent/tools.py | 85 ++++++ orchestrator/search/export.py | 271 ++++++++++++++++++++ 6 files changed, 424 insertions(+), 1 deletion(-) create mode 100644 orchestrator/search/export.py diff --git a/orchestrator/api/api_v1/endpoints/search.py b/orchestrator/api/api_v1/endpoints/search.py index f506cacf2..57f23bc57 100644 --- a/orchestrator/api/api_v1/endpoints/search.py +++ b/orchestrator/api/api_v1/endpoints/search.py @@ -39,6 +39,7 @@ ) from orchestrator.search.core.exceptions import InvalidCursorError from orchestrator.search.core.types import EntityType, UIType +from orchestrator.search.export import ExportData from orchestrator.search.filters.definitions import generate_definitions from orchestrator.search.indexing.registry import ENTITY_CONFIG_REGISTRY from orchestrator.search.retrieval import execute_search @@ -57,6 +58,8 @@ ) from orchestrator.search.schemas.results import SearchResult, TypeDefinition from orchestrator.services.subscriptions import format_special_types +from orchestrator.settings import app_settings +from orchestrator.utils.redis_client import create_redis_asyncio_client router = APIRouter() @@ -294,3 +297,48 @@ async def list_paths( async def get_definitions() -> dict[UIType, TypeDefinition]: """Provide a static definition of operators and schemas for each UI type.""" return generate_definitions() + + +@router.get( + "/export/{token}", + response_model=dict[str, Any], +) +async def export_by_token(token: str) -> dict[str, Any]: + """Export search results using a token generated by the search agent. + + The token contains entity IDs from a previous search. This endpoint retrieves + those IDs from Redis, fetches the full entity data from the database, and returns + flattened records suitable for CSV download. + + Args: + token: Export token generated by the prepare_export agent tool + + Returns: + Dictionary containing 'page' with an array of flattened entity records. + Each record contains snake_case field names from the database with nested + relationships flattened (e.g., product_name instead of product.name). + + Raises: + HTTPException: 404 if token not found or expired, 400 if data is invalid + """ + async with create_redis_asyncio_client(app_settings.CACHE_URI) as redis_client: + + # Load export data from Redis + try: + export_data = await ExportData.from_redis(token, redis_client) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) + + # Fetch the actual records + try: + export_records = export_data.fetch_records() + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + + return {"page": export_records} diff --git a/orchestrator/search/agent/agent.py b/orchestrator/search/agent/agent.py index 8112501c4..d8490a5d7 100644 --- a/orchestrator/search/agent/agent.py +++ b/orchestrator/search/agent/agent.py @@ -48,7 +48,9 @@ def build_agent_router(model: str | OpenAIModel, toolsets: list[FunctionToolset[ @router.post("/") async def agent_endpoint(request: Request) -> Response: - return await handle_ag_ui_request(agent, request, deps=StateDeps(SearchState())) + base_url = str(request.base_url).rstrip("/") + initial_state = SearchState(base_url=base_url) + return await handle_ag_ui_request(agent, request, deps=StateDeps(initial_state)) return router except Exception as e: diff --git a/orchestrator/search/agent/prompts.py b/orchestrator/search/agent/prompts.py index 8aa113c42..37bd8e46f 100644 --- a/orchestrator/search/agent/prompts.py +++ b/orchestrator/search/agent/prompts.py @@ -58,6 +58,11 @@ async def get_base_instructions() -> str: c. **Set Filters**: Call `set_filter_tree`. 3. **Execute**: Call `execute_search`. This is done for both filtered and non-filtered searches. 4. **Report**: Answer the users' question directly and summarize when appropiate. + 5. **Export (if requested)**: If the user asks to export, download, or save results as CSV/file: + - **IMPORTANT**: Export is ONLY available for SELECT actions (not COUNT or AGGREGATE) + - Call `prepare_export` to generate an export token + - The UI will automatically display a download button - you don't need to mention URLs or tokens + - Simply confirm to the user that the export is ready --- ### 4. Critical Rules @@ -65,6 +70,7 @@ async def get_base_instructions() -> str: - **NEVER GUESS PATHS IN THE DATABASE**: You *must* verify every filter path by calling `discover_filter_paths` first. If a path does not exist, you may attempt to map the question on an existing paths that are valid and available from `discover_filter_paths`. If you cannot infer a match, inform the user and do not include it in the `FilterTree`. - **USE FULL PATHS**: Always use the full, unambiguous path returned by the discovery tool. - **MATCH OPERATORS**: Only use operators that are compatible with the field type as confirmed by `get_filter_operators`. + - **EXPORT RECOGNITION**: When users say things like "export this", "download as CSV", "save these results", "export to file", or similar phrases, they are requesting an export. Call `prepare_export` to handle this. """ ) @@ -73,12 +79,19 @@ async def get_dynamic_instructions(ctx: RunContext[StateDeps[SearchState]]) -> s """Dynamically provides 'next step' coaching based on the current state.""" state = ctx.deps.state param_state_str = json.dumps(state.parameters, indent=2, default=str) if state.parameters else "Not set." + results_count = len(state.results) if state.results else 0 next_step_guidance = "" if not state.parameters or not state.parameters.get("entity_type"): next_step_guidance = ( "INSTRUCTION: The search context is not set. Your next action is to call `set_search_parameters`." ) + elif results_count > 0: + next_step_guidance = ( + f"INSTRUCTION: Search completed with {results_count} results. " + "You can answer the user's question with these results. " + "If the user requests an export/download, call `prepare_export`." + ) else: next_step_guidance = ( "INSTRUCTION: Context is set. Now, analyze the user's request. " @@ -95,6 +108,8 @@ async def get_dynamic_instructions(ctx: RunContext[StateDeps[SearchState]]) -> s {param_state_str} ``` + **Current Results Count:** {results_count} + **{next_step_guidance}** """ ) diff --git a/orchestrator/search/agent/state.py b/orchestrator/search/agent/state.py index 9a20f155e..1414664fc 100644 --- a/orchestrator/search/agent/state.py +++ b/orchestrator/search/agent/state.py @@ -19,3 +19,5 @@ class SearchState(BaseModel): parameters: dict[str, Any] | None = None results: list[dict[str, Any]] = Field(default_factory=list) + export_data: dict[str, Any] | None = None + base_url: str | None = None diff --git a/orchestrator/search/agent/tools.py b/orchestrator/search/agent/tools.py index 73471b57c..f81b74633 100644 --- a/orchestrator/search/agent/tools.py +++ b/orchestrator/search/agent/tools.py @@ -13,6 +13,7 @@ from collections.abc import Awaitable, Callable from typing import Any, TypeVar +from uuid import uuid4 import structlog from ag_ui.core import EventType, StateSnapshotEvent @@ -32,10 +33,13 @@ ) from orchestrator.schemas.search import SearchResultsSchema from orchestrator.search.core.types import ActionType, EntityType, FilterOp +from orchestrator.search.export import ExportData from orchestrator.search.filters import FilterTree from orchestrator.search.retrieval.exceptions import FilterValidationError, PathNotFoundError from orchestrator.search.retrieval.validation import validate_filter_tree from orchestrator.search.schemas.parameters import PARAMETER_REGISTRY, BaseSearchParameters +from orchestrator.settings import app_settings +from orchestrator.utils.redis_client import create_redis_asyncio_client from .state import SearchState @@ -256,3 +260,84 @@ async def get_valid_operators() -> dict[str, list[FilterOp]]: if hasattr(type_def, "operators"): operator_map[key] = type_def.operators return operator_map + + +@search_toolset.tool +async def prepare_export( + ctx: RunContext[StateDeps[SearchState]], + max_results: int = 1000, +) -> StateSnapshotEvent: + """Executes the search with the current parameters, collects up to max_results entity IDs, stores them in Redis with a temporary token, and returns the token for export.""" + if not ctx.deps.state.parameters: + raise ValueError("No search parameters set. Run a search first to see what will be exported.") + + # Validate that export is only available for SELECT actions + action = ctx.deps.state.parameters.get("action", ActionType.SELECT) + if action != ActionType.SELECT: + raise ValueError( + f"Export is only available for SELECT actions. Current action is '{action}'. " + "Please run a SELECT search first." + ) + + entity_type = EntityType(ctx.deps.state.parameters["entity_type"]) + param_class = PARAMETER_REGISTRY.get(entity_type) + if not param_class: + raise ValueError(f"Unknown entity type: {entity_type}") + + # Cap at 1000 results () + export_limit = min(max_results, 1000) + + params = param_class(**ctx.deps.state.parameters) + params.limit = export_limit + + logger.debug( + "Preparing export", + entity_type=entity_type.value, + limit=export_limit, + has_filters=params.filters is not None, + ) + + fn = SEARCH_FN_MAP[entity_type] + search_results = await fn(params) + + if not search_results.data: + raise ValueError("No results found to export. Try adjusting your search criteria.") + + entity_ids = [] + for result in search_results.data: + if entity_type == EntityType.SUBSCRIPTION: + entity_ids.append(str(result.subscription["subscription_id"])) + elif entity_type == EntityType.WORKFLOW: + entity_ids.append(str(result.workflow.name)) + elif entity_type == EntityType.PRODUCT: + entity_ids.append(str(result.product.product_id)) + elif entity_type == EntityType.PROCESS: + entity_ids.append(str(result.process.process_id)) + + # Generate export token and create export data model + export_token = str(uuid4()) + export_data = ExportData( + entity_type=entity_type, + entity_ids=entity_ids, + token=export_token, + ) + + # Store in Redis with TTL + async with create_redis_asyncio_client(app_settings.CACHE_URI) as redis_client: + await export_data.save_to_redis(redis_client, ttl=3000) + + download_url = f"{ctx.deps.state.base_url}/api/search/export/{export_token}" + + # Update state with export data so frontend can render the download button + ctx.deps.state.export_data = { + "action": "export", + "token": export_token, + "count": len(entity_ids), + "entity_type": entity_type.value, + "download_url": download_url, + "message": f"Found {len(entity_ids)} {entity_type.value.lower()}(s).", + } + + logger.debug("Export data set in state", export_data=ctx.deps.state.export_data) + + return StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=ctx.deps.state.model_dump()) diff --git a/orchestrator/search/export.py b/orchestrator/search/export.py new file mode 100644 index 000000000..50beaba07 --- /dev/null +++ b/orchestrator/search/export.py @@ -0,0 +1,271 @@ +# Copyright 2019-2025 SURF, GÉANT. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +from pydantic import BaseModel, Field +from pydantic.config import ConfigDict +from sqlalchemy import select +from sqlalchemy.orm import selectinload + +from orchestrator.db import ( + ProcessTable, + ProductTable, + SubscriptionTable, + WorkflowTable, + db, +) +from orchestrator.search.core.types import EntityType +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from redis.asyncio.client import Redis + +# Redis namespace for export data +EXPORT_REDIS_NAMESPACE = "orchestrator:export" + + +class ExportData(BaseModel): + """Model for export data stored in Redis. + + Attributes: + entity_type: The type of entities being exported + entity_ids: List of entity IDs/names to export + token: Unique export token for this data + """ + + model_config = ConfigDict(use_enum_values=True) + + entity_type: EntityType + entity_ids: list[str] + token: str | None = Field(default=None, exclude=True) + + @property + def redis_key(self) -> str: + """Redis key for storing this export data. + + Returns: + Redis key string + """ + return f"{EXPORT_REDIS_NAMESPACE}:{self.token}" + + @classmethod + async def from_redis(cls, token: str, redis_client: "Redis") -> "ExportData": + """Load export data from Redis using a token. + + Args: + token: Export token + redis_client: Redis client instance + + Returns: + ExportData instance + + Raises: + ValueError: If token not found in Redis + """ + redis_key = f"{EXPORT_REDIS_NAMESPACE}:{token}" + export_data_json = await redis_client.get(redis_key) + + if not export_data_json: + raise ValueError(f"Export token '{token}' not found or expired") + + obj = cls.model_validate_json(export_data_json) + return obj.model_copy(update={"token": token}) + + async def save_to_redis(self, redis_client: Redis, ttl: int | None = 300) -> None: + """Persist payload without token.""" + await redis_client.set(self.redis_key, self.model_dump_json(), ex=ttl) + + def fetch_records(self) -> list[dict]: + """Fetch the actual export records for this export data. + + Returns: + List of flattened entity records + + Raises: + ValueError: If entity type is not supported + """ + return fetch_export_data(self.entity_type, self.entity_ids) + + +def fetch_subscription_export_data(entity_ids: list[str]) -> list[dict]: + """Fetch subscription data for export. + + Args: + entity_ids: List of subscription IDs as strings + + Returns: + List of flattened subscription dictionaries with fields: + subscription_id, description, status, insync, start_date, end_date, + note, product_name, tag, product_type, customer_id + """ + stmt = ( + select( + SubscriptionTable.subscription_id, + SubscriptionTable.description, + SubscriptionTable.status, + SubscriptionTable.insync, + SubscriptionTable.start_date, + SubscriptionTable.end_date, + SubscriptionTable.note, + SubscriptionTable.customer_id, + ProductTable.name.label("product_name"), + ProductTable.tag, + ProductTable.product_type, + ) + .join(ProductTable, SubscriptionTable.product_id == ProductTable.product_id) + .filter(SubscriptionTable.subscription_id.in_([UUID(sid) for sid in entity_ids])) + ) + + rows = db.session.execute(stmt).all() + + return [ + { + "subscription_id": str(row.subscription_id), + "description": row.description, + "status": row.status, + "insync": row.insync, + "start_date": row.start_date.isoformat() if row.start_date else None, + "end_date": row.end_date.isoformat() if row.end_date else None, + "note": row.note, + "product_name": row.product_name, + "tag": row.tag, + "product_type": row.product_type, + "customer_id": row.customer_id, + } + for row in rows + ] + + +def fetch_workflow_export_data(entity_ids: list[str]) -> list[dict]: + """Fetch workflow data for export. + + Args: + entity_ids: List of workflow names as strings + + Returns: + List of flattened workflow dictionaries with fields: + name, description, created_at, product_names (comma-separated), + product_ids (comma-separated), product_types (comma-separated) + """ + stmt = ( + select(WorkflowTable).options(selectinload(WorkflowTable.products)).filter(WorkflowTable.name.in_(entity_ids)) + ) + workflows = db.session.scalars(stmt).all() + + return [ + { + "name": w.name, + "description": w.description, + "created_at": w.created_at.isoformat() if w.created_at else None, + "product_names": ", ".join(p.name for p in w.products), + "product_ids": ", ".join(str(p.product_id) for p in w.products), + "product_types": ", ".join(p.product_type for p in w.products), + } + for w in workflows + ] + + +def fetch_product_export_data(entity_ids: list[str]) -> list[dict]: + """Fetch product data for export. + + Args: + entity_ids: List of product IDs as strings + + Returns: + List of flattened product dictionaries with fields: + product_id, name, product_type, tag, description, status, created_at + """ + stmt = ( + select(ProductTable) + .options( + selectinload(ProductTable.workflows), + selectinload(ProductTable.fixed_inputs), + selectinload(ProductTable.product_blocks), + ) + .filter(ProductTable.product_id.in_([UUID(pid) for pid in entity_ids])) + ) + products = db.session.scalars(stmt).all() + + return [ + { + "product_id": str(p.product_id), + "name": p.name, + "product_type": p.product_type, + "tag": p.tag, + "description": p.description, + "status": p.status, + "created_at": p.created_at.isoformat() if p.created_at else None, + } + for p in products + ] + + +def fetch_process_export_data(entity_ids: list[str]) -> list[dict]: + """Fetch process data for export. + + Args: + entity_ids: List of process IDs as strings + + Returns: + List of flattened process dictionaries with fields: + process_id, workflow_name, workflow_id, last_status, is_task, + created_by, started_at, last_modified_at, last_step + """ + stmt = ( + select(ProcessTable) + .options(selectinload(ProcessTable.workflow)) + .filter(ProcessTable.process_id.in_([UUID(pid) for pid in entity_ids])) + ) + processes = db.session.scalars(stmt).all() + + return [ + { + "process_id": str(p.process_id), + "workflow_name": p.workflow.name if p.workflow else None, + "workflow_id": str(p.workflow_id), + "last_status": p.last_status, + "is_task": p.is_task, + "created_by": p.created_by, + "started_at": p.started_at.isoformat() if p.started_at else None, + "last_modified_at": p.last_modified_at.isoformat() if p.last_modified_at else None, + "last_step": p.last_step, + } + for p in processes + ] + + +def fetch_export_data(entity_type: EntityType, entity_ids: list[str]) -> list[dict]: + """Fetch export data for any entity type. + + Args: + entity_type: The type of entities to fetch + entity_ids: List of entity IDs/names as strings + + Returns: + List of flattened entity dictionaries ready for CSV export + + Raises: + ValueError: If entity_type is not supported + """ + match entity_type: + case EntityType.SUBSCRIPTION: + return fetch_subscription_export_data(entity_ids) + case EntityType.WORKFLOW: + return fetch_workflow_export_data(entity_ids) + case EntityType.PRODUCT: + return fetch_product_export_data(entity_ids) + case EntityType.PROCESS: + return fetch_process_export_data(entity_ids) + case _: + raise ValueError(f"Unsupported entity type: {entity_type}") From 4c4839e986f769321c571592a199a0ec289e81ba Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Wed, 8 Oct 2025 01:39:56 +0200 Subject: [PATCH 02/16] typing --- orchestrator/search/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orchestrator/search/export.py b/orchestrator/search/export.py index 50beaba07..314d2b2ab 100644 --- a/orchestrator/search/export.py +++ b/orchestrator/search/export.py @@ -82,7 +82,7 @@ async def from_redis(cls, token: str, redis_client: "Redis") -> "ExportData": obj = cls.model_validate_json(export_data_json) return obj.model_copy(update={"token": token}) - async def save_to_redis(self, redis_client: Redis, ttl: int | None = 300) -> None: + async def save_to_redis(self, redis_client: "Redis", ttl: int | None = 300) -> None: """Persist payload without token.""" await redis_client.set(self.redis_key, self.model_dump_json(), ex=ttl) From 264f9fe366c61271a507a531595f593d00162d0c Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Wed, 8 Oct 2025 01:59:19 +0200 Subject: [PATCH 03/16] juts hardcode base url --- orchestrator/search/agent/agent.py | 3 +-- orchestrator/search/agent/tools.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/orchestrator/search/agent/agent.py b/orchestrator/search/agent/agent.py index d8490a5d7..9dd34c1db 100644 --- a/orchestrator/search/agent/agent.py +++ b/orchestrator/search/agent/agent.py @@ -48,8 +48,7 @@ def build_agent_router(model: str | OpenAIModel, toolsets: list[FunctionToolset[ @router.post("/") async def agent_endpoint(request: Request) -> Response: - base_url = str(request.base_url).rstrip("/") - initial_state = SearchState(base_url=base_url) + initial_state = SearchState() return await handle_ag_ui_request(agent, request, deps=StateDeps(initial_state)) return router diff --git a/orchestrator/search/agent/tools.py b/orchestrator/search/agent/tools.py index f81b74633..1b58bfc81 100644 --- a/orchestrator/search/agent/tools.py +++ b/orchestrator/search/agent/tools.py @@ -326,7 +326,7 @@ async def prepare_export( async with create_redis_asyncio_client(app_settings.CACHE_URI) as redis_client: await export_data.save_to_redis(redis_client, ttl=3000) - download_url = f"{ctx.deps.state.base_url}/api/search/export/{export_token}" + download_url = f"http://localhost:8080/api/search/export/{export_token}" # Update state with export data so frontend can render the download button ctx.deps.state.export_data = { From d5d31a8e75166cee6b4023cdaf8e64e705e98311 Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Thu, 9 Oct 2025 17:09:59 +0200 Subject: [PATCH 04/16] refactor to save agent runs and queries in postgres --- orchestrator/agentic_app.py | 18 +-- orchestrator/api/api_v1/api.py | 5 + orchestrator/api/api_v1/endpoints/agent.py | 121 ++++++++++++++++++ orchestrator/api/api_v1/endpoints/search.py | 60 +-------- orchestrator/db/__init__.py | 6 + orchestrator/db/models.py | 57 +++++++++ ...add_agent_runs_and_agent_queries_tables.py | 58 +++++++++ orchestrator/search/agent/__init__.py | 4 +- orchestrator/search/agent/agent.py | 55 ++++---- orchestrator/search/agent/prompts.py | 6 +- orchestrator/search/agent/state.py | 2 + orchestrator/search/agent/tools.py | 104 ++++++--------- orchestrator/search/export.py | 72 ----------- orchestrator/search/retrieval/__init__.py | 4 +- orchestrator/search/retrieval/engine.py | 58 ++++++--- orchestrator/search/schemas/parameters.py | 30 +++-- 16 files changed, 386 insertions(+), 274 deletions(-) create mode 100644 orchestrator/api/api_v1/endpoints/agent.py create mode 100644 orchestrator/migrations/versions/schema/2025-10-09_459f352f5aa6_add_agent_runs_and_agent_queries_tables.py diff --git a/orchestrator/agentic_app.py b/orchestrator/agentic_app.py index a656578be..b9014863a 100644 --- a/orchestrator/agentic_app.py +++ b/orchestrator/agentic_app.py @@ -44,7 +44,7 @@ def __init__( """Initialize the `LLMOrchestratorCore` class. This class extends `OrchestratorCore` with LLM features (search and agent). - It runs the search migration and mounts the agent endpoint based on feature flags. + It runs the search migration based on feature flags. Args: *args: All the normal arguments passed to the `OrchestratorCore` class. @@ -79,22 +79,6 @@ def __init__( ) raise - # Mount agent endpoint if agent is enabled - if self.llm_settings.AGENT_ENABLED: - logger.info("Initializing agent features", model=self.agent_model) - try: - from orchestrator.search.agent import build_agent_router - - agent_app = build_agent_router(self.agent_model, self.agent_tools) - self.mount("/agent", agent_app) - except ImportError as e: - logger.error( - "Unable to initialize agent features. Please install agent dependencies: " - "`pip install orchestrator-core[agent]`", - error=str(e), - ) - raise - main_typer_app = typer.Typer() main_typer_app.add_typer(cli_app, name="orchestrator", help="The orchestrator CLI commands") diff --git a/orchestrator/api/api_v1/api.py b/orchestrator/api/api_v1/api.py index 9994ee5f9..a044a413a 100644 --- a/orchestrator/api/api_v1/api.py +++ b/orchestrator/api/api_v1/api.py @@ -95,3 +95,8 @@ api_router.include_router( search.router, prefix="/search", tags=["Core", "Search"], dependencies=[Depends(authorize)] ) + +if llm_settings.AGENT_ENABLED: + from orchestrator.api.api_v1.endpoints import agent + + api_router.include_router(agent.router, prefix="/agent", tags=["Core", "Agent"], dependencies=[Depends(authorize)]) diff --git a/orchestrator/api/api_v1/endpoints/agent.py b/orchestrator/api/api_v1/endpoints/agent.py new file mode 100644 index 000000000..81eadc167 --- /dev/null +++ b/orchestrator/api/api_v1/endpoints/agent.py @@ -0,0 +1,121 @@ +# Copyright 2019-2025 SURF, GÉANT. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import cache +from structlog import get_logger +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from pydantic_ai.ag_ui import StateDeps, handle_ag_ui_request +from pydantic_ai.agent import Agent +from starlette.responses import Response + +from orchestrator.db import AgentQueryTable, db +from orchestrator.llm_settings import llm_settings +from orchestrator.search.agent import build_agent_instance +from orchestrator.search.agent.state import SearchState +from orchestrator.search.retrieval import execute_search_for_export + +router = APIRouter() +logger = get_logger(__name__) + + +@cache +def get_agent() -> Agent[StateDeps[SearchState], str]: + """Dependency to provide the agent instance. + + The agent is built once and cached for the lifetime of the application. + """ + return build_agent_instance(llm_settings.AGENT_MODEL, agent_tools=None) + + +@router.post("/") +async def agent_conversation( + request: Request, + agent: Annotated[Agent[StateDeps[SearchState], str], Depends(get_agent)], +) -> Response: + """Agent conversation endpoint using pydantic-ai ag_ui protocol. + + This endpoint handles the interactive agent conversation for search. + """ + initial_state = SearchState() + return await handle_ag_ui_request(agent, request, deps=StateDeps(initial_state)) + + +@router.get( + "/runs/{run_id}/queries/{query_id}/export", + summary="Export query results by run_id and query_id", + response_model=dict[str, Any], +) +async def export_by_query_id(run_id: str, query_id: str) -> dict[str, Any]: + """Export search results using run_id and query_id. + + The query is retrieved from the database, re-executed, and results are returned + as flattened records suitable for CSV download. + + Args: + run_id: Agent run UUID + query_id: Query UUID + + Returns: + Dictionary containing 'page' with an array of flattened entity records. + Each record contains snake_case field names from the database with nested + relationships flattened (e.g., product_name instead of product.name). + + Raises: + HTTPException: 404 if query not found, 400 if invalid data + """ + from uuid import UUID + + from orchestrator.search.export import fetch_export_data + + try: + query_uuid = UUID(query_id) + run_uuid = UUID(run_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid run_id or query_id format", + ) + + agent_query = db.session.query(AgentQueryTable).filter_by(query_id=query_uuid, run_id=run_uuid).first() + + if not agent_query: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Query {query_id} not found in run {run_id}", + ) + try: + from orchestrator.search.retrieval.pagination import PaginationParams + + # Get the full query state including the embedding that was used + query_state = agent_query.get_state() + + # Create pagination params with the saved embedding to ensure consistent results + pagination_params = PaginationParams( + q_vec_override=query_state.query_embedding.tolist() if query_state.query_embedding is not None else None + ) + + search_response = await execute_search_for_export(query_state.parameters, db.session, pagination_params) + entity_ids = [res.entity_id for res in search_response.results] + + export_records = fetch_export_data(query_state.parameters.entity_type, entity_ids) + + return {"page": export_records} + + except Exception as e: + logger.error(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error executing export: {str(e)}", + ) diff --git a/orchestrator/api/api_v1/endpoints/search.py b/orchestrator/api/api_v1/endpoints/search.py index 57f23bc57..3140832c5 100644 --- a/orchestrator/api/api_v1/endpoints/search.py +++ b/orchestrator/api/api_v1/endpoints/search.py @@ -39,7 +39,6 @@ ) from orchestrator.search.core.exceptions import InvalidCursorError from orchestrator.search.core.types import EntityType, UIType -from orchestrator.search.export import ExportData from orchestrator.search.filters.definitions import generate_definitions from orchestrator.search.indexing.registry import ENTITY_CONFIG_REGISTRY from orchestrator.search.retrieval import execute_search @@ -58,8 +57,6 @@ ) from orchestrator.search.schemas.results import SearchResult, TypeDefinition from orchestrator.services.subscriptions import format_special_types -from orchestrator.settings import app_settings -from orchestrator.utils.redis_client import create_redis_asyncio_client router = APIRouter() @@ -134,11 +131,7 @@ async def _perform_search_and_fetch( except InvalidCursorError: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid pagination cursor") - search_response = await execute_search( - search_params=search_params, - db_session=db.session, - pagination_params=pagination_params, - ) + search_response = await execute_search(search_params, db.session, pagination_params) if not search_response.results: return SearchResultsSchema(search_metadata=search_response.metadata) @@ -182,11 +175,7 @@ async def search_subscriptions( except InvalidCursorError: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid pagination cursor") - search_response = await execute_search( - search_params=search_params, - db_session=db.session, - pagination_params=pagination_params, - ) + search_response = await execute_search(search_params, db.session, pagination_params) if not search_response.results: return SearchResultsSchema(search_metadata=search_response.metadata) @@ -297,48 +286,3 @@ async def list_paths( async def get_definitions() -> dict[UIType, TypeDefinition]: """Provide a static definition of operators and schemas for each UI type.""" return generate_definitions() - - -@router.get( - "/export/{token}", - response_model=dict[str, Any], -) -async def export_by_token(token: str) -> dict[str, Any]: - """Export search results using a token generated by the search agent. - - The token contains entity IDs from a previous search. This endpoint retrieves - those IDs from Redis, fetches the full entity data from the database, and returns - flattened records suitable for CSV download. - - Args: - token: Export token generated by the prepare_export agent tool - - Returns: - Dictionary containing 'page' with an array of flattened entity records. - Each record contains snake_case field names from the database with nested - relationships flattened (e.g., product_name instead of product.name). - - Raises: - HTTPException: 404 if token not found or expired, 400 if data is invalid - """ - async with create_redis_asyncio_client(app_settings.CACHE_URI) as redis_client: - - # Load export data from Redis - try: - export_data = await ExportData.from_redis(token, redis_client) - except ValueError as e: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=str(e), - ) - - # Fetch the actual records - try: - export_records = export_data.fetch_records() - except ValueError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), - ) - - return {"page": export_records} diff --git a/orchestrator/db/__init__.py b/orchestrator/db/__init__.py index 6e7138e08..e304c7003 100644 --- a/orchestrator/db/__init__.py +++ b/orchestrator/db/__init__.py @@ -17,6 +17,8 @@ from orchestrator.db.database import BaseModel as DbBaseModel from orchestrator.db.database import Database, transactional from orchestrator.db.models import ( # noqa: F401 + AgentQueryTable, + AgentRunTable, EngineSettingsTable, FixedInputTable, InputStateTable, @@ -74,6 +76,8 @@ def init_database(settings: AppSettings) -> Database: __all__ = [ "transactional", + "AgentQueryTable", + "AgentRunTable", "SubscriptionTable", "ProcessSubscriptionTable", "ProcessTable", @@ -97,6 +101,8 @@ def init_database(settings: AppSettings) -> Database: ] ALL_DB_MODELS: list[type[DbBaseModel]] = [ + AgentQueryTable, + AgentRunTable, FixedInputTable, ProcessStepTable, ProcessSubscriptionTable, diff --git a/orchestrator/db/models.py b/orchestrator/db/models.py index 2d7f43b3d..4ac2c2a80 100644 --- a/orchestrator/db/models.py +++ b/orchestrator/db/models.py @@ -15,6 +15,7 @@ import enum from datetime import datetime, timezone +from typing import TYPE_CHECKING from uuid import UUID import sqlalchemy @@ -58,6 +59,9 @@ from orchestrator.utils.datetime import nowtz from orchestrator.version import GIT_COMMIT_HASH +if TYPE_CHECKING: + from orchestrator.search.schemas.parameters import AgentQueryState + logger = structlog.get_logger(__name__) TAG_LENGTH = 20 @@ -674,6 +678,59 @@ class SubscriptionSearchView(BaseModel): subscription = relationship("SubscriptionTable", foreign_keys=[subscription_id]) +class AgentRunTable(BaseModel): + """Agent conversation/session tracking.""" + + __tablename__ = "agent_runs" + + run_id = mapped_column("run_id", UUIDType, server_default=text("uuid_generate_v4()"), primary_key=True) + agent_type = mapped_column(String(50), nullable=False) + created_at = mapped_column(UtcTimestamp, server_default=text("current_timestamp()"), nullable=False) + + queries = relationship("AgentQueryTable", back_populates="run", cascade="delete", passive_deletes=True) + + __table_args__ = (Index("ix_agent_runs_created_at", "created_at"),) + + +class AgentQueryTable(BaseModel): + """Individual query execution within an agent run.""" + + __tablename__ = "agent_queries" + + query_id = mapped_column("query_id", UUIDType, server_default=text("uuid_generate_v4()"), primary_key=True) + run_id = mapped_column( + "run_id", UUIDType, ForeignKey("agent_runs.run_id", ondelete="CASCADE"), nullable=False, index=True + ) + query_number = mapped_column(Integer, nullable=False) + + # Search parameters as JSONB (maps to BaseSearchParameters subclasses) + parameters = mapped_column(pg.JSONB, nullable=False) + + # Query embedding for semantic search (pgvector) + query_embedding = mapped_column(Vector(1536), nullable=True) + + executed_at = mapped_column(UtcTimestamp, server_default=text("current_timestamp()"), nullable=False) + + run = relationship("AgentRunTable", back_populates="queries") + + __table_args__ = ( + Index("ix_agent_queries_run_id", "run_id"), + Index("ix_agent_queries_executed_at", "executed_at"), + UniqueConstraint("run_id", "query_number", name="uq_run_query_number"), + ) + + def get_state(self) -> "AgentQueryState": + """Reconstruct complete query state including parameters and embedding. + + Returns: + AgentQueryState with typed parameters and embedding vector. + + """ + from orchestrator.search.schemas.parameters import AgentQueryState + + return AgentQueryState.model_validate(self) + + class EngineSettingsTable(BaseModel): __tablename__ = "engine_settings" global_lock = mapped_column(Boolean(), default=False, nullable=False, primary_key=True) diff --git a/orchestrator/migrations/versions/schema/2025-10-09_459f352f5aa6_add_agent_runs_and_agent_queries_tables.py b/orchestrator/migrations/versions/schema/2025-10-09_459f352f5aa6_add_agent_runs_and_agent_queries_tables.py new file mode 100644 index 000000000..39c7214f4 --- /dev/null +++ b/orchestrator/migrations/versions/schema/2025-10-09_459f352f5aa6_add_agent_runs_and_agent_queries_tables.py @@ -0,0 +1,58 @@ +"""Add agent_runs and agent_queries tables. + +Revision ID: 459f352f5aa6 +Revises: 850dccac3b02 +Create Date: 2025-10-09 00:52:16.297143 + +""" + +import sqlalchemy as sa +from alembic import op +from pgvector.sqlalchemy import Vector +from sqlalchemy.dialects import postgresql +from sqlalchemy_utils import UUIDType + +# revision identifiers, used by Alembic. +revision = "459f352f5aa6" +down_revision = "850dccac3b02" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "agent_runs", + sa.Column("run_id", UUIDType(), server_default=sa.text("uuid_generate_v4()"), nullable=False), + sa.Column("agent_type", sa.String(length=50), nullable=False), + sa.Column( + "created_at", sa.TIMESTAMP(timezone=True), server_default=sa.text("current_timestamp"), nullable=False + ), + sa.PrimaryKeyConstraint("run_id"), + ) + op.create_index("ix_agent_runs_created_at", "agent_runs", ["created_at"]) + + op.create_table( + "agent_queries", + sa.Column("query_id", UUIDType(), server_default=sa.text("uuid_generate_v4()"), nullable=False), + sa.Column("run_id", UUIDType(), nullable=False), + sa.Column("query_number", sa.Integer(), nullable=False), + sa.Column("parameters", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column("query_embedding", Vector(1536), nullable=True), + sa.Column( + "executed_at", sa.TIMESTAMP(timezone=True), server_default=sa.text("current_timestamp"), nullable=False + ), + sa.ForeignKeyConstraint(["run_id"], ["agent_runs.run_id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("query_id"), + sa.UniqueConstraint("run_id", "query_number", name="uq_run_query_number"), + ) + op.create_index("ix_agent_queries_run_id", "agent_queries", ["run_id"]) + op.create_index("ix_agent_queries_executed_at", "agent_queries", ["executed_at"]) + + +def downgrade() -> None: + op.drop_index("ix_agent_queries_executed_at", table_name="agent_queries") + op.drop_index("ix_agent_queries_run_id", table_name="agent_queries") + op.drop_table("agent_queries") + + op.drop_index("ix_agent_runs_created_at", table_name="agent_runs") + op.drop_table("agent_runs") diff --git a/orchestrator/search/agent/__init__.py b/orchestrator/search/agent/__init__.py index 46fa7fd53..55d92ec57 100644 --- a/orchestrator/search/agent/__init__.py +++ b/orchestrator/search/agent/__init__.py @@ -14,8 +14,8 @@ # This module requires: pydantic-ai==0.7.0, ag-ui-protocol>=0.1.8 -from orchestrator.search.agent.agent import build_agent_router +from orchestrator.search.agent.agent import build_agent_instance __all__ = [ - "build_agent_router", + "build_agent_instance", ] diff --git a/orchestrator/search/agent/agent.py b/orchestrator/search/agent/agent.py index 9dd34c1db..e3e0c8d04 100644 --- a/orchestrator/search/agent/agent.py +++ b/orchestrator/search/agent/agent.py @@ -14,13 +14,11 @@ from typing import Any import structlog -from fastapi import APIRouter, HTTPException, Request -from pydantic_ai.ag_ui import StateDeps, handle_ag_ui_request +from pydantic_ai.ag_ui import StateDeps from pydantic_ai.agent import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.settings import ModelSettings from pydantic_ai.toolsets import FunctionToolset -from starlette.responses import Response from orchestrator.search.agent.prompts import get_base_instructions, get_dynamic_instructions from orchestrator.search.agent.state import SearchState @@ -29,35 +27,32 @@ logger = structlog.get_logger(__name__) -def build_agent_router(model: str | OpenAIModel, toolsets: list[FunctionToolset[Any]] | None = None) -> APIRouter: - router = APIRouter() +def build_agent_instance( + model: str | OpenAIModel, agent_tools: list[FunctionToolset[Any]] | None = None +) -> Agent[StateDeps[SearchState], str]: + """Build and configure the search agent instance. - try: - toolsets = toolsets + [search_toolset] if toolsets else [search_toolset] + Args: + model: The LLM model to use (string or OpenAIModel instance) + agent_tools: Optional list of additional toolsets to include - agent = Agent( - model=model, - deps_type=StateDeps[SearchState], - model_settings=ModelSettings( - parallel_tool_calls=False, - ), # https://github.com/pydantic/pydantic-ai/issues/562 - toolsets=toolsets, - ) - agent.instructions(get_base_instructions) - agent.instructions(get_dynamic_instructions) + Returns: + Configured Agent instance with StateDeps[SearchState] dependencies - @router.post("/") - async def agent_endpoint(request: Request) -> Response: - initial_state = SearchState() - return await handle_ag_ui_request(agent, request, deps=StateDeps(initial_state)) + Raises: + Exception: If agent initialization fails + """ + toolsets = agent_tools + [search_toolset] if agent_tools else [search_toolset] - return router - except Exception as e: - logger.error("Agent init failed; serving disabled stub.", error=str(e)) - error_msg = f"Agent disabled: {str(e)}" + agent = Agent( + model=model, + deps_type=StateDeps[SearchState], + model_settings=ModelSettings( + parallel_tool_calls=False, + ), # https://github.com/pydantic/pydantic-ai/issues/562 + toolsets=toolsets, + ) + agent.instructions(get_base_instructions) + agent.instructions(get_dynamic_instructions) - @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"]) - async def _disabled(path: str) -> None: - raise HTTPException(status_code=503, detail=error_msg) - - return router + return agent diff --git a/orchestrator/search/agent/prompts.py b/orchestrator/search/agent/prompts.py index 37bd8e46f..8e2a2e869 100644 --- a/orchestrator/search/agent/prompts.py +++ b/orchestrator/search/agent/prompts.py @@ -60,9 +60,9 @@ async def get_base_instructions() -> str: 4. **Report**: Answer the users' question directly and summarize when appropiate. 5. **Export (if requested)**: If the user asks to export, download, or save results as CSV/file: - **IMPORTANT**: Export is ONLY available for SELECT actions (not COUNT or AGGREGATE) - - Call `prepare_export` to generate an export token - - The UI will automatically display a download button - you don't need to mention URLs or tokens - - Simply confirm to the user that the export is ready + - Call `prepare_export` to save the query and generate a download URL + - The UI will automatically display a download button - you don't need to mention URLs or IDs + - Simply confirm to the user that the export is ready for download --- ### 4. Critical Rules diff --git a/orchestrator/search/agent/state.py b/orchestrator/search/agent/state.py index 1414664fc..4e30a2fe2 100644 --- a/orchestrator/search/agent/state.py +++ b/orchestrator/search/agent/state.py @@ -12,11 +12,13 @@ # limitations under the License. from typing import Any +from uuid import UUID from pydantic import BaseModel, Field class SearchState(BaseModel): + run_id: UUID | None = None parameters: dict[str, Any] | None = None results: list[dict[str, Any]] = Field(default_factory=list) export_data: dict[str, Any] | None = None diff --git a/orchestrator/search/agent/tools.py b/orchestrator/search/agent/tools.py index 1b58bfc81..aa9a3b5be 100644 --- a/orchestrator/search/agent/tools.py +++ b/orchestrator/search/agent/tools.py @@ -13,7 +13,6 @@ from collections.abc import Awaitable, Callable from typing import Any, TypeVar -from uuid import uuid4 import structlog from ag_ui.core import EventType, StateSnapshotEvent @@ -33,13 +32,10 @@ ) from orchestrator.schemas.search import SearchResultsSchema from orchestrator.search.core.types import ActionType, EntityType, FilterOp -from orchestrator.search.export import ExportData from orchestrator.search.filters import FilterTree from orchestrator.search.retrieval.exceptions import FilterValidationError, PathNotFoundError from orchestrator.search.retrieval.validation import validate_filter_tree -from orchestrator.search.schemas.parameters import PARAMETER_REGISTRY, BaseSearchParameters -from orchestrator.settings import app_settings -from orchestrator.utils.redis_client import create_redis_asyncio_client +from orchestrator.search.schemas.parameters import BaseSearchParameters from .state import SearchState @@ -153,15 +149,10 @@ async def execute_search( if not ctx.deps.state.parameters: raise ValueError("No search parameters set") - entity_type = EntityType(ctx.deps.state.parameters["entity_type"]) - param_class = PARAMETER_REGISTRY.get(entity_type) - if not param_class: - raise ValueError(f"Unknown entity type: {entity_type}") - - params = param_class(**ctx.deps.state.parameters) + params = BaseSearchParameters.create(**ctx.deps.state.parameters) logger.debug( "Executing database search", - search_entity_type=entity_type.value, + search_entity_type=params.entity_type.value, limit=limit, has_filters=params.filters is not None, query=params.query, @@ -173,7 +164,7 @@ async def execute_search( params.limit = limit - fn = SEARCH_FN_MAP[entity_type] + fn = SEARCH_FN_MAP[params.entity_type] search_results = await fn(params) logger.debug( @@ -267,7 +258,7 @@ async def prepare_export( ctx: RunContext[StateDeps[SearchState]], max_results: int = 1000, ) -> StateSnapshotEvent: - """Executes the search with the current parameters, collects up to max_results entity IDs, stores them in Redis with a temporary token, and returns the token for export.""" + """Saves the current search query to the database and returns run_id/query_id for export.""" if not ctx.deps.state.parameters: raise ValueError("No search parameters set. Run a search first to see what will be exported.") @@ -279,63 +270,52 @@ async def prepare_export( "Please run a SELECT search first." ) - entity_type = EntityType(ctx.deps.state.parameters["entity_type"]) - param_class = PARAMETER_REGISTRY.get(entity_type) - if not param_class: - raise ValueError(f"Unknown entity type: {entity_type}") - - # Cap at 1000 results () - export_limit = min(max_results, 1000) - - params = param_class(**ctx.deps.state.parameters) - params.limit = export_limit - - logger.debug( - "Preparing export", - entity_type=entity_type.value, - limit=export_limit, - has_filters=params.filters is not None, + from orchestrator.db import AgentQueryTable, AgentRunTable, db + + # Ensure we have a run_id + if not ctx.deps.state.run_id: + # Create a new agent run + agent_run = AgentRunTable(agent_type="search") + db.session.add(agent_run) + db.session.commit() + db.session.refresh(agent_run) + ctx.deps.state.run_id = agent_run.run_id + logger.debug("Created new agent run", run_id=str(agent_run.run_id)) + + query_number = db.session.query(AgentQueryTable).filter_by(run_id=ctx.deps.state.run_id).count() + 1 + + export_limit = min(max_results, BaseSearchParameters.export_limit) + params_dict = ctx.deps.state.parameters.copy() + params_dict["export_limit"] = export_limit + + agent_query = AgentQueryTable( + run_id=ctx.deps.state.run_id, + query_number=query_number, + parameters=params_dict, + query_embedding=None, # TODO: We need to save the embeddding here. ) + db.session.add(agent_query) + db.session.commit() + db.session.refresh(agent_query) - fn = SEARCH_FN_MAP[entity_type] - search_results = await fn(params) - - if not search_results.data: - raise ValueError("No results found to export. Try adjusting your search criteria.") - - entity_ids = [] - for result in search_results.data: - if entity_type == EntityType.SUBSCRIPTION: - entity_ids.append(str(result.subscription["subscription_id"])) - elif entity_type == EntityType.WORKFLOW: - entity_ids.append(str(result.workflow.name)) - elif entity_type == EntityType.PRODUCT: - entity_ids.append(str(result.product.product_id)) - elif entity_type == EntityType.PROCESS: - entity_ids.append(str(result.process.process_id)) - - # Generate export token and create export data model - export_token = str(uuid4()) - export_data = ExportData( - entity_type=entity_type, - entity_ids=entity_ids, - token=export_token, + logger.debug( + "Saved query for export", + run_id=str(ctx.deps.state.run_id), + query_id=str(agent_query.query_id), + query_number=query_number, ) - # Store in Redis with TTL - async with create_redis_asyncio_client(app_settings.CACHE_URI) as redis_client: - await export_data.save_to_redis(redis_client, ttl=3000) - - download_url = f"http://localhost:8080/api/search/export/{export_token}" + # Build export URL using run_id and query_id + base_url = ctx.deps.state.base_url or "http://localhost:8080" + download_url = f"{base_url}/api/agent/runs/{ctx.deps.state.run_id}/queries/{agent_query.query_id}/export" # Update state with export data so frontend can render the download button ctx.deps.state.export_data = { "action": "export", - "token": export_token, - "count": len(entity_ids), - "entity_type": entity_type.value, + "run_id": str(ctx.deps.state.run_id), + "query_id": str(agent_query.query_id), "download_url": download_url, - "message": f"Found {len(entity_ids)} {entity_type.value.lower()}(s).", + "message": f"Export ready for download (up to {export_limit} results).", } logger.debug("Export data set in state", export_data=ctx.deps.state.export_data) diff --git a/orchestrator/search/export.py b/orchestrator/search/export.py index 314d2b2ab..311a34121 100644 --- a/orchestrator/search/export.py +++ b/orchestrator/search/export.py @@ -13,8 +13,6 @@ from uuid import UUID -from pydantic import BaseModel, Field -from pydantic.config import ConfigDict from sqlalchemy import select from sqlalchemy.orm import selectinload @@ -26,76 +24,6 @@ db, ) from orchestrator.search.core.types import EntityType -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from redis.asyncio.client import Redis - -# Redis namespace for export data -EXPORT_REDIS_NAMESPACE = "orchestrator:export" - - -class ExportData(BaseModel): - """Model for export data stored in Redis. - - Attributes: - entity_type: The type of entities being exported - entity_ids: List of entity IDs/names to export - token: Unique export token for this data - """ - - model_config = ConfigDict(use_enum_values=True) - - entity_type: EntityType - entity_ids: list[str] - token: str | None = Field(default=None, exclude=True) - - @property - def redis_key(self) -> str: - """Redis key for storing this export data. - - Returns: - Redis key string - """ - return f"{EXPORT_REDIS_NAMESPACE}:{self.token}" - - @classmethod - async def from_redis(cls, token: str, redis_client: "Redis") -> "ExportData": - """Load export data from Redis using a token. - - Args: - token: Export token - redis_client: Redis client instance - - Returns: - ExportData instance - - Raises: - ValueError: If token not found in Redis - """ - redis_key = f"{EXPORT_REDIS_NAMESPACE}:{token}" - export_data_json = await redis_client.get(redis_key) - - if not export_data_json: - raise ValueError(f"Export token '{token}' not found or expired") - - obj = cls.model_validate_json(export_data_json) - return obj.model_copy(update={"token": token}) - - async def save_to_redis(self, redis_client: "Redis", ttl: int | None = 300) -> None: - """Persist payload without token.""" - await redis_client.set(self.redis_key, self.model_dump_json(), ex=ttl) - - def fetch_records(self) -> list[dict]: - """Fetch the actual export records for this export data. - - Returns: - List of flattened entity records - - Raises: - ValueError: If entity type is not supported - """ - return fetch_export_data(self.entity_type, self.entity_ids) def fetch_subscription_export_data(entity_ids: list[str]) -> list[dict]: diff --git a/orchestrator/search/retrieval/__init__.py b/orchestrator/search/retrieval/__init__.py index 7bb32303a..353fb6fba 100644 --- a/orchestrator/search/retrieval/__init__.py +++ b/orchestrator/search/retrieval/__init__.py @@ -11,6 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .engine import execute_search +from .engine import execute_search, execute_search_for_export -__all__ = ["execute_search"] +__all__ = ["execute_search", "execute_search_for_export"] diff --git a/orchestrator/search/retrieval/engine.py b/orchestrator/search/retrieval/engine.py index b1a08f46f..374c12964 100644 --- a/orchestrator/search/retrieval/engine.py +++ b/orchestrator/search/retrieval/engine.py @@ -110,32 +110,23 @@ def _extract_matching_field_from_filters(filters: FilterTree) -> MatchingField | return MatchingField(text=text, path=pf.path, highlight_indices=[(0, len(text))]) -async def execute_search( +async def _execute_search_internal( search_params: BaseSearchParameters, db_session: Session, + limit: int, pagination_params: PaginationParams | None = None, ) -> SearchResponse: - """Execute a hybrid search and return ranked results. - - Builds a candidate entity query based on the given search parameters, - applies the appropriate ranking strategy, and executes the final ranked - query to retrieve results. + """Internal function to execute search with specified parameters. Args: - search_params (BaseSearchParameters): The search parameters specifying vector, fuzzy, or filter criteria. - db_session (Session): The active SQLAlchemy session for executing the query. - pagination_params (PaginationParams): Parameters controlling pagination of the search results. - limit (int, optional): The maximum number of search results to return, by default 5. + search_params: The search parameters specifying vector, fuzzy, or filter criteria. + db_session: The active SQLAlchemy session for executing the query. + limit: Maximum number of results to return. + pagination_params: Optional pagination parameters. Returns: - SearchResponse: A list of `SearchResult` objects containing entity IDs, scores, - and optional highlight metadata. - - Notes: - If no vector query, filters, or fuzzy term are provided, a warning is logged - and an empty result set is returned. + SearchResponse with results, or empty response if no search criteria provided. """ - if not search_params.vector_query and not search_params.filters and not search_params.fuzzy_term: logger.warning("No search criteria provided (vector_query, fuzzy_term, or filters).") return SearchResponse(results=[], metadata=SearchMetadata.empty()) @@ -147,8 +138,39 @@ async def execute_search( logger.debug("Using retriever", retriever_type=retriever.__class__.__name__) final_stmt = retriever.apply(candidate_query) - final_stmt = final_stmt.limit(search_params.limit) + final_stmt = final_stmt.limit(limit) logger.debug(final_stmt) result = db_session.execute(final_stmt).mappings().all() return _format_response(result, search_params, retriever.metadata) + + +async def execute_search( + search_params: BaseSearchParameters, + db_session: Session, + pagination_params: PaginationParams | None = None, +) -> SearchResponse: + """Execute a search and return ranked results.""" + return await _execute_search_internal(search_params, db_session, search_params.limit, pagination_params) + + +async def execute_search_for_export( + search_params: BaseSearchParameters, + db_session: Session, + pagination_params: PaginationParams | None = None, +) -> SearchResponse: + """Execute a search for export purposes. + + Similar to execute_search but uses export_limit instead of limit. + The pagination_params is primarily used to pass q_vec_override to ensure + the export uses the same embedding as the original search. + + Args: + search_params: The search parameters specifying vector, fuzzy, or filter criteria. + db_session: The active SQLAlchemy session for executing the query. + pagination_params: Optional pagination parameters (primarily for q_vec_override). + + Returns: + SearchResponse with results up to export_limit. + """ + return await _execute_search_internal(search_params, db_session, search_params.export_limit, pagination_params) diff --git a/orchestrator/search/schemas/parameters.py b/orchestrator/search/schemas/parameters.py index 26d3ed79a..46bf0f71b 100644 --- a/orchestrator/search/schemas/parameters.py +++ b/orchestrator/search/schemas/parameters.py @@ -14,6 +14,7 @@ import uuid from typing import Any, Literal +import numpy as np from pydantic import BaseModel, ConfigDict, Field from orchestrator.search.core.types import ActionType, EntityType @@ -33,14 +34,13 @@ class BaseSearchParameters(BaseModel): ) limit: int = Field(default=10, ge=1, le=30, description="Maximum number of search results to return.") + export_limit: int = Field(default=1000, ge=1, le=10000, description="Maximum number of results to export.") model_config = ConfigDict(extra="forbid") @classmethod def create(cls, entity_type: EntityType, **kwargs: Any) -> "BaseSearchParameters": - try: - return PARAMETER_REGISTRY[entity_type](entity_type=entity_type, **kwargs) - except KeyError: - raise ValueError(f"No search parameter class found for entity type: {entity_type.value}") + """Create the correct search parameter subclass instance.""" + return cls.model_validate({"entity_type": entity_type, **kwargs}) @property def vector_query(self) -> str | None: @@ -121,9 +121,19 @@ class ProcessSearchParameters(BaseSearchParameters): ) -PARAMETER_REGISTRY: dict[EntityType, type[BaseSearchParameters]] = { - EntityType.SUBSCRIPTION: SubscriptionSearchParameters, - EntityType.PRODUCT: ProductSearchParameters, - EntityType.WORKFLOW: WorkflowSearchParameters, - EntityType.PROCESS: ProcessSearchParameters, -} +SearchParameters = ( + SubscriptionSearchParameters | ProductSearchParameters | WorkflowSearchParameters | ProcessSearchParameters +) + + +class AgentQueryState(BaseModel): + """Complete state of an agent query including parameters and embedding. + + This model combines the search parameters with the query embedding, + providing a complete snapshot of what was searched and how. + """ + + parameters: SearchParameters = Field(discriminator="entity_type") + query_embedding: np.ndarray | None = Field(default=None, description="The embedding vector for semantic search") + + model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True) From bb8a2734092281c89e5c5c1a872aa020a6388ad0 Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Thu, 16 Oct 2025 09:51:35 +0200 Subject: [PATCH 05/16] Unified search result format, separation of search results and data fetching, saving queries --- orchestrator/api/api_v1/endpoints/agent.py | 28 +- orchestrator/api/api_v1/endpoints/search.py | 242 +++++------------- orchestrator/db/__init__.py | 6 +- orchestrator/db/models.py | 56 +++- ...add_agent_runs_and_agent_queries_tables.py | 19 +- orchestrator/schemas/search.py | 87 +------ orchestrator/search/agent/prompts.py | 2 +- orchestrator/search/agent/state.py | 27 +- orchestrator/search/agent/tools.py | 134 +++++----- orchestrator/search/retrieval/builder.py | 32 ++- orchestrator/search/retrieval/engine.py | 22 +- orchestrator/search/retrieval/pagination.py | 85 +++--- .../search/retrieval/retrievers/base.py | 78 ++++-- .../search/retrieval/retrievers/fuzzy.py | 7 +- .../search/retrieval/retrievers/hybrid.py | 8 +- .../search/retrieval/retrievers/semantic.py | 7 +- .../search/retrieval/retrievers/structured.py | 7 +- orchestrator/search/schemas/parameters.py | 30 ++- orchestrator/search/schemas/results.py | 5 +- orchestrator/settings.py | 1 + 20 files changed, 432 insertions(+), 451 deletions(-) diff --git a/orchestrator/api/api_v1/endpoints/agent.py b/orchestrator/api/api_v1/endpoints/agent.py index 81eadc167..d71db91d3 100644 --- a/orchestrator/api/api_v1/endpoints/agent.py +++ b/orchestrator/api/api_v1/endpoints/agent.py @@ -12,15 +12,15 @@ # limitations under the License. from functools import cache -from structlog import get_logger from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic_ai.ag_ui import StateDeps, handle_ag_ui_request from pydantic_ai.agent import Agent from starlette.responses import Response +from structlog import get_logger -from orchestrator.db import AgentQueryTable, db +from orchestrator.db import SearchQueryTable, db from orchestrator.llm_settings import llm_settings from orchestrator.search.agent import build_agent_instance from orchestrator.search.agent.state import SearchState @@ -53,24 +53,21 @@ async def agent_conversation( @router.get( - "/runs/{run_id}/queries/{query_id}/export", - summary="Export query results by run_id and query_id", + "/queries/{query_id}/export", + summary="Export query results by query_id", response_model=dict[str, Any], ) -async def export_by_query_id(run_id: str, query_id: str) -> dict[str, Any]: - """Export search results using run_id and query_id. +async def export_by_query_id(query_id: str) -> dict[str, Any]: + """Export search results using query_id. The query is retrieved from the database, re-executed, and results are returned as flattened records suitable for CSV download. Args: - run_id: Agent run UUID query_id: Query UUID Returns: Dictionary containing 'page' with an array of flattened entity records. - Each record contains snake_case field names from the database with nested - relationships flattened (e.g., product_name instead of product.name). Raises: HTTPException: 404 if query not found, 400 if invalid data @@ -81,30 +78,27 @@ async def export_by_query_id(run_id: str, query_id: str) -> dict[str, Any]: try: query_uuid = UUID(query_id) - run_uuid = UUID(run_id) except ValueError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid run_id or query_id format", + detail="Invalid query_id format", ) - agent_query = db.session.query(AgentQueryTable).filter_by(query_id=query_uuid, run_id=run_uuid).first() + agent_query = db.session.query(SearchQueryTable).filter_by(query_id=query_uuid).first() if not agent_query: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Query {query_id} not found in run {run_id}", + detail=f"Query {query_id} not found", ) try: from orchestrator.search.retrieval.pagination import PaginationParams # Get the full query state including the embedding that was used - query_state = agent_query.get_state() + query_state = agent_query.to_state() # Create pagination params with the saved embedding to ensure consistent results - pagination_params = PaginationParams( - q_vec_override=query_state.query_embedding.tolist() if query_state.query_embedding is not None else None - ) + pagination_params = PaginationParams(q_vec_override=query_state.query_embedding) search_response = await execute_search_for_export(query_state.parameters, db.session, pagination_params) entity_ids = [res.entity_id for res in search_response.results] diff --git a/orchestrator/api/api_v1/endpoints/search.py b/orchestrator/api/api_v1/endpoints/search.py index 3140832c5..78503e9be 100644 --- a/orchestrator/api/api_v1/endpoints/search.py +++ b/orchestrator/api/api_v1/endpoints/search.py @@ -11,243 +11,131 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Literal, overload - from fastapi import APIRouter, HTTPException, Query, status -from sqlalchemy import case, select -from sqlalchemy.orm import selectinload -from orchestrator.db import ( - ProcessTable, - ProductTable, - WorkflowTable, - db, -) -from orchestrator.domain.base import SubscriptionModel -from orchestrator.domain.context_cache import cache_subscription_models +from orchestrator.db import db from orchestrator.schemas.search import ( PageInfoSchema, PathsResponse, - ProcessSearchResult, - ProcessSearchSchema, - ProductSearchResult, - ProductSearchSchema, SearchResultsSchema, - SubscriptionSearchResult, - WorkflowSearchResult, - WorkflowSearchSchema, ) from orchestrator.search.core.exceptions import InvalidCursorError from orchestrator.search.core.types import EntityType, UIType from orchestrator.search.filters.definitions import generate_definitions -from orchestrator.search.indexing.registry import ENTITY_CONFIG_REGISTRY from orchestrator.search.retrieval import execute_search from orchestrator.search.retrieval.builder import build_paths_query, create_path_autocomplete_lquery, process_path_rows from orchestrator.search.retrieval.pagination import ( + PaginationParams, create_next_page_cursor, process_pagination_cursor, ) from orchestrator.search.retrieval.validation import is_lquery_syntactically_valid from orchestrator.search.schemas.parameters import ( - BaseSearchParameters, ProcessSearchParameters, ProductSearchParameters, + SearchParameters, SubscriptionSearchParameters, WorkflowSearchParameters, ) from orchestrator.search.schemas.results import SearchResult, TypeDefinition -from orchestrator.services.subscriptions import format_special_types router = APIRouter() -def _create_search_result_item( - entity: WorkflowTable | ProductTable | ProcessTable, entity_type: EntityType, search_info: SearchResult -) -> WorkflowSearchResult | ProductSearchResult | ProcessSearchResult | None: - match entity_type: - case EntityType.WORKFLOW: - workflow_data = WorkflowSearchSchema.model_validate(entity) - return WorkflowSearchResult( - workflow=workflow_data, - score=search_info.score, - perfect_match=search_info.perfect_match, - matching_field=search_info.matching_field, - ) - case EntityType.PRODUCT: - product_data = ProductSearchSchema.model_validate(entity) - return ProductSearchResult( - product=product_data, - score=search_info.score, - perfect_match=search_info.perfect_match, - matching_field=search_info.matching_field, - ) - case EntityType.PROCESS: - process_data = ProcessSearchSchema.model_validate(entity) - return ProcessSearchResult( - process=process_data, - score=search_info.score, - perfect_match=search_info.perfect_match, - matching_field=search_info.matching_field, - ) - case _: - return None - - -@overload -async def _perform_search_and_fetch( - search_params: BaseSearchParameters, - entity_type: Literal[EntityType.WORKFLOW], - eager_loads: list[Any], - cursor: str | None = None, -) -> SearchResultsSchema[WorkflowSearchResult]: ... - - -@overload async def _perform_search_and_fetch( - search_params: BaseSearchParameters, - entity_type: Literal[EntityType.PRODUCT], - eager_loads: list[Any], + search_params: SearchParameters, cursor: str | None = None, -) -> SearchResultsSchema[ProductSearchResult]: ... - - -@overload -async def _perform_search_and_fetch( - search_params: BaseSearchParameters, - entity_type: Literal[EntityType.PROCESS], - eager_loads: list[Any], - cursor: str | None = None, -) -> SearchResultsSchema[ProcessSearchResult]: ... + query_id: str | None = None, +) -> SearchResultsSchema[SearchResult]: + """Execute search and return results. + + Args: + search_params: Search parameters + cursor: Pagination cursor + query_id: Optional saved query ID to use for embedding retrieval + + Returns: + Search results with entity_id, score, and matching_field. + """ + # If query_id provided, retrieve saved embedding + if query_id and not cursor: + from uuid import UUID + + from orchestrator.db import SearchQueryTable + + try: + query_uuid = UUID(query_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid query_id format: {query_id}", + ) + search_query = db.session.query(SearchQueryTable).filter_by(query_id=query_uuid).first() + if not search_query: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Query {query_id} not found", + ) -async def _perform_search_and_fetch( - search_params: BaseSearchParameters, - entity_type: EntityType, - eager_loads: list[Any], - cursor: str | None = None, -) -> SearchResultsSchema[Any]: - try: - pagination_params = await process_pagination_cursor(cursor, search_params) - except InvalidCursorError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid pagination cursor") + query_state = search_query.to_state() + search_params = query_state.parameters + pagination_params = PaginationParams(q_vec_override=query_state.query_embedding) + else: + try: + pagination_params = await process_pagination_cursor(cursor, search_params) + except InvalidCursorError: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid pagination cursor") search_response = await execute_search(search_params, db.session, pagination_params) if not search_response.results: return SearchResultsSchema(search_metadata=search_response.metadata) - next_page_cursor = create_next_page_cursor(search_response.results, pagination_params, search_params.limit) + next_page_cursor = create_next_page_cursor( + search_response.results, pagination_params, search_params.limit, search_params + ) has_next_page = next_page_cursor is not None page_info = PageInfoSchema(has_next_page=has_next_page, next_page_cursor=next_page_cursor) - config = ENTITY_CONFIG_REGISTRY[entity_type] - entity_ids = [res.entity_id for res in search_response.results] - pk_column = getattr(config.table, config.pk_name) - ordering_case = case({entity_id: i for i, entity_id in enumerate(entity_ids)}, value=pk_column) - - stmt = select(config.table).options(*eager_loads).filter(pk_column.in_(entity_ids)).order_by(ordering_case) - entities = db.session.scalars(stmt).all() - - search_info_map = {res.entity_id: res for res in search_response.results} - data = [] - for entity in entities: - entity_id = getattr(entity, config.pk_name) - search_info = search_info_map.get(str(entity_id)) - if not search_info: - continue - - search_result_item = _create_search_result_item(entity, entity_type, search_info) - if search_result_item: - data.append(search_result_item) - - return SearchResultsSchema(data=data, page_info=page_info, search_metadata=search_response.metadata) + return SearchResultsSchema( + data=search_response.results, page_info=page_info, search_metadata=search_response.metadata + ) -@router.post( - "/subscriptions", - response_model=SearchResultsSchema[SubscriptionSearchResult], -) +@router.post("/subscriptions", response_model=SearchResultsSchema[SearchResult]) async def search_subscriptions( search_params: SubscriptionSearchParameters, cursor: str | None = None, -) -> SearchResultsSchema[SubscriptionSearchResult]: - try: - pagination_params = await process_pagination_cursor(cursor, search_params) - except InvalidCursorError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid pagination cursor") + query_id: str | None = Query(None, description="Optional saved query ID for embedding retrieval"), +) -> SearchResultsSchema[SearchResult]: + return await _perform_search_and_fetch(search_params, cursor, query_id) - search_response = await execute_search(search_params, db.session, pagination_params) - if not search_response.results: - return SearchResultsSchema(search_metadata=search_response.metadata) - - next_page_cursor = create_next_page_cursor(search_response.results, pagination_params, search_params.limit) - has_next_page = next_page_cursor is not None - page_info = PageInfoSchema(has_next_page=has_next_page, next_page_cursor=next_page_cursor) - - search_info_map = {res.entity_id: res for res in search_response.results} - - with cache_subscription_models(): - subscriptions_data = { - sub_id: SubscriptionModel.from_subscription(sub_id).model_dump(exclude_unset=False) - for sub_id in search_info_map - } - - results_data = [ - SubscriptionSearchResult( - subscription=format_special_types(subscriptions_data[sub_id]), - score=search_info.score, - perfect_match=search_info.perfect_match, - matching_field=search_info.matching_field, - ) - for sub_id, search_info in search_info_map.items() - ] - - return SearchResultsSchema(data=results_data, page_info=page_info, search_metadata=search_response.metadata) - - -@router.post("/workflows", response_model=SearchResultsSchema[WorkflowSearchResult]) +@router.post("/workflows", response_model=SearchResultsSchema[SearchResult]) async def search_workflows( search_params: WorkflowSearchParameters, cursor: str | None = None, -) -> SearchResultsSchema[WorkflowSearchResult]: - return await _perform_search_and_fetch( - search_params=search_params, - entity_type=EntityType.WORKFLOW, - eager_loads=[selectinload(WorkflowTable.products)], - cursor=cursor, - ) + query_id: str | None = Query(None, description="Optional saved query ID for embedding retrieval"), +) -> SearchResultsSchema[SearchResult]: + return await _perform_search_and_fetch(search_params, cursor, query_id) -@router.post("/products", response_model=SearchResultsSchema[ProductSearchResult]) +@router.post("/products", response_model=SearchResultsSchema[SearchResult]) async def search_products( search_params: ProductSearchParameters, cursor: str | None = None, -) -> SearchResultsSchema[ProductSearchResult]: - return await _perform_search_and_fetch( - search_params=search_params, - entity_type=EntityType.PRODUCT, - eager_loads=[ - selectinload(ProductTable.workflows), - selectinload(ProductTable.fixed_inputs), - selectinload(ProductTable.product_blocks), - ], - cursor=cursor, - ) + query_id: str | None = Query(None, description="Optional saved query ID for embedding retrieval"), +) -> SearchResultsSchema[SearchResult]: + return await _perform_search_and_fetch(search_params, cursor, query_id) -@router.post("/processes", response_model=SearchResultsSchema[ProcessSearchResult]) +@router.post("/processes", response_model=SearchResultsSchema[SearchResult]) async def search_processes( search_params: ProcessSearchParameters, cursor: str | None = None, -) -> SearchResultsSchema[ProcessSearchResult]: - return await _perform_search_and_fetch( - search_params=search_params, - entity_type=EntityType.PROCESS, - eager_loads=[ - selectinload(ProcessTable.workflow), - ], - cursor=cursor, - ) + query_id: str | None = Query(None, description="Optional saved query ID for embedding retrieval"), +) -> SearchResultsSchema[SearchResult]: + return await _perform_search_and_fetch(search_params, cursor, query_id) @router.get( diff --git a/orchestrator/db/__init__.py b/orchestrator/db/__init__.py index e304c7003..81befbc17 100644 --- a/orchestrator/db/__init__.py +++ b/orchestrator/db/__init__.py @@ -17,7 +17,6 @@ from orchestrator.db.database import BaseModel as DbBaseModel from orchestrator.db.database import Database, transactional from orchestrator.db.models import ( # noqa: F401 - AgentQueryTable, AgentRunTable, EngineSettingsTable, FixedInputTable, @@ -28,6 +27,7 @@ ProductBlockTable, ProductTable, ResourceTypeTable, + SearchQueryTable, SubscriptionCustomerDescriptionTable, SubscriptionInstanceRelationTable, SubscriptionInstanceTable, @@ -76,7 +76,7 @@ def init_database(settings: AppSettings) -> Database: __all__ = [ "transactional", - "AgentQueryTable", + "SearchQueryTable", "AgentRunTable", "SubscriptionTable", "ProcessSubscriptionTable", @@ -101,7 +101,7 @@ def init_database(settings: AppSettings) -> Database: ] ALL_DB_MODELS: list[type[DbBaseModel]] = [ - AgentQueryTable, + SearchQueryTable, AgentRunTable, FixedInputTable, ProcessStepTable, diff --git a/orchestrator/db/models.py b/orchestrator/db/models.py index 4ac2c2a80..bde0cc6cc 100644 --- a/orchestrator/db/models.py +++ b/orchestrator/db/models.py @@ -60,7 +60,7 @@ from orchestrator.version import GIT_COMMIT_HASH if TYPE_CHECKING: - from orchestrator.search.schemas.parameters import AgentQueryState + from orchestrator.search.schemas.parameters import SearchQueryState logger = structlog.get_logger(__name__) @@ -687,19 +687,23 @@ class AgentRunTable(BaseModel): agent_type = mapped_column(String(50), nullable=False) created_at = mapped_column(UtcTimestamp, server_default=text("current_timestamp()"), nullable=False) - queries = relationship("AgentQueryTable", back_populates="run", cascade="delete", passive_deletes=True) + queries = relationship("SearchQueryTable", back_populates="run", cascade="delete", passive_deletes=True) __table_args__ = (Index("ix_agent_runs_created_at", "created_at"),) -class AgentQueryTable(BaseModel): - """Individual query execution within an agent run.""" +class SearchQueryTable(BaseModel): + """Search query execution - used by both agent runs and regular API searches. - __tablename__ = "agent_queries" + When run_id is NULL: standalone API search query + When run_id is NOT NULL: query belongs to an agent conversation run + """ + + __tablename__ = "search_queries" query_id = mapped_column("query_id", UUIDType, server_default=text("uuid_generate_v4()"), primary_key=True) run_id = mapped_column( - "run_id", UUIDType, ForeignKey("agent_runs.run_id", ondelete="CASCADE"), nullable=False, index=True + "run_id", UUIDType, ForeignKey("agent_runs.run_id", ondelete="CASCADE"), nullable=True, index=True ) query_number = mapped_column(Integer, nullable=False) @@ -714,21 +718,45 @@ class AgentQueryTable(BaseModel): run = relationship("AgentRunTable", back_populates="queries") __table_args__ = ( - Index("ix_agent_queries_run_id", "run_id"), - Index("ix_agent_queries_executed_at", "executed_at"), - UniqueConstraint("run_id", "query_number", name="uq_run_query_number"), + Index("ix_search_queries_run_id", "run_id"), + Index("ix_search_queries_executed_at", "executed_at"), + Index("ix_search_queries_query_id", "query_id"), ) - def get_state(self) -> "AgentQueryState": - """Reconstruct complete query state including parameters and embedding. + @classmethod + def from_state( + cls, + state: "SearchQueryState", + run_id: "UUID | None" = None, + query_number: int = 1, + ) -> "SearchQueryTable": + """Create a SearchQueryTable instance from a SearchQueryState. + + Args: + state: The search query state with parameters and embedding + run_id: Optional agent run ID (NULL for regular API searches) + query_number: Query number within the run (default=1) + + Returns: + SearchQueryTable instance ready to be added to the database. + """ + return cls( + run_id=run_id, + query_number=query_number, + parameters=state.parameters.model_dump(), + query_embedding=state.query_embedding, + ) + + def to_state(self) -> "SearchQueryState": + """Convert database model to SearchQueryState. Returns: - AgentQueryState with typed parameters and embedding vector. + SearchQueryState with typed parameters and embedding vector. """ - from orchestrator.search.schemas.parameters import AgentQueryState + from orchestrator.search.schemas.parameters import SearchQueryState - return AgentQueryState.model_validate(self) + return SearchQueryState.model_validate(self) class EngineSettingsTable(BaseModel): diff --git a/orchestrator/migrations/versions/schema/2025-10-09_459f352f5aa6_add_agent_runs_and_agent_queries_tables.py b/orchestrator/migrations/versions/schema/2025-10-09_459f352f5aa6_add_agent_runs_and_agent_queries_tables.py index 39c7214f4..000cee1c3 100644 --- a/orchestrator/migrations/versions/schema/2025-10-09_459f352f5aa6_add_agent_runs_and_agent_queries_tables.py +++ b/orchestrator/migrations/versions/schema/2025-10-09_459f352f5aa6_add_agent_runs_and_agent_queries_tables.py @@ -1,4 +1,4 @@ -"""Add agent_runs and agent_queries tables. +"""Add agent_runs and search_queries tables. Revision ID: 459f352f5aa6 Revises: 850dccac3b02 @@ -32,9 +32,9 @@ def upgrade() -> None: op.create_index("ix_agent_runs_created_at", "agent_runs", ["created_at"]) op.create_table( - "agent_queries", + "search_queries", sa.Column("query_id", UUIDType(), server_default=sa.text("uuid_generate_v4()"), nullable=False), - sa.Column("run_id", UUIDType(), nullable=False), + sa.Column("run_id", UUIDType(), nullable=True), sa.Column("query_number", sa.Integer(), nullable=False), sa.Column("parameters", postgresql.JSONB(astext_type=sa.Text()), nullable=False), sa.Column("query_embedding", Vector(1536), nullable=True), @@ -43,16 +43,17 @@ def upgrade() -> None: ), sa.ForeignKeyConstraint(["run_id"], ["agent_runs.run_id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("query_id"), - sa.UniqueConstraint("run_id", "query_number", name="uq_run_query_number"), ) - op.create_index("ix_agent_queries_run_id", "agent_queries", ["run_id"]) - op.create_index("ix_agent_queries_executed_at", "agent_queries", ["executed_at"]) + op.create_index("ix_search_queries_run_id", "search_queries", ["run_id"]) + op.create_index("ix_search_queries_executed_at", "search_queries", ["executed_at"]) + op.create_index("ix_search_queries_query_id", "search_queries", ["query_id"]) def downgrade() -> None: - op.drop_index("ix_agent_queries_executed_at", table_name="agent_queries") - op.drop_index("ix_agent_queries_run_id", table_name="agent_queries") - op.drop_table("agent_queries") + op.drop_index("ix_search_queries_query_id", table_name="search_queries") + op.drop_index("ix_search_queries_executed_at", table_name="search_queries") + op.drop_index("ix_search_queries_run_id", table_name="search_queries") + op.drop_table("search_queries") op.drop_index("ix_agent_runs_created_at", table_name="agent_runs") op.drop_table("agent_runs") diff --git a/orchestrator/schemas/search.py b/orchestrator/schemas/search.py index d85639132..77f4263c1 100644 --- a/orchestrator/schemas/search.py +++ b/orchestrator/schemas/search.py @@ -11,14 +11,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime -from typing import Any, Generic, TypeVar -from uuid import UUID +from typing import Generic, TypeVar from pydantic import BaseModel, ConfigDict, Field from orchestrator.search.core.types import SearchMetadata -from orchestrator.search.schemas.results import ComponentInfo, LeafInfo, MatchingField +from orchestrator.search.schemas.results import ComponentInfo, LeafInfo T = TypeVar("T") @@ -36,93 +34,12 @@ class ProductSchema(BaseModel): product_type: str -class SubscriptionSearchResult(BaseModel): - score: float - perfect_match: int - matching_field: MatchingField | None = None - subscription: dict[str, Any] - - class SearchResultsSchema(BaseModel, Generic[T]): data: list[T] = Field(default_factory=list) page_info: PageInfoSchema = Field(default_factory=PageInfoSchema) search_metadata: SearchMetadata | None = None -class WorkflowProductSchema(BaseModel): - """Product associated with a workflow.""" - - model_config = ConfigDict(from_attributes=True) - - product_type: str - product_id: UUID - name: str - - -class WorkflowSearchSchema(BaseModel): - """Schema for workflow search results.""" - - model_config = ConfigDict(from_attributes=True) - - name: str - products: list[WorkflowProductSchema] - description: str | None = None - created_at: datetime | None = None - - -class ProductSearchSchema(BaseModel): - """Schema for product search results.""" - - model_config = ConfigDict(from_attributes=True) - - product_id: UUID - name: str - product_type: str - tag: str | None = None - description: str | None = None - status: str | None = None - created_at: datetime | None = None - - -class ProcessSearchSchema(BaseModel): - """Schema for process search results.""" - - model_config = ConfigDict(from_attributes=True) - - process_id: UUID - workflow_name: str - workflow_id: UUID - last_status: str - is_task: bool - created_by: str | None = None - started_at: datetime - last_modified_at: datetime - last_step: str | None = None - failed_reason: str | None = None - subscription_ids: list[UUID] | None = None - - -class WorkflowSearchResult(BaseModel): - score: float - perfect_match: int - matching_field: MatchingField | None = None - workflow: WorkflowSearchSchema - - -class ProductSearchResult(BaseModel): - score: float - perfect_match: int - matching_field: MatchingField | None = None - product: ProductSearchSchema - - -class ProcessSearchResult(BaseModel): - score: float - perfect_match: int - matching_field: MatchingField | None = None - process: ProcessSearchSchema - - class PathsResponse(BaseModel): leaves: list[LeafInfo] components: list[ComponentInfo] diff --git a/orchestrator/search/agent/prompts.py b/orchestrator/search/agent/prompts.py index 8e2a2e869..4c9ff8b6e 100644 --- a/orchestrator/search/agent/prompts.py +++ b/orchestrator/search/agent/prompts.py @@ -79,7 +79,7 @@ async def get_dynamic_instructions(ctx: RunContext[StateDeps[SearchState]]) -> s """Dynamically provides 'next step' coaching based on the current state.""" state = ctx.deps.state param_state_str = json.dumps(state.parameters, indent=2, default=str) if state.parameters else "Not set." - results_count = len(state.results) if state.results else 0 + results_count = state.results_data.total_count if state.results_data else 0 next_step_guidance = "" if not state.parameters or not state.parameters.get("entity_type"): diff --git a/orchestrator/search/agent/state.py b/orchestrator/search/agent/state.py index 4e30a2fe2..027fd84fa 100644 --- a/orchestrator/search/agent/state.py +++ b/orchestrator/search/agent/state.py @@ -14,12 +14,31 @@ from typing import Any from uuid import UUID -from pydantic import BaseModel, Field +from pydantic import BaseModel + + +class ExportData(BaseModel): + """Export metadata for download.""" + + action: str = "export" + query_id: str + download_url: str + message: str + + +class SearchResultsData(BaseModel): + """Search results metadata for frontend display.""" + + action: str = "view_results" + query_id: str + results_url: str + total_count: int + message: str class SearchState(BaseModel): run_id: UUID | None = None + query_id: UUID | None = None parameters: dict[str, Any] | None = None - results: list[dict[str, Any]] = Field(default_factory=list) - export_data: dict[str, Any] | None = None - base_url: str | None = None + results_data: SearchResultsData | None = None + export_data: ExportData | None = None diff --git a/orchestrator/search/agent/tools.py b/orchestrator/search/agent/tools.py index aa9a3b5be..44c580224 100644 --- a/orchestrator/search/agent/tools.py +++ b/orchestrator/search/agent/tools.py @@ -11,8 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Awaitable, Callable -from typing import Any, TypeVar +from typing import Any import structlog from ag_ui.core import EventType, StateSnapshotEvent @@ -25,34 +24,18 @@ from orchestrator.api.api_v1.endpoints.search import ( get_definitions, list_paths, - search_processes, - search_products, - search_subscriptions, - search_workflows, ) -from orchestrator.schemas.search import SearchResultsSchema +from orchestrator.db import AgentRunTable, SearchQueryTable, db +from orchestrator.search.agent.state import ExportData, SearchResultsData, SearchState from orchestrator.search.core.types import ActionType, EntityType, FilterOp from orchestrator.search.filters import FilterTree from orchestrator.search.retrieval.exceptions import FilterValidationError, PathNotFoundError from orchestrator.search.retrieval.validation import validate_filter_tree -from orchestrator.search.schemas.parameters import BaseSearchParameters - -from .state import SearchState +from orchestrator.search.schemas.parameters import BaseSearchParameters, SearchQueryState +from orchestrator.settings import app_settings logger = structlog.get_logger(__name__) - -P = TypeVar("P", bound=BaseSearchParameters) - -SearchFn = Callable[[P], Awaitable[SearchResultsSchema[Any]]] - -SEARCH_FN_MAP: dict[EntityType, SearchFn] = { - EntityType.SUBSCRIPTION: search_subscriptions, - EntityType.WORKFLOW: search_workflows, - EntityType.PRODUCT: search_products, - EntityType.PROCESS: search_processes, -} - search_toolset: FunctionToolset[StateDeps[SearchState]] = FunctionToolset(max_retries=1) @@ -89,7 +72,7 @@ async def set_search_parameters( ) ctx.deps.state.parameters = {"action": action, "entity_type": entity_type, "filters": None, "query": final_query} - ctx.deps.state.results = [] + ctx.deps.state.results_data = None logger.debug("Search parameters set", parameters=ctx.deps.state.parameters) return StateSnapshotEvent( @@ -145,7 +128,7 @@ async def execute_search( ctx: RunContext[StateDeps[SearchState]], limit: int = 10, ) -> StateSnapshotEvent: - """Execute the search with the current parameters.""" + """Execute the search with the current parameters and save to database.""" if not ctx.deps.state.parameters: raise ValueError("No search parameters set") @@ -164,15 +147,46 @@ async def execute_search( params.limit = limit - fn = SEARCH_FN_MAP[params.entity_type] - search_results = await fn(params) + if not ctx.deps.state.run_id: + agent_run = AgentRunTable(agent_type="search") + db.session.add(agent_run) + db.session.commit() + ctx.deps.state.run_id = agent_run.run_id + logger.debug("Created new agent run", run_id=str(agent_run.run_id)) + + # Get query with embedding and save to DB + from orchestrator.search.retrieval.engine import execute_search + + search_response = await execute_search(params, db.session) + query_embedding = search_response.query_embedding + query_state = SearchQueryState(parameters=params, query_embedding=query_embedding) + query_number = db.session.query(SearchQueryTable).filter_by(run_id=ctx.deps.state.run_id).count() + 1 + search_query = SearchQueryTable.from_state( + state=query_state, + run_id=ctx.deps.state.run_id, + query_number=query_number, + ) + db.session.add(search_query) + db.session.commit() + ctx.deps.state.query_id = search_query.query_id + logger.debug("Saved search query", query_id=str(search_query.query_id), query_number=query_number) logger.debug( "Search completed", - total_results=len(search_results.data) if search_results.data else 0, + total_results=len(search_response.results), ) - ctx.deps.state.results = search_results.data + # Store results metadata for frontend display + # Frontend may call search endpoints with query_id parameter + entity_type = params.entity_type.value + results_url = f"{app_settings.BASE_URL}/api/search/{entity_type}s?query_id={ctx.deps.state.query_id}" + + ctx.deps.state.results_data = SearchResultsData( + query_id=str(ctx.deps.state.query_id), + results_url=results_url, + total_count=len(search_response.results), + message=f"Found {len(search_response.results)} results.", + ) return StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=ctx.deps.state.model_dump()) @@ -258,9 +272,12 @@ async def prepare_export( ctx: RunContext[StateDeps[SearchState]], max_results: int = 1000, ) -> StateSnapshotEvent: - """Saves the current search query to the database and returns run_id/query_id for export.""" + """Prepares export URL using the last executed search query.""" + if not ctx.deps.state.query_id or not ctx.deps.state.run_id: + raise ValueError("No search has been executed yet. Run a search first before exporting.") + if not ctx.deps.state.parameters: - raise ValueError("No search parameters set. Run a search first to see what will be exported.") + raise ValueError("No search parameters found. Run a search first before exporting.") # Validate that export is only available for SELECT actions action = ctx.deps.state.parameters.get("action", ActionType.SELECT) @@ -270,54 +287,33 @@ async def prepare_export( "Please run a SELECT search first." ) - from orchestrator.db import AgentQueryTable, AgentRunTable, db - - # Ensure we have a run_id - if not ctx.deps.state.run_id: - # Create a new agent run - agent_run = AgentRunTable(agent_type="search") - db.session.add(agent_run) - db.session.commit() - db.session.refresh(agent_run) - ctx.deps.state.run_id = agent_run.run_id - logger.debug("Created new agent run", run_id=str(agent_run.run_id)) + # Retrieve the saved query to update export_limit if needed + agent_query = db.session.query(SearchQueryTable).filter_by(query_id=ctx.deps.state.query_id).first() + if not agent_query: + raise ValueError("Query not found in database") - query_number = db.session.query(AgentQueryTable).filter_by(run_id=ctx.deps.state.run_id).count() + 1 + export_limit = min(max_results, BaseSearchParameters.DEFAULT_EXPORT_LIMIT) - export_limit = min(max_results, BaseSearchParameters.export_limit) - params_dict = ctx.deps.state.parameters.copy() + # Update the parameters with export_limit + params_dict = agent_query.parameters.copy() params_dict["export_limit"] = export_limit - - agent_query = AgentQueryTable( - run_id=ctx.deps.state.run_id, - query_number=query_number, - parameters=params_dict, - query_embedding=None, # TODO: We need to save the embeddding here. - ) - db.session.add(agent_query) + agent_query.parameters = params_dict db.session.commit() - db.session.refresh(agent_query) logger.debug( - "Saved query for export", - run_id=str(ctx.deps.state.run_id), - query_id=str(agent_query.query_id), - query_number=query_number, + "Prepared query for export", + query_id=str(ctx.deps.state.query_id), + export_limit=export_limit, ) - # Build export URL using run_id and query_id - base_url = ctx.deps.state.base_url or "http://localhost:8080" - download_url = f"{base_url}/api/agent/runs/{ctx.deps.state.run_id}/queries/{agent_query.query_id}/export" + download_url = f"{app_settings.BASE_URL}/api/agent/queries/{ctx.deps.state.query_id}/export" - # Update state with export data so frontend can render the download button - ctx.deps.state.export_data = { - "action": "export", - "run_id": str(ctx.deps.state.run_id), - "query_id": str(agent_query.query_id), - "download_url": download_url, - "message": f"Export ready for download (up to {export_limit} results).", - } + ctx.deps.state.export_data = ExportData( + query_id=str(ctx.deps.state.query_id), + download_url=download_url, + message=f"Export ready for download (up to {export_limit} results).", + ) - logger.debug("Export data set in state", export_data=ctx.deps.state.export_data) + logger.debug("Export data set in state", export_data=ctx.deps.state.export_data.model_dump()) return StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=ctx.deps.state.model_dump()) diff --git a/orchestrator/search/retrieval/builder.py b/orchestrator/search/retrieval/builder.py index 087687561..446f3409f 100644 --- a/orchestrator/search/retrieval/builder.py +++ b/orchestrator/search/retrieval/builder.py @@ -16,6 +16,7 @@ from sqlalchemy import Select, String, cast, func, select from sqlalchemy.engine import Row +from sqlalchemy_utils import Ltree from orchestrator.db.models import AiSearchIndex from orchestrator.search.core.types import EntityType, FieldType, FilterOp, UIType @@ -43,7 +44,36 @@ def build_candidate_query(params: BaseSearchParameters) -> Select: Select: The SQLAlchemy `Select` object representing the query. """ - stmt = select(AiSearchIndex.entity_id).where(AiSearchIndex.entity_type == params.entity_type.value).distinct() + # Define title paths based on entity type + title_path_map = { + EntityType.SUBSCRIPTION: "subscription.description", + EntityType.PRODUCT: "product.description", + EntityType.WORKFLOW: "workflow.description", + EntityType.PROCESS: "process.workflowName", + } + + title_path = title_path_map.get(params.entity_type) + + # Subquery to get title value for each entity + title_subquery = ( + select( + AiSearchIndex.entity_id.label("title_entity_id"), + AiSearchIndex.value.label("entity_title"), + ) + .where( + AiSearchIndex.entity_type == params.entity_type.value, + AiSearchIndex.path == Ltree(title_path), + ) + .subquery() + ) + + stmt = ( + select(AiSearchIndex.entity_id, title_subquery.c.entity_title) + .select_from(AiSearchIndex) + .outerjoin(title_subquery, AiSearchIndex.entity_id == title_subquery.c.title_entity_id) + .where(AiSearchIndex.entity_type == params.entity_type.value) + .distinct() + ) if params.filters is not None: entity_id_col = AiSearchIndex.entity_id diff --git a/orchestrator/search/retrieval/engine.py b/orchestrator/search/retrieval/engine.py index 374c12964..9d8b7f48c 100644 --- a/orchestrator/search/retrieval/engine.py +++ b/orchestrator/search/retrieval/engine.py @@ -74,9 +74,15 @@ def _format_response( # Structured search (filter-only) matching_field = _extract_matching_field_from_filters(search_params.filters) + entity_title = row.get("entity_title", "") + if not isinstance(entity_title, str): + entity_title = str(entity_title) if entity_title is not None else "" + results.append( SearchResult( entity_id=str(row.entity_id), + entity_type=search_params.entity_type, + title=entity_title, score=row.score, perfect_match=row.get("perfect_match", 0), matching_field=matching_field, @@ -125,7 +131,7 @@ async def _execute_search_internal( pagination_params: Optional pagination parameters. Returns: - SearchResponse with results, or empty response if no search criteria provided. + SearchResponse with results and embedding (for internal use). """ if not search_params.vector_query and not search_params.filters and not search_params.fuzzy_term: logger.warning("No search criteria provided (vector_query, fuzzy_term, or filters).") @@ -134,6 +140,15 @@ async def _execute_search_internal( candidate_query = build_candidate_query(search_params) pagination_params = pagination_params or PaginationParams() + + if search_params.vector_query and not pagination_params.q_vec_override: + from orchestrator.search.core.embedding import QueryEmbedder + + q_vec = await QueryEmbedder.generate_for_text_async(search_params.vector_query) + if q_vec: + pagination_params.q_vec_override = q_vec + logger.debug("Generated embedding for vector query") + retriever = await Retriever.from_params(search_params, pagination_params) logger.debug("Using retriever", retriever_type=retriever.__class__.__name__) @@ -142,7 +157,10 @@ async def _execute_search_internal( logger.debug(final_stmt) result = db_session.execute(final_stmt).mappings().all() - return _format_response(result, search_params, retriever.metadata) + response = _format_response(result, search_params, retriever.metadata) + # Store embedding in response for agent to save to DB + response.query_embedding = pagination_params.q_vec_override + return response async def execute_search( diff --git a/orchestrator/search/retrieval/pagination.py b/orchestrator/search/retrieval/pagination.py index a72b31aec..be0b46bec 100644 --- a/orchestrator/search/retrieval/pagination.py +++ b/orchestrator/search/retrieval/pagination.py @@ -11,14 +11,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import array import base64 from dataclasses import dataclass from pydantic import BaseModel +from orchestrator.db import SearchQueryTable, db from orchestrator.search.core.exceptions import InvalidCursorError -from orchestrator.search.schemas.parameters import BaseSearchParameters +from orchestrator.search.schemas.parameters import SearchParameters, SearchQueryState from orchestrator.search.schemas.results import SearchResult @@ -29,24 +29,13 @@ class PaginationParams: page_after_score: float | None = None page_after_id: str | None = None q_vec_override: list[float] | None = None - - -def floats_to_b64(v: list[float]) -> str: - a = array.array("f", v) - return base64.urlsafe_b64encode(a.tobytes()).decode("ascii") - - -def b64_to_floats(s: str) -> list[float]: - raw = base64.urlsafe_b64decode(s.encode("ascii")) - a = array.array("f") - a.frombytes(raw) - return list(a) + query_id: str | None = None class PageCursor(BaseModel): score: float id: str - q_vec_b64: str + query_id: str | None = None def encode(self) -> str: """Encode the cursor data into a URL-safe Base64 string.""" @@ -63,34 +52,66 @@ def decode(cls, cursor: str) -> "PageCursor": raise InvalidCursorError("Invalid pagination cursor") from e -async def process_pagination_cursor(cursor: str | None, search_params: BaseSearchParameters) -> PaginationParams: +async def process_pagination_cursor(cursor: str | None, search_params: SearchParameters) -> PaginationParams: """Process pagination cursor and return pagination parameters.""" if cursor: c = PageCursor.decode(cursor) + + # If cursor has query_id, retrieve saved embedding + if c.query_id: + query = db.session.query(SearchQueryTable).filter_by(query_id=c.query_id).first() + if not query: + raise InvalidCursorError("Query not found") + + query_state = query.to_state() + + return PaginationParams( + page_after_score=c.score, + page_after_id=c.id, + q_vec_override=query_state.query_embedding, + query_id=c.query_id, + ) + + # No query_id - filter-only or fuzzy-only search return PaginationParams( page_after_score=c.score, page_after_id=c.id, - q_vec_override=b64_to_floats(c.q_vec_b64), ) - if search_params.vector_query: - from orchestrator.search.core.embedding import QueryEmbedder - q_vec_override = await QueryEmbedder.generate_for_text_async(search_params.vector_query) - return PaginationParams(q_vec_override=q_vec_override) + # First page, no embedding needed + # Engine will generate it return PaginationParams() def create_next_page_cursor( - search_results: list[SearchResult], pagination_params: PaginationParams, limit: int + search_results: list[SearchResult], + pagination_params: PaginationParams, + limit: int, + search_params: SearchParameters | None = None, ) -> str | None: - """Create next page cursor if there are more results.""" + """Create next page cursor if there are more results. + + On first page with hybrid search (embedding present), saves the query to database + and includes query_id in cursor for subsequent pages. + """ has_next_page = len(search_results) == limit and limit > 0 - if has_next_page: - last_item = search_results[-1] - cursor_data = PageCursor( - score=float(last_item.score), - id=last_item.entity_id, - q_vec_b64=floats_to_b64(pagination_params.q_vec_override or []), - ) - return cursor_data.encode() - return None + if not has_next_page: + return None + + # If this is the first page and we have an embedding, save to database + if not pagination_params.query_id and pagination_params.q_vec_override and search_params: + # Create query state and save to database + query_state = SearchQueryState(parameters=search_params, query_embedding=pagination_params.q_vec_override) + search_query = SearchQueryTable.from_state(state=query_state) + + db.session.add(search_query) + db.session.commit() + pagination_params.query_id = str(search_query.query_id) + + last_item = search_results[-1] + cursor_data = PageCursor( + score=float(last_item.score), + id=last_item.entity_id, + query_id=pagination_params.query_id, + ) + return cursor_data.encode() diff --git a/orchestrator/search/retrieval/retrievers/base.py b/orchestrator/search/retrieval/retrievers/base.py index 10242fa7e..8c79c4190 100644 --- a/orchestrator/search/retrieval/retrievers/base.py +++ b/orchestrator/search/retrieval/retrievers/base.py @@ -15,9 +15,11 @@ from decimal import Decimal import structlog -from sqlalchemy import BindParameter, Numeric, Select, literal +from sqlalchemy import BindParameter, Numeric, Select, literal, select +from sqlalchemy_utils import Ltree -from orchestrator.search.core.types import FieldType, SearchMetadata +from orchestrator.db.models import AiSearchIndex +from orchestrator.search.core.types import EntityType, FieldType, SearchMetadata from orchestrator.search.schemas.parameters import BaseSearchParameters from ..pagination import PaginationParams @@ -70,33 +72,28 @@ async def from_params( fallback_fuzzy_term = params.query if q_vec is not None and fallback_fuzzy_term is not None: - return RrfHybridRetriever(q_vec, fallback_fuzzy_term, pagination_params) + return RrfHybridRetriever(q_vec, fallback_fuzzy_term, pagination_params, params.entity_type) if q_vec is not None: - return SemanticRetriever(q_vec, pagination_params) + return SemanticRetriever(q_vec, pagination_params, params.entity_type) if fallback_fuzzy_term is not None: - return FuzzyRetriever(fallback_fuzzy_term, pagination_params) + return FuzzyRetriever(fallback_fuzzy_term, pagination_params, params.entity_type) - return StructuredRetriever(pagination_params) + return StructuredRetriever(pagination_params, params.entity_type) @classmethod async def _get_query_vector( cls, vector_query: str | None, q_vec_override: list[float] | None ) -> list[float] | None: - """Get query vector either from override or by generating from text.""" + """Get query vector from override (provided by engine.py).""" if q_vec_override: return q_vec_override - if not vector_query: - return None + if vector_query: + logger.warning( + "vector_query present but no q_vec_override provided - embedding should be generated in engine.py" + ) - from orchestrator.search.core.embedding import QueryEmbedder - - q_vec = await QueryEmbedder.generate_for_text_async(vector_query) - if not q_vec: - logger.warning("Embedding generation failed; using non-semantic retriever") - return None - - return q_vec + return None @abstractmethod def apply(self, candidate_query: Select) -> Select: @@ -116,6 +113,53 @@ def _quantize_score_for_pagination(self, score_value: float) -> BindParameter[De pas_dec = Decimal(str(score_value)).quantize(quantizer) return literal(pas_dec, type_=self.SCORE_NUMERIC_TYPE) + @staticmethod + def add_title_to_query(stmt: Select, entity_type: EntityType) -> Select: + """Add title column to a query by joining with the index table.""" + # Define title paths based on entity type + title_path_map = { + EntityType.SUBSCRIPTION: "subscription.description", + EntityType.PRODUCT: "product.description", + EntityType.WORKFLOW: "workflow.description", + EntityType.PROCESS: "process.workflowName", + } + + title_path = title_path_map.get(entity_type) + if not title_path: + # If no title path defined, return original statement + return stmt + + # Create subquery from the original statement + ranked = stmt.subquery("ranked") + + # Subquery to get title value for each entity + # Use a distinct label to avoid any column name conflicts + title_subquery = ( + select( + AiSearchIndex.entity_id.label("title_entity_id"), + AiSearchIndex.value.label("entity_title"), + ) + .where( + AiSearchIndex.entity_type == entity_type.value, + AiSearchIndex.path == Ltree(title_path), + ) + .subquery("titles") + ) + + # Build explicit column list to preserve order + columns = [ranked.c.entity_id, title_subquery.c.entity_title] + + # Add remaining columns in their original order + for col in ranked.c: + if col.name != "entity_id": + columns.append(col) + + return ( + select(*columns) + .select_from(ranked) + .outerjoin(title_subquery, ranked.c.entity_id == title_subquery.c.title_entity_id) + ) + @property @abstractmethod def metadata(self) -> SearchMetadata: diff --git a/orchestrator/search/retrieval/retrievers/fuzzy.py b/orchestrator/search/retrieval/retrievers/fuzzy.py index 7003b5b0f..df6404a43 100644 --- a/orchestrator/search/retrieval/retrievers/fuzzy.py +++ b/orchestrator/search/retrieval/retrievers/fuzzy.py @@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import ColumnElement from orchestrator.db.models import AiSearchIndex -from orchestrator.search.core.types import SearchMetadata +from orchestrator.search.core.types import EntityType, SearchMetadata from ..pagination import PaginationParams from .base import Retriever @@ -24,10 +24,11 @@ class FuzzyRetriever(Retriever): """Ranks results based on the max of fuzzy text similarity scores.""" - def __init__(self, fuzzy_term: str, pagination_params: PaginationParams) -> None: + def __init__(self, fuzzy_term: str, pagination_params: PaginationParams, entity_type: EntityType) -> None: self.fuzzy_term = fuzzy_term self.page_after_score = pagination_params.page_after_score self.page_after_id = pagination_params.page_after_id + self.entity_type = entity_type def apply(self, candidate_query: Select) -> Select: cand = candidate_query.subquery() @@ -42,6 +43,7 @@ def apply(self, candidate_query: Select) -> Select: combined_query = ( select( AiSearchIndex.entity_id, + cand.c.entity_title, score, func.first_value(AiSearchIndex.value) .over(partition_by=AiSearchIndex.entity_id, order_by=[similarity_expr.desc(), AiSearchIndex.path.asc()]) @@ -64,6 +66,7 @@ def apply(self, candidate_query: Select) -> Select: stmt = select( final_query.c.entity_id, + final_query.c.entity_title, final_query.c.score, final_query.c.highlight_text, final_query.c.highlight_path, diff --git a/orchestrator/search/retrieval/retrievers/hybrid.py b/orchestrator/search/retrieval/retrievers/hybrid.py index be91312f1..c44cf0656 100644 --- a/orchestrator/search/retrieval/retrievers/hybrid.py +++ b/orchestrator/search/retrieval/retrievers/hybrid.py @@ -18,7 +18,7 @@ from sqlalchemy.types import TypeEngine from orchestrator.db.models import AiSearchIndex -from orchestrator.search.core.types import SearchMetadata +from orchestrator.search.core.types import EntityType, SearchMetadata from ..pagination import PaginationParams from .base import Retriever @@ -128,6 +128,7 @@ def __init__( q_vec: list[float], fuzzy_term: str, pagination_params: PaginationParams, + entity_type: "EntityType", k: int = 60, field_candidates_limit: int = 100, ) -> None: @@ -135,6 +136,7 @@ def __init__( self.fuzzy_term = fuzzy_term self.page_after_score = pagination_params.page_after_score self.page_after_id = pagination_params.page_after_id + self.entity_type = entity_type self.k = k self.field_candidates_limit = field_candidates_limit @@ -154,6 +156,7 @@ def apply(self, candidate_query: Select) -> Select: field_candidates = ( select( AiSearchIndex.entity_id, + cand.c.entity_title, AiSearchIndex.path, AiSearchIndex.value, sem_val, @@ -178,6 +181,7 @@ def apply(self, candidate_query: Select) -> Select: entity_scores = ( select( field_candidates.c.entity_id, + func.max(field_candidates.c.entity_title).label("entity_title"), func.avg(field_candidates.c.semantic_distance).label("avg_semantic_distance"), func.avg(field_candidates.c.fuzzy_score).label("avg_fuzzy_score"), ).group_by(field_candidates.c.entity_id) @@ -204,6 +208,7 @@ def apply(self, candidate_query: Select) -> Select: ranked = ( select( entity_scores.c.entity_id, + entity_scores.c.entity_title, entity_scores.c.avg_semantic_distance, entity_scores.c.avg_fuzzy_score, entity_highlights.c.highlight_text, @@ -242,6 +247,7 @@ def apply(self, candidate_query: Select) -> Select: stmt = select( ranked.c.entity_id, + ranked.c.entity_title, score, ranked.c.highlight_text, ranked.c.highlight_path, diff --git a/orchestrator/search/retrieval/retrievers/semantic.py b/orchestrator/search/retrieval/retrievers/semantic.py index 3fdfa2802..8940837b9 100644 --- a/orchestrator/search/retrieval/retrievers/semantic.py +++ b/orchestrator/search/retrieval/retrievers/semantic.py @@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import ColumnElement from orchestrator.db.models import AiSearchIndex -from orchestrator.search.core.types import SearchMetadata +from orchestrator.search.core.types import EntityType, SearchMetadata from ..pagination import PaginationParams from .base import Retriever @@ -24,10 +24,11 @@ class SemanticRetriever(Retriever): """Ranks results based on the minimum semantic vector distance.""" - def __init__(self, vector_query: list[float], pagination_params: PaginationParams) -> None: + def __init__(self, vector_query: list[float], pagination_params: PaginationParams, entity_type: EntityType) -> None: self.vector_query = vector_query self.page_after_score = pagination_params.page_after_score self.page_after_id = pagination_params.page_after_id + self.entity_type = entity_type def apply(self, candidate_query: Select) -> Select: cand = candidate_query.subquery() @@ -49,6 +50,7 @@ def apply(self, candidate_query: Select) -> Select: combined_query = ( select( AiSearchIndex.entity_id, + cand.c.entity_title, score, func.first_value(AiSearchIndex.value) .over(partition_by=AiSearchIndex.entity_id, order_by=[dist.asc(), AiSearchIndex.path.asc()]) @@ -66,6 +68,7 @@ def apply(self, candidate_query: Select) -> Select: stmt = select( final_query.c.entity_id, + final_query.c.entity_title, final_query.c.score, final_query.c.highlight_text, final_query.c.highlight_path, diff --git a/orchestrator/search/retrieval/retrievers/structured.py b/orchestrator/search/retrieval/retrievers/structured.py index b50a093f0..4c33892c4 100644 --- a/orchestrator/search/retrieval/retrievers/structured.py +++ b/orchestrator/search/retrieval/retrievers/structured.py @@ -13,7 +13,7 @@ from sqlalchemy import Select, literal, select -from orchestrator.search.core.types import SearchMetadata +from orchestrator.search.core.types import EntityType, SearchMetadata from ..pagination import PaginationParams from .base import Retriever @@ -22,12 +22,13 @@ class StructuredRetriever(Retriever): """Applies a dummy score for purely structured searches with no text query.""" - def __init__(self, pagination_params: PaginationParams) -> None: + def __init__(self, pagination_params: PaginationParams, entity_type: EntityType) -> None: self.page_after_id = pagination_params.page_after_id + self.entity_type = entity_type def apply(self, candidate_query: Select) -> Select: cand = candidate_query.subquery() - stmt = select(cand.c.entity_id, literal(1.0).label("score")).select_from(cand) + stmt = select(cand.c.entity_id, cand.c.entity_title, literal(1.0).label("score")).select_from(cand) if self.page_after_id: stmt = stmt.where(cand.c.entity_id > self.page_after_id) diff --git a/orchestrator/search/schemas/parameters.py b/orchestrator/search/schemas/parameters.py index 46bf0f71b..ccd85b794 100644 --- a/orchestrator/search/schemas/parameters.py +++ b/orchestrator/search/schemas/parameters.py @@ -12,10 +12,9 @@ # limitations under the License. import uuid -from typing import Any, Literal +from typing import Any, ClassVar, Literal -import numpy as np -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from orchestrator.search.core.types import ActionType, EntityType from orchestrator.search.filters import FilterTree @@ -24,6 +23,9 @@ class BaseSearchParameters(BaseModel): """Base model with common search parameters.""" + DEFAULT_EXPORT_LIMIT: ClassVar[int] = 1000 + MAX_EXPORT_LIMIT: ClassVar[int] = 10000 + action: ActionType = Field(default=ActionType.SELECT, description="The action to perform.") entity_type: EntityType @@ -34,13 +36,18 @@ class BaseSearchParameters(BaseModel): ) limit: int = Field(default=10, ge=1, le=30, description="Maximum number of search results to return.") - export_limit: int = Field(default=1000, ge=1, le=10000, description="Maximum number of results to export.") + export_limit: int = Field( + default=DEFAULT_EXPORT_LIMIT, ge=1, le=MAX_EXPORT_LIMIT, description="Maximum number of results to export." + ) model_config = ConfigDict(extra="forbid") @classmethod - def create(cls, entity_type: EntityType, **kwargs: Any) -> "BaseSearchParameters": - """Create the correct search parameter subclass instance.""" - return cls.model_validate({"entity_type": entity_type, **kwargs}) + def create(cls, **kwargs: Any) -> "SearchParameters": + """Create the correct search parameter subclass instance based on entity_type.""" + from orchestrator.search.schemas.parameters import SearchParameters + + adapter: TypeAdapter = TypeAdapter(SearchParameters) + return adapter.validate_python(kwargs) @property def vector_query(self) -> str | None: @@ -126,14 +133,15 @@ class ProcessSearchParameters(BaseSearchParameters): ) -class AgentQueryState(BaseModel): - """Complete state of an agent query including parameters and embedding. +class SearchQueryState(BaseModel): + """Complete state of a search query including parameters and embedding. This model combines the search parameters with the query embedding, providing a complete snapshot of what was searched and how. + Used for both agent and regular API searches. """ parameters: SearchParameters = Field(discriminator="entity_type") - query_embedding: np.ndarray | None = Field(default=None, description="The embedding vector for semantic search") + query_embedding: list[float] | None = Field(default=None, description="The embedding vector for semantic search") - model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True) + model_config = ConfigDict(from_attributes=True) diff --git a/orchestrator/search/schemas/results.py b/orchestrator/search/schemas/results.py index b5203d78c..7aee2e191 100644 --- a/orchestrator/search/schemas/results.py +++ b/orchestrator/search/schemas/results.py @@ -15,7 +15,7 @@ from pydantic import BaseModel, ConfigDict -from orchestrator.search.core.types import FilterOp, SearchMetadata, UIType +from orchestrator.search.core.types import EntityType, FilterOp, SearchMetadata, UIType class MatchingField(BaseModel): @@ -30,6 +30,8 @@ class SearchResult(BaseModel): """Represents a single search result item.""" entity_id: str + entity_type: EntityType + title: str score: float perfect_match: int = 0 matching_field: MatchingField | None = None @@ -40,6 +42,7 @@ class SearchResponse(BaseModel): results: list[SearchResult] metadata: SearchMetadata + query_embedding: list[float] | None = None class ValueSchema(BaseModel): diff --git a/orchestrator/settings.py b/orchestrator/settings.py index 01cae55f5..fdedc783b 100644 --- a/orchestrator/settings.py +++ b/orchestrator/settings.py @@ -57,6 +57,7 @@ class AppSettings(BaseSettings): EXECUTOR: str = ExecutorType.THREADPOOL WORKFLOWS_SWAGGER_HOST: str = "localhost" WORKFLOWS_GUI_URI: str = "http://localhost:3000" + BASE_URL: str = "http://localhost:8080" # Base URL for the API (used for generating export URLs) DATABASE_URI: PostgresDsn = "postgresql://nwa:nwa@localhost/orchestrator-core" # type: ignore MAX_WORKERS: int = 5 MAIL_SERVER: str = "localhost" From 7d92108b6e93156eb134541f3c3ac55f57714af4 Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Thu, 16 Oct 2025 10:04:59 +0200 Subject: [PATCH 06/16] title join in builder only --- .../search/retrieval/retrievers/base.py | 61 ++----------------- .../search/retrieval/retrievers/fuzzy.py | 5 +- .../search/retrieval/retrievers/hybrid.py | 4 +- .../search/retrieval/retrievers/semantic.py | 5 +- .../search/retrieval/retrievers/structured.py | 5 +- 5 files changed, 13 insertions(+), 67 deletions(-) diff --git a/orchestrator/search/retrieval/retrievers/base.py b/orchestrator/search/retrieval/retrievers/base.py index 8c79c4190..73921a50c 100644 --- a/orchestrator/search/retrieval/retrievers/base.py +++ b/orchestrator/search/retrieval/retrievers/base.py @@ -15,11 +15,9 @@ from decimal import Decimal import structlog -from sqlalchemy import BindParameter, Numeric, Select, literal, select -from sqlalchemy_utils import Ltree +from sqlalchemy import BindParameter, Numeric, Select, literal -from orchestrator.db.models import AiSearchIndex -from orchestrator.search.core.types import EntityType, FieldType, SearchMetadata +from orchestrator.search.core.types import FieldType, SearchMetadata from orchestrator.search.schemas.parameters import BaseSearchParameters from ..pagination import PaginationParams @@ -72,13 +70,13 @@ async def from_params( fallback_fuzzy_term = params.query if q_vec is not None and fallback_fuzzy_term is not None: - return RrfHybridRetriever(q_vec, fallback_fuzzy_term, pagination_params, params.entity_type) + return RrfHybridRetriever(q_vec, fallback_fuzzy_term, pagination_params) if q_vec is not None: - return SemanticRetriever(q_vec, pagination_params, params.entity_type) + return SemanticRetriever(q_vec, pagination_params) if fallback_fuzzy_term is not None: - return FuzzyRetriever(fallback_fuzzy_term, pagination_params, params.entity_type) + return FuzzyRetriever(fallback_fuzzy_term, pagination_params) - return StructuredRetriever(pagination_params, params.entity_type) + return StructuredRetriever(pagination_params) @classmethod async def _get_query_vector( @@ -113,53 +111,6 @@ def _quantize_score_for_pagination(self, score_value: float) -> BindParameter[De pas_dec = Decimal(str(score_value)).quantize(quantizer) return literal(pas_dec, type_=self.SCORE_NUMERIC_TYPE) - @staticmethod - def add_title_to_query(stmt: Select, entity_type: EntityType) -> Select: - """Add title column to a query by joining with the index table.""" - # Define title paths based on entity type - title_path_map = { - EntityType.SUBSCRIPTION: "subscription.description", - EntityType.PRODUCT: "product.description", - EntityType.WORKFLOW: "workflow.description", - EntityType.PROCESS: "process.workflowName", - } - - title_path = title_path_map.get(entity_type) - if not title_path: - # If no title path defined, return original statement - return stmt - - # Create subquery from the original statement - ranked = stmt.subquery("ranked") - - # Subquery to get title value for each entity - # Use a distinct label to avoid any column name conflicts - title_subquery = ( - select( - AiSearchIndex.entity_id.label("title_entity_id"), - AiSearchIndex.value.label("entity_title"), - ) - .where( - AiSearchIndex.entity_type == entity_type.value, - AiSearchIndex.path == Ltree(title_path), - ) - .subquery("titles") - ) - - # Build explicit column list to preserve order - columns = [ranked.c.entity_id, title_subquery.c.entity_title] - - # Add remaining columns in their original order - for col in ranked.c: - if col.name != "entity_id": - columns.append(col) - - return ( - select(*columns) - .select_from(ranked) - .outerjoin(title_subquery, ranked.c.entity_id == title_subquery.c.title_entity_id) - ) - @property @abstractmethod def metadata(self) -> SearchMetadata: diff --git a/orchestrator/search/retrieval/retrievers/fuzzy.py b/orchestrator/search/retrieval/retrievers/fuzzy.py index df6404a43..9f3a3f633 100644 --- a/orchestrator/search/retrieval/retrievers/fuzzy.py +++ b/orchestrator/search/retrieval/retrievers/fuzzy.py @@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import ColumnElement from orchestrator.db.models import AiSearchIndex -from orchestrator.search.core.types import EntityType, SearchMetadata +from orchestrator.search.core.types import SearchMetadata from ..pagination import PaginationParams from .base import Retriever @@ -24,11 +24,10 @@ class FuzzyRetriever(Retriever): """Ranks results based on the max of fuzzy text similarity scores.""" - def __init__(self, fuzzy_term: str, pagination_params: PaginationParams, entity_type: EntityType) -> None: + def __init__(self, fuzzy_term: str, pagination_params: PaginationParams) -> None: self.fuzzy_term = fuzzy_term self.page_after_score = pagination_params.page_after_score self.page_after_id = pagination_params.page_after_id - self.entity_type = entity_type def apply(self, candidate_query: Select) -> Select: cand = candidate_query.subquery() diff --git a/orchestrator/search/retrieval/retrievers/hybrid.py b/orchestrator/search/retrieval/retrievers/hybrid.py index c44cf0656..b7134b93b 100644 --- a/orchestrator/search/retrieval/retrievers/hybrid.py +++ b/orchestrator/search/retrieval/retrievers/hybrid.py @@ -18,7 +18,7 @@ from sqlalchemy.types import TypeEngine from orchestrator.db.models import AiSearchIndex -from orchestrator.search.core.types import EntityType, SearchMetadata +from orchestrator.search.core.types import SearchMetadata from ..pagination import PaginationParams from .base import Retriever @@ -128,7 +128,6 @@ def __init__( q_vec: list[float], fuzzy_term: str, pagination_params: PaginationParams, - entity_type: "EntityType", k: int = 60, field_candidates_limit: int = 100, ) -> None: @@ -136,7 +135,6 @@ def __init__( self.fuzzy_term = fuzzy_term self.page_after_score = pagination_params.page_after_score self.page_after_id = pagination_params.page_after_id - self.entity_type = entity_type self.k = k self.field_candidates_limit = field_candidates_limit diff --git a/orchestrator/search/retrieval/retrievers/semantic.py b/orchestrator/search/retrieval/retrievers/semantic.py index 8940837b9..3b8226ee3 100644 --- a/orchestrator/search/retrieval/retrievers/semantic.py +++ b/orchestrator/search/retrieval/retrievers/semantic.py @@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import ColumnElement from orchestrator.db.models import AiSearchIndex -from orchestrator.search.core.types import EntityType, SearchMetadata +from orchestrator.search.core.types import SearchMetadata from ..pagination import PaginationParams from .base import Retriever @@ -24,11 +24,10 @@ class SemanticRetriever(Retriever): """Ranks results based on the minimum semantic vector distance.""" - def __init__(self, vector_query: list[float], pagination_params: PaginationParams, entity_type: EntityType) -> None: + def __init__(self, vector_query: list[float], pagination_params: PaginationParams) -> None: self.vector_query = vector_query self.page_after_score = pagination_params.page_after_score self.page_after_id = pagination_params.page_after_id - self.entity_type = entity_type def apply(self, candidate_query: Select) -> Select: cand = candidate_query.subquery() diff --git a/orchestrator/search/retrieval/retrievers/structured.py b/orchestrator/search/retrieval/retrievers/structured.py index 4c33892c4..29d546eff 100644 --- a/orchestrator/search/retrieval/retrievers/structured.py +++ b/orchestrator/search/retrieval/retrievers/structured.py @@ -13,7 +13,7 @@ from sqlalchemy import Select, literal, select -from orchestrator.search.core.types import EntityType, SearchMetadata +from orchestrator.search.core.types import SearchMetadata from ..pagination import PaginationParams from .base import Retriever @@ -22,9 +22,8 @@ class StructuredRetriever(Retriever): """Applies a dummy score for purely structured searches with no text query.""" - def __init__(self, pagination_params: PaginationParams, entity_type: EntityType) -> None: + def __init__(self, pagination_params: PaginationParams) -> None: self.page_after_id = pagination_params.page_after_id - self.entity_type = entity_type def apply(self, candidate_query: Select) -> Select: cand = candidate_query.subquery() From 851a76ffa9ec1490be66fbb3f1f35484756ad5bb Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Thu, 16 Oct 2025 12:21:21 +0200 Subject: [PATCH 07/16] add extra column for title --- orchestrator/db/models.py | 1 + orchestrator/search/core/types.py | 1 + orchestrator/search/indexing/indexer.py | 17 ++++++++++---- orchestrator/search/indexing/registry.py | 15 ++++++++++++- orchestrator/search/llm_migration.py | 18 +++++++++++++++ orchestrator/search/retrieval/builder.py | 28 +----------------------- orchestrator/search/retrieval/engine.py | 2 +- orchestrator/search/schemas/results.py | 2 +- 8 files changed, 50 insertions(+), 34 deletions(-) diff --git a/orchestrator/db/models.py b/orchestrator/db/models.py index bde0cc6cc..da071d284 100644 --- a/orchestrator/db/models.py +++ b/orchestrator/db/models.py @@ -790,6 +790,7 @@ class AiSearchIndex(BaseModel): UUIDType, nullable=False, ) + entity_title = mapped_column(TEXT, nullable=True) # Ltree path for hierarchical data path = mapped_column(LtreeType, nullable=False, index=True) diff --git a/orchestrator/search/core/types.py b/orchestrator/search/core/types.py index 5589e4ceb..01f7a4425 100644 --- a/orchestrator/search/core/types.py +++ b/orchestrator/search/core/types.py @@ -289,6 +289,7 @@ def from_raw(cls, path: str, raw_value: Any) -> "ExtractedField": class IndexableRecord(TypedDict): entity_id: str entity_type: str + entity_title: str path: Ltree value: Any value_type: Any diff --git a/orchestrator/search/indexing/indexer.py b/orchestrator/search/indexing/indexer.py index 1f5ae23d8..d2906f767 100644 --- a/orchestrator/search/indexing/indexer.py +++ b/orchestrator/search/indexing/indexer.py @@ -96,6 +96,7 @@ def __init__(self, config: EntityConfig, dry_run: bool, force_index: bool, chunk self.chunk_size = chunk_size self.embedding_model = llm_settings.EMBEDDING_MODEL self.logger = logger.bind(entity_kind=config.entity_kind.value) + self._entity_titles: dict[str, str] = {} def run(self, entities: Iterable[DatabaseEntity]) -> int: """Orchestrates the entire indexing process.""" @@ -138,6 +139,8 @@ def _process_chunk(self, entity_chunk: list[DatabaseEntity], session: Session | if not entity_chunk: return 0, 0 + self._entity_titles.clear() + fields_to_upsert, paths_to_delete, identical_count = self._determine_changes(entity_chunk, session) if paths_to_delete and session is not None: @@ -174,12 +177,15 @@ def _determine_changes( entity, pk_name=self.config.pk_name, root_name=self.config.root_name ) + entity_title = self.config.get_title_from_fields(current_fields) + self._entity_titles[entity_id] = entity_title + entity_hashes = existing_hashes.get(entity_id, {}) current_paths = set() for field in current_fields: current_paths.add(field.path) - current_hash = self._compute_content_hash(field.path, field.value, field.value_type) + current_hash = self._compute_content_hash(field.path, field.value, field.value_type, entity_title) if field.path not in entity_hashes or entity_hashes[field.path] != current_hash: fields_to_upsert.append((entity_id, field)) else: @@ -301,21 +307,23 @@ def _prepare_text_for_embedding(field: ExtractedField) -> str: return f"{field.path}: {str(field.value)}" @staticmethod - def _compute_content_hash(path: str, value: Any, value_type: Any) -> str: + def _compute_content_hash(path: str, value: Any, value_type: Any, entity_title: str = "") -> str: v = "" if value is None else str(value) - content = f"{path}:{v}:{value_type}" + content = f"{path}:{v}:{value_type}:{entity_title}" return hashlib.sha256(content.encode("utf-8")).hexdigest() def _make_indexable_record( self, field: ExtractedField, entity_id: str, embedding: list[float] | None ) -> IndexableRecord: + entity_title = self._entity_titles[entity_id] return IndexableRecord( entity_id=entity_id, entity_type=self.config.entity_kind.value, + entity_title=entity_title, path=Ltree(field.path), value=field.value, value_type=field.value_type, - content_hash=self._compute_content_hash(field.path, field.value, field.value_type), + content_hash=self._compute_content_hash(field.path, field.value, field.value_type, entity_title), embedding=embedding if embedding else None, ) @@ -326,6 +334,7 @@ def _get_upsert_statement() -> Insert: return stmt.on_conflict_do_update( index_elements=[AiSearchIndex.entity_id, AiSearchIndex.path], set_={ + AiSearchIndex.entity_title: stmt.excluded.entity_title, AiSearchIndex.value: stmt.excluded.value, AiSearchIndex.value_type: stmt.excluded.value_type, AiSearchIndex.content_hash: stmt.excluded.content_hash, diff --git a/orchestrator/search/indexing/registry.py b/orchestrator/search/indexing/registry.py index 497bfb81f..acf10f676 100644 --- a/orchestrator/search/indexing/registry.py +++ b/orchestrator/search/indexing/registry.py @@ -25,7 +25,7 @@ WorkflowTable, ) from orchestrator.db.database import BaseModel -from orchestrator.search.core.types import EntityType +from orchestrator.search.core.types import EntityType, ExtractedField from .traverse import ( BaseTraverser, @@ -48,6 +48,7 @@ class EntityConfig(Generic[ModelT]): traverser: "type[BaseTraverser]" pk_name: str root_name: str + title_paths: list[str] # List of field paths to check for title (with fallback) def get_all_query(self, entity_id: str | None = None) -> Query | Select: query = self.table.query @@ -56,6 +57,14 @@ def get_all_query(self, entity_id: str | None = None) -> Query | Select: query = query.filter(pk_column == UUID(entity_id)) return query + def get_title_from_fields(self, fields: list[ExtractedField]) -> str: + """Extract title from fields using configured paths.""" + for title_path in self.title_paths: + for field in fields: + if field.path == title_path and field.value: + return str(field.value) + return "UNKNOWN" + @dataclass(frozen=True) class WorkflowConfig(EntityConfig[WorkflowTable]): @@ -76,6 +85,7 @@ def get_all_query(self, entity_id: str | None = None) -> Select: traverser=SubscriptionTraverser, pk_name="subscription_id", root_name="subscription", + title_paths=["subscription.description"], ), EntityType.PRODUCT: EntityConfig( entity_kind=EntityType.PRODUCT, @@ -83,6 +93,7 @@ def get_all_query(self, entity_id: str | None = None) -> Select: traverser=ProductTraverser, pk_name="product_id", root_name="product", + title_paths=["product.description", "product.name"], ), EntityType.PROCESS: EntityConfig( entity_kind=EntityType.PROCESS, @@ -90,6 +101,7 @@ def get_all_query(self, entity_id: str | None = None) -> Select: traverser=ProcessTraverser, pk_name="process_id", root_name="process", + title_paths=["process.workflow_name"], ), EntityType.WORKFLOW: WorkflowConfig( entity_kind=EntityType.WORKFLOW, @@ -97,5 +109,6 @@ def get_all_query(self, entity_id: str | None = None) -> Select: traverser=WorkflowTraverser, pk_name="workflow_id", root_name="workflow", + title_paths=["workflow.description", "workflow.name"], ), } diff --git a/orchestrator/search/llm_migration.py b/orchestrator/search/llm_migration.py index d32990b89..1cad10c1e 100644 --- a/orchestrator/search/llm_migration.py +++ b/orchestrator/search/llm_migration.py @@ -64,6 +64,7 @@ def run_migration(connection: Connection) -> None: CREATE TABLE IF NOT EXISTS {TABLE} ( entity_type TEXT NOT NULL, entity_id UUID NOT NULL, + entity_title TEXT, path LTREE NOT NULL, value TEXT NOT NULL, embedding VECTOR({TARGET_DIM}), @@ -78,6 +79,23 @@ def run_migration(connection: Connection) -> None: # Drop default connection.execute(text(f"ALTER TABLE {TABLE} ALTER COLUMN value_type DROP DEFAULT;")) + # Add entity_title column if it doesn't exist (for existing installations) + connection.execute( + text( + f""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = '{TABLE}' AND column_name = 'entity_title' + ) THEN + ALTER TABLE {TABLE} ADD COLUMN entity_title TEXT; + END IF; + END $$; + """ + ) + ) + # Create indexes with IF NOT EXISTS connection.execute(text(f"CREATE INDEX IF NOT EXISTS ix_ai_search_index_entity_id ON {TABLE} (entity_id);")) connection.execute( diff --git a/orchestrator/search/retrieval/builder.py b/orchestrator/search/retrieval/builder.py index 446f3409f..7aabb36a9 100644 --- a/orchestrator/search/retrieval/builder.py +++ b/orchestrator/search/retrieval/builder.py @@ -16,7 +16,6 @@ from sqlalchemy import Select, String, cast, func, select from sqlalchemy.engine import Row -from sqlalchemy_utils import Ltree from orchestrator.db.models import AiSearchIndex from orchestrator.search.core.types import EntityType, FieldType, FilterOp, UIType @@ -44,33 +43,8 @@ def build_candidate_query(params: BaseSearchParameters) -> Select: Select: The SQLAlchemy `Select` object representing the query. """ - # Define title paths based on entity type - title_path_map = { - EntityType.SUBSCRIPTION: "subscription.description", - EntityType.PRODUCT: "product.description", - EntityType.WORKFLOW: "workflow.description", - EntityType.PROCESS: "process.workflowName", - } - - title_path = title_path_map.get(params.entity_type) - - # Subquery to get title value for each entity - title_subquery = ( - select( - AiSearchIndex.entity_id.label("title_entity_id"), - AiSearchIndex.value.label("entity_title"), - ) - .where( - AiSearchIndex.entity_type == params.entity_type.value, - AiSearchIndex.path == Ltree(title_path), - ) - .subquery() - ) - stmt = ( - select(AiSearchIndex.entity_id, title_subquery.c.entity_title) - .select_from(AiSearchIndex) - .outerjoin(title_subquery, AiSearchIndex.entity_id == title_subquery.c.title_entity_id) + select(AiSearchIndex.entity_id, AiSearchIndex.entity_title) .where(AiSearchIndex.entity_type == params.entity_type.value) .distinct() ) diff --git a/orchestrator/search/retrieval/engine.py b/orchestrator/search/retrieval/engine.py index 9d8b7f48c..7a8dfc906 100644 --- a/orchestrator/search/retrieval/engine.py +++ b/orchestrator/search/retrieval/engine.py @@ -82,7 +82,7 @@ def _format_response( SearchResult( entity_id=str(row.entity_id), entity_type=search_params.entity_type, - title=entity_title, + entity_title=entity_title, score=row.score, perfect_match=row.get("perfect_match", 0), matching_field=matching_field, diff --git a/orchestrator/search/schemas/results.py b/orchestrator/search/schemas/results.py index 7aee2e191..7dcb36394 100644 --- a/orchestrator/search/schemas/results.py +++ b/orchestrator/search/schemas/results.py @@ -31,7 +31,7 @@ class SearchResult(BaseModel): entity_id: str entity_type: EntityType - title: str + entity_title: str score: float perfect_match: int = 0 matching_field: MatchingField | None = None From 83e264da5cccfbe18c21ca850f12e8116e8b7473 Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Thu, 16 Oct 2025 13:28:42 +0200 Subject: [PATCH 08/16] use consistent naming --- orchestrator/search/retrieval/retrievers/fuzzy.py | 4 ++-- orchestrator/search/retrieval/retrievers/hybrid.py | 6 +++--- orchestrator/search/retrieval/retrievers/semantic.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/orchestrator/search/retrieval/retrievers/fuzzy.py b/orchestrator/search/retrieval/retrievers/fuzzy.py index 9f3a3f633..885036cd7 100644 --- a/orchestrator/search/retrieval/retrievers/fuzzy.py +++ b/orchestrator/search/retrieval/retrievers/fuzzy.py @@ -42,7 +42,7 @@ def apply(self, candidate_query: Select) -> Select: combined_query = ( select( AiSearchIndex.entity_id, - cand.c.entity_title, + AiSearchIndex.entity_title, score, func.first_value(AiSearchIndex.value) .over(partition_by=AiSearchIndex.entity_id, order_by=[similarity_expr.desc(), AiSearchIndex.path.asc()]) @@ -59,7 +59,7 @@ def apply(self, candidate_query: Select) -> Select: literal(self.fuzzy_term).op("<%")(AiSearchIndex.value), ) ) - .distinct(AiSearchIndex.entity_id) + .distinct(AiSearchIndex.entity_id, AiSearchIndex.entity_title) ) final_query = combined_query.subquery("ranked_fuzzy") diff --git a/orchestrator/search/retrieval/retrievers/hybrid.py b/orchestrator/search/retrieval/retrievers/hybrid.py index b7134b93b..a3cc8ac3e 100644 --- a/orchestrator/search/retrieval/retrievers/hybrid.py +++ b/orchestrator/search/retrieval/retrievers/hybrid.py @@ -154,7 +154,7 @@ def apply(self, candidate_query: Select) -> Select: field_candidates = ( select( AiSearchIndex.entity_id, - cand.c.entity_title, + AiSearchIndex.entity_title, AiSearchIndex.path, AiSearchIndex.value, sem_val, @@ -179,10 +179,10 @@ def apply(self, candidate_query: Select) -> Select: entity_scores = ( select( field_candidates.c.entity_id, - func.max(field_candidates.c.entity_title).label("entity_title"), + field_candidates.c.entity_title, func.avg(field_candidates.c.semantic_distance).label("avg_semantic_distance"), func.avg(field_candidates.c.fuzzy_score).label("avg_fuzzy_score"), - ).group_by(field_candidates.c.entity_id) + ).group_by(field_candidates.c.entity_id, field_candidates.c.entity_title) ).cte("entity_scores") entity_highlights = ( diff --git a/orchestrator/search/retrieval/retrievers/semantic.py b/orchestrator/search/retrieval/retrievers/semantic.py index 3b8226ee3..36f5efc13 100644 --- a/orchestrator/search/retrieval/retrievers/semantic.py +++ b/orchestrator/search/retrieval/retrievers/semantic.py @@ -49,7 +49,7 @@ def apply(self, candidate_query: Select) -> Select: combined_query = ( select( AiSearchIndex.entity_id, - cand.c.entity_title, + AiSearchIndex.entity_title, score, func.first_value(AiSearchIndex.value) .over(partition_by=AiSearchIndex.entity_id, order_by=[dist.asc(), AiSearchIndex.path.asc()]) @@ -61,7 +61,7 @@ def apply(self, candidate_query: Select) -> Select: .select_from(AiSearchIndex) .join(cand, cand.c.entity_id == AiSearchIndex.entity_id) .where(AiSearchIndex.embedding.isnot(None)) - .distinct(AiSearchIndex.entity_id) + .distinct(AiSearchIndex.entity_id, AiSearchIndex.entity_title) ) final_query = combined_query.subquery("ranked_semantic") From 07e5f8e2973abfdc6795f127054f1015a37e6265 Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Thu, 16 Oct 2025 13:35:55 +0200 Subject: [PATCH 09/16] use response schema for export --- orchestrator/api/api_v1/endpoints/agent.py | 11 ++++++----- orchestrator/schemas/search.py | 6 ++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/orchestrator/api/api_v1/endpoints/agent.py b/orchestrator/api/api_v1/endpoints/agent.py index d71db91d3..99f42af51 100644 --- a/orchestrator/api/api_v1/endpoints/agent.py +++ b/orchestrator/api/api_v1/endpoints/agent.py @@ -22,9 +22,11 @@ from orchestrator.db import SearchQueryTable, db from orchestrator.llm_settings import llm_settings +from orchestrator.schemas.search import ExportResponse from orchestrator.search.agent import build_agent_instance from orchestrator.search.agent.state import SearchState from orchestrator.search.retrieval import execute_search_for_export +from orchestrator.search.retrieval.pagination import PaginationParams router = APIRouter() logger = get_logger(__name__) @@ -55,9 +57,9 @@ async def agent_conversation( @router.get( "/queries/{query_id}/export", summary="Export query results by query_id", - response_model=dict[str, Any], + response_model=ExportResponse, ) -async def export_by_query_id(query_id: str) -> dict[str, Any]: +async def export_by_query_id(query_id: str) -> ExportResponse: """Export search results using query_id. The query is retrieved from the database, re-executed, and results are returned @@ -67,7 +69,7 @@ async def export_by_query_id(query_id: str) -> dict[str, Any]: query_id: Query UUID Returns: - Dictionary containing 'page' with an array of flattened entity records. + ExportResponse containing 'page' with an array of flattened entity records. Raises: HTTPException: 404 if query not found, 400 if invalid data @@ -92,7 +94,6 @@ async def export_by_query_id(query_id: str) -> dict[str, Any]: detail=f"Query {query_id} not found", ) try: - from orchestrator.search.retrieval.pagination import PaginationParams # Get the full query state including the embedding that was used query_state = agent_query.to_state() @@ -105,7 +106,7 @@ async def export_by_query_id(query_id: str) -> dict[str, Any]: export_records = fetch_export_data(query_state.parameters.entity_type, entity_ids) - return {"page": export_records} + return ExportResponse(page=export_records) except Exception as e: logger.error(e) diff --git a/orchestrator/schemas/search.py b/orchestrator/schemas/search.py index 77f4263c1..d6f4a7f3e 100644 --- a/orchestrator/schemas/search.py +++ b/orchestrator/schemas/search.py @@ -45,3 +45,9 @@ class PathsResponse(BaseModel): components: list[ComponentInfo] model_config = ConfigDict(extra="forbid", use_enum_values=True) + + +class ExportResponse(BaseModel): + page: list[dict] + + model_config = ConfigDict(extra="forbid") From 47a7cf652a1ce2e34b7e22c6377d0658b6b95a49 Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Thu, 16 Oct 2025 15:48:51 +0200 Subject: [PATCH 10/16] Improve structure --- orchestrator/api/api_v1/endpoints/agent.py | 45 ++----- orchestrator/api/api_v1/endpoints/search.py | 99 +++++++--------- orchestrator/cli/search/speedtest.py | 12 +- orchestrator/search/core/exceptions.py | 6 + orchestrator/search/retrieval/__init__.py | 4 +- orchestrator/search/retrieval/engine.py | 81 +++++++++---- orchestrator/search/retrieval/pagination.py | 112 +++++++++++------- .../search/retrieval/retrievers/base.py | 55 ++++----- 8 files changed, 218 insertions(+), 196 deletions(-) diff --git a/orchestrator/api/api_v1/endpoints/agent.py b/orchestrator/api/api_v1/endpoints/agent.py index 99f42af51..40272b7f7 100644 --- a/orchestrator/api/api_v1/endpoints/agent.py +++ b/orchestrator/api/api_v1/endpoints/agent.py @@ -12,7 +12,7 @@ # limitations under the License. from functools import cache -from typing import Annotated, Any +from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic_ai.ag_ui import StateDeps, handle_ag_ui_request @@ -20,13 +20,13 @@ from starlette.responses import Response from structlog import get_logger -from orchestrator.db import SearchQueryTable, db +from orchestrator.db import db from orchestrator.llm_settings import llm_settings from orchestrator.schemas.search import ExportResponse from orchestrator.search.agent import build_agent_instance from orchestrator.search.agent.state import SearchState -from orchestrator.search.retrieval import execute_search_for_export -from orchestrator.search.retrieval.pagination import PaginationParams +from orchestrator.search.core.exceptions import QueryStateNotFoundError +from orchestrator.search.retrieval import execute_search_for_export, get_query_state router = APIRouter() logger = get_logger(__name__) @@ -74,40 +74,15 @@ async def export_by_query_id(query_id: str) -> ExportResponse: Raises: HTTPException: 404 if query not found, 400 if invalid data """ - from uuid import UUID - - from orchestrator.search.export import fetch_export_data - - try: - query_uuid = UUID(query_id) - except ValueError: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid query_id format", - ) - - agent_query = db.session.query(SearchQueryTable).filter_by(query_id=query_uuid).first() - - if not agent_query: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Query {query_id} not found", - ) try: - - # Get the full query state including the embedding that was used - query_state = agent_query.to_state() - - # Create pagination params with the saved embedding to ensure consistent results - pagination_params = PaginationParams(q_vec_override=query_state.query_embedding) - - search_response = await execute_search_for_export(query_state.parameters, db.session, pagination_params) - entity_ids = [res.entity_id for res in search_response.results] - - export_records = fetch_export_data(query_state.parameters.entity_type, entity_ids) - + query_state = get_query_state(query_id) + export_records = await execute_search_for_export(query_state, db.session) return ExportResponse(page=export_records) + except (ValueError, TypeError) as e: + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)) + except QueryStateNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) except Exception as e: logger.error(e) raise HTTPException( diff --git a/orchestrator/api/api_v1/endpoints/search.py b/orchestrator/api/api_v1/endpoints/search.py index 78503e9be..4ad35976a 100644 --- a/orchestrator/api/api_v1/endpoints/search.py +++ b/orchestrator/api/api_v1/endpoints/search.py @@ -19,10 +19,10 @@ PathsResponse, SearchResultsSchema, ) -from orchestrator.search.core.exceptions import InvalidCursorError +from orchestrator.search.core.exceptions import InvalidCursorError, QueryStateNotFoundError from orchestrator.search.core.types import EntityType, UIType from orchestrator.search.filters.definitions import generate_definitions -from orchestrator.search.retrieval import execute_search +from orchestrator.search.retrieval import execute_search, get_query_state from orchestrator.search.retrieval.builder import build_paths_query, create_path_autocomplete_lquery, process_path_rows from orchestrator.search.retrieval.pagination import ( PaginationParams, @@ -45,97 +45,86 @@ async def _perform_search_and_fetch( search_params: SearchParameters, cursor: str | None = None, - query_id: str | None = None, ) -> SearchResultsSchema[SearchResult]: """Execute search and return results. Args: search_params: Search parameters cursor: Pagination cursor - query_id: Optional saved query ID to use for embedding retrieval Returns: Search results with entity_id, score, and matching_field. """ - # If query_id provided, retrieve saved embedding - if query_id and not cursor: - from uuid import UUID - - from orchestrator.db import SearchQueryTable - - try: - query_uuid = UUID(query_id) - except ValueError: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid query_id format: {query_id}", - ) - - search_query = db.session.query(SearchQueryTable).filter_by(query_id=query_uuid).first() - if not search_query: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Query {query_id} not found", - ) - - query_state = search_query.to_state() - search_params = query_state.parameters - pagination_params = PaginationParams(q_vec_override=query_state.query_embedding) - else: - try: - pagination_params = await process_pagination_cursor(cursor, search_params) - except InvalidCursorError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid pagination cursor") - - search_response = await execute_search(search_params, db.session, pagination_params) - if not search_response.results: - return SearchResultsSchema(search_metadata=search_response.metadata) - - next_page_cursor = create_next_page_cursor( - search_response.results, pagination_params, search_params.limit, search_params - ) - has_next_page = next_page_cursor is not None - page_info = PageInfoSchema(has_next_page=has_next_page, next_page_cursor=next_page_cursor) - - return SearchResultsSchema( - data=search_response.results, page_info=page_info, search_metadata=search_response.metadata - ) + try: + if cursor: + try: + pagination_params = await process_pagination_cursor(cursor) + if pagination_params.query_id is None: + raise InvalidCursorError("Cursor missing query_id") + except InvalidCursorError as e: + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)) + + try: + query_state = get_query_state(pagination_params.query_id) + except QueryStateNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) + + search_params = query_state.parameters + query_embedding = query_state.query_embedding + else: + pagination_params = PaginationParams() + query_embedding = None + + search_response = await execute_search(search_params, db.session, pagination_params, query_embedding) + if not search_response.results: + return SearchResultsSchema(search_metadata=search_response.metadata) + + next_page_cursor = create_next_page_cursor(search_response, pagination_params, search_params) + has_next_page = next_page_cursor is not None + page_info = PageInfoSchema(has_next_page=has_next_page, next_page_cursor=next_page_cursor) + + return SearchResultsSchema( + data=search_response.results, page_info=page_info, search_metadata=search_response.metadata + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Search failed: {str(e)}", + ) @router.post("/subscriptions", response_model=SearchResultsSchema[SearchResult]) async def search_subscriptions( search_params: SubscriptionSearchParameters, cursor: str | None = None, - query_id: str | None = Query(None, description="Optional saved query ID for embedding retrieval"), ) -> SearchResultsSchema[SearchResult]: - return await _perform_search_and_fetch(search_params, cursor, query_id) + return await _perform_search_and_fetch(search_params, cursor) @router.post("/workflows", response_model=SearchResultsSchema[SearchResult]) async def search_workflows( search_params: WorkflowSearchParameters, cursor: str | None = None, - query_id: str | None = Query(None, description="Optional saved query ID for embedding retrieval"), ) -> SearchResultsSchema[SearchResult]: - return await _perform_search_and_fetch(search_params, cursor, query_id) + return await _perform_search_and_fetch(search_params, cursor) @router.post("/products", response_model=SearchResultsSchema[SearchResult]) async def search_products( search_params: ProductSearchParameters, cursor: str | None = None, - query_id: str | None = Query(None, description="Optional saved query ID for embedding retrieval"), ) -> SearchResultsSchema[SearchResult]: - return await _perform_search_and_fetch(search_params, cursor, query_id) + return await _perform_search_and_fetch(search_params, cursor) @router.post("/processes", response_model=SearchResultsSchema[SearchResult]) async def search_processes( search_params: ProcessSearchParameters, cursor: str | None = None, - query_id: str | None = Query(None, description="Optional saved query ID for embedding retrieval"), ) -> SearchResultsSchema[SearchResult]: - return await _perform_search_and_fetch(search_params, cursor, query_id) + return await _perform_search_and_fetch(search_params, cursor) @router.get( diff --git a/orchestrator/cli/search/speedtest.py b/orchestrator/cli/search/speedtest.py index a49544e80..f3f1d9306 100644 --- a/orchestrator/cli/search/speedtest.py +++ b/orchestrator/cli/search/speedtest.py @@ -54,17 +54,19 @@ async def generate_embeddings_for_queries(queries: list[str]) -> dict[str, list[ async def run_single_query(query: str, embedding_lookup: dict[str, list[float]]) -> dict[str, Any]: search_params = BaseSearchParameters(entity_type=EntityType.SUBSCRIPTION, query=query, limit=30) + pagination_params = PaginationParams() + query_embedding = None + if is_uuid(query): - pagination_params = PaginationParams() logger.debug("Using fuzzy-only ranking for full UUID", query=query) else: - - cached_embedding = embedding_lookup[query] - pagination_params = PaginationParams(q_vec_override=cached_embedding) + query_embedding = embedding_lookup[query] with db.session as session: start_time = time.perf_counter() - response = await execute_search(search_params, session, pagination_params=pagination_params) + response = await execute_search( + search_params, session, pagination_params=pagination_params, query_embedding=query_embedding + ) end_time = time.perf_counter() return { diff --git a/orchestrator/search/core/exceptions.py b/orchestrator/search/core/exceptions.py index 0170b115c..4f16d11b3 100644 --- a/orchestrator/search/core/exceptions.py +++ b/orchestrator/search/core/exceptions.py @@ -34,3 +34,9 @@ class InvalidCursorError(SearchUtilsError): """Raised when cursor cannot be decoded.""" pass + + +class QueryStateNotFoundError(SearchUtilsError): + """Raised when a query state cannot be found in the database.""" + + pass diff --git a/orchestrator/search/retrieval/__init__.py b/orchestrator/search/retrieval/__init__.py index 353fb6fba..c8b3319a3 100644 --- a/orchestrator/search/retrieval/__init__.py +++ b/orchestrator/search/retrieval/__init__.py @@ -11,6 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .engine import execute_search, execute_search_for_export +from .engine import execute_search, execute_search_for_export, get_query_state -__all__ = ["execute_search", "execute_search_for_export"] +__all__ = ["execute_search", "execute_search_for_export", "get_query_state"] diff --git a/orchestrator/search/retrieval/engine.py b/orchestrator/search/retrieval/engine.py index 7a8dfc906..c22bfe54e 100644 --- a/orchestrator/search/retrieval/engine.py +++ b/orchestrator/search/retrieval/engine.py @@ -12,14 +12,18 @@ # limitations under the License. from collections.abc import Sequence +from uuid import UUID import structlog from sqlalchemy.engine.row import RowMapping from sqlalchemy.orm import Session +from orchestrator.db import SearchQueryTable, db +from orchestrator.search.core.embedding import QueryEmbedder +from orchestrator.search.core.exceptions import QueryStateNotFoundError from orchestrator.search.core.types import FilterOp, SearchMetadata from orchestrator.search.filters import FilterTree, LtreeFilter -from orchestrator.search.schemas.parameters import BaseSearchParameters +from orchestrator.search.schemas.parameters import BaseSearchParameters, SearchQueryState from orchestrator.search.schemas.results import MatchingField, SearchResponse, SearchResult from .builder import build_candidate_query @@ -121,6 +125,7 @@ async def _execute_search_internal( db_session: Session, limit: int, pagination_params: PaginationParams | None = None, + query_embedding: list[float] | None = None, ) -> SearchResponse: """Internal function to execute search with specified parameters. @@ -129,6 +134,7 @@ async def _execute_search_internal( db_session: The active SQLAlchemy session for executing the query. limit: Maximum number of results to return. pagination_params: Optional pagination parameters. + query_embedding: Optional pre-computed query embedding to use instead of generating a new one. Returns: SearchResponse with results and embedding (for internal use). @@ -141,15 +147,11 @@ async def _execute_search_internal( pagination_params = pagination_params or PaginationParams() - if search_params.vector_query and not pagination_params.q_vec_override: - from orchestrator.search.core.embedding import QueryEmbedder + if search_params.vector_query and not query_embedding: - q_vec = await QueryEmbedder.generate_for_text_async(search_params.vector_query) - if q_vec: - pagination_params.q_vec_override = q_vec - logger.debug("Generated embedding for vector query") + query_embedding = await QueryEmbedder.generate_for_text_async(search_params.vector_query) - retriever = await Retriever.from_params(search_params, pagination_params) + retriever = await Retriever.route(search_params, pagination_params, query_embedding) logger.debug("Using retriever", retriever_type=retriever.__class__.__name__) final_stmt = retriever.apply(candidate_query) @@ -159,7 +161,7 @@ async def _execute_search_internal( response = _format_response(result, search_params, retriever.metadata) # Store embedding in response for agent to save to DB - response.query_embedding = pagination_params.q_vec_override + response.query_embedding = query_embedding return response @@ -167,28 +169,63 @@ async def execute_search( search_params: BaseSearchParameters, db_session: Session, pagination_params: PaginationParams | None = None, + query_embedding: list[float] | None = None, ) -> SearchResponse: """Execute a search and return ranked results.""" - return await _execute_search_internal(search_params, db_session, search_params.limit, pagination_params) + return await _execute_search_internal( + search_params, db_session, search_params.limit, pagination_params, query_embedding + ) + + +def get_query_state(query_id: UUID | str) -> SearchQueryState: + """Retrieve query state from database by query_id. + + Args: + query_id: UUID or string UUID of the saved query + + Returns: + SearchQueryState loaded from database + + Raises: + ValueError: If query_id format is invalid + QueryStateNotFoundError: If query not found in database + """ + if isinstance(query_id, UUID): + query_uuid = query_id + else: + try: + query_uuid = UUID(query_id) + except (ValueError, TypeError) as e: + raise ValueError(f"Invalid query_id format: {query_id}") from e + + search_query = db.session.query(SearchQueryTable).filter_by(query_id=query_uuid).first() + if not search_query: + raise QueryStateNotFoundError(f"Query {query_uuid} not found in database") + + return search_query.to_state() async def execute_search_for_export( - search_params: BaseSearchParameters, + query_state: SearchQueryState, db_session: Session, - pagination_params: PaginationParams | None = None, -) -> SearchResponse: - """Execute a search for export purposes. - - Similar to execute_search but uses export_limit instead of limit. - The pagination_params is primarily used to pass q_vec_override to ensure - the export uses the same embedding as the original search. +) -> list[dict]: + """Execute a search for export and fetch flattened entity data. Args: - search_params: The search parameters specifying vector, fuzzy, or filter criteria. + query_state: Query state containing parameters and query_embedding. db_session: The active SQLAlchemy session for executing the query. - pagination_params: Optional pagination parameters (primarily for q_vec_override). Returns: - SearchResponse with results up to export_limit. + List of flattened entity records suitable for export. """ - return await _execute_search_internal(search_params, db_session, search_params.export_limit, pagination_params) + from orchestrator.search.export import fetch_export_data + + search_response = await _execute_search_internal( + search_params=query_state.parameters, + db_session=db_session, + limit=query_state.parameters.export_limit, + query_embedding=query_state.query_embedding, + ) + + entity_ids = [res.entity_id for res in search_response.results] + return fetch_export_data(query_state.parameters.entity_type, entity_ids) diff --git a/orchestrator/search/retrieval/pagination.py b/orchestrator/search/retrieval/pagination.py index be0b46bec..6aaf32cfa 100644 --- a/orchestrator/search/retrieval/pagination.py +++ b/orchestrator/search/retrieval/pagination.py @@ -12,30 +12,28 @@ # limitations under the License. import base64 -from dataclasses import dataclass +from uuid import UUID from pydantic import BaseModel from orchestrator.db import SearchQueryTable, db -from orchestrator.search.core.exceptions import InvalidCursorError +from orchestrator.search.core.exceptions import InvalidCursorError, QueryStateNotFoundError from orchestrator.search.schemas.parameters import SearchParameters, SearchQueryState -from orchestrator.search.schemas.results import SearchResult +from orchestrator.search.schemas.results import SearchResponse -@dataclass -class PaginationParams: +class PaginationParams(BaseModel): """Parameters for pagination in search queries.""" page_after_score: float | None = None page_after_id: str | None = None - q_vec_override: list[float] | None = None - query_id: str | None = None + query_id: UUID | None = None # None only for first page, always set when cursor exists class PageCursor(BaseModel): score: float id: str - query_id: str | None = None + query_id: UUID def encode(self) -> str: """Encode the cursor data into a URL-safe Base64 string.""" @@ -52,66 +50,90 @@ def decode(cls, cursor: str) -> "PageCursor": raise InvalidCursorError("Invalid pagination cursor") from e -async def process_pagination_cursor(cursor: str | None, search_params: SearchParameters) -> PaginationParams: - """Process pagination cursor and return pagination parameters.""" - if cursor: - c = PageCursor.decode(cursor) +async def process_pagination_cursor(cursor: str) -> PaginationParams: + """Decode pagination cursor and extract pagination parameters. - # If cursor has query_id, retrieve saved embedding - if c.query_id: - query = db.session.query(SearchQueryTable).filter_by(query_id=c.query_id).first() - if not query: - raise InvalidCursorError("Query not found") + Args: + cursor: Base64-encoded cursor - query_state = query.to_state() + Returns: + PaginationParams containing page position and query_id from the cursor - return PaginationParams( - page_after_score=c.score, - page_after_id=c.id, - q_vec_override=query_state.query_embedding, - query_id=c.query_id, - ) + Raises: + InvalidCursorError: If cursor cannot be decoded + """ + page_cursor = PageCursor.decode(cursor) + return PaginationParams( + page_after_score=page_cursor.score, + page_after_id=page_cursor.id, + query_id=page_cursor.query_id, + ) + + +def get_query_state(query_id: UUID | str) -> SearchQueryState: + """Retrieve query state from database by query_id. + + Args: + query_id: UUID or string UUID of the saved query + + Returns: + SearchQueryState loaded from database - # No query_id - filter-only or fuzzy-only search - return PaginationParams( - page_after_score=c.score, - page_after_id=c.id, - ) + Raises: + ValueError: If query_id string format is invalid + QueryStateNotFoundError: If query not found in database + """ + + if not isinstance(query_id, UUID): + try: + query_id = UUID(query_id) + except (ValueError, TypeError) as e: + raise ValueError(f"Invalid query_id format: {query_id}") from e - # First page, no embedding needed - # Engine will generate it - return PaginationParams() + search_query = db.session.query(SearchQueryTable).filter_by(query_id=query_id).first() + if not search_query: + raise QueryStateNotFoundError(f"Query {query_id} not found in database") + + return search_query.to_state() def create_next_page_cursor( - search_results: list[SearchResult], + search_response: SearchResponse, pagination_params: PaginationParams, - limit: int, - search_params: SearchParameters | None = None, + search_params: SearchParameters, ) -> str | None: """Create next page cursor if there are more results. - On first page with hybrid search (embedding present), saves the query to database - and includes query_id in cursor for subsequent pages. + On first page, saves the query to database and includes query_id in cursor + for subsequent pages to ensure consistent parameters across pagination. + + Args: + search_response: SearchResponse containing results and query_embedding + pagination_params: Current pagination parameters (may have query_id if not first page) + search_params: Search parameters to save for pagination consistency + + Returns: + Encoded cursor for next page, or None if no more results """ - has_next_page = len(search_results) == limit and limit > 0 + has_next_page = len(search_response.results) == search_params.limit and search_params.limit > 0 if not has_next_page: return None - # If this is the first page and we have an embedding, save to database - if not pagination_params.query_id and pagination_params.q_vec_override and search_params: - # Create query state and save to database - query_state = SearchQueryState(parameters=search_params, query_embedding=pagination_params.q_vec_override) + # If this is the first page, save query state to database + if not pagination_params.query_id: + query_state = SearchQueryState(parameters=search_params, query_embedding=search_response.query_embedding) search_query = SearchQueryTable.from_state(state=query_state) db.session.add(search_query) db.session.commit() - pagination_params.query_id = str(search_query.query_id) + query_id = search_query.query_id + else: + query_id = pagination_params.query_id - last_item = search_results[-1] + last_item = search_response.results[-1] cursor_data = PageCursor( score=float(last_item.score), id=last_item.entity_id, - query_id=pagination_params.query_id, + query_id=query_id, ) return cursor_data.encode() diff --git a/orchestrator/search/retrieval/retrievers/base.py b/orchestrator/search/retrieval/retrievers/base.py index 73921a50c..02bd94c7a 100644 --- a/orchestrator/search/retrieval/retrievers/base.py +++ b/orchestrator/search/retrieval/retrievers/base.py @@ -41,58 +41,49 @@ class Retriever(ABC): ] @classmethod - async def from_params( + async def route( cls, params: BaseSearchParameters, pagination_params: PaginationParams, + query_embedding: list[float] | None = None, ) -> "Retriever": - """Create the appropriate retriever instance from search parameters. + """Route to the appropriate retriever instance based on search parameters. + + Selects the retriever type based on available search criteria: + - Hybrid: both embedding and fuzzy term available + - Semantic: only embedding available + - Fuzzy: only text term available (or fallback when embedding generation fails) + - Structured: only filters available Args: - params (BaseSearchParameters): Search parameters including vector queries, fuzzy terms, and filters. - pagination_params (PaginationParams): Pagination parameters for cursor-based paging. + params: Search parameters including vector queries, fuzzy terms, and filters + pagination_params: Pagination parameters for cursor-based paging + query_embedding: Query embedding for semantic search, or None if not available Returns: - Retriever: A concrete retriever instance (semantic, fuzzy, hybrid, or structured). + A concrete retriever instance based on available search criteria """ - from .fuzzy import FuzzyRetriever from .hybrid import RrfHybridRetriever from .semantic import SemanticRetriever from .structured import StructuredRetriever fuzzy_term = params.fuzzy_term - q_vec = await cls._get_query_vector(params.vector_query, pagination_params.q_vec_override) - # If semantic search was attempted but failed, fall back to fuzzy with the full query - fallback_fuzzy_term = fuzzy_term - if q_vec is None and params.vector_query is not None and params.query is not None: - fallback_fuzzy_term = params.query + # If vector_query exists but embedding generation failed, fall back to fuzzy search with full query + if query_embedding is None and params.vector_query is not None and params.query is not None: + fuzzy_term = params.query - if q_vec is not None and fallback_fuzzy_term is not None: - return RrfHybridRetriever(q_vec, fallback_fuzzy_term, pagination_params) - if q_vec is not None: - return SemanticRetriever(q_vec, pagination_params) - if fallback_fuzzy_term is not None: - return FuzzyRetriever(fallback_fuzzy_term, pagination_params) + # Select retriever based on available search criteria + if query_embedding is not None and fuzzy_term is not None: + return RrfHybridRetriever(query_embedding, fuzzy_term, pagination_params) + if query_embedding is not None: + return SemanticRetriever(query_embedding, pagination_params) + if fuzzy_term is not None: + return FuzzyRetriever(fuzzy_term, pagination_params) return StructuredRetriever(pagination_params) - @classmethod - async def _get_query_vector( - cls, vector_query: str | None, q_vec_override: list[float] | None - ) -> list[float] | None: - """Get query vector from override (provided by engine.py).""" - if q_vec_override: - return q_vec_override - - if vector_query: - logger.warning( - "vector_query present but no q_vec_override provided - embedding should be generated in engine.py" - ) - - return None - @abstractmethod def apply(self, candidate_query: Select) -> Select: """Apply the ranking logic to the given candidate query. From d25830d4a11a76b1b5c4ece3cbed22276f04ccb5 Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Thu, 16 Oct 2025 17:34:07 +0200 Subject: [PATCH 11/16] Move query state model to a separate file --- orchestrator/api/api_v1/endpoints/agent.py | 7 +- orchestrator/api/api_v1/endpoints/search.py | 68 +++++++++++++------- orchestrator/db/models.py | 13 +--- orchestrator/search/agent/tools.py | 8 +-- orchestrator/search/retrieval/__init__.py | 5 +- orchestrator/search/retrieval/engine.py | 34 +--------- orchestrator/search/retrieval/pagination.py | 33 ++-------- orchestrator/search/retrieval/query_state.py | 61 ++++++++++++++++++ orchestrator/search/schemas/parameters.py | 14 ---- 9 files changed, 121 insertions(+), 122 deletions(-) create mode 100644 orchestrator/search/retrieval/query_state.py diff --git a/orchestrator/api/api_v1/endpoints/agent.py b/orchestrator/api/api_v1/endpoints/agent.py index 40272b7f7..22ca54ed6 100644 --- a/orchestrator/api/api_v1/endpoints/agent.py +++ b/orchestrator/api/api_v1/endpoints/agent.py @@ -26,7 +26,7 @@ from orchestrator.search.agent import build_agent_instance from orchestrator.search.agent.state import SearchState from orchestrator.search.core.exceptions import QueryStateNotFoundError -from orchestrator.search.retrieval import execute_search_for_export, get_query_state +from orchestrator.search.retrieval import SearchQueryState, execute_search_for_export router = APIRouter() logger = get_logger(__name__) @@ -75,11 +75,10 @@ async def export_by_query_id(query_id: str) -> ExportResponse: HTTPException: 404 if query not found, 400 if invalid data """ try: - query_state = get_query_state(query_id) + query_state = SearchQueryState.load_from_id(query_id) export_records = await execute_search_for_export(query_state, db.session) return ExportResponse(page=export_records) - - except (ValueError, TypeError) as e: + except ValueError as e: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)) except QueryStateNotFoundError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) diff --git a/orchestrator/api/api_v1/endpoints/search.py b/orchestrator/api/api_v1/endpoints/search.py index 4ad35976a..5272d81f6 100644 --- a/orchestrator/api/api_v1/endpoints/search.py +++ b/orchestrator/api/api_v1/endpoints/search.py @@ -22,7 +22,7 @@ from orchestrator.search.core.exceptions import InvalidCursorError, QueryStateNotFoundError from orchestrator.search.core.types import EntityType, UIType from orchestrator.search.filters.definitions import generate_definitions -from orchestrator.search.retrieval import execute_search, get_query_state +from orchestrator.search.retrieval import SearchQueryState, execute_search from orchestrator.search.retrieval.builder import build_paths_query, create_path_autocomplete_lquery, process_path_rows from orchestrator.search.retrieval.pagination import ( PaginationParams, @@ -43,51 +43,56 @@ async def _perform_search_and_fetch( - search_params: SearchParameters, + search_params: SearchParameters | None = None, cursor: str | None = None, + query_id: str | None = None, ) -> SearchResultsSchema[SearchResult]: - """Execute search and return results. + """Execute search with optional pagination. Args: - search_params: Search parameters - cursor: Pagination cursor + search_params: Search parameters for new search + cursor: Pagination cursor (loads saved query state) + query_id: Saved query ID to retrieve and execute Returns: Search results with entity_id, score, and matching_field. """ try: + # Default pagination for first page + pagination_params = PaginationParams() + if cursor: - try: - pagination_params = await process_pagination_cursor(cursor) - if pagination_params.query_id is None: - raise InvalidCursorError("Cursor missing query_id") - except InvalidCursorError as e: - raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)) - - try: - query_state = get_query_state(pagination_params.query_id) - except QueryStateNotFoundError as e: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) - - search_params = query_state.parameters - query_embedding = query_state.query_embedding + pagination_params = await process_pagination_cursor(cursor) + if pagination_params.query_id is None: + raise InvalidCursorError("Cursor missing query_id") + query_state = SearchQueryState.load_from_id(pagination_params.query_id) + elif query_id: + query_state = SearchQueryState.load_from_id(query_id) + elif search_params: + query_state = SearchQueryState(parameters=search_params, query_embedding=None) else: - pagination_params = PaginationParams() - query_embedding = None + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Either search_params, cursor, or query_id must be provided", + ) - search_response = await execute_search(search_params, db.session, pagination_params, query_embedding) + search_response = await execute_search( + query_state.parameters, db.session, pagination_params, query_state.query_embedding + ) if not search_response.results: return SearchResultsSchema(search_metadata=search_response.metadata) - next_page_cursor = create_next_page_cursor(search_response, pagination_params, search_params) + next_page_cursor = create_next_page_cursor(search_response, pagination_params, query_state.parameters) has_next_page = next_page_cursor is not None page_info = PageInfoSchema(has_next_page=has_next_page, next_page_cursor=next_page_cursor) return SearchResultsSchema( data=search_response.results, page_info=page_info, search_metadata=search_response.metadata ) - except HTTPException: - raise + except (InvalidCursorError, ValueError) as e: + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)) + except QueryStateNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -163,3 +168,16 @@ async def list_paths( async def get_definitions() -> dict[UIType, TypeDefinition]: """Provide a static definition of operators and schemas for each UI type.""" return generate_definitions() + + +@router.get( + "/queries/{query_id}", + response_model=SearchResultsSchema[SearchResult], + summary="Retrieve saved search results by query_id", +) +async def get_query_results( + query_id: str, + cursor: str | None = None, +) -> SearchResultsSchema[SearchResult]: + """Retrieve and execute a saved search by query_id.""" + return await _perform_search_and_fetch(query_id=query_id, cursor=cursor) diff --git a/orchestrator/db/models.py b/orchestrator/db/models.py index da071d284..000db36d0 100644 --- a/orchestrator/db/models.py +++ b/orchestrator/db/models.py @@ -60,7 +60,7 @@ from orchestrator.version import GIT_COMMIT_HASH if TYPE_CHECKING: - from orchestrator.search.schemas.parameters import SearchQueryState + from orchestrator.search.retrieval.query_state import SearchQueryState logger = structlog.get_logger(__name__) @@ -747,17 +747,6 @@ def from_state( query_embedding=state.query_embedding, ) - def to_state(self) -> "SearchQueryState": - """Convert database model to SearchQueryState. - - Returns: - SearchQueryState with typed parameters and embedding vector. - - """ - from orchestrator.search.schemas.parameters import SearchQueryState - - return SearchQueryState.model_validate(self) - class EngineSettingsTable(BaseModel): __tablename__ = "engine_settings" diff --git a/orchestrator/search/agent/tools.py b/orchestrator/search/agent/tools.py index 44c580224..e8536f7cd 100644 --- a/orchestrator/search/agent/tools.py +++ b/orchestrator/search/agent/tools.py @@ -30,8 +30,9 @@ from orchestrator.search.core.types import ActionType, EntityType, FilterOp from orchestrator.search.filters import FilterTree from orchestrator.search.retrieval.exceptions import FilterValidationError, PathNotFoundError +from orchestrator.search.retrieval.query_state import SearchQueryState from orchestrator.search.retrieval.validation import validate_filter_tree -from orchestrator.search.schemas.parameters import BaseSearchParameters, SearchQueryState +from orchestrator.search.schemas.parameters import BaseSearchParameters from orchestrator.settings import app_settings logger = structlog.get_logger(__name__) @@ -177,9 +178,8 @@ async def execute_search( ) # Store results metadata for frontend display - # Frontend may call search endpoints with query_id parameter - entity_type = params.entity_type.value - results_url = f"{app_settings.BASE_URL}/api/search/{entity_type}s?query_id={ctx.deps.state.query_id}" + # Frontend calls the queries endpoint to retrieve saved search results + results_url = f"{app_settings.BASE_URL}/api/search/queries/{ctx.deps.state.query_id}" ctx.deps.state.results_data = SearchResultsData( query_id=str(ctx.deps.state.query_id), diff --git a/orchestrator/search/retrieval/__init__.py b/orchestrator/search/retrieval/__init__.py index c8b3319a3..a6f505450 100644 --- a/orchestrator/search/retrieval/__init__.py +++ b/orchestrator/search/retrieval/__init__.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .engine import execute_search, execute_search_for_export, get_query_state +from .engine import execute_search, execute_search_for_export +from .query_state import SearchQueryState -__all__ = ["execute_search", "execute_search_for_export", "get_query_state"] +__all__ = ["execute_search", "execute_search_for_export", "SearchQueryState"] diff --git a/orchestrator/search/retrieval/engine.py b/orchestrator/search/retrieval/engine.py index c22bfe54e..cf00e71e3 100644 --- a/orchestrator/search/retrieval/engine.py +++ b/orchestrator/search/retrieval/engine.py @@ -12,22 +12,20 @@ # limitations under the License. from collections.abc import Sequence -from uuid import UUID import structlog from sqlalchemy.engine.row import RowMapping from sqlalchemy.orm import Session -from orchestrator.db import SearchQueryTable, db from orchestrator.search.core.embedding import QueryEmbedder -from orchestrator.search.core.exceptions import QueryStateNotFoundError from orchestrator.search.core.types import FilterOp, SearchMetadata from orchestrator.search.filters import FilterTree, LtreeFilter -from orchestrator.search.schemas.parameters import BaseSearchParameters, SearchQueryState +from orchestrator.search.schemas.parameters import BaseSearchParameters from orchestrator.search.schemas.results import MatchingField, SearchResponse, SearchResult from .builder import build_candidate_query from .pagination import PaginationParams +from .query_state import SearchQueryState from .retrievers import Retriever from .utils import generate_highlight_indices @@ -177,34 +175,6 @@ async def execute_search( ) -def get_query_state(query_id: UUID | str) -> SearchQueryState: - """Retrieve query state from database by query_id. - - Args: - query_id: UUID or string UUID of the saved query - - Returns: - SearchQueryState loaded from database - - Raises: - ValueError: If query_id format is invalid - QueryStateNotFoundError: If query not found in database - """ - if isinstance(query_id, UUID): - query_uuid = query_id - else: - try: - query_uuid = UUID(query_id) - except (ValueError, TypeError) as e: - raise ValueError(f"Invalid query_id format: {query_id}") from e - - search_query = db.session.query(SearchQueryTable).filter_by(query_id=query_uuid).first() - if not search_query: - raise QueryStateNotFoundError(f"Query {query_uuid} not found in database") - - return search_query.to_state() - - async def execute_search_for_export( query_state: SearchQueryState, db_session: Session, diff --git a/orchestrator/search/retrieval/pagination.py b/orchestrator/search/retrieval/pagination.py index 6aaf32cfa..f672b8b98 100644 --- a/orchestrator/search/retrieval/pagination.py +++ b/orchestrator/search/retrieval/pagination.py @@ -17,8 +17,8 @@ from pydantic import BaseModel from orchestrator.db import SearchQueryTable, db -from orchestrator.search.core.exceptions import InvalidCursorError, QueryStateNotFoundError -from orchestrator.search.schemas.parameters import SearchParameters, SearchQueryState +from orchestrator.search.core.exceptions import InvalidCursorError +from orchestrator.search.schemas.parameters import SearchParameters from orchestrator.search.schemas.results import SearchResponse @@ -70,33 +70,6 @@ async def process_pagination_cursor(cursor: str) -> PaginationParams: ) -def get_query_state(query_id: UUID | str) -> SearchQueryState: - """Retrieve query state from database by query_id. - - Args: - query_id: UUID or string UUID of the saved query - - Returns: - SearchQueryState loaded from database - - Raises: - ValueError: If query_id string format is invalid - QueryStateNotFoundError: If query not found in database - """ - - if not isinstance(query_id, UUID): - try: - query_id = UUID(query_id) - except (ValueError, TypeError) as e: - raise ValueError(f"Invalid query_id format: {query_id}") from e - - search_query = db.session.query(SearchQueryTable).filter_by(query_id=query_id).first() - if not search_query: - raise QueryStateNotFoundError(f"Query {query_id} not found in database") - - return search_query.to_state() - - def create_next_page_cursor( search_response: SearchResponse, pagination_params: PaginationParams, @@ -115,6 +88,8 @@ def create_next_page_cursor( Returns: Encoded cursor for next page, or None if no more results """ + from orchestrator.search.retrieval.query_state import SearchQueryState + has_next_page = len(search_response.results) == search_params.limit and search_params.limit > 0 if not has_next_page: return None diff --git a/orchestrator/search/retrieval/query_state.py b/orchestrator/search/retrieval/query_state.py new file mode 100644 index 000000000..2dc80574b --- /dev/null +++ b/orchestrator/search/retrieval/query_state.py @@ -0,0 +1,61 @@ +# Copyright 2019-2025 SURF, GÉANT. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field + +from orchestrator.db import SearchQueryTable, db +from orchestrator.search.core.exceptions import QueryStateNotFoundError +from orchestrator.search.schemas.parameters import SearchParameters + + +class SearchQueryState(BaseModel): + """State of a search query including parameters and embedding. + + This model provides a complete snapshot of what was searched and how. + Used for both agent and regular API searches. + """ + + parameters: SearchParameters = Field(discriminator="entity_type") + query_embedding: list[float] | None = Field(default=None, description="The embedding vector for semantic search") + + model_config = ConfigDict(from_attributes=True) + + @classmethod + def load_from_id(cls, query_id: UUID | str) -> "SearchQueryState": + """Load query state from database by query_id. + + Args: + query_id: UUID or string UUID of the saved query + + Returns: + SearchQueryState loaded from database + + Raises: + ValueError: If query_id format is invalid + QueryStateNotFoundError: If query not found in database + """ + if isinstance(query_id, UUID): + query_uuid = query_id + else: + try: + query_uuid = UUID(query_id) + except (ValueError, TypeError) as e: + raise ValueError(f"Invalid query_id format: {query_id}") from e + + search_query = db.session.query(SearchQueryTable).filter_by(query_id=query_uuid).first() + if not search_query: + raise QueryStateNotFoundError(f"Query {query_uuid} not found in database") + + return cls.model_validate(search_query) diff --git a/orchestrator/search/schemas/parameters.py b/orchestrator/search/schemas/parameters.py index ccd85b794..0a006d9e1 100644 --- a/orchestrator/search/schemas/parameters.py +++ b/orchestrator/search/schemas/parameters.py @@ -131,17 +131,3 @@ class ProcessSearchParameters(BaseSearchParameters): SearchParameters = ( SubscriptionSearchParameters | ProductSearchParameters | WorkflowSearchParameters | ProcessSearchParameters ) - - -class SearchQueryState(BaseModel): - """Complete state of a search query including parameters and embedding. - - This model combines the search parameters with the query embedding, - providing a complete snapshot of what was searched and how. - Used for both agent and regular API searches. - """ - - parameters: SearchParameters = Field(discriminator="entity_type") - query_embedding: list[float] | None = Field(default=None, description="The embedding vector for semantic search") - - model_config = ConfigDict(from_attributes=True) From 17160a9c7f3a681abcc0972390946e771ed2981e Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Thu, 16 Oct 2025 20:47:25 +0200 Subject: [PATCH 12/16] tool calling --- orchestrator/api/api_v1/endpoints/agent.py | 40 -------- orchestrator/api/api_v1/endpoints/search.py | 43 ++++++++- orchestrator/search/agent/prompts.py | 23 +++-- orchestrator/search/agent/state.py | 5 +- orchestrator/search/agent/tools.py | 100 +++++++++++++------- 5 files changed, 128 insertions(+), 83 deletions(-) diff --git a/orchestrator/api/api_v1/endpoints/agent.py b/orchestrator/api/api_v1/endpoints/agent.py index 22ca54ed6..eb29b43c8 100644 --- a/orchestrator/api/api_v1/endpoints/agent.py +++ b/orchestrator/api/api_v1/endpoints/agent.py @@ -20,13 +20,9 @@ from starlette.responses import Response from structlog import get_logger -from orchestrator.db import db from orchestrator.llm_settings import llm_settings -from orchestrator.schemas.search import ExportResponse from orchestrator.search.agent import build_agent_instance from orchestrator.search.agent.state import SearchState -from orchestrator.search.core.exceptions import QueryStateNotFoundError -from orchestrator.search.retrieval import SearchQueryState, execute_search_for_export router = APIRouter() logger = get_logger(__name__) @@ -52,39 +48,3 @@ async def agent_conversation( """ initial_state = SearchState() return await handle_ag_ui_request(agent, request, deps=StateDeps(initial_state)) - - -@router.get( - "/queries/{query_id}/export", - summary="Export query results by query_id", - response_model=ExportResponse, -) -async def export_by_query_id(query_id: str) -> ExportResponse: - """Export search results using query_id. - - The query is retrieved from the database, re-executed, and results are returned - as flattened records suitable for CSV download. - - Args: - query_id: Query UUID - - Returns: - ExportResponse containing 'page' with an array of flattened entity records. - - Raises: - HTTPException: 404 if query not found, 400 if invalid data - """ - try: - query_state = SearchQueryState.load_from_id(query_id) - export_records = await execute_search_for_export(query_state, db.session) - return ExportResponse(page=export_records) - except ValueError as e: - raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)) - except QueryStateNotFoundError as e: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) - except Exception as e: - logger.error(e) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error executing export: {str(e)}", - ) diff --git a/orchestrator/api/api_v1/endpoints/search.py b/orchestrator/api/api_v1/endpoints/search.py index 5272d81f6..c07b2221c 100644 --- a/orchestrator/api/api_v1/endpoints/search.py +++ b/orchestrator/api/api_v1/endpoints/search.py @@ -11,10 +11,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import structlog from fastapi import APIRouter, HTTPException, Query, status from orchestrator.db import db from orchestrator.schemas.search import ( + ExportResponse, PageInfoSchema, PathsResponse, SearchResultsSchema, @@ -22,7 +24,7 @@ from orchestrator.search.core.exceptions import InvalidCursorError, QueryStateNotFoundError from orchestrator.search.core.types import EntityType, UIType from orchestrator.search.filters.definitions import generate_definitions -from orchestrator.search.retrieval import SearchQueryState, execute_search +from orchestrator.search.retrieval import SearchQueryState, execute_search, execute_search_for_export from orchestrator.search.retrieval.builder import build_paths_query, create_path_autocomplete_lquery, process_path_rows from orchestrator.search.retrieval.pagination import ( PaginationParams, @@ -40,6 +42,7 @@ from orchestrator.search.schemas.results import SearchResult, TypeDefinition router = APIRouter() +logger = structlog.get_logger(__name__) async def _perform_search_and_fetch( @@ -175,9 +178,45 @@ async def get_definitions() -> dict[UIType, TypeDefinition]: response_model=SearchResultsSchema[SearchResult], summary="Retrieve saved search results by query_id", ) -async def get_query_results( +async def get_by_query_id( query_id: str, cursor: str | None = None, ) -> SearchResultsSchema[SearchResult]: """Retrieve and execute a saved search by query_id.""" return await _perform_search_and_fetch(query_id=query_id, cursor=cursor) + + +@router.get( + "/queries/{query_id}/export", + summary="Export query results by query_id", + response_model=ExportResponse, +) +async def export_by_query_id(query_id: str) -> ExportResponse: + """Export search results using query_id. + + The query is retrieved from the database, re-executed, and results are returned + as flattened records suitable for CSV download. + + Args: + query_id: Query UUID + + Returns: + ExportResponse containing 'page' with an array of flattened entity records. + + Raises: + HTTPException: 404 if query not found, 400 if invalid data + """ + try: + query_state = SearchQueryState.load_from_id(query_id) + export_records = await execute_search_for_export(query_state, db.session) + return ExportResponse(page=export_records) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)) + except QueryStateNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) + except Exception as e: + logger.error(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error executing export: {str(e)}", + ) diff --git a/orchestrator/search/agent/prompts.py b/orchestrator/search/agent/prompts.py index 4c9ff8b6e..41d6fbf33 100644 --- a/orchestrator/search/agent/prompts.py +++ b/orchestrator/search/agent/prompts.py @@ -50,13 +50,13 @@ async def get_base_instructions() -> str: Follow these steps in strict order: - 1. **Set Context**: Always begin by calling `set_search_parameters`. + 1. **Set Context**: If the user is asking for a NEW search, call `start_new_search`. If the user is asking a question about EXISTING results shown in the "Current Search Results" section, skip to step 4 (Report) and answer using the data provided. 2. **Analyze for Filters**: Based on the user's request, decide if specific filters are necessary. - **If filters ARE required**, follow these sub-steps: a. **Gather Intel**: Identify all needed field names, then call `discover_filter_paths` and `get_valid_operators` **once each** to get all required information. b. **Construct FilterTree**: Build the `FilterTree` object. c. **Set Filters**: Call `set_filter_tree`. - 3. **Execute**: Call `execute_search`. This is done for both filtered and non-filtered searches. + 3. **Execute**: Call `run_search`. This is done for both filtered and non-filtered searches. 4. **Report**: Answer the users' question directly and summarize when appropiate. 5. **Export (if requested)**: If the user asks to export, download, or save results as CSV/file: - **IMPORTANT**: Export is ONLY available for SELECT actions (not COUNT or AGGREGATE) @@ -82,22 +82,31 @@ async def get_dynamic_instructions(ctx: RunContext[StateDeps[SearchState]]) -> s results_count = state.results_data.total_count if state.results_data else 0 next_step_guidance = "" + results_section = "" + if not state.parameters or not state.parameters.get("entity_type"): next_step_guidance = ( - "INSTRUCTION: The search context is not set. Your next action is to call `set_search_parameters`." + "INSTRUCTION: The search context is not set. Your next action is to call `start_new_search`." ) elif results_count > 0: next_step_guidance = ( - f"INSTRUCTION: Search completed with {results_count} results. " - "You can answer the user's question with these results. " + f"INSTRUCTION: Search completed with {results_count} results shown below. " + "Answer the user's question using the actual result data provided. " "If the user requests an export/download, call `prepare_export`." ) + + # Include results data in the prompt + if state.results_data and state.results_data.results: + results_section = "\n\n**Current Search Results:**\n" + for idx, result in enumerate(state.results_data.results, 1): + results_section += f"{idx}. **{result.entity_title}** (ID: {result.entity_id}, Score: {result.score:.2f})\n" else: next_step_guidance = ( "INSTRUCTION: Context is set. Now, analyze the user's request. " "If specific filters ARE required, use the information-gathering tools to build a `FilterTree` and call `set_filter_tree`. " - "If no specific filters are needed, you can proceed directly to `execute_search`." + "If no specific filters are needed, you can proceed directly to `run_search`." ) + return dedent( f""" --- @@ -110,6 +119,6 @@ async def get_dynamic_instructions(ctx: RunContext[StateDeps[SearchState]]) -> s **Current Results Count:** {results_count} - **{next_step_guidance}** + **{next_step_guidance}**{results_section} """ ) diff --git a/orchestrator/search/agent/state.py b/orchestrator/search/agent/state.py index 027fd84fa..075e3a192 100644 --- a/orchestrator/search/agent/state.py +++ b/orchestrator/search/agent/state.py @@ -16,6 +16,8 @@ from pydantic import BaseModel +from orchestrator.search.schemas.results import SearchResult + class ExportData(BaseModel): """Export metadata for download.""" @@ -27,13 +29,14 @@ class ExportData(BaseModel): class SearchResultsData(BaseModel): - """Search results metadata for frontend display.""" + """Search results data for frontend display and agent context.""" action: str = "view_results" query_id: str results_url: str total_count: int message: str + results: list[SearchResult] = [] class SearchState(BaseModel): diff --git a/orchestrator/search/agent/tools.py b/orchestrator/search/agent/tools.py index e8536f7cd..2835a56e1 100644 --- a/orchestrator/search/agent/tools.py +++ b/orchestrator/search/agent/tools.py @@ -29,6 +29,7 @@ from orchestrator.search.agent.state import ExportData, SearchResultsData, SearchState from orchestrator.search.core.types import ActionType, EntityType, FilterOp from orchestrator.search.filters import FilterTree +from orchestrator.search.retrieval.engine import execute_search from orchestrator.search.retrieval.exceptions import FilterValidationError, PathNotFoundError from orchestrator.search.retrieval.query_state import SearchQueryState from orchestrator.search.retrieval.validation import validate_filter_tree @@ -49,32 +50,82 @@ def last_user_message(ctx: RunContext[StateDeps[SearchState]]) -> str | None: return None +def _set_parameters( + ctx: RunContext[StateDeps[SearchState]], + entity_type: EntityType, + action: str | ActionType, + query: str, + filters: Any | None, +) -> None: + """Internal helper to set parameters.""" + ctx.deps.state.parameters = { + "action": action, + "entity_type": entity_type, + "filters": filters, + "query": query, + } + + @search_toolset.tool -async def set_search_parameters( +async def start_new_search( ctx: RunContext[StateDeps[SearchState]], entity_type: EntityType, action: str | ActionType = ActionType.SELECT, ) -> StateSnapshotEvent: - """Sets the initial search context, like the entity type and the user's query. + """Starts a completely new search, clearing all previous state. - This MUST be the first tool called to start any new search. - Warning: Calling this tool will erase any existing filters and search results from the state. + This MUST be the first tool called when the user asks for a NEW search. + Warning: This will erase any existing filters, results, and search state. """ - params = ctx.deps.state.parameters or {} - is_new_search = params.get("entity_type") != entity_type.value - final_query = (last_user_message(ctx) or "") if is_new_search else params.get("query", "") + final_query = last_user_message(ctx) or "" logger.debug( - "Setting search parameters", + "Starting new search", entity_type=entity_type.value, action=action, - is_new_search=is_new_search, query=final_query, ) - ctx.deps.state.parameters = {"action": action, "entity_type": entity_type, "filters": None, "query": final_query} + # Clear all state ctx.deps.state.results_data = None - logger.debug("Search parameters set", parameters=ctx.deps.state.parameters) + ctx.deps.state.export_data = None + + # Set fresh parameters with no filters + _set_parameters(ctx, entity_type, action, final_query, None) + + logger.debug("New search started", parameters=ctx.deps.state.parameters) + + return StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + snapshot=ctx.deps.state.model_dump(), + ) + + +@search_toolset.tool +async def set_search_parameters( + ctx: RunContext[StateDeps[SearchState]], + entity_type: EntityType, + action: str | ActionType = ActionType.SELECT, +) -> StateSnapshotEvent: + """Updates search parameters without clearing filters or results. + + Use this to modify the entity type or action while preserving existing filters. + For a completely new search, use start_new_search instead. + """ + params = ctx.deps.state.parameters or {} + existing_filters = params.get("filters") + existing_query = params.get("query", "") + + logger.debug( + "Updating search parameters", + entity_type=entity_type.value, + action=action, + preserving_filters=existing_filters is not None, + ) + + _set_parameters(ctx, entity_type, action, existing_query, existing_filters) + + logger.debug("Search parameters updated", parameters=ctx.deps.state.parameters) return StateSnapshotEvent( type=EventType.STATE_SNAPSHOT, @@ -125,7 +176,7 @@ async def set_filter_tree( @search_toolset.tool -async def execute_search( +async def run_search( ctx: RunContext[StateDeps[SearchState]], limit: int = 10, ) -> StateSnapshotEvent: @@ -156,8 +207,6 @@ async def execute_search( logger.debug("Created new agent run", run_id=str(agent_run.run_id)) # Get query with embedding and save to DB - from orchestrator.search.retrieval.engine import execute_search - search_response = await execute_search(params, db.session) query_embedding = search_response.query_embedding query_state = SearchQueryState(parameters=params, query_embedding=query_embedding) @@ -177,8 +226,7 @@ async def execute_search( total_results=len(search_response.results), ) - # Store results metadata for frontend display - # Frontend calls the queries endpoint to retrieve saved search results + # Store results data for both frontend display and agent context results_url = f"{app_settings.BASE_URL}/api/search/queries/{ctx.deps.state.query_id}" ctx.deps.state.results_data = SearchResultsData( @@ -186,6 +234,7 @@ async def execute_search( results_url=results_url, total_count=len(search_response.results), message=f"Found {len(search_response.results)} results.", + results=search_response.results, # Include actual results in state ) return StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=ctx.deps.state.model_dump()) @@ -270,7 +319,6 @@ async def get_valid_operators() -> dict[str, list[FilterOp]]: @search_toolset.tool async def prepare_export( ctx: RunContext[StateDeps[SearchState]], - max_results: int = 1000, ) -> StateSnapshotEvent: """Prepares export URL using the last executed search query.""" if not ctx.deps.state.query_id or not ctx.deps.state.run_id: @@ -287,31 +335,17 @@ async def prepare_export( "Please run a SELECT search first." ) - # Retrieve the saved query to update export_limit if needed - agent_query = db.session.query(SearchQueryTable).filter_by(query_id=ctx.deps.state.query_id).first() - if not agent_query: - raise ValueError("Query not found in database") - - export_limit = min(max_results, BaseSearchParameters.DEFAULT_EXPORT_LIMIT) - - # Update the parameters with export_limit - params_dict = agent_query.parameters.copy() - params_dict["export_limit"] = export_limit - agent_query.parameters = params_dict - db.session.commit() - logger.debug( "Prepared query for export", query_id=str(ctx.deps.state.query_id), - export_limit=export_limit, ) - download_url = f"{app_settings.BASE_URL}/api/agent/queries/{ctx.deps.state.query_id}/export" + download_url = f"{app_settings.BASE_URL}/api/search/queries/{ctx.deps.state.query_id}/export" ctx.deps.state.export_data = ExportData( query_id=str(ctx.deps.state.query_id), download_url=download_url, - message=f"Export ready for download (up to {export_limit} results).", + message="Export ready for download.", ) logger.debug("Export data set in state", export_data=ctx.deps.state.export_data.model_dump()) From 1f0f2c767d85f95f6edda44a1c57bcaa1772eecb Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Wed, 22 Oct 2025 17:29:30 +0200 Subject: [PATCH 13/16] Simplify cursor logic, improve dynamic prompt instructions, use delta state updates instead of full snapshots --- orchestrator/agentic_app.py | 4 - orchestrator/api/api_v1/endpoints/agent.py | 4 +- orchestrator/api/api_v1/endpoints/search.py | 19 ++-- orchestrator/cli/search/speedtest.py | 6 +- orchestrator/search/agent/json_patch.py | 51 ++++++++++ orchestrator/search/agent/prompts.py | 50 +++++----- orchestrator/search/agent/tools.py | 97 +++++++++++++++++-- orchestrator/search/retrieval/engine.py | 16 ++- orchestrator/search/retrieval/pagination.py | 38 +------- .../search/retrieval/retrievers/base.py | 14 +-- .../search/retrieval/retrievers/fuzzy.py | 15 ++- .../search/retrieval/retrievers/hybrid.py | 13 ++- .../search/retrieval/retrievers/semantic.py | 13 ++- .../search/retrieval/retrievers/structured.py | 10 +- 14 files changed, 215 insertions(+), 135 deletions(-) create mode 100644 orchestrator/search/agent/json_patch.py diff --git a/orchestrator/agentic_app.py b/orchestrator/agentic_app.py index b9014863a..e1a0353f7 100644 --- a/orchestrator/agentic_app.py +++ b/orchestrator/agentic_app.py @@ -27,7 +27,6 @@ if TYPE_CHECKING: from pydantic_ai.models.openai import OpenAIModel - from pydantic_ai.toolsets import FunctionToolset logger = get_logger(__name__) @@ -38,7 +37,6 @@ def __init__( *args: Any, llm_settings: LLMSettings = llm_settings, agent_model: "OpenAIModel | str | None" = None, - agent_tools: "list[FunctionToolset] | None" = None, **kwargs: Any, ) -> None: """Initialize the `LLMOrchestratorCore` class. @@ -50,7 +48,6 @@ def __init__( *args: All the normal arguments passed to the `OrchestratorCore` class. llm_settings: A class of settings for the LLM agent_model: Override the agent model (defaults to llm_settings.AGENT_MODEL) - agent_tools: A list of tools that can be used by the agent **kwargs: Additional arguments passed to the `OrchestratorCore` class. Returns: @@ -58,7 +55,6 @@ def __init__( """ self.llm_settings = llm_settings self.agent_model = agent_model or llm_settings.AGENT_MODEL - self.agent_tools = agent_tools super().__init__(*args, **kwargs) diff --git a/orchestrator/api/api_v1/endpoints/agent.py b/orchestrator/api/api_v1/endpoints/agent.py index eb29b43c8..832b48356 100644 --- a/orchestrator/api/api_v1/endpoints/agent.py +++ b/orchestrator/api/api_v1/endpoints/agent.py @@ -14,7 +14,7 @@ from functools import cache from typing import Annotated -from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi import APIRouter, Depends, Request from pydantic_ai.ag_ui import StateDeps, handle_ag_ui_request from pydantic_ai.agent import Agent from starlette.responses import Response @@ -34,7 +34,7 @@ def get_agent() -> Agent[StateDeps[SearchState], str]: The agent is built once and cached for the lifetime of the application. """ - return build_agent_instance(llm_settings.AGENT_MODEL, agent_tools=None) + return build_agent_instance(llm_settings.AGENT_MODEL) @router.post("/") diff --git a/orchestrator/api/api_v1/endpoints/search.py b/orchestrator/api/api_v1/endpoints/search.py index c07b2221c..ac5fa4dcc 100644 --- a/orchestrator/api/api_v1/endpoints/search.py +++ b/orchestrator/api/api_v1/endpoints/search.py @@ -26,11 +26,7 @@ from orchestrator.search.filters.definitions import generate_definitions from orchestrator.search.retrieval import SearchQueryState, execute_search, execute_search_for_export from orchestrator.search.retrieval.builder import build_paths_query, create_path_autocomplete_lquery, process_path_rows -from orchestrator.search.retrieval.pagination import ( - PaginationParams, - create_next_page_cursor, - process_pagination_cursor, -) +from orchestrator.search.retrieval.pagination import PageCursor, encode_next_page_cursor from orchestrator.search.retrieval.validation import is_lquery_syntactically_valid from orchestrator.search.schemas.parameters import ( ProcessSearchParameters, @@ -61,14 +57,11 @@ async def _perform_search_and_fetch( Search results with entity_id, score, and matching_field. """ try: - # Default pagination for first page - pagination_params = PaginationParams() + page_cursor: PageCursor | None = None if cursor: - pagination_params = await process_pagination_cursor(cursor) - if pagination_params.query_id is None: - raise InvalidCursorError("Cursor missing query_id") - query_state = SearchQueryState.load_from_id(pagination_params.query_id) + page_cursor = PageCursor.decode(cursor) + query_state = SearchQueryState.load_from_id(page_cursor.query_id) elif query_id: query_state = SearchQueryState.load_from_id(query_id) elif search_params: @@ -80,12 +73,12 @@ async def _perform_search_and_fetch( ) search_response = await execute_search( - query_state.parameters, db.session, pagination_params, query_state.query_embedding + query_state.parameters, db.session, page_cursor, query_state.query_embedding ) if not search_response.results: return SearchResultsSchema(search_metadata=search_response.metadata) - next_page_cursor = create_next_page_cursor(search_response, pagination_params, query_state.parameters) + next_page_cursor = encode_next_page_cursor(search_response, page_cursor, query_state.parameters) has_next_page = next_page_cursor is not None page_info = PageInfoSchema(has_next_page=has_next_page, next_page_cursor=next_page_cursor) diff --git a/orchestrator/cli/search/speedtest.py b/orchestrator/cli/search/speedtest.py index f3f1d9306..5c215c219 100644 --- a/orchestrator/cli/search/speedtest.py +++ b/orchestrator/cli/search/speedtest.py @@ -13,7 +13,6 @@ from orchestrator.search.core.types import EntityType from orchestrator.search.core.validators import is_uuid from orchestrator.search.retrieval.engine import execute_search -from orchestrator.search.retrieval.pagination import PaginationParams from orchestrator.search.schemas.parameters import BaseSearchParameters logger = structlog.get_logger(__name__) @@ -54,7 +53,6 @@ async def generate_embeddings_for_queries(queries: list[str]) -> dict[str, list[ async def run_single_query(query: str, embedding_lookup: dict[str, list[float]]) -> dict[str, Any]: search_params = BaseSearchParameters(entity_type=EntityType.SUBSCRIPTION, query=query, limit=30) - pagination_params = PaginationParams() query_embedding = None if is_uuid(query): @@ -64,9 +62,7 @@ async def run_single_query(query: str, embedding_lookup: dict[str, list[float]]) with db.session as session: start_time = time.perf_counter() - response = await execute_search( - search_params, session, pagination_params=pagination_params, query_embedding=query_embedding - ) + response = await execute_search(search_params, session, cursor=None, query_embedding=query_embedding) end_time = time.perf_counter() return { diff --git a/orchestrator/search/agent/json_patch.py b/orchestrator/search/agent/json_patch.py new file mode 100644 index 000000000..1758c5704 --- /dev/null +++ b/orchestrator/search/agent/json_patch.py @@ -0,0 +1,51 @@ +# Copyright 2019-2025 SURF, GÉANT. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class JSONPatchOp(BaseModel): + """A JSON Patch operation (RFC 6902). + + Docs reference: https://docs.ag-ui.com/concepts/state + """ + + op: Literal["add", "remove", "replace", "move", "copy", "test"] = Field( + description="The operation to perform: add, remove, replace, move, copy, or test" + ) + path: str = Field(description="JSON Pointer (RFC 6901) to the target location") + value: Any | None = Field( + default=None, + description="The value to apply (for add, replace operations)", + ) + from_: str | None = Field( + default=None, + alias="from", + description="Source path (for move, copy operations)", + ) + + @classmethod + def upsert(cls, path: str, value: Any, existed: bool) -> "JSONPatchOp": + """Create an add or replace operation depending on whether the path existed. + + Args: + path: JSON Pointer path to the target location + value: The value to set + existed: True if the path already exists (use replace), False otherwise (use add) + + Returns: + JSONPatchOp with 'replace' if existed is True, 'add' otherwise + """ + return cls(op="replace" if existed else "add", path=path, value=value) diff --git a/orchestrator/search/agent/prompts.py b/orchestrator/search/agent/prompts.py index 41d6fbf33..e0567f4ab 100644 --- a/orchestrator/search/agent/prompts.py +++ b/orchestrator/search/agent/prompts.py @@ -50,19 +50,15 @@ async def get_base_instructions() -> str: Follow these steps in strict order: - 1. **Set Context**: If the user is asking for a NEW search, call `start_new_search`. If the user is asking a question about EXISTING results shown in the "Current Search Results" section, skip to step 4 (Report) and answer using the data provided. + 1. **Set Context**: If the user is asking for a NEW search, call `start_new_search`. 2. **Analyze for Filters**: Based on the user's request, decide if specific filters are necessary. - **If filters ARE required**, follow these sub-steps: a. **Gather Intel**: Identify all needed field names, then call `discover_filter_paths` and `get_valid_operators` **once each** to get all required information. b. **Construct FilterTree**: Build the `FilterTree` object. c. **Set Filters**: Call `set_filter_tree`. 3. **Execute**: Call `run_search`. This is done for both filtered and non-filtered searches. - 4. **Report**: Answer the users' question directly and summarize when appropiate. - 5. **Export (if requested)**: If the user asks to export, download, or save results as CSV/file: - - **IMPORTANT**: Export is ONLY available for SELECT actions (not COUNT or AGGREGATE) - - Call `prepare_export` to save the query and generate a download URL - - The UI will automatically display a download button - you don't need to mention URLs or IDs - - Simply confirm to the user that the export is ready for download + + After search execution, follow the dynamic instructions based on the current state. --- ### 4. Critical Rules @@ -70,7 +66,6 @@ async def get_base_instructions() -> str: - **NEVER GUESS PATHS IN THE DATABASE**: You *must* verify every filter path by calling `discover_filter_paths` first. If a path does not exist, you may attempt to map the question on an existing paths that are valid and available from `discover_filter_paths`. If you cannot infer a match, inform the user and do not include it in the `FilterTree`. - **USE FULL PATHS**: Always use the full, unambiguous path returned by the discovery tool. - **MATCH OPERATORS**: Only use operators that are compatible with the field type as confirmed by `get_filter_operators`. - - **EXPORT RECOGNITION**: When users say things like "export this", "download as CSV", "save these results", "export to file", or similar phrases, they are requesting an export. Call `prepare_export` to handle this. """ ) @@ -81,25 +76,29 @@ async def get_dynamic_instructions(ctx: RunContext[StateDeps[SearchState]]) -> s param_state_str = json.dumps(state.parameters, indent=2, default=str) if state.parameters else "Not set." results_count = state.results_data.total_count if state.results_data else 0 - next_step_guidance = "" - results_section = "" - - if not state.parameters or not state.parameters.get("entity_type"): + if state.export_data: + next_step_guidance = ( + "INSTRUCTION: Export has been prepared successfully. " + "Simply confirm to the user that the export is ready for download. " + "DO NOT include or mention the download URL - the UI will display it automatically." + ) + elif not state.parameters or not state.parameters.get("entity_type"): next_step_guidance = ( "INSTRUCTION: The search context is not set. Your next action is to call `start_new_search`." ) elif results_count > 0: - next_step_guidance = ( - f"INSTRUCTION: Search completed with {results_count} results shown below. " - "Answer the user's question using the actual result data provided. " - "If the user requests an export/download, call `prepare_export`." + next_step_guidance = dedent( + f""" + INSTRUCTION: Search completed successfully. + Found {results_count} results containing only: entity_id, title, score. + + Choose your next action based on what the user requested: + 1. **Broad/generic search** (e.g., 'show me subscriptions'): Confirm search completed and report count. Do nothing else. + 2. **Question answerable with entity_id/title/score**: Answer directly using the current results. + 3. **Question requiring other details**: Call `fetch_entity_details` first, then answer with the detailed data. + 4. **Export request** (phrases like 'export', 'download', 'save as CSV'): Call `prepare_export` directly. + """ ) - - # Include results data in the prompt - if state.results_data and state.results_data.results: - results_section = "\n\n**Current Search Results:**\n" - for idx, result in enumerate(state.results_data.results, 1): - results_section += f"{idx}. **{result.entity_title}** (ID: {result.entity_id}, Score: {result.score:.2f})\n" else: next_step_guidance = ( "INSTRUCTION: Context is set. Now, analyze the user's request. " @@ -110,7 +109,7 @@ async def get_dynamic_instructions(ctx: RunContext[StateDeps[SearchState]]) -> s return dedent( f""" --- - ### Current State & Next Action + ## CURRENT STATE **Current Search Parameters:** ```json @@ -119,6 +118,9 @@ async def get_dynamic_instructions(ctx: RunContext[StateDeps[SearchState]]) -> s **Current Results Count:** {results_count} - **{next_step_guidance}**{results_section} + --- + ## NEXT ACTION REQUIRED + + {next_step_guidance} """ ) diff --git a/orchestrator/search/agent/tools.py b/orchestrator/search/agent/tools.py index 2835a56e1..848c8a20f 100644 --- a/orchestrator/search/agent/tools.py +++ b/orchestrator/search/agent/tools.py @@ -11,10 +11,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from typing import Any import structlog -from ag_ui.core import EventType, StateSnapshotEvent +from ag_ui.core import EventType, StateDeltaEvent, StateSnapshotEvent from pydantic_ai import RunContext from pydantic_ai.ag_ui import StateDeps from pydantic_ai.exceptions import ModelRetry @@ -26,8 +27,10 @@ list_paths, ) from orchestrator.db import AgentRunTable, SearchQueryTable, db +from orchestrator.search.agent.json_patch import JSONPatchOp from orchestrator.search.agent.state import ExportData, SearchResultsData, SearchState from orchestrator.search.core.types import ActionType, EntityType, FilterOp +from orchestrator.search.export import fetch_export_data from orchestrator.search.filters import FilterTree from orchestrator.search.retrieval.engine import execute_search from orchestrator.search.retrieval.exceptions import FilterValidationError, PathNotFoundError @@ -106,7 +109,7 @@ async def set_search_parameters( ctx: RunContext[StateDeps[SearchState]], entity_type: EntityType, action: str | ActionType = ActionType.SELECT, -) -> StateSnapshotEvent: +) -> StateDeltaEvent: """Updates search parameters without clearing filters or results. Use this to modify the entity type or action while preserving existing filters. @@ -127,9 +130,15 @@ async def set_search_parameters( logger.debug("Search parameters updated", parameters=ctx.deps.state.parameters) - return StateSnapshotEvent( - type=EventType.STATE_SNAPSHOT, - snapshot=ctx.deps.state.model_dump(), + return StateDeltaEvent( + type=EventType.STATE_DELTA, + delta=[ + JSONPatchOp.upsert( + path="/parameters", + value=ctx.deps.state.parameters, + existed=bool(params), + ) + ], ) @@ -137,7 +146,7 @@ async def set_search_parameters( async def set_filter_tree( ctx: RunContext[StateDeps[SearchState]], filters: FilterTree | None, -) -> StateSnapshotEvent: +) -> StateDeltaEvent: """Replace current filters atomically with a full FilterTree, or clear with None. Requirements: @@ -171,15 +180,25 @@ async def set_filter_tree( raise ModelRetry(f"Filter validation failed: {str(e)}. Please check your filter structure and try again.") filter_data = None if filters is None else filters.model_dump(mode="json", by_alias=True) + filters_existed = "filters" in ctx.deps.state.parameters ctx.deps.state.parameters["filters"] = filter_data - return StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=ctx.deps.state.model_dump()) + return StateDeltaEvent( + type=EventType.STATE_DELTA, + delta=[ + JSONPatchOp.upsert( + path="/parameters/filters", + value=filter_data, + existed=filters_existed, + ) + ], + ) @search_toolset.tool async def run_search( ctx: RunContext[StateDeps[SearchState]], limit: int = 10, -) -> StateSnapshotEvent: +) -> StateDeltaEvent: """Execute the search with the current parameters and save to database.""" if not ctx.deps.state.parameters: raise ValueError("No search parameters set") @@ -199,12 +218,15 @@ async def run_search( params.limit = limit + changes: list[JSONPatchOp] = [] + if not ctx.deps.state.run_id: agent_run = AgentRunTable(agent_type="search") db.session.add(agent_run) db.session.commit() ctx.deps.state.run_id = agent_run.run_id logger.debug("Created new agent run", run_id=str(agent_run.run_id)) + changes.append(JSONPatchOp(op="add", path="/run_id", value=str(ctx.deps.state.run_id))) # Get query with embedding and save to DB search_response = await execute_search(params, db.session) @@ -218,8 +240,10 @@ async def run_search( ) db.session.add(search_query) db.session.commit() + query_id_existed = ctx.deps.state.query_id is not None ctx.deps.state.query_id = search_query.query_id logger.debug("Saved search query", query_id=str(search_query.query_id), query_number=query_number) + changes.append(JSONPatchOp.upsert(path="/query_id", value=str(ctx.deps.state.query_id), existed=query_id_existed)) logger.debug( "Search completed", @@ -229,6 +253,7 @@ async def run_search( # Store results data for both frontend display and agent context results_url = f"{app_settings.BASE_URL}/api/search/queries/{ctx.deps.state.query_id}" + results_data_existed = ctx.deps.state.results_data is not None ctx.deps.state.results_data = SearchResultsData( query_id=str(ctx.deps.state.query_id), results_url=results_url, @@ -236,8 +261,13 @@ async def run_search( message=f"Found {len(search_response.results)} results.", results=search_response.results, # Include actual results in state ) + changes.append( + JSONPatchOp.upsert( + path="/results_data", value=ctx.deps.state.results_data.model_dump(), existed=results_data_existed + ) + ) - return StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=ctx.deps.state.model_dump()) + return StateDeltaEvent(type=EventType.STATE_DELTA, delta=changes) @search_toolset.tool @@ -316,6 +346,48 @@ async def get_valid_operators() -> dict[str, list[FilterOp]]: return operator_map +@search_toolset.tool +async def fetch_entity_details( + ctx: RunContext[StateDeps[SearchState]], + limit: int = 10, +) -> str: + """Fetch detailed entity information to answer user questions. + + Use this tool when you need detailed information about entities from the search results + to answer the user's question. This provides the same detailed data that would be + included in an export (e.g., subscription status, product details, workflow info, etc.). + + Args: + ctx: Runtime context for agent (injected). + limit: Maximum number of entities to fetch details for (default 10). + + Returns: + JSON string containing detailed entity information. + + Raises: + ValueError: If no search results are available. + """ + if not ctx.deps.state.results_data or not ctx.deps.state.results_data.results: + raise ValueError("No search results available. Run a search first before fetching entity details.") + + if not ctx.deps.state.parameters: + raise ValueError("No search parameters found.") + + entity_type = EntityType(ctx.deps.state.parameters["entity_type"]) + + entity_ids = [r.entity_id for r in ctx.deps.state.results_data.results[:limit]] + + logger.debug( + "Fetching detailed entity data", + entity_type=entity_type.value, + entity_count=len(entity_ids), + ) + + detailed_data = fetch_export_data(entity_type, entity_ids) + + return json.dumps(detailed_data, indent=2) + + @search_toolset.tool async def prepare_export( ctx: RunContext[StateDeps[SearchState]], @@ -350,4 +422,9 @@ async def prepare_export( logger.debug("Export data set in state", export_data=ctx.deps.state.export_data.model_dump()) - return StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=ctx.deps.state.model_dump()) + # Should use StateDelta here? Use snapshot to workaround state persistence issue + # TODO: Fix root cause; state is empty on frontend when it should have data from run_search + return StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + snapshot=ctx.deps.state.model_dump(), + ) diff --git a/orchestrator/search/retrieval/engine.py b/orchestrator/search/retrieval/engine.py index cf00e71e3..df4b773dd 100644 --- a/orchestrator/search/retrieval/engine.py +++ b/orchestrator/search/retrieval/engine.py @@ -24,7 +24,7 @@ from orchestrator.search.schemas.results import MatchingField, SearchResponse, SearchResult from .builder import build_candidate_query -from .pagination import PaginationParams +from .pagination import PageCursor from .query_state import SearchQueryState from .retrievers import Retriever from .utils import generate_highlight_indices @@ -122,7 +122,7 @@ async def _execute_search_internal( search_params: BaseSearchParameters, db_session: Session, limit: int, - pagination_params: PaginationParams | None = None, + cursor: PageCursor | None = None, query_embedding: list[float] | None = None, ) -> SearchResponse: """Internal function to execute search with specified parameters. @@ -131,7 +131,7 @@ async def _execute_search_internal( search_params: The search parameters specifying vector, fuzzy, or filter criteria. db_session: The active SQLAlchemy session for executing the query. limit: Maximum number of results to return. - pagination_params: Optional pagination parameters. + cursor: Optional pagination cursor. query_embedding: Optional pre-computed query embedding to use instead of generating a new one. Returns: @@ -143,13 +143,11 @@ async def _execute_search_internal( candidate_query = build_candidate_query(search_params) - pagination_params = pagination_params or PaginationParams() - if search_params.vector_query and not query_embedding: query_embedding = await QueryEmbedder.generate_for_text_async(search_params.vector_query) - retriever = await Retriever.route(search_params, pagination_params, query_embedding) + retriever = await Retriever.route(search_params, cursor, query_embedding) logger.debug("Using retriever", retriever_type=retriever.__class__.__name__) final_stmt = retriever.apply(candidate_query) @@ -166,13 +164,11 @@ async def _execute_search_internal( async def execute_search( search_params: BaseSearchParameters, db_session: Session, - pagination_params: PaginationParams | None = None, + cursor: PageCursor | None = None, query_embedding: list[float] | None = None, ) -> SearchResponse: """Execute a search and return ranked results.""" - return await _execute_search_internal( - search_params, db_session, search_params.limit, pagination_params, query_embedding - ) + return await _execute_search_internal(search_params, db_session, search_params.limit, cursor, query_embedding) async def execute_search_for_export( diff --git a/orchestrator/search/retrieval/pagination.py b/orchestrator/search/retrieval/pagination.py index f672b8b98..a630bf149 100644 --- a/orchestrator/search/retrieval/pagination.py +++ b/orchestrator/search/retrieval/pagination.py @@ -22,14 +22,6 @@ from orchestrator.search.schemas.results import SearchResponse -class PaginationParams(BaseModel): - """Parameters for pagination in search queries.""" - - page_after_score: float | None = None - page_after_id: str | None = None - query_id: UUID | None = None # None only for first page, always set when cursor exists - - class PageCursor(BaseModel): score: float id: str @@ -50,29 +42,9 @@ def decode(cls, cursor: str) -> "PageCursor": raise InvalidCursorError("Invalid pagination cursor") from e -async def process_pagination_cursor(cursor: str) -> PaginationParams: - """Decode pagination cursor and extract pagination parameters. - - Args: - cursor: Base64-encoded cursor - - Returns: - PaginationParams containing page position and query_id from the cursor - - Raises: - InvalidCursorError: If cursor cannot be decoded - """ - page_cursor = PageCursor.decode(cursor) - return PaginationParams( - page_after_score=page_cursor.score, - page_after_id=page_cursor.id, - query_id=page_cursor.query_id, - ) - - -def create_next_page_cursor( +def encode_next_page_cursor( search_response: SearchResponse, - pagination_params: PaginationParams, + cursor: PageCursor | None, search_params: SearchParameters, ) -> str | None: """Create next page cursor if there are more results. @@ -82,7 +54,7 @@ def create_next_page_cursor( Args: search_response: SearchResponse containing results and query_embedding - pagination_params: Current pagination parameters (may have query_id if not first page) + cursor: Current page cursor (None for first page, PageCursor for subsequent pages) search_params: Search parameters to save for pagination consistency Returns: @@ -95,7 +67,7 @@ def create_next_page_cursor( return None # If this is the first page, save query state to database - if not pagination_params.query_id: + if cursor is None: query_state = SearchQueryState(parameters=search_params, query_embedding=search_response.query_embedding) search_query = SearchQueryTable.from_state(state=query_state) @@ -103,7 +75,7 @@ def create_next_page_cursor( db.session.commit() query_id = search_query.query_id else: - query_id = pagination_params.query_id + query_id = cursor.query_id last_item = search_response.results[-1] cursor_data = PageCursor( diff --git a/orchestrator/search/retrieval/retrievers/base.py b/orchestrator/search/retrieval/retrievers/base.py index 02bd94c7a..e7da35596 100644 --- a/orchestrator/search/retrieval/retrievers/base.py +++ b/orchestrator/search/retrieval/retrievers/base.py @@ -20,7 +20,7 @@ from orchestrator.search.core.types import FieldType, SearchMetadata from orchestrator.search.schemas.parameters import BaseSearchParameters -from ..pagination import PaginationParams +from ..pagination import PageCursor logger = structlog.get_logger(__name__) @@ -44,7 +44,7 @@ class Retriever(ABC): async def route( cls, params: BaseSearchParameters, - pagination_params: PaginationParams, + cursor: PageCursor | None, query_embedding: list[float] | None = None, ) -> "Retriever": """Route to the appropriate retriever instance based on search parameters. @@ -57,7 +57,7 @@ async def route( Args: params: Search parameters including vector queries, fuzzy terms, and filters - pagination_params: Pagination parameters for cursor-based paging + cursor: Pagination cursor for cursor-based paging query_embedding: Query embedding for semantic search, or None if not available Returns: @@ -76,13 +76,13 @@ async def route( # Select retriever based on available search criteria if query_embedding is not None and fuzzy_term is not None: - return RrfHybridRetriever(query_embedding, fuzzy_term, pagination_params) + return RrfHybridRetriever(query_embedding, fuzzy_term, cursor) if query_embedding is not None: - return SemanticRetriever(query_embedding, pagination_params) + return SemanticRetriever(query_embedding, cursor) if fuzzy_term is not None: - return FuzzyRetriever(fuzzy_term, pagination_params) + return FuzzyRetriever(fuzzy_term, cursor) - return StructuredRetriever(pagination_params) + return StructuredRetriever(cursor) @abstractmethod def apply(self, candidate_query: Select) -> Select: diff --git a/orchestrator/search/retrieval/retrievers/fuzzy.py b/orchestrator/search/retrieval/retrievers/fuzzy.py index 885036cd7..4dd19fee5 100644 --- a/orchestrator/search/retrieval/retrievers/fuzzy.py +++ b/orchestrator/search/retrieval/retrievers/fuzzy.py @@ -17,17 +17,16 @@ from orchestrator.db.models import AiSearchIndex from orchestrator.search.core.types import SearchMetadata -from ..pagination import PaginationParams +from ..pagination import PageCursor from .base import Retriever class FuzzyRetriever(Retriever): """Ranks results based on the max of fuzzy text similarity scores.""" - def __init__(self, fuzzy_term: str, pagination_params: PaginationParams) -> None: + def __init__(self, fuzzy_term: str, cursor: PageCursor | None) -> None: self.fuzzy_term = fuzzy_term - self.page_after_score = pagination_params.page_after_score - self.page_after_id = pagination_params.page_after_id + self.cursor = cursor def apply(self, candidate_query: Select) -> Select: cand = candidate_query.subquery() @@ -83,13 +82,13 @@ def _apply_score_pagination( self, stmt: Select, score_column: ColumnElement, entity_id_column: ColumnElement ) -> Select: """Apply standard score + entity_id pagination.""" - if self.page_after_score is not None and self.page_after_id is not None: + if self.cursor is not None: stmt = stmt.where( or_( - score_column < self.page_after_score, + score_column < self.cursor.score, and_( - score_column == self.page_after_score, - entity_id_column > self.page_after_id, + score_column == self.cursor.score, + entity_id_column > self.cursor.id, ), ) ) diff --git a/orchestrator/search/retrieval/retrievers/hybrid.py b/orchestrator/search/retrieval/retrievers/hybrid.py index a3cc8ac3e..89f56df41 100644 --- a/orchestrator/search/retrieval/retrievers/hybrid.py +++ b/orchestrator/search/retrieval/retrievers/hybrid.py @@ -20,7 +20,7 @@ from orchestrator.db.models import AiSearchIndex from orchestrator.search.core.types import SearchMetadata -from ..pagination import PaginationParams +from ..pagination import PageCursor from .base import Retriever @@ -127,14 +127,13 @@ def __init__( self, q_vec: list[float], fuzzy_term: str, - pagination_params: PaginationParams, + cursor: PageCursor | None, k: int = 60, field_candidates_limit: int = 100, ) -> None: self.q_vec = q_vec self.fuzzy_term = fuzzy_term - self.page_after_score = pagination_params.page_after_score - self.page_after_id = pagination_params.page_after_id + self.cursor = cursor self.k = k self.field_candidates_limit = field_candidates_limit @@ -266,12 +265,12 @@ def _apply_fused_pagination( entity_id_column: ColumnElement, ) -> Select: """Keyset paginate by fused score + id.""" - if self.page_after_score is not None and self.page_after_id is not None: - score_param = self._quantize_score_for_pagination(self.page_after_score) + if self.cursor is not None: + score_param = self._quantize_score_for_pagination(self.cursor.score) stmt = stmt.where( or_( score_column < score_param, - and_(score_column == score_param, entity_id_column > self.page_after_id), + and_(score_column == score_param, entity_id_column > self.cursor.id), ) ) return stmt diff --git a/orchestrator/search/retrieval/retrievers/semantic.py b/orchestrator/search/retrieval/retrievers/semantic.py index 36f5efc13..89702fa1d 100644 --- a/orchestrator/search/retrieval/retrievers/semantic.py +++ b/orchestrator/search/retrieval/retrievers/semantic.py @@ -17,17 +17,16 @@ from orchestrator.db.models import AiSearchIndex from orchestrator.search.core.types import SearchMetadata -from ..pagination import PaginationParams +from ..pagination import PageCursor from .base import Retriever class SemanticRetriever(Retriever): """Ranks results based on the minimum semantic vector distance.""" - def __init__(self, vector_query: list[float], pagination_params: PaginationParams) -> None: + def __init__(self, vector_query: list[float], cursor: PageCursor | None) -> None: self.vector_query = vector_query - self.page_after_score = pagination_params.page_after_score - self.page_after_id = pagination_params.page_after_id + self.cursor = cursor def apply(self, candidate_query: Select) -> Select: cand = candidate_query.subquery() @@ -85,12 +84,12 @@ def _apply_semantic_pagination( self, stmt: Select, score_column: ColumnElement, entity_id_column: ColumnElement ) -> Select: """Apply semantic score pagination with precise Decimal handling.""" - if self.page_after_score is not None and self.page_after_id is not None: - score_param = self._quantize_score_for_pagination(self.page_after_score) + if self.cursor is not None: + score_param = self._quantize_score_for_pagination(self.cursor.score) stmt = stmt.where( or_( score_column < score_param, - and_(score_column == score_param, entity_id_column > self.page_after_id), + and_(score_column == score_param, entity_id_column > self.cursor.id), ) ) return stmt diff --git a/orchestrator/search/retrieval/retrievers/structured.py b/orchestrator/search/retrieval/retrievers/structured.py index 29d546eff..0ce9f0ea6 100644 --- a/orchestrator/search/retrieval/retrievers/structured.py +++ b/orchestrator/search/retrieval/retrievers/structured.py @@ -15,22 +15,22 @@ from orchestrator.search.core.types import SearchMetadata -from ..pagination import PaginationParams +from ..pagination import PageCursor from .base import Retriever class StructuredRetriever(Retriever): """Applies a dummy score for purely structured searches with no text query.""" - def __init__(self, pagination_params: PaginationParams) -> None: - self.page_after_id = pagination_params.page_after_id + def __init__(self, cursor: PageCursor | None) -> None: + self.cursor = cursor def apply(self, candidate_query: Select) -> Select: cand = candidate_query.subquery() stmt = select(cand.c.entity_id, cand.c.entity_title, literal(1.0).label("score")).select_from(cand) - if self.page_after_id: - stmt = stmt.where(cand.c.entity_id > self.page_after_id) + if self.cursor is not None: + stmt = stmt.where(cand.c.entity_id > self.cursor.id) return stmt.order_by(cand.c.entity_id.asc()) From 7e70187b8eeb1190977ce01115d0ac6a33fef2ed Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Wed, 22 Oct 2025 18:19:54 +0200 Subject: [PATCH 14/16] Update tests, migration, test snapshots and include search query table in cli resize script --- orchestrator/cli/search/resize_embedding.py | 50 ++++++++------- ...add_agent_runs_and_agent_queries_tables.py | 59 ----------------- orchestrator/search/llm_migration.py | 37 +++++++++++ .../retrieval/retrievers/sql_snapshots.json | 16 ++--- .../retrieval/retrievers/test_retrievers.py | 64 ++++++++++--------- 5 files changed, 106 insertions(+), 120 deletions(-) delete mode 100644 orchestrator/migrations/versions/schema/2025-10-09_459f352f5aa6_add_agent_runs_and_agent_queries_tables.py diff --git a/orchestrator/cli/search/resize_embedding.py b/orchestrator/cli/search/resize_embedding.py index 8d944e3cf..7fa966cb3 100644 --- a/orchestrator/cli/search/resize_embedding.py +++ b/orchestrator/cli/search/resize_embedding.py @@ -4,7 +4,7 @@ from sqlalchemy.exc import SQLAlchemyError from orchestrator.db import db -from orchestrator.db.models import AiSearchIndex +from orchestrator.db.models import AiSearchIndex, SearchQueryTable from orchestrator.llm_settings import llm_settings logger = structlog.get_logger(__name__) @@ -40,17 +40,20 @@ def get_current_embedding_dimension() -> int | None: return None -def drop_all_embeddings() -> int: - """Drop all records from the ai_search_index table. +def drop_all_embeddings() -> tuple[int, int]: + """Drop all records from ai_search_index and search_queries tables. Returns: - Number of records deleted + Tuple of (ai_search_index records deleted, search_queries records deleted) """ try: - result = db.session.query(AiSearchIndex).delete() + index_deleted = db.session.query(AiSearchIndex).delete() + query_deleted = db.session.query(SearchQueryTable).delete() db.session.commit() - logger.info(f"Deleted {result} records from ai_search_index") - return result + logger.info( + f"Deleted {index_deleted} records from ai_search_index and {query_deleted} records from search_queries" + ) + return index_deleted, query_deleted except SQLAlchemyError as e: db.session.rollback() @@ -59,34 +62,34 @@ def drop_all_embeddings() -> int: def alter_embedding_column_dimension(new_dimension: int) -> None: - """Alter the embedding column to use the new dimension size. + """Alter the embedding columns in both ai_search_index and search_queries tables. Args: new_dimension: New vector dimension size """ try: - drop_query = text("ALTER TABLE ai_search_index DROP COLUMN IF EXISTS embedding") - db.session.execute(drop_query) + db.session.execute(text("ALTER TABLE ai_search_index DROP COLUMN IF EXISTS embedding")) + db.session.execute(text(f"ALTER TABLE ai_search_index ADD COLUMN embedding vector({new_dimension})")) - add_query = text(f"ALTER TABLE ai_search_index ADD COLUMN embedding vector({new_dimension})") - db.session.execute(add_query) + db.session.execute(text("ALTER TABLE search_queries DROP COLUMN IF EXISTS query_embedding")) + db.session.execute(text(f"ALTER TABLE search_queries ADD COLUMN query_embedding vector({new_dimension})")) db.session.commit() - logger.info(f"Altered embedding column to dimension {new_dimension}") + logger.info(f"Altered embedding columns to dimension {new_dimension} in ai_search_index and search_queries") except SQLAlchemyError as e: db.session.rollback() - logger.error("Failed to alter embedding column dimension", error=str(e)) + logger.error("Failed to alter embedding column dimensions", error=str(e)) raise @app.command("resize") def resize_embeddings_command() -> None: - """Resize vector dimensions of the ai_search_index embedding column. + """Resize vector dimensions of embedding columns in ai_search_index and search_queries tables. Compares the current embedding dimension in the database with the configured - dimension in llm_settings. If they differ, drops all records and alters the - column to match the new dimension. + dimension in llm_settings. If they differ, drops all records and alters both + embedding columns to match the new dimension. """ new_dimension = llm_settings.EMBEDDING_DIMENSION @@ -107,22 +110,25 @@ def resize_embeddings_command() -> None: logger.info("Dimension mismatch detected", current_dimension=current_dimension, new_dimension=new_dimension) - if not typer.confirm("This will DELETE ALL RECORDS from ai_search_index and alter the embedding column. Continue?"): + if not typer.confirm( + "This will DELETE ALL RECORDS from ai_search_index and search_queries tables and alter embedding columns. Continue?" + ): logger.info("Operation cancelled by user") return try: # Drop all records first. logger.info("Dropping all embedding records...") - deleted_count = drop_all_embeddings() + index_deleted, query_deleted = drop_all_embeddings() - # Then alter column dimension. - logger.info(f"Altering embedding column to dimension {new_dimension}...") + # Then alter column dimensions. + logger.info(f"Altering embedding columns to dimension {new_dimension}...") alter_embedding_column_dimension(new_dimension) logger.info( "Embedding dimension resize completed successfully", - records_deleted=deleted_count, + index_records_deleted=index_deleted, + query_records_deleted=query_deleted, new_dimension=new_dimension, ) diff --git a/orchestrator/migrations/versions/schema/2025-10-09_459f352f5aa6_add_agent_runs_and_agent_queries_tables.py b/orchestrator/migrations/versions/schema/2025-10-09_459f352f5aa6_add_agent_runs_and_agent_queries_tables.py deleted file mode 100644 index 000cee1c3..000000000 --- a/orchestrator/migrations/versions/schema/2025-10-09_459f352f5aa6_add_agent_runs_and_agent_queries_tables.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Add agent_runs and search_queries tables. - -Revision ID: 459f352f5aa6 -Revises: 850dccac3b02 -Create Date: 2025-10-09 00:52:16.297143 - -""" - -import sqlalchemy as sa -from alembic import op -from pgvector.sqlalchemy import Vector -from sqlalchemy.dialects import postgresql -from sqlalchemy_utils import UUIDType - -# revision identifiers, used by Alembic. -revision = "459f352f5aa6" -down_revision = "850dccac3b02" -branch_labels = None -depends_on = None - - -def upgrade() -> None: - op.create_table( - "agent_runs", - sa.Column("run_id", UUIDType(), server_default=sa.text("uuid_generate_v4()"), nullable=False), - sa.Column("agent_type", sa.String(length=50), nullable=False), - sa.Column( - "created_at", sa.TIMESTAMP(timezone=True), server_default=sa.text("current_timestamp"), nullable=False - ), - sa.PrimaryKeyConstraint("run_id"), - ) - op.create_index("ix_agent_runs_created_at", "agent_runs", ["created_at"]) - - op.create_table( - "search_queries", - sa.Column("query_id", UUIDType(), server_default=sa.text("uuid_generate_v4()"), nullable=False), - sa.Column("run_id", UUIDType(), nullable=True), - sa.Column("query_number", sa.Integer(), nullable=False), - sa.Column("parameters", postgresql.JSONB(astext_type=sa.Text()), nullable=False), - sa.Column("query_embedding", Vector(1536), nullable=True), - sa.Column( - "executed_at", sa.TIMESTAMP(timezone=True), server_default=sa.text("current_timestamp"), nullable=False - ), - sa.ForeignKeyConstraint(["run_id"], ["agent_runs.run_id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("query_id"), - ) - op.create_index("ix_search_queries_run_id", "search_queries", ["run_id"]) - op.create_index("ix_search_queries_executed_at", "search_queries", ["executed_at"]) - op.create_index("ix_search_queries_query_id", "search_queries", ["query_id"]) - - -def downgrade() -> None: - op.drop_index("ix_search_queries_query_id", table_name="search_queries") - op.drop_index("ix_search_queries_executed_at", table_name="search_queries") - op.drop_index("ix_search_queries_run_id", table_name="search_queries") - op.drop_table("search_queries") - - op.drop_index("ix_agent_runs_created_at", table_name="agent_runs") - op.drop_table("agent_runs") diff --git a/orchestrator/search/llm_migration.py b/orchestrator/search/llm_migration.py index 1cad10c1e..19bd06cee 100644 --- a/orchestrator/search/llm_migration.py +++ b/orchestrator/search/llm_migration.py @@ -37,6 +37,7 @@ def run_migration(connection: Connection) -> None: if llm_settings.LLM_FORCE_EXTENTION_MIGRATION or res.rowcount == 0: # Create PostgreSQL extensions logger.info("Attempting to run the extention creation;") + connection.execute(text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')) connection.execute(text("CREATE EXTENSION IF NOT EXISTS ltree;")) connection.execute(text("CREATE EXTENSION IF NOT EXISTS unaccent;")) connection.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;")) @@ -114,6 +115,42 @@ def run_migration(connection: Connection) -> None: ) ) + # Create agent_runs table + connection.execute( + text( + """ + CREATE TABLE IF NOT EXISTS agent_runs ( + run_id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + agent_type VARCHAR(50) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL + ); + """ + ) + ) + connection.execute(text("CREATE INDEX IF NOT EXISTS ix_agent_runs_created_at ON agent_runs (created_at);")) + + # Create search_queries table + connection.execute( + text( + f""" + CREATE TABLE IF NOT EXISTS search_queries ( + query_id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + run_id UUID, + query_number INTEGER NOT NULL, + parameters JSONB NOT NULL, + query_embedding VECTOR({TARGET_DIM}), + executed_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + CONSTRAINT fk_search_queries_run_id FOREIGN KEY (run_id) REFERENCES agent_runs(run_id) ON DELETE CASCADE + ); + """ + ) + ) + connection.execute(text("CREATE INDEX IF NOT EXISTS ix_search_queries_run_id ON search_queries (run_id);")) + connection.execute( + text("CREATE INDEX IF NOT EXISTS ix_search_queries_executed_at ON search_queries (executed_at);") + ) + connection.execute(text("CREATE INDEX IF NOT EXISTS ix_search_queries_query_id ON search_queries (query_id);")) + connection.commit() logger.info("LLM migration completed successfully") diff --git a/test/unit_tests/search/retrieval/retrievers/sql_snapshots.json b/test/unit_tests/search/retrieval/retrievers/sql_snapshots.json index c17692086..cd75ee092 100644 --- a/test/unit_tests/search/retrieval/retrievers/sql_snapshots.json +++ b/test/unit_tests/search/retrieval/retrievers/sql_snapshots.json @@ -1,10 +1,10 @@ { - "FuzzyRetriever.test_basic_query_structure": "SELECT ranked_fuzzy.entity_id, ranked_fuzzy.score, ranked_fuzzy.highlight_text, ranked_fuzzy.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id) ai_search_index.entity_id AS entity_id, CAST(round(CAST(max(word_similarity(%(word_similarity_1)s, ai_search_index.value)) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_1)s <%% ai_search_index.value)) AS ranked_fuzzy ORDER BY ranked_fuzzy.score DESC NULLS LAST, ranked_fuzzy.entity_id ASC", - "FuzzyRetriever.test_pagination_structure": "SELECT ranked_fuzzy.entity_id, ranked_fuzzy.score, ranked_fuzzy.highlight_text, ranked_fuzzy.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id) ai_search_index.entity_id AS entity_id, CAST(round(CAST(max(word_similarity(%(word_similarity_1)s, ai_search_index.value)) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_1)s <%% ai_search_index.value)) AS ranked_fuzzy \nWHERE ranked_fuzzy.score < %(score_1)s OR ranked_fuzzy.score = %(score_2)s AND ranked_fuzzy.entity_id > %(entity_id_1)s::UUID ORDER BY ranked_fuzzy.score DESC NULLS LAST, ranked_fuzzy.entity_id ASC", - "RrfHybridRetriever.test_basic_query_structure": "WITH field_candidates AS \n(SELECT ai_search_index.entity_id AS entity_id, ai_search_index.path AS path, ai_search_index.value AS value, coalesce(CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END, %(param_9)s) AS semantic_distance, word_similarity(%(word_similarity_1)s, ai_search_index.value) AS fuzzy_score \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_10)s <%% ai_search_index.value) ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC NULLS LAST, CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END ASC NULLS LAST, ai_search_index.entity_id ASC \n LIMIT %(param_11)s), \nentity_scores AS \n(SELECT field_candidates.entity_id AS entity_id, avg(field_candidates.semantic_distance) AS avg_semantic_distance, avg(field_candidates.fuzzy_score) AS avg_fuzzy_score \nFROM field_candidates GROUP BY field_candidates.entity_id), \nentity_highlights AS \n(SELECT DISTINCT ON (field_candidates.entity_id) field_candidates.entity_id AS entity_id, first_value(field_candidates.value) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_text, first_value(field_candidates.path) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_path \nFROM field_candidates), \nranked_results AS \n(SELECT entity_scores.entity_id AS entity_id, entity_scores.avg_semantic_distance AS avg_semantic_distance, entity_scores.avg_fuzzy_score AS avg_fuzzy_score, entity_highlights.highlight_text AS highlight_text, entity_highlights.highlight_path AS highlight_path, dense_rank() OVER (ORDER BY entity_scores.avg_semantic_distance ASC NULLS LAST, entity_scores.entity_id ASC) AS sem_rank, dense_rank() OVER (ORDER BY entity_scores.avg_fuzzy_score DESC NULLS LAST, entity_scores.entity_id ASC) AS fuzzy_rank \nFROM entity_scores JOIN entity_highlights ON entity_scores.entity_id = entity_highlights.entity_id)\n SELECT ranked_results.entity_id, CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, ranked_results.highlight_text, ranked_results.highlight_path, CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS perfect_match \nFROM ranked_results ORDER BY score DESC NULLS LAST, ranked_results.entity_id ASC", - "RrfHybridRetriever.test_pagination_structure": "WITH field_candidates AS \n(SELECT ai_search_index.entity_id AS entity_id, ai_search_index.path AS path, ai_search_index.value AS value, coalesce(CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END, %(param_9)s) AS semantic_distance, word_similarity(%(word_similarity_1)s, ai_search_index.value) AS fuzzy_score \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_10)s <%% ai_search_index.value) ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC NULLS LAST, CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END ASC NULLS LAST, ai_search_index.entity_id ASC \n LIMIT %(param_11)s), \nentity_scores AS \n(SELECT field_candidates.entity_id AS entity_id, avg(field_candidates.semantic_distance) AS avg_semantic_distance, avg(field_candidates.fuzzy_score) AS avg_fuzzy_score \nFROM field_candidates GROUP BY field_candidates.entity_id), \nentity_highlights AS \n(SELECT DISTINCT ON (field_candidates.entity_id) field_candidates.entity_id AS entity_id, first_value(field_candidates.value) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_text, first_value(field_candidates.path) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_path \nFROM field_candidates), \nranked_results AS \n(SELECT entity_scores.entity_id AS entity_id, entity_scores.avg_semantic_distance AS avg_semantic_distance, entity_scores.avg_fuzzy_score AS avg_fuzzy_score, entity_highlights.highlight_text AS highlight_text, entity_highlights.highlight_path AS highlight_path, dense_rank() OVER (ORDER BY entity_scores.avg_semantic_distance ASC NULLS LAST, entity_scores.entity_id ASC) AS sem_rank, dense_rank() OVER (ORDER BY entity_scores.avg_fuzzy_score DESC NULLS LAST, entity_scores.entity_id ASC) AS fuzzy_rank \nFROM entity_scores JOIN entity_highlights ON entity_scores.entity_id = entity_highlights.entity_id)\n SELECT ranked_results.entity_id, CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, ranked_results.highlight_text, ranked_results.highlight_path, CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS perfect_match \nFROM ranked_results \nWHERE CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) < %(param_12)s OR CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) = %(param_12)s AND ranked_results.entity_id > %(entity_id_1)s::UUID ORDER BY score DESC NULLS LAST, ranked_results.entity_id ASC", - "SemanticRetriever.test_basic_query_structure": "SELECT ranked_semantic.entity_id, ranked_semantic.score, ranked_semantic.highlight_text, ranked_semantic.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id) ai_search_index.entity_id AS entity_id, CAST(round(CAST(%(param_1)s / CAST((%(param_2)s + CAST(min(ai_search_index.embedding <-> %(embedding_1)s) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.embedding IS NOT NULL) AS ranked_semantic ORDER BY ranked_semantic.score DESC NULLS LAST, ranked_semantic.entity_id ASC", - "SemanticRetriever.test_pagination_structure": "SELECT ranked_semantic.entity_id, ranked_semantic.score, ranked_semantic.highlight_text, ranked_semantic.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id) ai_search_index.entity_id AS entity_id, CAST(round(CAST(%(param_1)s / CAST((%(param_2)s + CAST(min(ai_search_index.embedding <-> %(embedding_1)s) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.embedding IS NOT NULL) AS ranked_semantic \nWHERE ranked_semantic.score < %(param_3)s OR ranked_semantic.score = %(param_3)s AND ranked_semantic.entity_id > %(entity_id_1)s::UUID ORDER BY ranked_semantic.score DESC NULLS LAST, ranked_semantic.entity_id ASC", - "StructuredRetriever.test_basic_query_structure": "SELECT anon_1.entity_id, %(param_1)s AS score \nFROM (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 ORDER BY anon_1.entity_id ASC", - "StructuredRetriever.test_pagination_structure": "SELECT anon_1.entity_id, %(param_1)s AS score \nFROM (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 \nWHERE anon_1.entity_id > %(entity_id_1)s::UUID ORDER BY anon_1.entity_id ASC" + "FuzzyRetriever.test_basic_query_structure": "SELECT ranked_fuzzy.entity_id, ranked_fuzzy.entity_title, ranked_fuzzy.score, ranked_fuzzy.highlight_text, ranked_fuzzy.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id, ai_search_index.entity_title) ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title, CAST(round(CAST(max(word_similarity(%(word_similarity_1)s, ai_search_index.value)) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_1)s <%% ai_search_index.value)) AS ranked_fuzzy ORDER BY ranked_fuzzy.score DESC NULLS LAST, ranked_fuzzy.entity_id ASC", + "FuzzyRetriever.test_pagination_structure": "SELECT ranked_fuzzy.entity_id, ranked_fuzzy.entity_title, ranked_fuzzy.score, ranked_fuzzy.highlight_text, ranked_fuzzy.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id, ai_search_index.entity_title) ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title, CAST(round(CAST(max(word_similarity(%(word_similarity_1)s, ai_search_index.value)) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_1)s <%% ai_search_index.value)) AS ranked_fuzzy \nWHERE ranked_fuzzy.score < %(score_1)s OR ranked_fuzzy.score = %(score_2)s AND ranked_fuzzy.entity_id > %(entity_id_1)s::UUID ORDER BY ranked_fuzzy.score DESC NULLS LAST, ranked_fuzzy.entity_id ASC", + "RrfHybridRetriever.test_basic_query_structure": "WITH field_candidates AS \n(SELECT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title, ai_search_index.path AS path, ai_search_index.value AS value, coalesce(CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END, %(param_9)s) AS semantic_distance, word_similarity(%(word_similarity_1)s, ai_search_index.value) AS fuzzy_score \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_10)s <%% ai_search_index.value) ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC NULLS LAST, CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END ASC NULLS LAST, ai_search_index.entity_id ASC \n LIMIT %(param_11)s), \nentity_scores AS \n(SELECT field_candidates.entity_id AS entity_id, field_candidates.entity_title AS entity_title, avg(field_candidates.semantic_distance) AS avg_semantic_distance, avg(field_candidates.fuzzy_score) AS avg_fuzzy_score \nFROM field_candidates GROUP BY field_candidates.entity_id, field_candidates.entity_title), \nentity_highlights AS \n(SELECT DISTINCT ON (field_candidates.entity_id) field_candidates.entity_id AS entity_id, first_value(field_candidates.value) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_text, first_value(field_candidates.path) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_path \nFROM field_candidates), \nranked_results AS \n(SELECT entity_scores.entity_id AS entity_id, entity_scores.entity_title AS entity_title, entity_scores.avg_semantic_distance AS avg_semantic_distance, entity_scores.avg_fuzzy_score AS avg_fuzzy_score, entity_highlights.highlight_text AS highlight_text, entity_highlights.highlight_path AS highlight_path, dense_rank() OVER (ORDER BY entity_scores.avg_semantic_distance ASC NULLS LAST, entity_scores.entity_id ASC) AS sem_rank, dense_rank() OVER (ORDER BY entity_scores.avg_fuzzy_score DESC NULLS LAST, entity_scores.entity_id ASC) AS fuzzy_rank \nFROM entity_scores JOIN entity_highlights ON entity_scores.entity_id = entity_highlights.entity_id)\n SELECT ranked_results.entity_id, ranked_results.entity_title, CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, ranked_results.highlight_text, ranked_results.highlight_path, CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS perfect_match \nFROM ranked_results ORDER BY score DESC NULLS LAST, ranked_results.entity_id ASC", + "RrfHybridRetriever.test_pagination_structure": "WITH field_candidates AS \n(SELECT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title, ai_search_index.path AS path, ai_search_index.value AS value, coalesce(CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END, %(param_9)s) AS semantic_distance, word_similarity(%(word_similarity_1)s, ai_search_index.value) AS fuzzy_score \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_10)s <%% ai_search_index.value) ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC NULLS LAST, CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END ASC NULLS LAST, ai_search_index.entity_id ASC \n LIMIT %(param_11)s), \nentity_scores AS \n(SELECT field_candidates.entity_id AS entity_id, field_candidates.entity_title AS entity_title, avg(field_candidates.semantic_distance) AS avg_semantic_distance, avg(field_candidates.fuzzy_score) AS avg_fuzzy_score \nFROM field_candidates GROUP BY field_candidates.entity_id, field_candidates.entity_title), \nentity_highlights AS \n(SELECT DISTINCT ON (field_candidates.entity_id) field_candidates.entity_id AS entity_id, first_value(field_candidates.value) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_text, first_value(field_candidates.path) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_path \nFROM field_candidates), \nranked_results AS \n(SELECT entity_scores.entity_id AS entity_id, entity_scores.entity_title AS entity_title, entity_scores.avg_semantic_distance AS avg_semantic_distance, entity_scores.avg_fuzzy_score AS avg_fuzzy_score, entity_highlights.highlight_text AS highlight_text, entity_highlights.highlight_path AS highlight_path, dense_rank() OVER (ORDER BY entity_scores.avg_semantic_distance ASC NULLS LAST, entity_scores.entity_id ASC) AS sem_rank, dense_rank() OVER (ORDER BY entity_scores.avg_fuzzy_score DESC NULLS LAST, entity_scores.entity_id ASC) AS fuzzy_rank \nFROM entity_scores JOIN entity_highlights ON entity_scores.entity_id = entity_highlights.entity_id)\n SELECT ranked_results.entity_id, ranked_results.entity_title, CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, ranked_results.highlight_text, ranked_results.highlight_path, CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS perfect_match \nFROM ranked_results \nWHERE CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) < %(param_12)s OR CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) = %(param_12)s AND ranked_results.entity_id > %(entity_id_1)s::UUID ORDER BY score DESC NULLS LAST, ranked_results.entity_id ASC", + "SemanticRetriever.test_basic_query_structure": "SELECT ranked_semantic.entity_id, ranked_semantic.entity_title, ranked_semantic.score, ranked_semantic.highlight_text, ranked_semantic.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id, ai_search_index.entity_title) ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title, CAST(round(CAST(%(param_1)s / CAST((%(param_2)s + CAST(min(ai_search_index.embedding <-> %(embedding_1)s) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.embedding IS NOT NULL) AS ranked_semantic ORDER BY ranked_semantic.score DESC NULLS LAST, ranked_semantic.entity_id ASC", + "SemanticRetriever.test_pagination_structure": "SELECT ranked_semantic.entity_id, ranked_semantic.entity_title, ranked_semantic.score, ranked_semantic.highlight_text, ranked_semantic.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id, ai_search_index.entity_title) ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title, CAST(round(CAST(%(param_1)s / CAST((%(param_2)s + CAST(min(ai_search_index.embedding <-> %(embedding_1)s) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.embedding IS NOT NULL) AS ranked_semantic \nWHERE ranked_semantic.score < %(param_3)s OR ranked_semantic.score = %(param_3)s AND ranked_semantic.entity_id > %(entity_id_1)s::UUID ORDER BY ranked_semantic.score DESC NULLS LAST, ranked_semantic.entity_id ASC", + "StructuredRetriever.test_basic_query_structure": "SELECT anon_1.entity_id, anon_1.entity_title, %(param_1)s AS score \nFROM (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 ORDER BY anon_1.entity_id ASC", + "StructuredRetriever.test_pagination_structure": "SELECT anon_1.entity_id, anon_1.entity_title, %(param_1)s AS score \nFROM (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 \nWHERE anon_1.entity_id > %(entity_id_1)s::UUID ORDER BY anon_1.entity_id ASC" } diff --git a/test/unit_tests/search/retrieval/retrievers/test_retrievers.py b/test/unit_tests/search/retrieval/retrievers/test_retrievers.py index aaa5957b4..73927c0b8 100644 --- a/test/unit_tests/search/retrieval/retrievers/test_retrievers.py +++ b/test/unit_tests/search/retrieval/retrievers/test_retrievers.py @@ -11,13 +11,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import uuid + import pytest from sqlalchemy import literal, select from sqlalchemy.dialects import postgresql from orchestrator.db import db from orchestrator.db.models import AiSearchIndex -from orchestrator.search.retrieval.pagination import PaginationParams +from orchestrator.search.retrieval.pagination import PageCursor from orchestrator.search.retrieval.retrievers.fuzzy import FuzzyRetriever from orchestrator.search.retrieval.retrievers.hybrid import RrfHybridRetriever, compute_rrf_hybrid_score_sql from orchestrator.search.retrieval.retrievers.semantic import SemanticRetriever @@ -34,8 +36,16 @@ def compile_query_to_sql(query) -> str: @pytest.fixture def candidate_query(): - """Basic candidate query that returns entity IDs.""" - return select(AiSearchIndex.entity_id.label("entity_id")).distinct() + """Basic candidate query that returns entity IDs and titles.""" + return select( + AiSearchIndex.entity_id.label("entity_id"), AiSearchIndex.entity_title.label("entity_title") + ).distinct() + + +@pytest.fixture +def query_id(): + """Fixed query_id for pagination tests.""" + return uuid.uuid4() class TestStructuredRetriever: @@ -43,18 +53,17 @@ class TestStructuredRetriever: def test_basic_query_structure(self, candidate_query, request): """Test basic structured retrieval query structure.""" - pagination_params = PaginationParams() - retriever = StructuredRetriever(pagination_params) + retriever = StructuredRetriever(cursor=None) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) assert_sql_matches_snapshot("StructuredRetriever.test_basic_query_structure", sql, request) - def test_pagination_structure(self, candidate_query, request): + def test_pagination_structure(self, candidate_query, query_id, request): """Test pagination adds WHERE clause with correct comparison operator.""" - pagination_params = PaginationParams(page_after_id="test-id-123") - retriever = StructuredRetriever(pagination_params) + cursor = PageCursor(score=1.0, id="test-id-123", query_id=query_id) + retriever = StructuredRetriever(cursor=cursor) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) @@ -63,8 +72,7 @@ def test_pagination_structure(self, candidate_query, request): def test_metadata(self): """Test metadata returns correct search type.""" - pagination_params = PaginationParams() - retriever = StructuredRetriever(pagination_params) + retriever = StructuredRetriever(cursor=None) metadata = retriever.metadata @@ -76,18 +84,17 @@ class TestFuzzyRetriever: def test_basic_query_structure(self, candidate_query, request): """Test fuzzy retrieval query structure with all components.""" - pagination_params = PaginationParams() - retriever = FuzzyRetriever("test query", pagination_params) + retriever = FuzzyRetriever("test query", cursor=None) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) assert_sql_matches_snapshot("FuzzyRetriever.test_basic_query_structure", sql, request) - def test_pagination_structure(self, candidate_query, request): + def test_pagination_structure(self, candidate_query, query_id, request): """Test pagination with score and id adds correct WHERE clause.""" - pagination_params = PaginationParams(page_after_score=0.85, page_after_id="entity-123") - retriever = FuzzyRetriever("test", pagination_params) + cursor = PageCursor(score=0.85, id="entity-123", query_id=query_id) + retriever = FuzzyRetriever("test", cursor=cursor) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) @@ -96,8 +103,7 @@ def test_pagination_structure(self, candidate_query, request): def test_metadata(self): """Test metadata returns correct search type.""" - pagination_params = PaginationParams() - retriever = FuzzyRetriever("test", pagination_params) + retriever = FuzzyRetriever("test", cursor=None) metadata = retriever.metadata @@ -109,20 +115,19 @@ class TestSemanticRetriever: def test_basic_query_structure(self, candidate_query, request): """Test semantic retrieval query structure with all components.""" - pagination_params = PaginationParams() query_vector = [0.1, 0.2, 0.3] - retriever = SemanticRetriever(query_vector, pagination_params) + retriever = SemanticRetriever(query_vector, cursor=None) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) assert_sql_matches_snapshot("SemanticRetriever.test_basic_query_structure", sql, request) - def test_pagination_structure(self, candidate_query, request): + def test_pagination_structure(self, candidate_query, query_id, request): """Test pagination with score and id adds correct WHERE clause.""" - pagination_params = PaginationParams(page_after_score=0.92, page_after_id="entity-456") + cursor = PageCursor(score=0.92, id="entity-456", query_id=query_id) query_vector = [0.1, 0.2, 0.3] - retriever = SemanticRetriever(query_vector, pagination_params) + retriever = SemanticRetriever(query_vector, cursor=cursor) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) @@ -131,9 +136,8 @@ def test_pagination_structure(self, candidate_query, request): def test_metadata(self): """Test metadata returns correct search type.""" - pagination_params = PaginationParams() query_vector = [0.1, 0.2, 0.3] - retriever = SemanticRetriever(query_vector, pagination_params) + retriever = SemanticRetriever(query_vector, cursor=None) metadata = retriever.metadata @@ -145,20 +149,19 @@ class TestRrfHybridRetriever: def test_basic_query_structure(self, candidate_query, request): """Test hybrid RRF query structure with all CTEs.""" - pagination_params = PaginationParams() query_vector = [0.1, 0.2, 0.3] - retriever = RrfHybridRetriever(query_vector, "test", pagination_params) + retriever = RrfHybridRetriever(query_vector, "test", cursor=None) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) assert_sql_matches_snapshot("RrfHybridRetriever.test_basic_query_structure", sql, request) - def test_pagination_structure(self, candidate_query, request): + def test_pagination_structure(self, candidate_query, query_id, request): """Test that pagination adds score and entity_id comparison logic.""" - pagination_params = PaginationParams(page_after_score=0.95, page_after_id="entity-789") + cursor = PageCursor(score=0.95, id="entity-789", query_id=query_id) query_vector = [0.1, 0.2, 0.3] - retriever = RrfHybridRetriever(query_vector, "test", pagination_params) + retriever = RrfHybridRetriever(query_vector, "test", cursor=cursor) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) @@ -167,9 +170,8 @@ def test_pagination_structure(self, candidate_query, request): def test_metadata(self): """Test metadata returns correct search type.""" - pagination_params = PaginationParams() query_vector = [0.1, 0.2, 0.3] - retriever = RrfHybridRetriever(query_vector, "test", pagination_params) + retriever = RrfHybridRetriever(query_vector, "test", cursor=None) metadata = retriever.metadata From 3dcf9c6f41bdabae608b049aa56085ec612b793a Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Wed, 22 Oct 2025 23:42:47 +0200 Subject: [PATCH 15/16] remove unused tool --- orchestrator/search/agent/tools.py | 40 +----------------------------- 1 file changed, 1 insertion(+), 39 deletions(-) diff --git a/orchestrator/search/agent/tools.py b/orchestrator/search/agent/tools.py index 848c8a20f..5532c40ea 100644 --- a/orchestrator/search/agent/tools.py +++ b/orchestrator/search/agent/tools.py @@ -104,44 +104,6 @@ async def start_new_search( ) -@search_toolset.tool -async def set_search_parameters( - ctx: RunContext[StateDeps[SearchState]], - entity_type: EntityType, - action: str | ActionType = ActionType.SELECT, -) -> StateDeltaEvent: - """Updates search parameters without clearing filters or results. - - Use this to modify the entity type or action while preserving existing filters. - For a completely new search, use start_new_search instead. - """ - params = ctx.deps.state.parameters or {} - existing_filters = params.get("filters") - existing_query = params.get("query", "") - - logger.debug( - "Updating search parameters", - entity_type=entity_type.value, - action=action, - preserving_filters=existing_filters is not None, - ) - - _set_parameters(ctx, entity_type, action, existing_query, existing_filters) - - logger.debug("Search parameters updated", parameters=ctx.deps.state.parameters) - - return StateDeltaEvent( - type=EventType.STATE_DELTA, - delta=[ - JSONPatchOp.upsert( - path="/parameters", - value=ctx.deps.state.parameters, - existed=bool(params), - ) - ], - ) - - @search_toolset.tool(retries=2) async def set_filter_tree( ctx: RunContext[StateDeps[SearchState]], @@ -155,7 +117,7 @@ async def set_filter_tree( - See the FilterTree schema examples for the exact shape. """ if ctx.deps.state.parameters is None: - raise ModelRetry("Search parameters are not initialized. Call set_search_parameters first.") + raise ModelRetry("Search parameters are not initialized. Call start_new_search first.") entity_type = EntityType(ctx.deps.state.parameters["entity_type"]) From 5edf8a702524fab54820e1e2b15664e62f955200 Mon Sep 17 00:00:00 2001 From: Tim Frohlich Date: Thu, 23 Oct 2025 02:47:52 +0200 Subject: [PATCH 16/16] prevent connection stacking while agent runs --- orchestrator/search/agent/tools.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/orchestrator/search/agent/tools.py b/orchestrator/search/agent/tools.py index 5532c40ea..5a0e48154 100644 --- a/orchestrator/search/agent/tools.py +++ b/orchestrator/search/agent/tools.py @@ -184,8 +184,11 @@ async def run_search( if not ctx.deps.state.run_id: agent_run = AgentRunTable(agent_type="search") + db.session.add(agent_run) db.session.commit() + db.session.expire_all() # Release connection to prevent stacking while agent runs + ctx.deps.state.run_id = agent_run.run_id logger.debug("Created new agent run", run_id=str(agent_run.run_id)) changes.append(JSONPatchOp(op="add", path="/run_id", value=str(ctx.deps.state.run_id))) @@ -202,6 +205,8 @@ async def run_search( ) db.session.add(search_query) db.session.commit() + db.session.expire_all() + query_id_existed = ctx.deps.state.query_id is not None ctx.deps.state.query_id = search_query.query_id logger.debug("Saved search query", query_id=str(search_query.query_id), query_number=query_number)