diff --git a/orchestrator/agentic_app.py b/orchestrator/agentic_app.py index a656578be..e1a0353f7 100644 --- a/orchestrator/agentic_app.py +++ b/orchestrator/agentic_app.py @@ -27,7 +27,6 @@ if TYPE_CHECKING: from pydantic_ai.models.openai import OpenAIModel - from pydantic_ai.toolsets import FunctionToolset logger = get_logger(__name__) @@ -38,19 +37,17 @@ def __init__( *args: Any, llm_settings: LLMSettings = llm_settings, agent_model: "OpenAIModel | str | None" = None, - agent_tools: "list[FunctionToolset] | None" = None, **kwargs: Any, ) -> None: """Initialize the `LLMOrchestratorCore` class. This class extends `OrchestratorCore` with LLM features (search and agent). - It runs the search migration and mounts the agent endpoint based on feature flags. + It runs the search migration based on feature flags. Args: *args: All the normal arguments passed to the `OrchestratorCore` class. llm_settings: A class of settings for the LLM agent_model: Override the agent model (defaults to llm_settings.AGENT_MODEL) - agent_tools: A list of tools that can be used by the agent **kwargs: Additional arguments passed to the `OrchestratorCore` class. Returns: @@ -58,7 +55,6 @@ def __init__( """ self.llm_settings = llm_settings self.agent_model = agent_model or llm_settings.AGENT_MODEL - self.agent_tools = agent_tools super().__init__(*args, **kwargs) @@ -79,22 +75,6 @@ def __init__( ) raise - # Mount agent endpoint if agent is enabled - if self.llm_settings.AGENT_ENABLED: - logger.info("Initializing agent features", model=self.agent_model) - try: - from orchestrator.search.agent import build_agent_router - - agent_app = build_agent_router(self.agent_model, self.agent_tools) - self.mount("/agent", agent_app) - except ImportError as e: - logger.error( - "Unable to initialize agent features. Please install agent dependencies: " - "`pip install orchestrator-core[agent]`", - error=str(e), - ) - raise - main_typer_app = typer.Typer() main_typer_app.add_typer(cli_app, name="orchestrator", help="The orchestrator CLI commands") diff --git a/orchestrator/api/api_v1/api.py b/orchestrator/api/api_v1/api.py index 9994ee5f9..a044a413a 100644 --- a/orchestrator/api/api_v1/api.py +++ b/orchestrator/api/api_v1/api.py @@ -95,3 +95,8 @@ api_router.include_router( search.router, prefix="/search", tags=["Core", "Search"], dependencies=[Depends(authorize)] ) + +if llm_settings.AGENT_ENABLED: + from orchestrator.api.api_v1.endpoints import agent + + api_router.include_router(agent.router, prefix="/agent", tags=["Core", "Agent"], dependencies=[Depends(authorize)]) diff --git a/orchestrator/api/api_v1/endpoints/agent.py b/orchestrator/api/api_v1/endpoints/agent.py new file mode 100644 index 000000000..832b48356 --- /dev/null +++ b/orchestrator/api/api_v1/endpoints/agent.py @@ -0,0 +1,50 @@ +# Copyright 2019-2025 SURF, GÉANT. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import cache +from typing import Annotated + +from fastapi import APIRouter, Depends, Request +from pydantic_ai.ag_ui import StateDeps, handle_ag_ui_request +from pydantic_ai.agent import Agent +from starlette.responses import Response +from structlog import get_logger + +from orchestrator.llm_settings import llm_settings +from orchestrator.search.agent import build_agent_instance +from orchestrator.search.agent.state import SearchState + +router = APIRouter() +logger = get_logger(__name__) + + +@cache +def get_agent() -> Agent[StateDeps[SearchState], str]: + """Dependency to provide the agent instance. + + The agent is built once and cached for the lifetime of the application. + """ + return build_agent_instance(llm_settings.AGENT_MODEL) + + +@router.post("/") +async def agent_conversation( + request: Request, + agent: Annotated[Agent[StateDeps[SearchState], str], Depends(get_agent)], +) -> Response: + """Agent conversation endpoint using pydantic-ai ag_ui protocol. + + This endpoint handles the interactive agent conversation for search. + """ + initial_state = SearchState() + return await handle_ag_ui_request(agent, request, deps=StateDeps(initial_state)) diff --git a/orchestrator/api/api_v1/endpoints/search.py b/orchestrator/api/api_v1/endpoints/search.py index f506cacf2..ac5fa4dcc 100644 --- a/orchestrator/api/api_v1/endpoints/search.py +++ b/orchestrator/api/api_v1/endpoints/search.py @@ -11,251 +11,121 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Literal, overload - +import structlog from fastapi import APIRouter, HTTPException, Query, status -from sqlalchemy import case, select -from sqlalchemy.orm import selectinload - -from orchestrator.db import ( - ProcessTable, - ProductTable, - WorkflowTable, - db, -) -from orchestrator.domain.base import SubscriptionModel -from orchestrator.domain.context_cache import cache_subscription_models + +from orchestrator.db import db from orchestrator.schemas.search import ( + ExportResponse, PageInfoSchema, PathsResponse, - ProcessSearchResult, - ProcessSearchSchema, - ProductSearchResult, - ProductSearchSchema, SearchResultsSchema, - SubscriptionSearchResult, - WorkflowSearchResult, - WorkflowSearchSchema, ) -from orchestrator.search.core.exceptions import InvalidCursorError +from orchestrator.search.core.exceptions import InvalidCursorError, QueryStateNotFoundError from orchestrator.search.core.types import EntityType, UIType from orchestrator.search.filters.definitions import generate_definitions -from orchestrator.search.indexing.registry import ENTITY_CONFIG_REGISTRY -from orchestrator.search.retrieval import execute_search +from orchestrator.search.retrieval import SearchQueryState, execute_search, execute_search_for_export from orchestrator.search.retrieval.builder import build_paths_query, create_path_autocomplete_lquery, process_path_rows -from orchestrator.search.retrieval.pagination import ( - create_next_page_cursor, - process_pagination_cursor, -) +from orchestrator.search.retrieval.pagination import PageCursor, encode_next_page_cursor from orchestrator.search.retrieval.validation import is_lquery_syntactically_valid from orchestrator.search.schemas.parameters import ( - BaseSearchParameters, ProcessSearchParameters, ProductSearchParameters, + SearchParameters, SubscriptionSearchParameters, WorkflowSearchParameters, ) from orchestrator.search.schemas.results import SearchResult, TypeDefinition -from orchestrator.services.subscriptions import format_special_types router = APIRouter() +logger = structlog.get_logger(__name__) -def _create_search_result_item( - entity: WorkflowTable | ProductTable | ProcessTable, entity_type: EntityType, search_info: SearchResult -) -> WorkflowSearchResult | ProductSearchResult | ProcessSearchResult | None: - match entity_type: - case EntityType.WORKFLOW: - workflow_data = WorkflowSearchSchema.model_validate(entity) - return WorkflowSearchResult( - workflow=workflow_data, - score=search_info.score, - perfect_match=search_info.perfect_match, - matching_field=search_info.matching_field, - ) - case EntityType.PRODUCT: - product_data = ProductSearchSchema.model_validate(entity) - return ProductSearchResult( - product=product_data, - score=search_info.score, - perfect_match=search_info.perfect_match, - matching_field=search_info.matching_field, - ) - case EntityType.PROCESS: - process_data = ProcessSearchSchema.model_validate(entity) - return ProcessSearchResult( - process=process_data, - score=search_info.score, - perfect_match=search_info.perfect_match, - matching_field=search_info.matching_field, - ) - case _: - return None - - -@overload async def _perform_search_and_fetch( - search_params: BaseSearchParameters, - entity_type: Literal[EntityType.WORKFLOW], - eager_loads: list[Any], + search_params: SearchParameters | None = None, cursor: str | None = None, -) -> SearchResultsSchema[WorkflowSearchResult]: ... - + query_id: str | None = None, +) -> SearchResultsSchema[SearchResult]: + """Execute search with optional pagination. + + Args: + search_params: Search parameters for new search + cursor: Pagination cursor (loads saved query state) + query_id: Saved query ID to retrieve and execute + + Returns: + Search results with entity_id, score, and matching_field. + """ + try: + page_cursor: PageCursor | None = None + + if cursor: + page_cursor = PageCursor.decode(cursor) + query_state = SearchQueryState.load_from_id(page_cursor.query_id) + elif query_id: + query_state = SearchQueryState.load_from_id(query_id) + elif search_params: + query_state = SearchQueryState(parameters=search_params, query_embedding=None) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Either search_params, cursor, or query_id must be provided", + ) -@overload -async def _perform_search_and_fetch( - search_params: BaseSearchParameters, - entity_type: Literal[EntityType.PRODUCT], - eager_loads: list[Any], - cursor: str | None = None, -) -> SearchResultsSchema[ProductSearchResult]: ... + search_response = await execute_search( + query_state.parameters, db.session, page_cursor, query_state.query_embedding + ) + if not search_response.results: + return SearchResultsSchema(search_metadata=search_response.metadata) + next_page_cursor = encode_next_page_cursor(search_response, page_cursor, query_state.parameters) + has_next_page = next_page_cursor is not None + page_info = PageInfoSchema(has_next_page=has_next_page, next_page_cursor=next_page_cursor) -@overload -async def _perform_search_and_fetch( - search_params: BaseSearchParameters, - entity_type: Literal[EntityType.PROCESS], - eager_loads: list[Any], - cursor: str | None = None, -) -> SearchResultsSchema[ProcessSearchResult]: ... + return SearchResultsSchema( + data=search_response.results, page_info=page_info, search_metadata=search_response.metadata + ) + except (InvalidCursorError, ValueError) as e: + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)) + except QueryStateNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Search failed: {str(e)}", + ) -async def _perform_search_and_fetch( - search_params: BaseSearchParameters, - entity_type: EntityType, - eager_loads: list[Any], - cursor: str | None = None, -) -> SearchResultsSchema[Any]: - try: - pagination_params = await process_pagination_cursor(cursor, search_params) - except InvalidCursorError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid pagination cursor") - - search_response = await execute_search( - search_params=search_params, - db_session=db.session, - pagination_params=pagination_params, - ) - if not search_response.results: - return SearchResultsSchema(search_metadata=search_response.metadata) - - next_page_cursor = create_next_page_cursor(search_response.results, pagination_params, search_params.limit) - has_next_page = next_page_cursor is not None - page_info = PageInfoSchema(has_next_page=has_next_page, next_page_cursor=next_page_cursor) - - config = ENTITY_CONFIG_REGISTRY[entity_type] - entity_ids = [res.entity_id for res in search_response.results] - pk_column = getattr(config.table, config.pk_name) - ordering_case = case({entity_id: i for i, entity_id in enumerate(entity_ids)}, value=pk_column) - - stmt = select(config.table).options(*eager_loads).filter(pk_column.in_(entity_ids)).order_by(ordering_case) - entities = db.session.scalars(stmt).all() - - search_info_map = {res.entity_id: res for res in search_response.results} - data = [] - for entity in entities: - entity_id = getattr(entity, config.pk_name) - search_info = search_info_map.get(str(entity_id)) - if not search_info: - continue - - search_result_item = _create_search_result_item(entity, entity_type, search_info) - if search_result_item: - data.append(search_result_item) - - return SearchResultsSchema(data=data, page_info=page_info, search_metadata=search_response.metadata) - - -@router.post( - "/subscriptions", - response_model=SearchResultsSchema[SubscriptionSearchResult], -) +@router.post("/subscriptions", response_model=SearchResultsSchema[SearchResult]) async def search_subscriptions( search_params: SubscriptionSearchParameters, cursor: str | None = None, -) -> SearchResultsSchema[SubscriptionSearchResult]: - try: - pagination_params = await process_pagination_cursor(cursor, search_params) - except InvalidCursorError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid pagination cursor") - - search_response = await execute_search( - search_params=search_params, - db_session=db.session, - pagination_params=pagination_params, - ) - - if not search_response.results: - return SearchResultsSchema(search_metadata=search_response.metadata) - - next_page_cursor = create_next_page_cursor(search_response.results, pagination_params, search_params.limit) - has_next_page = next_page_cursor is not None - page_info = PageInfoSchema(has_next_page=has_next_page, next_page_cursor=next_page_cursor) - - search_info_map = {res.entity_id: res for res in search_response.results} - - with cache_subscription_models(): - subscriptions_data = { - sub_id: SubscriptionModel.from_subscription(sub_id).model_dump(exclude_unset=False) - for sub_id in search_info_map - } - - results_data = [ - SubscriptionSearchResult( - subscription=format_special_types(subscriptions_data[sub_id]), - score=search_info.score, - perfect_match=search_info.perfect_match, - matching_field=search_info.matching_field, - ) - for sub_id, search_info in search_info_map.items() - ] - - return SearchResultsSchema(data=results_data, page_info=page_info, search_metadata=search_response.metadata) +) -> SearchResultsSchema[SearchResult]: + return await _perform_search_and_fetch(search_params, cursor) -@router.post("/workflows", response_model=SearchResultsSchema[WorkflowSearchResult]) +@router.post("/workflows", response_model=SearchResultsSchema[SearchResult]) async def search_workflows( search_params: WorkflowSearchParameters, cursor: str | None = None, -) -> SearchResultsSchema[WorkflowSearchResult]: - return await _perform_search_and_fetch( - search_params=search_params, - entity_type=EntityType.WORKFLOW, - eager_loads=[selectinload(WorkflowTable.products)], - cursor=cursor, - ) +) -> SearchResultsSchema[SearchResult]: + return await _perform_search_and_fetch(search_params, cursor) -@router.post("/products", response_model=SearchResultsSchema[ProductSearchResult]) +@router.post("/products", response_model=SearchResultsSchema[SearchResult]) async def search_products( search_params: ProductSearchParameters, cursor: str | None = None, -) -> SearchResultsSchema[ProductSearchResult]: - return await _perform_search_and_fetch( - search_params=search_params, - entity_type=EntityType.PRODUCT, - eager_loads=[ - selectinload(ProductTable.workflows), - selectinload(ProductTable.fixed_inputs), - selectinload(ProductTable.product_blocks), - ], - cursor=cursor, - ) - - -@router.post("/processes", response_model=SearchResultsSchema[ProcessSearchResult]) +) -> SearchResultsSchema[SearchResult]: + return await _perform_search_and_fetch(search_params, cursor) + + +@router.post("/processes", response_model=SearchResultsSchema[SearchResult]) async def search_processes( search_params: ProcessSearchParameters, cursor: str | None = None, -) -> SearchResultsSchema[ProcessSearchResult]: - return await _perform_search_and_fetch( - search_params=search_params, - entity_type=EntityType.PROCESS, - eager_loads=[ - selectinload(ProcessTable.workflow), - ], - cursor=cursor, - ) +) -> SearchResultsSchema[SearchResult]: + return await _perform_search_and_fetch(search_params, cursor) @router.get( @@ -294,3 +164,52 @@ async def list_paths( async def get_definitions() -> dict[UIType, TypeDefinition]: """Provide a static definition of operators and schemas for each UI type.""" return generate_definitions() + + +@router.get( + "/queries/{query_id}", + response_model=SearchResultsSchema[SearchResult], + summary="Retrieve saved search results by query_id", +) +async def get_by_query_id( + query_id: str, + cursor: str | None = None, +) -> SearchResultsSchema[SearchResult]: + """Retrieve and execute a saved search by query_id.""" + return await _perform_search_and_fetch(query_id=query_id, cursor=cursor) + + +@router.get( + "/queries/{query_id}/export", + summary="Export query results by query_id", + response_model=ExportResponse, +) +async def export_by_query_id(query_id: str) -> ExportResponse: + """Export search results using query_id. + + The query is retrieved from the database, re-executed, and results are returned + as flattened records suitable for CSV download. + + Args: + query_id: Query UUID + + Returns: + ExportResponse containing 'page' with an array of flattened entity records. + + Raises: + HTTPException: 404 if query not found, 400 if invalid data + """ + try: + query_state = SearchQueryState.load_from_id(query_id) + export_records = await execute_search_for_export(query_state, db.session) + return ExportResponse(page=export_records) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)) + except QueryStateNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) + except Exception as e: + logger.error(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error executing export: {str(e)}", + ) diff --git a/orchestrator/cli/search/resize_embedding.py b/orchestrator/cli/search/resize_embedding.py index 8d944e3cf..7fa966cb3 100644 --- a/orchestrator/cli/search/resize_embedding.py +++ b/orchestrator/cli/search/resize_embedding.py @@ -4,7 +4,7 @@ from sqlalchemy.exc import SQLAlchemyError from orchestrator.db import db -from orchestrator.db.models import AiSearchIndex +from orchestrator.db.models import AiSearchIndex, SearchQueryTable from orchestrator.llm_settings import llm_settings logger = structlog.get_logger(__name__) @@ -40,17 +40,20 @@ def get_current_embedding_dimension() -> int | None: return None -def drop_all_embeddings() -> int: - """Drop all records from the ai_search_index table. +def drop_all_embeddings() -> tuple[int, int]: + """Drop all records from ai_search_index and search_queries tables. Returns: - Number of records deleted + Tuple of (ai_search_index records deleted, search_queries records deleted) """ try: - result = db.session.query(AiSearchIndex).delete() + index_deleted = db.session.query(AiSearchIndex).delete() + query_deleted = db.session.query(SearchQueryTable).delete() db.session.commit() - logger.info(f"Deleted {result} records from ai_search_index") - return result + logger.info( + f"Deleted {index_deleted} records from ai_search_index and {query_deleted} records from search_queries" + ) + return index_deleted, query_deleted except SQLAlchemyError as e: db.session.rollback() @@ -59,34 +62,34 @@ def drop_all_embeddings() -> int: def alter_embedding_column_dimension(new_dimension: int) -> None: - """Alter the embedding column to use the new dimension size. + """Alter the embedding columns in both ai_search_index and search_queries tables. Args: new_dimension: New vector dimension size """ try: - drop_query = text("ALTER TABLE ai_search_index DROP COLUMN IF EXISTS embedding") - db.session.execute(drop_query) + db.session.execute(text("ALTER TABLE ai_search_index DROP COLUMN IF EXISTS embedding")) + db.session.execute(text(f"ALTER TABLE ai_search_index ADD COLUMN embedding vector({new_dimension})")) - add_query = text(f"ALTER TABLE ai_search_index ADD COLUMN embedding vector({new_dimension})") - db.session.execute(add_query) + db.session.execute(text("ALTER TABLE search_queries DROP COLUMN IF EXISTS query_embedding")) + db.session.execute(text(f"ALTER TABLE search_queries ADD COLUMN query_embedding vector({new_dimension})")) db.session.commit() - logger.info(f"Altered embedding column to dimension {new_dimension}") + logger.info(f"Altered embedding columns to dimension {new_dimension} in ai_search_index and search_queries") except SQLAlchemyError as e: db.session.rollback() - logger.error("Failed to alter embedding column dimension", error=str(e)) + logger.error("Failed to alter embedding column dimensions", error=str(e)) raise @app.command("resize") def resize_embeddings_command() -> None: - """Resize vector dimensions of the ai_search_index embedding column. + """Resize vector dimensions of embedding columns in ai_search_index and search_queries tables. Compares the current embedding dimension in the database with the configured - dimension in llm_settings. If they differ, drops all records and alters the - column to match the new dimension. + dimension in llm_settings. If they differ, drops all records and alters both + embedding columns to match the new dimension. """ new_dimension = llm_settings.EMBEDDING_DIMENSION @@ -107,22 +110,25 @@ def resize_embeddings_command() -> None: logger.info("Dimension mismatch detected", current_dimension=current_dimension, new_dimension=new_dimension) - if not typer.confirm("This will DELETE ALL RECORDS from ai_search_index and alter the embedding column. Continue?"): + if not typer.confirm( + "This will DELETE ALL RECORDS from ai_search_index and search_queries tables and alter embedding columns. Continue?" + ): logger.info("Operation cancelled by user") return try: # Drop all records first. logger.info("Dropping all embedding records...") - deleted_count = drop_all_embeddings() + index_deleted, query_deleted = drop_all_embeddings() - # Then alter column dimension. - logger.info(f"Altering embedding column to dimension {new_dimension}...") + # Then alter column dimensions. + logger.info(f"Altering embedding columns to dimension {new_dimension}...") alter_embedding_column_dimension(new_dimension) logger.info( "Embedding dimension resize completed successfully", - records_deleted=deleted_count, + index_records_deleted=index_deleted, + query_records_deleted=query_deleted, new_dimension=new_dimension, ) diff --git a/orchestrator/cli/search/speedtest.py b/orchestrator/cli/search/speedtest.py index a49544e80..5c215c219 100644 --- a/orchestrator/cli/search/speedtest.py +++ b/orchestrator/cli/search/speedtest.py @@ -13,7 +13,6 @@ from orchestrator.search.core.types import EntityType from orchestrator.search.core.validators import is_uuid from orchestrator.search.retrieval.engine import execute_search -from orchestrator.search.retrieval.pagination import PaginationParams from orchestrator.search.schemas.parameters import BaseSearchParameters logger = structlog.get_logger(__name__) @@ -54,17 +53,16 @@ async def generate_embeddings_for_queries(queries: list[str]) -> dict[str, list[ async def run_single_query(query: str, embedding_lookup: dict[str, list[float]]) -> dict[str, Any]: search_params = BaseSearchParameters(entity_type=EntityType.SUBSCRIPTION, query=query, limit=30) + query_embedding = None + if is_uuid(query): - pagination_params = PaginationParams() logger.debug("Using fuzzy-only ranking for full UUID", query=query) else: - - cached_embedding = embedding_lookup[query] - pagination_params = PaginationParams(q_vec_override=cached_embedding) + query_embedding = embedding_lookup[query] with db.session as session: start_time = time.perf_counter() - response = await execute_search(search_params, session, pagination_params=pagination_params) + response = await execute_search(search_params, session, cursor=None, query_embedding=query_embedding) end_time = time.perf_counter() return { diff --git a/orchestrator/db/__init__.py b/orchestrator/db/__init__.py index 6e7138e08..81befbc17 100644 --- a/orchestrator/db/__init__.py +++ b/orchestrator/db/__init__.py @@ -17,6 +17,7 @@ from orchestrator.db.database import BaseModel as DbBaseModel from orchestrator.db.database import Database, transactional from orchestrator.db.models import ( # noqa: F401 + AgentRunTable, EngineSettingsTable, FixedInputTable, InputStateTable, @@ -26,6 +27,7 @@ ProductBlockTable, ProductTable, ResourceTypeTable, + SearchQueryTable, SubscriptionCustomerDescriptionTable, SubscriptionInstanceRelationTable, SubscriptionInstanceTable, @@ -74,6 +76,8 @@ def init_database(settings: AppSettings) -> Database: __all__ = [ "transactional", + "SearchQueryTable", + "AgentRunTable", "SubscriptionTable", "ProcessSubscriptionTable", "ProcessTable", @@ -97,6 +101,8 @@ def init_database(settings: AppSettings) -> Database: ] ALL_DB_MODELS: list[type[DbBaseModel]] = [ + SearchQueryTable, + AgentRunTable, FixedInputTable, ProcessStepTable, ProcessSubscriptionTable, diff --git a/orchestrator/db/models.py b/orchestrator/db/models.py index 2d7f43b3d..000db36d0 100644 --- a/orchestrator/db/models.py +++ b/orchestrator/db/models.py @@ -15,6 +15,7 @@ import enum from datetime import datetime, timezone +from typing import TYPE_CHECKING from uuid import UUID import sqlalchemy @@ -58,6 +59,9 @@ from orchestrator.utils.datetime import nowtz from orchestrator.version import GIT_COMMIT_HASH +if TYPE_CHECKING: + from orchestrator.search.retrieval.query_state import SearchQueryState + logger = structlog.get_logger(__name__) TAG_LENGTH = 20 @@ -674,6 +678,76 @@ class SubscriptionSearchView(BaseModel): subscription = relationship("SubscriptionTable", foreign_keys=[subscription_id]) +class AgentRunTable(BaseModel): + """Agent conversation/session tracking.""" + + __tablename__ = "agent_runs" + + run_id = mapped_column("run_id", UUIDType, server_default=text("uuid_generate_v4()"), primary_key=True) + agent_type = mapped_column(String(50), nullable=False) + created_at = mapped_column(UtcTimestamp, server_default=text("current_timestamp()"), nullable=False) + + queries = relationship("SearchQueryTable", back_populates="run", cascade="delete", passive_deletes=True) + + __table_args__ = (Index("ix_agent_runs_created_at", "created_at"),) + + +class SearchQueryTable(BaseModel): + """Search query execution - used by both agent runs and regular API searches. + + When run_id is NULL: standalone API search query + When run_id is NOT NULL: query belongs to an agent conversation run + """ + + __tablename__ = "search_queries" + + query_id = mapped_column("query_id", UUIDType, server_default=text("uuid_generate_v4()"), primary_key=True) + run_id = mapped_column( + "run_id", UUIDType, ForeignKey("agent_runs.run_id", ondelete="CASCADE"), nullable=True, index=True + ) + query_number = mapped_column(Integer, nullable=False) + + # Search parameters as JSONB (maps to BaseSearchParameters subclasses) + parameters = mapped_column(pg.JSONB, nullable=False) + + # Query embedding for semantic search (pgvector) + query_embedding = mapped_column(Vector(1536), nullable=True) + + executed_at = mapped_column(UtcTimestamp, server_default=text("current_timestamp()"), nullable=False) + + run = relationship("AgentRunTable", back_populates="queries") + + __table_args__ = ( + Index("ix_search_queries_run_id", "run_id"), + Index("ix_search_queries_executed_at", "executed_at"), + Index("ix_search_queries_query_id", "query_id"), + ) + + @classmethod + def from_state( + cls, + state: "SearchQueryState", + run_id: "UUID | None" = None, + query_number: int = 1, + ) -> "SearchQueryTable": + """Create a SearchQueryTable instance from a SearchQueryState. + + Args: + state: The search query state with parameters and embedding + run_id: Optional agent run ID (NULL for regular API searches) + query_number: Query number within the run (default=1) + + Returns: + SearchQueryTable instance ready to be added to the database. + """ + return cls( + run_id=run_id, + query_number=query_number, + parameters=state.parameters.model_dump(), + query_embedding=state.query_embedding, + ) + + class EngineSettingsTable(BaseModel): __tablename__ = "engine_settings" global_lock = mapped_column(Boolean(), default=False, nullable=False, primary_key=True) @@ -705,6 +779,7 @@ class AiSearchIndex(BaseModel): UUIDType, nullable=False, ) + entity_title = mapped_column(TEXT, nullable=True) # Ltree path for hierarchical data path = mapped_column(LtreeType, nullable=False, index=True) diff --git a/orchestrator/schemas/search.py b/orchestrator/schemas/search.py index d85639132..d6f4a7f3e 100644 --- a/orchestrator/schemas/search.py +++ b/orchestrator/schemas/search.py @@ -11,14 +11,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime -from typing import Any, Generic, TypeVar -from uuid import UUID +from typing import Generic, TypeVar from pydantic import BaseModel, ConfigDict, Field from orchestrator.search.core.types import SearchMetadata -from orchestrator.search.schemas.results import ComponentInfo, LeafInfo, MatchingField +from orchestrator.search.schemas.results import ComponentInfo, LeafInfo T = TypeVar("T") @@ -36,95 +34,20 @@ class ProductSchema(BaseModel): product_type: str -class SubscriptionSearchResult(BaseModel): - score: float - perfect_match: int - matching_field: MatchingField | None = None - subscription: dict[str, Any] - - class SearchResultsSchema(BaseModel, Generic[T]): data: list[T] = Field(default_factory=list) page_info: PageInfoSchema = Field(default_factory=PageInfoSchema) search_metadata: SearchMetadata | None = None -class WorkflowProductSchema(BaseModel): - """Product associated with a workflow.""" - - model_config = ConfigDict(from_attributes=True) - - product_type: str - product_id: UUID - name: str - - -class WorkflowSearchSchema(BaseModel): - """Schema for workflow search results.""" - - model_config = ConfigDict(from_attributes=True) - - name: str - products: list[WorkflowProductSchema] - description: str | None = None - created_at: datetime | None = None - - -class ProductSearchSchema(BaseModel): - """Schema for product search results.""" - - model_config = ConfigDict(from_attributes=True) - - product_id: UUID - name: str - product_type: str - tag: str | None = None - description: str | None = None - status: str | None = None - created_at: datetime | None = None - - -class ProcessSearchSchema(BaseModel): - """Schema for process search results.""" - - model_config = ConfigDict(from_attributes=True) - - process_id: UUID - workflow_name: str - workflow_id: UUID - last_status: str - is_task: bool - created_by: str | None = None - started_at: datetime - last_modified_at: datetime - last_step: str | None = None - failed_reason: str | None = None - subscription_ids: list[UUID] | None = None - - -class WorkflowSearchResult(BaseModel): - score: float - perfect_match: int - matching_field: MatchingField | None = None - workflow: WorkflowSearchSchema - - -class ProductSearchResult(BaseModel): - score: float - perfect_match: int - matching_field: MatchingField | None = None - product: ProductSearchSchema - - -class ProcessSearchResult(BaseModel): - score: float - perfect_match: int - matching_field: MatchingField | None = None - process: ProcessSearchSchema - - class PathsResponse(BaseModel): leaves: list[LeafInfo] components: list[ComponentInfo] model_config = ConfigDict(extra="forbid", use_enum_values=True) + + +class ExportResponse(BaseModel): + page: list[dict] + + model_config = ConfigDict(extra="forbid") diff --git a/orchestrator/search/agent/__init__.py b/orchestrator/search/agent/__init__.py index 46fa7fd53..55d92ec57 100644 --- a/orchestrator/search/agent/__init__.py +++ b/orchestrator/search/agent/__init__.py @@ -14,8 +14,8 @@ # This module requires: pydantic-ai==0.7.0, ag-ui-protocol>=0.1.8 -from orchestrator.search.agent.agent import build_agent_router +from orchestrator.search.agent.agent import build_agent_instance __all__ = [ - "build_agent_router", + "build_agent_instance", ] diff --git a/orchestrator/search/agent/agent.py b/orchestrator/search/agent/agent.py index 8112501c4..e3e0c8d04 100644 --- a/orchestrator/search/agent/agent.py +++ b/orchestrator/search/agent/agent.py @@ -14,13 +14,11 @@ from typing import Any import structlog -from fastapi import APIRouter, HTTPException, Request -from pydantic_ai.ag_ui import StateDeps, handle_ag_ui_request +from pydantic_ai.ag_ui import StateDeps from pydantic_ai.agent import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.settings import ModelSettings from pydantic_ai.toolsets import FunctionToolset -from starlette.responses import Response from orchestrator.search.agent.prompts import get_base_instructions, get_dynamic_instructions from orchestrator.search.agent.state import SearchState @@ -29,34 +27,32 @@ logger = structlog.get_logger(__name__) -def build_agent_router(model: str | OpenAIModel, toolsets: list[FunctionToolset[Any]] | None = None) -> APIRouter: - router = APIRouter() +def build_agent_instance( + model: str | OpenAIModel, agent_tools: list[FunctionToolset[Any]] | None = None +) -> Agent[StateDeps[SearchState], str]: + """Build and configure the search agent instance. - try: - toolsets = toolsets + [search_toolset] if toolsets else [search_toolset] + Args: + model: The LLM model to use (string or OpenAIModel instance) + agent_tools: Optional list of additional toolsets to include - agent = Agent( - model=model, - deps_type=StateDeps[SearchState], - model_settings=ModelSettings( - parallel_tool_calls=False, - ), # https://github.com/pydantic/pydantic-ai/issues/562 - toolsets=toolsets, - ) - agent.instructions(get_base_instructions) - agent.instructions(get_dynamic_instructions) + Returns: + Configured Agent instance with StateDeps[SearchState] dependencies - @router.post("/") - async def agent_endpoint(request: Request) -> Response: - return await handle_ag_ui_request(agent, request, deps=StateDeps(SearchState())) + Raises: + Exception: If agent initialization fails + """ + toolsets = agent_tools + [search_toolset] if agent_tools else [search_toolset] - return router - except Exception as e: - logger.error("Agent init failed; serving disabled stub.", error=str(e)) - error_msg = f"Agent disabled: {str(e)}" + agent = Agent( + model=model, + deps_type=StateDeps[SearchState], + model_settings=ModelSettings( + parallel_tool_calls=False, + ), # https://github.com/pydantic/pydantic-ai/issues/562 + toolsets=toolsets, + ) + agent.instructions(get_base_instructions) + agent.instructions(get_dynamic_instructions) - @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"]) - async def _disabled(path: str) -> None: - raise HTTPException(status_code=503, detail=error_msg) - - return router + return agent diff --git a/orchestrator/search/agent/json_patch.py b/orchestrator/search/agent/json_patch.py new file mode 100644 index 000000000..1758c5704 --- /dev/null +++ b/orchestrator/search/agent/json_patch.py @@ -0,0 +1,51 @@ +# Copyright 2019-2025 SURF, GÉANT. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class JSONPatchOp(BaseModel): + """A JSON Patch operation (RFC 6902). + + Docs reference: https://docs.ag-ui.com/concepts/state + """ + + op: Literal["add", "remove", "replace", "move", "copy", "test"] = Field( + description="The operation to perform: add, remove, replace, move, copy, or test" + ) + path: str = Field(description="JSON Pointer (RFC 6901) to the target location") + value: Any | None = Field( + default=None, + description="The value to apply (for add, replace operations)", + ) + from_: str | None = Field( + default=None, + alias="from", + description="Source path (for move, copy operations)", + ) + + @classmethod + def upsert(cls, path: str, value: Any, existed: bool) -> "JSONPatchOp": + """Create an add or replace operation depending on whether the path existed. + + Args: + path: JSON Pointer path to the target location + value: The value to set + existed: True if the path already exists (use replace), False otherwise (use add) + + Returns: + JSONPatchOp with 'replace' if existed is True, 'add' otherwise + """ + return cls(op="replace" if existed else "add", path=path, value=value) diff --git a/orchestrator/search/agent/prompts.py b/orchestrator/search/agent/prompts.py index 8aa113c42..e0567f4ab 100644 --- a/orchestrator/search/agent/prompts.py +++ b/orchestrator/search/agent/prompts.py @@ -50,14 +50,15 @@ async def get_base_instructions() -> str: Follow these steps in strict order: - 1. **Set Context**: Always begin by calling `set_search_parameters`. + 1. **Set Context**: If the user is asking for a NEW search, call `start_new_search`. 2. **Analyze for Filters**: Based on the user's request, decide if specific filters are necessary. - **If filters ARE required**, follow these sub-steps: a. **Gather Intel**: Identify all needed field names, then call `discover_filter_paths` and `get_valid_operators` **once each** to get all required information. b. **Construct FilterTree**: Build the `FilterTree` object. c. **Set Filters**: Call `set_filter_tree`. - 3. **Execute**: Call `execute_search`. This is done for both filtered and non-filtered searches. - 4. **Report**: Answer the users' question directly and summarize when appropiate. + 3. **Execute**: Call `run_search`. This is done for both filtered and non-filtered searches. + + After search execution, follow the dynamic instructions based on the current state. --- ### 4. Critical Rules @@ -73,28 +74,53 @@ async def get_dynamic_instructions(ctx: RunContext[StateDeps[SearchState]]) -> s """Dynamically provides 'next step' coaching based on the current state.""" state = ctx.deps.state param_state_str = json.dumps(state.parameters, indent=2, default=str) if state.parameters else "Not set." + results_count = state.results_data.total_count if state.results_data else 0 - next_step_guidance = "" - if not state.parameters or not state.parameters.get("entity_type"): + if state.export_data: + next_step_guidance = ( + "INSTRUCTION: Export has been prepared successfully. " + "Simply confirm to the user that the export is ready for download. " + "DO NOT include or mention the download URL - the UI will display it automatically." + ) + elif not state.parameters or not state.parameters.get("entity_type"): next_step_guidance = ( - "INSTRUCTION: The search context is not set. Your next action is to call `set_search_parameters`." + "INSTRUCTION: The search context is not set. Your next action is to call `start_new_search`." + ) + elif results_count > 0: + next_step_guidance = dedent( + f""" + INSTRUCTION: Search completed successfully. + Found {results_count} results containing only: entity_id, title, score. + + Choose your next action based on what the user requested: + 1. **Broad/generic search** (e.g., 'show me subscriptions'): Confirm search completed and report count. Do nothing else. + 2. **Question answerable with entity_id/title/score**: Answer directly using the current results. + 3. **Question requiring other details**: Call `fetch_entity_details` first, then answer with the detailed data. + 4. **Export request** (phrases like 'export', 'download', 'save as CSV'): Call `prepare_export` directly. + """ ) else: next_step_guidance = ( "INSTRUCTION: Context is set. Now, analyze the user's request. " "If specific filters ARE required, use the information-gathering tools to build a `FilterTree` and call `set_filter_tree`. " - "If no specific filters are needed, you can proceed directly to `execute_search`." + "If no specific filters are needed, you can proceed directly to `run_search`." ) + return dedent( f""" --- - ### Current State & Next Action + ## CURRENT STATE **Current Search Parameters:** ```json {param_state_str} ``` - **{next_step_guidance}** + **Current Results Count:** {results_count} + + --- + ## NEXT ACTION REQUIRED + + {next_step_guidance} """ ) diff --git a/orchestrator/search/agent/state.py b/orchestrator/search/agent/state.py index 9a20f155e..075e3a192 100644 --- a/orchestrator/search/agent/state.py +++ b/orchestrator/search/agent/state.py @@ -12,10 +12,36 @@ # limitations under the License. from typing import Any +from uuid import UUID -from pydantic import BaseModel, Field +from pydantic import BaseModel + +from orchestrator.search.schemas.results import SearchResult + + +class ExportData(BaseModel): + """Export metadata for download.""" + + action: str = "export" + query_id: str + download_url: str + message: str + + +class SearchResultsData(BaseModel): + """Search results data for frontend display and agent context.""" + + action: str = "view_results" + query_id: str + results_url: str + total_count: int + message: str + results: list[SearchResult] = [] class SearchState(BaseModel): + run_id: UUID | None = None + query_id: UUID | None = None parameters: dict[str, Any] | None = None - results: list[dict[str, Any]] = Field(default_factory=list) + results_data: SearchResultsData | None = None + export_data: ExportData | None = None diff --git a/orchestrator/search/agent/tools.py b/orchestrator/search/agent/tools.py index 73471b57c..5a0e48154 100644 --- a/orchestrator/search/agent/tools.py +++ b/orchestrator/search/agent/tools.py @@ -11,11 +11,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Awaitable, Callable -from typing import Any, TypeVar +import json +from typing import Any import structlog -from ag_ui.core import EventType, StateSnapshotEvent +from ag_ui.core import EventType, StateDeltaEvent, StateSnapshotEvent from pydantic_ai import RunContext from pydantic_ai.ag_ui import StateDeps from pydantic_ai.exceptions import ModelRetry @@ -25,34 +25,22 @@ from orchestrator.api.api_v1.endpoints.search import ( get_definitions, list_paths, - search_processes, - search_products, - search_subscriptions, - search_workflows, ) -from orchestrator.schemas.search import SearchResultsSchema +from orchestrator.db import AgentRunTable, SearchQueryTable, db +from orchestrator.search.agent.json_patch import JSONPatchOp +from orchestrator.search.agent.state import ExportData, SearchResultsData, SearchState from orchestrator.search.core.types import ActionType, EntityType, FilterOp +from orchestrator.search.export import fetch_export_data from orchestrator.search.filters import FilterTree +from orchestrator.search.retrieval.engine import execute_search from orchestrator.search.retrieval.exceptions import FilterValidationError, PathNotFoundError +from orchestrator.search.retrieval.query_state import SearchQueryState from orchestrator.search.retrieval.validation import validate_filter_tree -from orchestrator.search.schemas.parameters import PARAMETER_REGISTRY, BaseSearchParameters - -from .state import SearchState +from orchestrator.search.schemas.parameters import BaseSearchParameters +from orchestrator.settings import app_settings logger = structlog.get_logger(__name__) - -P = TypeVar("P", bound=BaseSearchParameters) - -SearchFn = Callable[[P], Awaitable[SearchResultsSchema[Any]]] - -SEARCH_FN_MAP: dict[EntityType, SearchFn] = { - EntityType.SUBSCRIPTION: search_subscriptions, - EntityType.WORKFLOW: search_workflows, - EntityType.PRODUCT: search_products, - EntityType.PROCESS: search_processes, -} - search_toolset: FunctionToolset[StateDeps[SearchState]] = FunctionToolset(max_retries=1) @@ -65,32 +53,50 @@ def last_user_message(ctx: RunContext[StateDeps[SearchState]]) -> str | None: return None +def _set_parameters( + ctx: RunContext[StateDeps[SearchState]], + entity_type: EntityType, + action: str | ActionType, + query: str, + filters: Any | None, +) -> None: + """Internal helper to set parameters.""" + ctx.deps.state.parameters = { + "action": action, + "entity_type": entity_type, + "filters": filters, + "query": query, + } + + @search_toolset.tool -async def set_search_parameters( +async def start_new_search( ctx: RunContext[StateDeps[SearchState]], entity_type: EntityType, action: str | ActionType = ActionType.SELECT, ) -> StateSnapshotEvent: - """Sets the initial search context, like the entity type and the user's query. + """Starts a completely new search, clearing all previous state. - This MUST be the first tool called to start any new search. - Warning: Calling this tool will erase any existing filters and search results from the state. + This MUST be the first tool called when the user asks for a NEW search. + Warning: This will erase any existing filters, results, and search state. """ - params = ctx.deps.state.parameters or {} - is_new_search = params.get("entity_type") != entity_type.value - final_query = (last_user_message(ctx) or "") if is_new_search else params.get("query", "") + final_query = last_user_message(ctx) or "" logger.debug( - "Setting search parameters", + "Starting new search", entity_type=entity_type.value, action=action, - is_new_search=is_new_search, query=final_query, ) - ctx.deps.state.parameters = {"action": action, "entity_type": entity_type, "filters": None, "query": final_query} - ctx.deps.state.results = [] - logger.debug("Search parameters set", parameters=ctx.deps.state.parameters) + # Clear all state + ctx.deps.state.results_data = None + ctx.deps.state.export_data = None + + # Set fresh parameters with no filters + _set_parameters(ctx, entity_type, action, final_query, None) + + logger.debug("New search started", parameters=ctx.deps.state.parameters) return StateSnapshotEvent( type=EventType.STATE_SNAPSHOT, @@ -102,7 +108,7 @@ async def set_search_parameters( async def set_filter_tree( ctx: RunContext[StateDeps[SearchState]], filters: FilterTree | None, -) -> StateSnapshotEvent: +) -> StateDeltaEvent: """Replace current filters atomically with a full FilterTree, or clear with None. Requirements: @@ -111,7 +117,7 @@ async def set_filter_tree( - See the FilterTree schema examples for the exact shape. """ if ctx.deps.state.parameters is None: - raise ModelRetry("Search parameters are not initialized. Call set_search_parameters first.") + raise ModelRetry("Search parameters are not initialized. Call start_new_search first.") entity_type = EntityType(ctx.deps.state.parameters["entity_type"]) @@ -136,28 +142,33 @@ async def set_filter_tree( raise ModelRetry(f"Filter validation failed: {str(e)}. Please check your filter structure and try again.") filter_data = None if filters is None else filters.model_dump(mode="json", by_alias=True) + filters_existed = "filters" in ctx.deps.state.parameters ctx.deps.state.parameters["filters"] = filter_data - return StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=ctx.deps.state.model_dump()) + return StateDeltaEvent( + type=EventType.STATE_DELTA, + delta=[ + JSONPatchOp.upsert( + path="/parameters/filters", + value=filter_data, + existed=filters_existed, + ) + ], + ) @search_toolset.tool -async def execute_search( +async def run_search( ctx: RunContext[StateDeps[SearchState]], limit: int = 10, -) -> StateSnapshotEvent: - """Execute the search with the current parameters.""" +) -> StateDeltaEvent: + """Execute the search with the current parameters and save to database.""" if not ctx.deps.state.parameters: raise ValueError("No search parameters set") - entity_type = EntityType(ctx.deps.state.parameters["entity_type"]) - param_class = PARAMETER_REGISTRY.get(entity_type) - if not param_class: - raise ValueError(f"Unknown entity type: {entity_type}") - - params = param_class(**ctx.deps.state.parameters) + params = BaseSearchParameters.create(**ctx.deps.state.parameters) logger.debug( "Executing database search", - search_entity_type=entity_type.value, + search_entity_type=params.entity_type.value, limit=limit, has_filters=params.filters is not None, query=params.query, @@ -169,17 +180,61 @@ async def execute_search( params.limit = limit - fn = SEARCH_FN_MAP[entity_type] - search_results = await fn(params) + changes: list[JSONPatchOp] = [] + + if not ctx.deps.state.run_id: + agent_run = AgentRunTable(agent_type="search") + + db.session.add(agent_run) + db.session.commit() + db.session.expire_all() # Release connection to prevent stacking while agent runs + + ctx.deps.state.run_id = agent_run.run_id + logger.debug("Created new agent run", run_id=str(agent_run.run_id)) + changes.append(JSONPatchOp(op="add", path="/run_id", value=str(ctx.deps.state.run_id))) + + # Get query with embedding and save to DB + search_response = await execute_search(params, db.session) + query_embedding = search_response.query_embedding + query_state = SearchQueryState(parameters=params, query_embedding=query_embedding) + query_number = db.session.query(SearchQueryTable).filter_by(run_id=ctx.deps.state.run_id).count() + 1 + search_query = SearchQueryTable.from_state( + state=query_state, + run_id=ctx.deps.state.run_id, + query_number=query_number, + ) + db.session.add(search_query) + db.session.commit() + db.session.expire_all() + + query_id_existed = ctx.deps.state.query_id is not None + ctx.deps.state.query_id = search_query.query_id + logger.debug("Saved search query", query_id=str(search_query.query_id), query_number=query_number) + changes.append(JSONPatchOp.upsert(path="/query_id", value=str(ctx.deps.state.query_id), existed=query_id_existed)) logger.debug( "Search completed", - total_results=len(search_results.data) if search_results.data else 0, + total_results=len(search_response.results), ) - ctx.deps.state.results = search_results.data + # Store results data for both frontend display and agent context + results_url = f"{app_settings.BASE_URL}/api/search/queries/{ctx.deps.state.query_id}" + + results_data_existed = ctx.deps.state.results_data is not None + ctx.deps.state.results_data = SearchResultsData( + query_id=str(ctx.deps.state.query_id), + results_url=results_url, + total_count=len(search_response.results), + message=f"Found {len(search_response.results)} results.", + results=search_response.results, # Include actual results in state + ) + changes.append( + JSONPatchOp.upsert( + path="/results_data", value=ctx.deps.state.results_data.model_dump(), existed=results_data_existed + ) + ) - return StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=ctx.deps.state.model_dump()) + return StateDeltaEvent(type=EventType.STATE_DELTA, delta=changes) @search_toolset.tool @@ -256,3 +311,87 @@ async def get_valid_operators() -> dict[str, list[FilterOp]]: if hasattr(type_def, "operators"): operator_map[key] = type_def.operators return operator_map + + +@search_toolset.tool +async def fetch_entity_details( + ctx: RunContext[StateDeps[SearchState]], + limit: int = 10, +) -> str: + """Fetch detailed entity information to answer user questions. + + Use this tool when you need detailed information about entities from the search results + to answer the user's question. This provides the same detailed data that would be + included in an export (e.g., subscription status, product details, workflow info, etc.). + + Args: + ctx: Runtime context for agent (injected). + limit: Maximum number of entities to fetch details for (default 10). + + Returns: + JSON string containing detailed entity information. + + Raises: + ValueError: If no search results are available. + """ + if not ctx.deps.state.results_data or not ctx.deps.state.results_data.results: + raise ValueError("No search results available. Run a search first before fetching entity details.") + + if not ctx.deps.state.parameters: + raise ValueError("No search parameters found.") + + entity_type = EntityType(ctx.deps.state.parameters["entity_type"]) + + entity_ids = [r.entity_id for r in ctx.deps.state.results_data.results[:limit]] + + logger.debug( + "Fetching detailed entity data", + entity_type=entity_type.value, + entity_count=len(entity_ids), + ) + + detailed_data = fetch_export_data(entity_type, entity_ids) + + return json.dumps(detailed_data, indent=2) + + +@search_toolset.tool +async def prepare_export( + ctx: RunContext[StateDeps[SearchState]], +) -> StateSnapshotEvent: + """Prepares export URL using the last executed search query.""" + if not ctx.deps.state.query_id or not ctx.deps.state.run_id: + raise ValueError("No search has been executed yet. Run a search first before exporting.") + + if not ctx.deps.state.parameters: + raise ValueError("No search parameters found. Run a search first before exporting.") + + # Validate that export is only available for SELECT actions + action = ctx.deps.state.parameters.get("action", ActionType.SELECT) + if action != ActionType.SELECT: + raise ValueError( + f"Export is only available for SELECT actions. Current action is '{action}'. " + "Please run a SELECT search first." + ) + + logger.debug( + "Prepared query for export", + query_id=str(ctx.deps.state.query_id), + ) + + download_url = f"{app_settings.BASE_URL}/api/search/queries/{ctx.deps.state.query_id}/export" + + ctx.deps.state.export_data = ExportData( + query_id=str(ctx.deps.state.query_id), + download_url=download_url, + message="Export ready for download.", + ) + + logger.debug("Export data set in state", export_data=ctx.deps.state.export_data.model_dump()) + + # Should use StateDelta here? Use snapshot to workaround state persistence issue + # TODO: Fix root cause; state is empty on frontend when it should have data from run_search + return StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + snapshot=ctx.deps.state.model_dump(), + ) diff --git a/orchestrator/search/core/exceptions.py b/orchestrator/search/core/exceptions.py index 0170b115c..4f16d11b3 100644 --- a/orchestrator/search/core/exceptions.py +++ b/orchestrator/search/core/exceptions.py @@ -34,3 +34,9 @@ class InvalidCursorError(SearchUtilsError): """Raised when cursor cannot be decoded.""" pass + + +class QueryStateNotFoundError(SearchUtilsError): + """Raised when a query state cannot be found in the database.""" + + pass diff --git a/orchestrator/search/core/types.py b/orchestrator/search/core/types.py index 5589e4ceb..01f7a4425 100644 --- a/orchestrator/search/core/types.py +++ b/orchestrator/search/core/types.py @@ -289,6 +289,7 @@ def from_raw(cls, path: str, raw_value: Any) -> "ExtractedField": class IndexableRecord(TypedDict): entity_id: str entity_type: str + entity_title: str path: Ltree value: Any value_type: Any diff --git a/orchestrator/search/export.py b/orchestrator/search/export.py new file mode 100644 index 000000000..311a34121 --- /dev/null +++ b/orchestrator/search/export.py @@ -0,0 +1,199 @@ +# Copyright 2019-2025 SURF, GÉANT. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.orm import selectinload + +from orchestrator.db import ( + ProcessTable, + ProductTable, + SubscriptionTable, + WorkflowTable, + db, +) +from orchestrator.search.core.types import EntityType + + +def fetch_subscription_export_data(entity_ids: list[str]) -> list[dict]: + """Fetch subscription data for export. + + Args: + entity_ids: List of subscription IDs as strings + + Returns: + List of flattened subscription dictionaries with fields: + subscription_id, description, status, insync, start_date, end_date, + note, product_name, tag, product_type, customer_id + """ + stmt = ( + select( + SubscriptionTable.subscription_id, + SubscriptionTable.description, + SubscriptionTable.status, + SubscriptionTable.insync, + SubscriptionTable.start_date, + SubscriptionTable.end_date, + SubscriptionTable.note, + SubscriptionTable.customer_id, + ProductTable.name.label("product_name"), + ProductTable.tag, + ProductTable.product_type, + ) + .join(ProductTable, SubscriptionTable.product_id == ProductTable.product_id) + .filter(SubscriptionTable.subscription_id.in_([UUID(sid) for sid in entity_ids])) + ) + + rows = db.session.execute(stmt).all() + + return [ + { + "subscription_id": str(row.subscription_id), + "description": row.description, + "status": row.status, + "insync": row.insync, + "start_date": row.start_date.isoformat() if row.start_date else None, + "end_date": row.end_date.isoformat() if row.end_date else None, + "note": row.note, + "product_name": row.product_name, + "tag": row.tag, + "product_type": row.product_type, + "customer_id": row.customer_id, + } + for row in rows + ] + + +def fetch_workflow_export_data(entity_ids: list[str]) -> list[dict]: + """Fetch workflow data for export. + + Args: + entity_ids: List of workflow names as strings + + Returns: + List of flattened workflow dictionaries with fields: + name, description, created_at, product_names (comma-separated), + product_ids (comma-separated), product_types (comma-separated) + """ + stmt = ( + select(WorkflowTable).options(selectinload(WorkflowTable.products)).filter(WorkflowTable.name.in_(entity_ids)) + ) + workflows = db.session.scalars(stmt).all() + + return [ + { + "name": w.name, + "description": w.description, + "created_at": w.created_at.isoformat() if w.created_at else None, + "product_names": ", ".join(p.name for p in w.products), + "product_ids": ", ".join(str(p.product_id) for p in w.products), + "product_types": ", ".join(p.product_type for p in w.products), + } + for w in workflows + ] + + +def fetch_product_export_data(entity_ids: list[str]) -> list[dict]: + """Fetch product data for export. + + Args: + entity_ids: List of product IDs as strings + + Returns: + List of flattened product dictionaries with fields: + product_id, name, product_type, tag, description, status, created_at + """ + stmt = ( + select(ProductTable) + .options( + selectinload(ProductTable.workflows), + selectinload(ProductTable.fixed_inputs), + selectinload(ProductTable.product_blocks), + ) + .filter(ProductTable.product_id.in_([UUID(pid) for pid in entity_ids])) + ) + products = db.session.scalars(stmt).all() + + return [ + { + "product_id": str(p.product_id), + "name": p.name, + "product_type": p.product_type, + "tag": p.tag, + "description": p.description, + "status": p.status, + "created_at": p.created_at.isoformat() if p.created_at else None, + } + for p in products + ] + + +def fetch_process_export_data(entity_ids: list[str]) -> list[dict]: + """Fetch process data for export. + + Args: + entity_ids: List of process IDs as strings + + Returns: + List of flattened process dictionaries with fields: + process_id, workflow_name, workflow_id, last_status, is_task, + created_by, started_at, last_modified_at, last_step + """ + stmt = ( + select(ProcessTable) + .options(selectinload(ProcessTable.workflow)) + .filter(ProcessTable.process_id.in_([UUID(pid) for pid in entity_ids])) + ) + processes = db.session.scalars(stmt).all() + + return [ + { + "process_id": str(p.process_id), + "workflow_name": p.workflow.name if p.workflow else None, + "workflow_id": str(p.workflow_id), + "last_status": p.last_status, + "is_task": p.is_task, + "created_by": p.created_by, + "started_at": p.started_at.isoformat() if p.started_at else None, + "last_modified_at": p.last_modified_at.isoformat() if p.last_modified_at else None, + "last_step": p.last_step, + } + for p in processes + ] + + +def fetch_export_data(entity_type: EntityType, entity_ids: list[str]) -> list[dict]: + """Fetch export data for any entity type. + + Args: + entity_type: The type of entities to fetch + entity_ids: List of entity IDs/names as strings + + Returns: + List of flattened entity dictionaries ready for CSV export + + Raises: + ValueError: If entity_type is not supported + """ + match entity_type: + case EntityType.SUBSCRIPTION: + return fetch_subscription_export_data(entity_ids) + case EntityType.WORKFLOW: + return fetch_workflow_export_data(entity_ids) + case EntityType.PRODUCT: + return fetch_product_export_data(entity_ids) + case EntityType.PROCESS: + return fetch_process_export_data(entity_ids) + case _: + raise ValueError(f"Unsupported entity type: {entity_type}") diff --git a/orchestrator/search/indexing/indexer.py b/orchestrator/search/indexing/indexer.py index 1f5ae23d8..d2906f767 100644 --- a/orchestrator/search/indexing/indexer.py +++ b/orchestrator/search/indexing/indexer.py @@ -96,6 +96,7 @@ def __init__(self, config: EntityConfig, dry_run: bool, force_index: bool, chunk self.chunk_size = chunk_size self.embedding_model = llm_settings.EMBEDDING_MODEL self.logger = logger.bind(entity_kind=config.entity_kind.value) + self._entity_titles: dict[str, str] = {} def run(self, entities: Iterable[DatabaseEntity]) -> int: """Orchestrates the entire indexing process.""" @@ -138,6 +139,8 @@ def _process_chunk(self, entity_chunk: list[DatabaseEntity], session: Session | if not entity_chunk: return 0, 0 + self._entity_titles.clear() + fields_to_upsert, paths_to_delete, identical_count = self._determine_changes(entity_chunk, session) if paths_to_delete and session is not None: @@ -174,12 +177,15 @@ def _determine_changes( entity, pk_name=self.config.pk_name, root_name=self.config.root_name ) + entity_title = self.config.get_title_from_fields(current_fields) + self._entity_titles[entity_id] = entity_title + entity_hashes = existing_hashes.get(entity_id, {}) current_paths = set() for field in current_fields: current_paths.add(field.path) - current_hash = self._compute_content_hash(field.path, field.value, field.value_type) + current_hash = self._compute_content_hash(field.path, field.value, field.value_type, entity_title) if field.path not in entity_hashes or entity_hashes[field.path] != current_hash: fields_to_upsert.append((entity_id, field)) else: @@ -301,21 +307,23 @@ def _prepare_text_for_embedding(field: ExtractedField) -> str: return f"{field.path}: {str(field.value)}" @staticmethod - def _compute_content_hash(path: str, value: Any, value_type: Any) -> str: + def _compute_content_hash(path: str, value: Any, value_type: Any, entity_title: str = "") -> str: v = "" if value is None else str(value) - content = f"{path}:{v}:{value_type}" + content = f"{path}:{v}:{value_type}:{entity_title}" return hashlib.sha256(content.encode("utf-8")).hexdigest() def _make_indexable_record( self, field: ExtractedField, entity_id: str, embedding: list[float] | None ) -> IndexableRecord: + entity_title = self._entity_titles[entity_id] return IndexableRecord( entity_id=entity_id, entity_type=self.config.entity_kind.value, + entity_title=entity_title, path=Ltree(field.path), value=field.value, value_type=field.value_type, - content_hash=self._compute_content_hash(field.path, field.value, field.value_type), + content_hash=self._compute_content_hash(field.path, field.value, field.value_type, entity_title), embedding=embedding if embedding else None, ) @@ -326,6 +334,7 @@ def _get_upsert_statement() -> Insert: return stmt.on_conflict_do_update( index_elements=[AiSearchIndex.entity_id, AiSearchIndex.path], set_={ + AiSearchIndex.entity_title: stmt.excluded.entity_title, AiSearchIndex.value: stmt.excluded.value, AiSearchIndex.value_type: stmt.excluded.value_type, AiSearchIndex.content_hash: stmt.excluded.content_hash, diff --git a/orchestrator/search/indexing/registry.py b/orchestrator/search/indexing/registry.py index 497bfb81f..acf10f676 100644 --- a/orchestrator/search/indexing/registry.py +++ b/orchestrator/search/indexing/registry.py @@ -25,7 +25,7 @@ WorkflowTable, ) from orchestrator.db.database import BaseModel -from orchestrator.search.core.types import EntityType +from orchestrator.search.core.types import EntityType, ExtractedField from .traverse import ( BaseTraverser, @@ -48,6 +48,7 @@ class EntityConfig(Generic[ModelT]): traverser: "type[BaseTraverser]" pk_name: str root_name: str + title_paths: list[str] # List of field paths to check for title (with fallback) def get_all_query(self, entity_id: str | None = None) -> Query | Select: query = self.table.query @@ -56,6 +57,14 @@ def get_all_query(self, entity_id: str | None = None) -> Query | Select: query = query.filter(pk_column == UUID(entity_id)) return query + def get_title_from_fields(self, fields: list[ExtractedField]) -> str: + """Extract title from fields using configured paths.""" + for title_path in self.title_paths: + for field in fields: + if field.path == title_path and field.value: + return str(field.value) + return "UNKNOWN" + @dataclass(frozen=True) class WorkflowConfig(EntityConfig[WorkflowTable]): @@ -76,6 +85,7 @@ def get_all_query(self, entity_id: str | None = None) -> Select: traverser=SubscriptionTraverser, pk_name="subscription_id", root_name="subscription", + title_paths=["subscription.description"], ), EntityType.PRODUCT: EntityConfig( entity_kind=EntityType.PRODUCT, @@ -83,6 +93,7 @@ def get_all_query(self, entity_id: str | None = None) -> Select: traverser=ProductTraverser, pk_name="product_id", root_name="product", + title_paths=["product.description", "product.name"], ), EntityType.PROCESS: EntityConfig( entity_kind=EntityType.PROCESS, @@ -90,6 +101,7 @@ def get_all_query(self, entity_id: str | None = None) -> Select: traverser=ProcessTraverser, pk_name="process_id", root_name="process", + title_paths=["process.workflow_name"], ), EntityType.WORKFLOW: WorkflowConfig( entity_kind=EntityType.WORKFLOW, @@ -97,5 +109,6 @@ def get_all_query(self, entity_id: str | None = None) -> Select: traverser=WorkflowTraverser, pk_name="workflow_id", root_name="workflow", + title_paths=["workflow.description", "workflow.name"], ), } diff --git a/orchestrator/search/llm_migration.py b/orchestrator/search/llm_migration.py index d32990b89..19bd06cee 100644 --- a/orchestrator/search/llm_migration.py +++ b/orchestrator/search/llm_migration.py @@ -37,6 +37,7 @@ def run_migration(connection: Connection) -> None: if llm_settings.LLM_FORCE_EXTENTION_MIGRATION or res.rowcount == 0: # Create PostgreSQL extensions logger.info("Attempting to run the extention creation;") + connection.execute(text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')) connection.execute(text("CREATE EXTENSION IF NOT EXISTS ltree;")) connection.execute(text("CREATE EXTENSION IF NOT EXISTS unaccent;")) connection.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;")) @@ -64,6 +65,7 @@ def run_migration(connection: Connection) -> None: CREATE TABLE IF NOT EXISTS {TABLE} ( entity_type TEXT NOT NULL, entity_id UUID NOT NULL, + entity_title TEXT, path LTREE NOT NULL, value TEXT NOT NULL, embedding VECTOR({TARGET_DIM}), @@ -78,6 +80,23 @@ def run_migration(connection: Connection) -> None: # Drop default connection.execute(text(f"ALTER TABLE {TABLE} ALTER COLUMN value_type DROP DEFAULT;")) + # Add entity_title column if it doesn't exist (for existing installations) + connection.execute( + text( + f""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = '{TABLE}' AND column_name = 'entity_title' + ) THEN + ALTER TABLE {TABLE} ADD COLUMN entity_title TEXT; + END IF; + END $$; + """ + ) + ) + # Create indexes with IF NOT EXISTS connection.execute(text(f"CREATE INDEX IF NOT EXISTS ix_ai_search_index_entity_id ON {TABLE} (entity_id);")) connection.execute( @@ -96,6 +115,42 @@ def run_migration(connection: Connection) -> None: ) ) + # Create agent_runs table + connection.execute( + text( + """ + CREATE TABLE IF NOT EXISTS agent_runs ( + run_id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + agent_type VARCHAR(50) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL + ); + """ + ) + ) + connection.execute(text("CREATE INDEX IF NOT EXISTS ix_agent_runs_created_at ON agent_runs (created_at);")) + + # Create search_queries table + connection.execute( + text( + f""" + CREATE TABLE IF NOT EXISTS search_queries ( + query_id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + run_id UUID, + query_number INTEGER NOT NULL, + parameters JSONB NOT NULL, + query_embedding VECTOR({TARGET_DIM}), + executed_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + CONSTRAINT fk_search_queries_run_id FOREIGN KEY (run_id) REFERENCES agent_runs(run_id) ON DELETE CASCADE + ); + """ + ) + ) + connection.execute(text("CREATE INDEX IF NOT EXISTS ix_search_queries_run_id ON search_queries (run_id);")) + connection.execute( + text("CREATE INDEX IF NOT EXISTS ix_search_queries_executed_at ON search_queries (executed_at);") + ) + connection.execute(text("CREATE INDEX IF NOT EXISTS ix_search_queries_query_id ON search_queries (query_id);")) + connection.commit() logger.info("LLM migration completed successfully") diff --git a/orchestrator/search/retrieval/__init__.py b/orchestrator/search/retrieval/__init__.py index 7bb32303a..a6f505450 100644 --- a/orchestrator/search/retrieval/__init__.py +++ b/orchestrator/search/retrieval/__init__.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .engine import execute_search +from .engine import execute_search, execute_search_for_export +from .query_state import SearchQueryState -__all__ = ["execute_search"] +__all__ = ["execute_search", "execute_search_for_export", "SearchQueryState"] diff --git a/orchestrator/search/retrieval/builder.py b/orchestrator/search/retrieval/builder.py index 087687561..7aabb36a9 100644 --- a/orchestrator/search/retrieval/builder.py +++ b/orchestrator/search/retrieval/builder.py @@ -43,7 +43,11 @@ def build_candidate_query(params: BaseSearchParameters) -> Select: Select: The SQLAlchemy `Select` object representing the query. """ - stmt = select(AiSearchIndex.entity_id).where(AiSearchIndex.entity_type == params.entity_type.value).distinct() + stmt = ( + select(AiSearchIndex.entity_id, AiSearchIndex.entity_title) + .where(AiSearchIndex.entity_type == params.entity_type.value) + .distinct() + ) if params.filters is not None: entity_id_col = AiSearchIndex.entity_id diff --git a/orchestrator/search/retrieval/engine.py b/orchestrator/search/retrieval/engine.py index b1a08f46f..df4b773dd 100644 --- a/orchestrator/search/retrieval/engine.py +++ b/orchestrator/search/retrieval/engine.py @@ -17,13 +17,15 @@ from sqlalchemy.engine.row import RowMapping from sqlalchemy.orm import Session +from orchestrator.search.core.embedding import QueryEmbedder from orchestrator.search.core.types import FilterOp, SearchMetadata from orchestrator.search.filters import FilterTree, LtreeFilter from orchestrator.search.schemas.parameters import BaseSearchParameters from orchestrator.search.schemas.results import MatchingField, SearchResponse, SearchResult from .builder import build_candidate_query -from .pagination import PaginationParams +from .pagination import PageCursor +from .query_state import SearchQueryState from .retrievers import Retriever from .utils import generate_highlight_indices @@ -74,9 +76,15 @@ def _format_response( # Structured search (filter-only) matching_field = _extract_matching_field_from_filters(search_params.filters) + entity_title = row.get("entity_title", "") + if not isinstance(entity_title, str): + entity_title = str(entity_title) if entity_title is not None else "" + results.append( SearchResult( entity_id=str(row.entity_id), + entity_type=search_params.entity_type, + entity_title=entity_title, score=row.score, perfect_match=row.get("perfect_match", 0), matching_field=matching_field, @@ -110,45 +118,80 @@ def _extract_matching_field_from_filters(filters: FilterTree) -> MatchingField | return MatchingField(text=text, path=pf.path, highlight_indices=[(0, len(text))]) -async def execute_search( +async def _execute_search_internal( search_params: BaseSearchParameters, db_session: Session, - pagination_params: PaginationParams | None = None, + limit: int, + cursor: PageCursor | None = None, + query_embedding: list[float] | None = None, ) -> SearchResponse: - """Execute a hybrid search and return ranked results. - - Builds a candidate entity query based on the given search parameters, - applies the appropriate ranking strategy, and executes the final ranked - query to retrieve results. + """Internal function to execute search with specified parameters. Args: - search_params (BaseSearchParameters): The search parameters specifying vector, fuzzy, or filter criteria. - db_session (Session): The active SQLAlchemy session for executing the query. - pagination_params (PaginationParams): Parameters controlling pagination of the search results. - limit (int, optional): The maximum number of search results to return, by default 5. + search_params: The search parameters specifying vector, fuzzy, or filter criteria. + db_session: The active SQLAlchemy session for executing the query. + limit: Maximum number of results to return. + cursor: Optional pagination cursor. + query_embedding: Optional pre-computed query embedding to use instead of generating a new one. Returns: - SearchResponse: A list of `SearchResult` objects containing entity IDs, scores, - and optional highlight metadata. - - Notes: - If no vector query, filters, or fuzzy term are provided, a warning is logged - and an empty result set is returned. + SearchResponse with results and embedding (for internal use). """ - if not search_params.vector_query and not search_params.filters and not search_params.fuzzy_term: logger.warning("No search criteria provided (vector_query, fuzzy_term, or filters).") return SearchResponse(results=[], metadata=SearchMetadata.empty()) candidate_query = build_candidate_query(search_params) - pagination_params = pagination_params or PaginationParams() - retriever = await Retriever.from_params(search_params, pagination_params) + if search_params.vector_query and not query_embedding: + + query_embedding = await QueryEmbedder.generate_for_text_async(search_params.vector_query) + + retriever = await Retriever.route(search_params, cursor, query_embedding) logger.debug("Using retriever", retriever_type=retriever.__class__.__name__) final_stmt = retriever.apply(candidate_query) - final_stmt = final_stmt.limit(search_params.limit) + final_stmt = final_stmt.limit(limit) logger.debug(final_stmt) result = db_session.execute(final_stmt).mappings().all() - return _format_response(result, search_params, retriever.metadata) + response = _format_response(result, search_params, retriever.metadata) + # Store embedding in response for agent to save to DB + response.query_embedding = query_embedding + return response + + +async def execute_search( + search_params: BaseSearchParameters, + db_session: Session, + cursor: PageCursor | None = None, + query_embedding: list[float] | None = None, +) -> SearchResponse: + """Execute a search and return ranked results.""" + return await _execute_search_internal(search_params, db_session, search_params.limit, cursor, query_embedding) + + +async def execute_search_for_export( + query_state: SearchQueryState, + db_session: Session, +) -> list[dict]: + """Execute a search for export and fetch flattened entity data. + + Args: + query_state: Query state containing parameters and query_embedding. + db_session: The active SQLAlchemy session for executing the query. + + Returns: + List of flattened entity records suitable for export. + """ + from orchestrator.search.export import fetch_export_data + + search_response = await _execute_search_internal( + search_params=query_state.parameters, + db_session=db_session, + limit=query_state.parameters.export_limit, + query_embedding=query_state.query_embedding, + ) + + entity_ids = [res.entity_id for res in search_response.results] + return fetch_export_data(query_state.parameters.entity_type, entity_ids) diff --git a/orchestrator/search/retrieval/pagination.py b/orchestrator/search/retrieval/pagination.py index a72b31aec..a630bf149 100644 --- a/orchestrator/search/retrieval/pagination.py +++ b/orchestrator/search/retrieval/pagination.py @@ -11,42 +11,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import array import base64 -from dataclasses import dataclass +from uuid import UUID from pydantic import BaseModel +from orchestrator.db import SearchQueryTable, db from orchestrator.search.core.exceptions import InvalidCursorError -from orchestrator.search.schemas.parameters import BaseSearchParameters -from orchestrator.search.schemas.results import SearchResult - - -@dataclass -class PaginationParams: - """Parameters for pagination in search queries.""" - - page_after_score: float | None = None - page_after_id: str | None = None - q_vec_override: list[float] | None = None - - -def floats_to_b64(v: list[float]) -> str: - a = array.array("f", v) - return base64.urlsafe_b64encode(a.tobytes()).decode("ascii") - - -def b64_to_floats(s: str) -> list[float]: - raw = base64.urlsafe_b64decode(s.encode("ascii")) - a = array.array("f") - a.frombytes(raw) - return list(a) +from orchestrator.search.schemas.parameters import SearchParameters +from orchestrator.search.schemas.results import SearchResponse class PageCursor(BaseModel): score: float id: str - q_vec_b64: str + query_id: UUID def encode(self) -> str: """Encode the cursor data into a URL-safe Base64 string.""" @@ -63,34 +42,45 @@ def decode(cls, cursor: str) -> "PageCursor": raise InvalidCursorError("Invalid pagination cursor") from e -async def process_pagination_cursor(cursor: str | None, search_params: BaseSearchParameters) -> PaginationParams: - """Process pagination cursor and return pagination parameters.""" - if cursor: - c = PageCursor.decode(cursor) - return PaginationParams( - page_after_score=c.score, - page_after_id=c.id, - q_vec_override=b64_to_floats(c.q_vec_b64), - ) - if search_params.vector_query: - from orchestrator.search.core.embedding import QueryEmbedder - - q_vec_override = await QueryEmbedder.generate_for_text_async(search_params.vector_query) - return PaginationParams(q_vec_override=q_vec_override) - return PaginationParams() - - -def create_next_page_cursor( - search_results: list[SearchResult], pagination_params: PaginationParams, limit: int +def encode_next_page_cursor( + search_response: SearchResponse, + cursor: PageCursor | None, + search_params: SearchParameters, ) -> str | None: - """Create next page cursor if there are more results.""" - has_next_page = len(search_results) == limit and limit > 0 - if has_next_page: - last_item = search_results[-1] - cursor_data = PageCursor( - score=float(last_item.score), - id=last_item.entity_id, - q_vec_b64=floats_to_b64(pagination_params.q_vec_override or []), - ) - return cursor_data.encode() - return None + """Create next page cursor if there are more results. + + On first page, saves the query to database and includes query_id in cursor + for subsequent pages to ensure consistent parameters across pagination. + + Args: + search_response: SearchResponse containing results and query_embedding + cursor: Current page cursor (None for first page, PageCursor for subsequent pages) + search_params: Search parameters to save for pagination consistency + + Returns: + Encoded cursor for next page, or None if no more results + """ + from orchestrator.search.retrieval.query_state import SearchQueryState + + has_next_page = len(search_response.results) == search_params.limit and search_params.limit > 0 + if not has_next_page: + return None + + # If this is the first page, save query state to database + if cursor is None: + query_state = SearchQueryState(parameters=search_params, query_embedding=search_response.query_embedding) + search_query = SearchQueryTable.from_state(state=query_state) + + db.session.add(search_query) + db.session.commit() + query_id = search_query.query_id + else: + query_id = cursor.query_id + + last_item = search_response.results[-1] + cursor_data = PageCursor( + score=float(last_item.score), + id=last_item.entity_id, + query_id=query_id, + ) + return cursor_data.encode() diff --git a/orchestrator/search/retrieval/query_state.py b/orchestrator/search/retrieval/query_state.py new file mode 100644 index 000000000..2dc80574b --- /dev/null +++ b/orchestrator/search/retrieval/query_state.py @@ -0,0 +1,61 @@ +# Copyright 2019-2025 SURF, GÉANT. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field + +from orchestrator.db import SearchQueryTable, db +from orchestrator.search.core.exceptions import QueryStateNotFoundError +from orchestrator.search.schemas.parameters import SearchParameters + + +class SearchQueryState(BaseModel): + """State of a search query including parameters and embedding. + + This model provides a complete snapshot of what was searched and how. + Used for both agent and regular API searches. + """ + + parameters: SearchParameters = Field(discriminator="entity_type") + query_embedding: list[float] | None = Field(default=None, description="The embedding vector for semantic search") + + model_config = ConfigDict(from_attributes=True) + + @classmethod + def load_from_id(cls, query_id: UUID | str) -> "SearchQueryState": + """Load query state from database by query_id. + + Args: + query_id: UUID or string UUID of the saved query + + Returns: + SearchQueryState loaded from database + + Raises: + ValueError: If query_id format is invalid + QueryStateNotFoundError: If query not found in database + """ + if isinstance(query_id, UUID): + query_uuid = query_id + else: + try: + query_uuid = UUID(query_id) + except (ValueError, TypeError) as e: + raise ValueError(f"Invalid query_id format: {query_id}") from e + + search_query = db.session.query(SearchQueryTable).filter_by(query_id=query_uuid).first() + if not search_query: + raise QueryStateNotFoundError(f"Query {query_uuid} not found in database") + + return cls.model_validate(search_query) diff --git a/orchestrator/search/retrieval/retrievers/base.py b/orchestrator/search/retrieval/retrievers/base.py index 10242fa7e..e7da35596 100644 --- a/orchestrator/search/retrieval/retrievers/base.py +++ b/orchestrator/search/retrieval/retrievers/base.py @@ -20,7 +20,7 @@ from orchestrator.search.core.types import FieldType, SearchMetadata from orchestrator.search.schemas.parameters import BaseSearchParameters -from ..pagination import PaginationParams +from ..pagination import PageCursor logger = structlog.get_logger(__name__) @@ -41,62 +41,48 @@ class Retriever(ABC): ] @classmethod - async def from_params( + async def route( cls, params: BaseSearchParameters, - pagination_params: PaginationParams, + cursor: PageCursor | None, + query_embedding: list[float] | None = None, ) -> "Retriever": - """Create the appropriate retriever instance from search parameters. + """Route to the appropriate retriever instance based on search parameters. + + Selects the retriever type based on available search criteria: + - Hybrid: both embedding and fuzzy term available + - Semantic: only embedding available + - Fuzzy: only text term available (or fallback when embedding generation fails) + - Structured: only filters available Args: - params (BaseSearchParameters): Search parameters including vector queries, fuzzy terms, and filters. - pagination_params (PaginationParams): Pagination parameters for cursor-based paging. + params: Search parameters including vector queries, fuzzy terms, and filters + cursor: Pagination cursor for cursor-based paging + query_embedding: Query embedding for semantic search, or None if not available Returns: - Retriever: A concrete retriever instance (semantic, fuzzy, hybrid, or structured). + A concrete retriever instance based on available search criteria """ - from .fuzzy import FuzzyRetriever from .hybrid import RrfHybridRetriever from .semantic import SemanticRetriever from .structured import StructuredRetriever fuzzy_term = params.fuzzy_term - q_vec = await cls._get_query_vector(params.vector_query, pagination_params.q_vec_override) - - # If semantic search was attempted but failed, fall back to fuzzy with the full query - fallback_fuzzy_term = fuzzy_term - if q_vec is None and params.vector_query is not None and params.query is not None: - fallback_fuzzy_term = params.query - - if q_vec is not None and fallback_fuzzy_term is not None: - return RrfHybridRetriever(q_vec, fallback_fuzzy_term, pagination_params) - if q_vec is not None: - return SemanticRetriever(q_vec, pagination_params) - if fallback_fuzzy_term is not None: - return FuzzyRetriever(fallback_fuzzy_term, pagination_params) - - return StructuredRetriever(pagination_params) - - @classmethod - async def _get_query_vector( - cls, vector_query: str | None, q_vec_override: list[float] | None - ) -> list[float] | None: - """Get query vector either from override or by generating from text.""" - if q_vec_override: - return q_vec_override - - if not vector_query: - return None - from orchestrator.search.core.embedding import QueryEmbedder + # If vector_query exists but embedding generation failed, fall back to fuzzy search with full query + if query_embedding is None and params.vector_query is not None and params.query is not None: + fuzzy_term = params.query - q_vec = await QueryEmbedder.generate_for_text_async(vector_query) - if not q_vec: - logger.warning("Embedding generation failed; using non-semantic retriever") - return None + # Select retriever based on available search criteria + if query_embedding is not None and fuzzy_term is not None: + return RrfHybridRetriever(query_embedding, fuzzy_term, cursor) + if query_embedding is not None: + return SemanticRetriever(query_embedding, cursor) + if fuzzy_term is not None: + return FuzzyRetriever(fuzzy_term, cursor) - return q_vec + return StructuredRetriever(cursor) @abstractmethod def apply(self, candidate_query: Select) -> Select: diff --git a/orchestrator/search/retrieval/retrievers/fuzzy.py b/orchestrator/search/retrieval/retrievers/fuzzy.py index 7003b5b0f..4dd19fee5 100644 --- a/orchestrator/search/retrieval/retrievers/fuzzy.py +++ b/orchestrator/search/retrieval/retrievers/fuzzy.py @@ -17,17 +17,16 @@ from orchestrator.db.models import AiSearchIndex from orchestrator.search.core.types import SearchMetadata -from ..pagination import PaginationParams +from ..pagination import PageCursor from .base import Retriever class FuzzyRetriever(Retriever): """Ranks results based on the max of fuzzy text similarity scores.""" - def __init__(self, fuzzy_term: str, pagination_params: PaginationParams) -> None: + def __init__(self, fuzzy_term: str, cursor: PageCursor | None) -> None: self.fuzzy_term = fuzzy_term - self.page_after_score = pagination_params.page_after_score - self.page_after_id = pagination_params.page_after_id + self.cursor = cursor def apply(self, candidate_query: Select) -> Select: cand = candidate_query.subquery() @@ -42,6 +41,7 @@ def apply(self, candidate_query: Select) -> Select: combined_query = ( select( AiSearchIndex.entity_id, + AiSearchIndex.entity_title, score, func.first_value(AiSearchIndex.value) .over(partition_by=AiSearchIndex.entity_id, order_by=[similarity_expr.desc(), AiSearchIndex.path.asc()]) @@ -58,12 +58,13 @@ def apply(self, candidate_query: Select) -> Select: literal(self.fuzzy_term).op("<%")(AiSearchIndex.value), ) ) - .distinct(AiSearchIndex.entity_id) + .distinct(AiSearchIndex.entity_id, AiSearchIndex.entity_title) ) final_query = combined_query.subquery("ranked_fuzzy") stmt = select( final_query.c.entity_id, + final_query.c.entity_title, final_query.c.score, final_query.c.highlight_text, final_query.c.highlight_path, @@ -81,13 +82,13 @@ def _apply_score_pagination( self, stmt: Select, score_column: ColumnElement, entity_id_column: ColumnElement ) -> Select: """Apply standard score + entity_id pagination.""" - if self.page_after_score is not None and self.page_after_id is not None: + if self.cursor is not None: stmt = stmt.where( or_( - score_column < self.page_after_score, + score_column < self.cursor.score, and_( - score_column == self.page_after_score, - entity_id_column > self.page_after_id, + score_column == self.cursor.score, + entity_id_column > self.cursor.id, ), ) ) diff --git a/orchestrator/search/retrieval/retrievers/hybrid.py b/orchestrator/search/retrieval/retrievers/hybrid.py index be91312f1..89f56df41 100644 --- a/orchestrator/search/retrieval/retrievers/hybrid.py +++ b/orchestrator/search/retrieval/retrievers/hybrid.py @@ -20,7 +20,7 @@ from orchestrator.db.models import AiSearchIndex from orchestrator.search.core.types import SearchMetadata -from ..pagination import PaginationParams +from ..pagination import PageCursor from .base import Retriever @@ -127,14 +127,13 @@ def __init__( self, q_vec: list[float], fuzzy_term: str, - pagination_params: PaginationParams, + cursor: PageCursor | None, k: int = 60, field_candidates_limit: int = 100, ) -> None: self.q_vec = q_vec self.fuzzy_term = fuzzy_term - self.page_after_score = pagination_params.page_after_score - self.page_after_id = pagination_params.page_after_id + self.cursor = cursor self.k = k self.field_candidates_limit = field_candidates_limit @@ -154,6 +153,7 @@ def apply(self, candidate_query: Select) -> Select: field_candidates = ( select( AiSearchIndex.entity_id, + AiSearchIndex.entity_title, AiSearchIndex.path, AiSearchIndex.value, sem_val, @@ -178,9 +178,10 @@ def apply(self, candidate_query: Select) -> Select: entity_scores = ( select( field_candidates.c.entity_id, + field_candidates.c.entity_title, func.avg(field_candidates.c.semantic_distance).label("avg_semantic_distance"), func.avg(field_candidates.c.fuzzy_score).label("avg_fuzzy_score"), - ).group_by(field_candidates.c.entity_id) + ).group_by(field_candidates.c.entity_id, field_candidates.c.entity_title) ).cte("entity_scores") entity_highlights = ( @@ -204,6 +205,7 @@ def apply(self, candidate_query: Select) -> Select: ranked = ( select( entity_scores.c.entity_id, + entity_scores.c.entity_title, entity_scores.c.avg_semantic_distance, entity_scores.c.avg_fuzzy_score, entity_highlights.c.highlight_text, @@ -242,6 +244,7 @@ def apply(self, candidate_query: Select) -> Select: stmt = select( ranked.c.entity_id, + ranked.c.entity_title, score, ranked.c.highlight_text, ranked.c.highlight_path, @@ -262,12 +265,12 @@ def _apply_fused_pagination( entity_id_column: ColumnElement, ) -> Select: """Keyset paginate by fused score + id.""" - if self.page_after_score is not None and self.page_after_id is not None: - score_param = self._quantize_score_for_pagination(self.page_after_score) + if self.cursor is not None: + score_param = self._quantize_score_for_pagination(self.cursor.score) stmt = stmt.where( or_( score_column < score_param, - and_(score_column == score_param, entity_id_column > self.page_after_id), + and_(score_column == score_param, entity_id_column > self.cursor.id), ) ) return stmt diff --git a/orchestrator/search/retrieval/retrievers/semantic.py b/orchestrator/search/retrieval/retrievers/semantic.py index 3fdfa2802..89702fa1d 100644 --- a/orchestrator/search/retrieval/retrievers/semantic.py +++ b/orchestrator/search/retrieval/retrievers/semantic.py @@ -17,17 +17,16 @@ from orchestrator.db.models import AiSearchIndex from orchestrator.search.core.types import SearchMetadata -from ..pagination import PaginationParams +from ..pagination import PageCursor from .base import Retriever class SemanticRetriever(Retriever): """Ranks results based on the minimum semantic vector distance.""" - def __init__(self, vector_query: list[float], pagination_params: PaginationParams) -> None: + def __init__(self, vector_query: list[float], cursor: PageCursor | None) -> None: self.vector_query = vector_query - self.page_after_score = pagination_params.page_after_score - self.page_after_id = pagination_params.page_after_id + self.cursor = cursor def apply(self, candidate_query: Select) -> Select: cand = candidate_query.subquery() @@ -49,6 +48,7 @@ def apply(self, candidate_query: Select) -> Select: combined_query = ( select( AiSearchIndex.entity_id, + AiSearchIndex.entity_title, score, func.first_value(AiSearchIndex.value) .over(partition_by=AiSearchIndex.entity_id, order_by=[dist.asc(), AiSearchIndex.path.asc()]) @@ -60,12 +60,13 @@ def apply(self, candidate_query: Select) -> Select: .select_from(AiSearchIndex) .join(cand, cand.c.entity_id == AiSearchIndex.entity_id) .where(AiSearchIndex.embedding.isnot(None)) - .distinct(AiSearchIndex.entity_id) + .distinct(AiSearchIndex.entity_id, AiSearchIndex.entity_title) ) final_query = combined_query.subquery("ranked_semantic") stmt = select( final_query.c.entity_id, + final_query.c.entity_title, final_query.c.score, final_query.c.highlight_text, final_query.c.highlight_path, @@ -83,12 +84,12 @@ def _apply_semantic_pagination( self, stmt: Select, score_column: ColumnElement, entity_id_column: ColumnElement ) -> Select: """Apply semantic score pagination with precise Decimal handling.""" - if self.page_after_score is not None and self.page_after_id is not None: - score_param = self._quantize_score_for_pagination(self.page_after_score) + if self.cursor is not None: + score_param = self._quantize_score_for_pagination(self.cursor.score) stmt = stmt.where( or_( score_column < score_param, - and_(score_column == score_param, entity_id_column > self.page_after_id), + and_(score_column == score_param, entity_id_column > self.cursor.id), ) ) return stmt diff --git a/orchestrator/search/retrieval/retrievers/structured.py b/orchestrator/search/retrieval/retrievers/structured.py index b50a093f0..0ce9f0ea6 100644 --- a/orchestrator/search/retrieval/retrievers/structured.py +++ b/orchestrator/search/retrieval/retrievers/structured.py @@ -15,22 +15,22 @@ from orchestrator.search.core.types import SearchMetadata -from ..pagination import PaginationParams +from ..pagination import PageCursor from .base import Retriever class StructuredRetriever(Retriever): """Applies a dummy score for purely structured searches with no text query.""" - def __init__(self, pagination_params: PaginationParams) -> None: - self.page_after_id = pagination_params.page_after_id + def __init__(self, cursor: PageCursor | None) -> None: + self.cursor = cursor def apply(self, candidate_query: Select) -> Select: cand = candidate_query.subquery() - stmt = select(cand.c.entity_id, literal(1.0).label("score")).select_from(cand) + stmt = select(cand.c.entity_id, cand.c.entity_title, literal(1.0).label("score")).select_from(cand) - if self.page_after_id: - stmt = stmt.where(cand.c.entity_id > self.page_after_id) + if self.cursor is not None: + stmt = stmt.where(cand.c.entity_id > self.cursor.id) return stmt.order_by(cand.c.entity_id.asc()) diff --git a/orchestrator/search/schemas/parameters.py b/orchestrator/search/schemas/parameters.py index 26d3ed79a..0a006d9e1 100644 --- a/orchestrator/search/schemas/parameters.py +++ b/orchestrator/search/schemas/parameters.py @@ -12,9 +12,9 @@ # limitations under the License. import uuid -from typing import Any, Literal +from typing import Any, ClassVar, Literal -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from orchestrator.search.core.types import ActionType, EntityType from orchestrator.search.filters import FilterTree @@ -23,6 +23,9 @@ class BaseSearchParameters(BaseModel): """Base model with common search parameters.""" + DEFAULT_EXPORT_LIMIT: ClassVar[int] = 1000 + MAX_EXPORT_LIMIT: ClassVar[int] = 10000 + action: ActionType = Field(default=ActionType.SELECT, description="The action to perform.") entity_type: EntityType @@ -33,14 +36,18 @@ class BaseSearchParameters(BaseModel): ) limit: int = Field(default=10, ge=1, le=30, description="Maximum number of search results to return.") + export_limit: int = Field( + default=DEFAULT_EXPORT_LIMIT, ge=1, le=MAX_EXPORT_LIMIT, description="Maximum number of results to export." + ) model_config = ConfigDict(extra="forbid") @classmethod - def create(cls, entity_type: EntityType, **kwargs: Any) -> "BaseSearchParameters": - try: - return PARAMETER_REGISTRY[entity_type](entity_type=entity_type, **kwargs) - except KeyError: - raise ValueError(f"No search parameter class found for entity type: {entity_type.value}") + def create(cls, **kwargs: Any) -> "SearchParameters": + """Create the correct search parameter subclass instance based on entity_type.""" + from orchestrator.search.schemas.parameters import SearchParameters + + adapter: TypeAdapter = TypeAdapter(SearchParameters) + return adapter.validate_python(kwargs) @property def vector_query(self) -> str | None: @@ -121,9 +128,6 @@ class ProcessSearchParameters(BaseSearchParameters): ) -PARAMETER_REGISTRY: dict[EntityType, type[BaseSearchParameters]] = { - EntityType.SUBSCRIPTION: SubscriptionSearchParameters, - EntityType.PRODUCT: ProductSearchParameters, - EntityType.WORKFLOW: WorkflowSearchParameters, - EntityType.PROCESS: ProcessSearchParameters, -} +SearchParameters = ( + SubscriptionSearchParameters | ProductSearchParameters | WorkflowSearchParameters | ProcessSearchParameters +) diff --git a/orchestrator/search/schemas/results.py b/orchestrator/search/schemas/results.py index b5203d78c..7dcb36394 100644 --- a/orchestrator/search/schemas/results.py +++ b/orchestrator/search/schemas/results.py @@ -15,7 +15,7 @@ from pydantic import BaseModel, ConfigDict -from orchestrator.search.core.types import FilterOp, SearchMetadata, UIType +from orchestrator.search.core.types import EntityType, FilterOp, SearchMetadata, UIType class MatchingField(BaseModel): @@ -30,6 +30,8 @@ class SearchResult(BaseModel): """Represents a single search result item.""" entity_id: str + entity_type: EntityType + entity_title: str score: float perfect_match: int = 0 matching_field: MatchingField | None = None @@ -40,6 +42,7 @@ class SearchResponse(BaseModel): results: list[SearchResult] metadata: SearchMetadata + query_embedding: list[float] | None = None class ValueSchema(BaseModel): diff --git a/orchestrator/settings.py b/orchestrator/settings.py index 01cae55f5..fdedc783b 100644 --- a/orchestrator/settings.py +++ b/orchestrator/settings.py @@ -57,6 +57,7 @@ class AppSettings(BaseSettings): EXECUTOR: str = ExecutorType.THREADPOOL WORKFLOWS_SWAGGER_HOST: str = "localhost" WORKFLOWS_GUI_URI: str = "http://localhost:3000" + BASE_URL: str = "http://localhost:8080" # Base URL for the API (used for generating export URLs) DATABASE_URI: PostgresDsn = "postgresql://nwa:nwa@localhost/orchestrator-core" # type: ignore MAX_WORKERS: int = 5 MAIL_SERVER: str = "localhost" diff --git a/test/unit_tests/search/retrieval/retrievers/sql_snapshots.json b/test/unit_tests/search/retrieval/retrievers/sql_snapshots.json index c17692086..cd75ee092 100644 --- a/test/unit_tests/search/retrieval/retrievers/sql_snapshots.json +++ b/test/unit_tests/search/retrieval/retrievers/sql_snapshots.json @@ -1,10 +1,10 @@ { - "FuzzyRetriever.test_basic_query_structure": "SELECT ranked_fuzzy.entity_id, ranked_fuzzy.score, ranked_fuzzy.highlight_text, ranked_fuzzy.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id) ai_search_index.entity_id AS entity_id, CAST(round(CAST(max(word_similarity(%(word_similarity_1)s, ai_search_index.value)) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_1)s <%% ai_search_index.value)) AS ranked_fuzzy ORDER BY ranked_fuzzy.score DESC NULLS LAST, ranked_fuzzy.entity_id ASC", - "FuzzyRetriever.test_pagination_structure": "SELECT ranked_fuzzy.entity_id, ranked_fuzzy.score, ranked_fuzzy.highlight_text, ranked_fuzzy.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id) ai_search_index.entity_id AS entity_id, CAST(round(CAST(max(word_similarity(%(word_similarity_1)s, ai_search_index.value)) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_1)s <%% ai_search_index.value)) AS ranked_fuzzy \nWHERE ranked_fuzzy.score < %(score_1)s OR ranked_fuzzy.score = %(score_2)s AND ranked_fuzzy.entity_id > %(entity_id_1)s::UUID ORDER BY ranked_fuzzy.score DESC NULLS LAST, ranked_fuzzy.entity_id ASC", - "RrfHybridRetriever.test_basic_query_structure": "WITH field_candidates AS \n(SELECT ai_search_index.entity_id AS entity_id, ai_search_index.path AS path, ai_search_index.value AS value, coalesce(CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END, %(param_9)s) AS semantic_distance, word_similarity(%(word_similarity_1)s, ai_search_index.value) AS fuzzy_score \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_10)s <%% ai_search_index.value) ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC NULLS LAST, CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END ASC NULLS LAST, ai_search_index.entity_id ASC \n LIMIT %(param_11)s), \nentity_scores AS \n(SELECT field_candidates.entity_id AS entity_id, avg(field_candidates.semantic_distance) AS avg_semantic_distance, avg(field_candidates.fuzzy_score) AS avg_fuzzy_score \nFROM field_candidates GROUP BY field_candidates.entity_id), \nentity_highlights AS \n(SELECT DISTINCT ON (field_candidates.entity_id) field_candidates.entity_id AS entity_id, first_value(field_candidates.value) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_text, first_value(field_candidates.path) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_path \nFROM field_candidates), \nranked_results AS \n(SELECT entity_scores.entity_id AS entity_id, entity_scores.avg_semantic_distance AS avg_semantic_distance, entity_scores.avg_fuzzy_score AS avg_fuzzy_score, entity_highlights.highlight_text AS highlight_text, entity_highlights.highlight_path AS highlight_path, dense_rank() OVER (ORDER BY entity_scores.avg_semantic_distance ASC NULLS LAST, entity_scores.entity_id ASC) AS sem_rank, dense_rank() OVER (ORDER BY entity_scores.avg_fuzzy_score DESC NULLS LAST, entity_scores.entity_id ASC) AS fuzzy_rank \nFROM entity_scores JOIN entity_highlights ON entity_scores.entity_id = entity_highlights.entity_id)\n SELECT ranked_results.entity_id, CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, ranked_results.highlight_text, ranked_results.highlight_path, CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS perfect_match \nFROM ranked_results ORDER BY score DESC NULLS LAST, ranked_results.entity_id ASC", - "RrfHybridRetriever.test_pagination_structure": "WITH field_candidates AS \n(SELECT ai_search_index.entity_id AS entity_id, ai_search_index.path AS path, ai_search_index.value AS value, coalesce(CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END, %(param_9)s) AS semantic_distance, word_similarity(%(word_similarity_1)s, ai_search_index.value) AS fuzzy_score \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_10)s <%% ai_search_index.value) ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC NULLS LAST, CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END ASC NULLS LAST, ai_search_index.entity_id ASC \n LIMIT %(param_11)s), \nentity_scores AS \n(SELECT field_candidates.entity_id AS entity_id, avg(field_candidates.semantic_distance) AS avg_semantic_distance, avg(field_candidates.fuzzy_score) AS avg_fuzzy_score \nFROM field_candidates GROUP BY field_candidates.entity_id), \nentity_highlights AS \n(SELECT DISTINCT ON (field_candidates.entity_id) field_candidates.entity_id AS entity_id, first_value(field_candidates.value) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_text, first_value(field_candidates.path) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_path \nFROM field_candidates), \nranked_results AS \n(SELECT entity_scores.entity_id AS entity_id, entity_scores.avg_semantic_distance AS avg_semantic_distance, entity_scores.avg_fuzzy_score AS avg_fuzzy_score, entity_highlights.highlight_text AS highlight_text, entity_highlights.highlight_path AS highlight_path, dense_rank() OVER (ORDER BY entity_scores.avg_semantic_distance ASC NULLS LAST, entity_scores.entity_id ASC) AS sem_rank, dense_rank() OVER (ORDER BY entity_scores.avg_fuzzy_score DESC NULLS LAST, entity_scores.entity_id ASC) AS fuzzy_rank \nFROM entity_scores JOIN entity_highlights ON entity_scores.entity_id = entity_highlights.entity_id)\n SELECT ranked_results.entity_id, CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, ranked_results.highlight_text, ranked_results.highlight_path, CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS perfect_match \nFROM ranked_results \nWHERE CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) < %(param_12)s OR CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) = %(param_12)s AND ranked_results.entity_id > %(entity_id_1)s::UUID ORDER BY score DESC NULLS LAST, ranked_results.entity_id ASC", - "SemanticRetriever.test_basic_query_structure": "SELECT ranked_semantic.entity_id, ranked_semantic.score, ranked_semantic.highlight_text, ranked_semantic.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id) ai_search_index.entity_id AS entity_id, CAST(round(CAST(%(param_1)s / CAST((%(param_2)s + CAST(min(ai_search_index.embedding <-> %(embedding_1)s) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.embedding IS NOT NULL) AS ranked_semantic ORDER BY ranked_semantic.score DESC NULLS LAST, ranked_semantic.entity_id ASC", - "SemanticRetriever.test_pagination_structure": "SELECT ranked_semantic.entity_id, ranked_semantic.score, ranked_semantic.highlight_text, ranked_semantic.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id) ai_search_index.entity_id AS entity_id, CAST(round(CAST(%(param_1)s / CAST((%(param_2)s + CAST(min(ai_search_index.embedding <-> %(embedding_1)s) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.embedding IS NOT NULL) AS ranked_semantic \nWHERE ranked_semantic.score < %(param_3)s OR ranked_semantic.score = %(param_3)s AND ranked_semantic.entity_id > %(entity_id_1)s::UUID ORDER BY ranked_semantic.score DESC NULLS LAST, ranked_semantic.entity_id ASC", - "StructuredRetriever.test_basic_query_structure": "SELECT anon_1.entity_id, %(param_1)s AS score \nFROM (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 ORDER BY anon_1.entity_id ASC", - "StructuredRetriever.test_pagination_structure": "SELECT anon_1.entity_id, %(param_1)s AS score \nFROM (SELECT DISTINCT ai_search_index.entity_id AS entity_id \nFROM ai_search_index) AS anon_1 \nWHERE anon_1.entity_id > %(entity_id_1)s::UUID ORDER BY anon_1.entity_id ASC" + "FuzzyRetriever.test_basic_query_structure": "SELECT ranked_fuzzy.entity_id, ranked_fuzzy.entity_title, ranked_fuzzy.score, ranked_fuzzy.highlight_text, ranked_fuzzy.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id, ai_search_index.entity_title) ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title, CAST(round(CAST(max(word_similarity(%(word_similarity_1)s, ai_search_index.value)) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_1)s <%% ai_search_index.value)) AS ranked_fuzzy ORDER BY ranked_fuzzy.score DESC NULLS LAST, ranked_fuzzy.entity_id ASC", + "FuzzyRetriever.test_pagination_structure": "SELECT ranked_fuzzy.entity_id, ranked_fuzzy.entity_title, ranked_fuzzy.score, ranked_fuzzy.highlight_text, ranked_fuzzy.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id, ai_search_index.entity_title) ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title, CAST(round(CAST(max(word_similarity(%(word_similarity_1)s, ai_search_index.value)) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_1)s <%% ai_search_index.value)) AS ranked_fuzzy \nWHERE ranked_fuzzy.score < %(score_1)s OR ranked_fuzzy.score = %(score_2)s AND ranked_fuzzy.entity_id > %(entity_id_1)s::UUID ORDER BY ranked_fuzzy.score DESC NULLS LAST, ranked_fuzzy.entity_id ASC", + "RrfHybridRetriever.test_basic_query_structure": "WITH field_candidates AS \n(SELECT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title, ai_search_index.path AS path, ai_search_index.value AS value, coalesce(CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END, %(param_9)s) AS semantic_distance, word_similarity(%(word_similarity_1)s, ai_search_index.value) AS fuzzy_score \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_10)s <%% ai_search_index.value) ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC NULLS LAST, CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END ASC NULLS LAST, ai_search_index.entity_id ASC \n LIMIT %(param_11)s), \nentity_scores AS \n(SELECT field_candidates.entity_id AS entity_id, field_candidates.entity_title AS entity_title, avg(field_candidates.semantic_distance) AS avg_semantic_distance, avg(field_candidates.fuzzy_score) AS avg_fuzzy_score \nFROM field_candidates GROUP BY field_candidates.entity_id, field_candidates.entity_title), \nentity_highlights AS \n(SELECT DISTINCT ON (field_candidates.entity_id) field_candidates.entity_id AS entity_id, first_value(field_candidates.value) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_text, first_value(field_candidates.path) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_path \nFROM field_candidates), \nranked_results AS \n(SELECT entity_scores.entity_id AS entity_id, entity_scores.entity_title AS entity_title, entity_scores.avg_semantic_distance AS avg_semantic_distance, entity_scores.avg_fuzzy_score AS avg_fuzzy_score, entity_highlights.highlight_text AS highlight_text, entity_highlights.highlight_path AS highlight_path, dense_rank() OVER (ORDER BY entity_scores.avg_semantic_distance ASC NULLS LAST, entity_scores.entity_id ASC) AS sem_rank, dense_rank() OVER (ORDER BY entity_scores.avg_fuzzy_score DESC NULLS LAST, entity_scores.entity_id ASC) AS fuzzy_rank \nFROM entity_scores JOIN entity_highlights ON entity_scores.entity_id = entity_highlights.entity_id)\n SELECT ranked_results.entity_id, ranked_results.entity_title, CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, ranked_results.highlight_text, ranked_results.highlight_path, CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS perfect_match \nFROM ranked_results ORDER BY score DESC NULLS LAST, ranked_results.entity_id ASC", + "RrfHybridRetriever.test_pagination_structure": "WITH field_candidates AS \n(SELECT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title, ai_search_index.path AS path, ai_search_index.value AS value, coalesce(CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END, %(param_9)s) AS semantic_distance, word_similarity(%(word_similarity_1)s, ai_search_index.value) AS fuzzy_score \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.value_type IN (__[POSTCOMPILE_value_type_1]) AND (%(param_10)s <%% ai_search_index.value) ORDER BY word_similarity(%(word_similarity_1)s, ai_search_index.value) DESC NULLS LAST, CASE WHEN (ai_search_index.embedding IS NULL) THEN NULL ELSE ai_search_index.embedding <-> %(q_vec)s END ASC NULLS LAST, ai_search_index.entity_id ASC \n LIMIT %(param_11)s), \nentity_scores AS \n(SELECT field_candidates.entity_id AS entity_id, field_candidates.entity_title AS entity_title, avg(field_candidates.semantic_distance) AS avg_semantic_distance, avg(field_candidates.fuzzy_score) AS avg_fuzzy_score \nFROM field_candidates GROUP BY field_candidates.entity_id, field_candidates.entity_title), \nentity_highlights AS \n(SELECT DISTINCT ON (field_candidates.entity_id) field_candidates.entity_id AS entity_id, first_value(field_candidates.value) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_text, first_value(field_candidates.path) OVER (PARTITION BY field_candidates.entity_id ORDER BY field_candidates.fuzzy_score DESC, field_candidates.path ASC) AS highlight_path \nFROM field_candidates), \nranked_results AS \n(SELECT entity_scores.entity_id AS entity_id, entity_scores.entity_title AS entity_title, entity_scores.avg_semantic_distance AS avg_semantic_distance, entity_scores.avg_fuzzy_score AS avg_fuzzy_score, entity_highlights.highlight_text AS highlight_text, entity_highlights.highlight_path AS highlight_path, dense_rank() OVER (ORDER BY entity_scores.avg_semantic_distance ASC NULLS LAST, entity_scores.entity_id ASC) AS sem_rank, dense_rank() OVER (ORDER BY entity_scores.avg_fuzzy_score DESC NULLS LAST, entity_scores.entity_id ASC) AS fuzzy_rank \nFROM entity_scores JOIN entity_highlights ON entity_scores.entity_id = entity_highlights.entity_id)\n SELECT ranked_results.entity_id, ranked_results.entity_title, CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, ranked_results.highlight_text, ranked_results.highlight_path, CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS perfect_match \nFROM ranked_results \nWHERE CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) < %(param_12)s OR CAST(round(CAST((CAST(%(param_1)s / CAST((%(sem_rank_1)s + ranked_results.sem_rank) AS NUMERIC) + %(param_2)s / CAST((%(fuzzy_rank_1)s + ranked_results.fuzzy_rank) AS NUMERIC) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s) * CAST(CASE WHEN (ranked_results.avg_fuzzy_score >= %(avg_fuzzy_score_1)s) THEN %(param_7)s ELSE %(param_8)s END AS NUMERIC(38, 12))) / CAST((%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12)) + (%(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) * %(param_6)s + %(param_3)s / CAST((%(param_4)s + %(param_5)s) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) = %(param_12)s AND ranked_results.entity_id > %(entity_id_1)s::UUID ORDER BY score DESC NULLS LAST, ranked_results.entity_id ASC", + "SemanticRetriever.test_basic_query_structure": "SELECT ranked_semantic.entity_id, ranked_semantic.entity_title, ranked_semantic.score, ranked_semantic.highlight_text, ranked_semantic.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id, ai_search_index.entity_title) ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title, CAST(round(CAST(%(param_1)s / CAST((%(param_2)s + CAST(min(ai_search_index.embedding <-> %(embedding_1)s) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.embedding IS NOT NULL) AS ranked_semantic ORDER BY ranked_semantic.score DESC NULLS LAST, ranked_semantic.entity_id ASC", + "SemanticRetriever.test_pagination_structure": "SELECT ranked_semantic.entity_id, ranked_semantic.entity_title, ranked_semantic.score, ranked_semantic.highlight_text, ranked_semantic.highlight_path \nFROM (SELECT DISTINCT ON (ai_search_index.entity_id, ai_search_index.entity_title) ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title, CAST(round(CAST(%(param_1)s / CAST((%(param_2)s + CAST(min(ai_search_index.embedding <-> %(embedding_1)s) OVER (PARTITION BY ai_search_index.entity_id) AS NUMERIC(38, 12))) AS NUMERIC(38, 12)) AS NUMERIC(38, 12)), %(round_1)s) AS NUMERIC(38, 12)) AS score, first_value(ai_search_index.value) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_text, first_value(ai_search_index.path) OVER (PARTITION BY ai_search_index.entity_id ORDER BY (ai_search_index.embedding <-> %(embedding_1)s) ASC, ai_search_index.path ASC) AS highlight_path \nFROM ai_search_index JOIN (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 ON anon_1.entity_id = ai_search_index.entity_id \nWHERE ai_search_index.embedding IS NOT NULL) AS ranked_semantic \nWHERE ranked_semantic.score < %(param_3)s OR ranked_semantic.score = %(param_3)s AND ranked_semantic.entity_id > %(entity_id_1)s::UUID ORDER BY ranked_semantic.score DESC NULLS LAST, ranked_semantic.entity_id ASC", + "StructuredRetriever.test_basic_query_structure": "SELECT anon_1.entity_id, anon_1.entity_title, %(param_1)s AS score \nFROM (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 ORDER BY anon_1.entity_id ASC", + "StructuredRetriever.test_pagination_structure": "SELECT anon_1.entity_id, anon_1.entity_title, %(param_1)s AS score \nFROM (SELECT DISTINCT ai_search_index.entity_id AS entity_id, ai_search_index.entity_title AS entity_title \nFROM ai_search_index) AS anon_1 \nWHERE anon_1.entity_id > %(entity_id_1)s::UUID ORDER BY anon_1.entity_id ASC" } diff --git a/test/unit_tests/search/retrieval/retrievers/test_retrievers.py b/test/unit_tests/search/retrieval/retrievers/test_retrievers.py index aaa5957b4..73927c0b8 100644 --- a/test/unit_tests/search/retrieval/retrievers/test_retrievers.py +++ b/test/unit_tests/search/retrieval/retrievers/test_retrievers.py @@ -11,13 +11,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import uuid + import pytest from sqlalchemy import literal, select from sqlalchemy.dialects import postgresql from orchestrator.db import db from orchestrator.db.models import AiSearchIndex -from orchestrator.search.retrieval.pagination import PaginationParams +from orchestrator.search.retrieval.pagination import PageCursor from orchestrator.search.retrieval.retrievers.fuzzy import FuzzyRetriever from orchestrator.search.retrieval.retrievers.hybrid import RrfHybridRetriever, compute_rrf_hybrid_score_sql from orchestrator.search.retrieval.retrievers.semantic import SemanticRetriever @@ -34,8 +36,16 @@ def compile_query_to_sql(query) -> str: @pytest.fixture def candidate_query(): - """Basic candidate query that returns entity IDs.""" - return select(AiSearchIndex.entity_id.label("entity_id")).distinct() + """Basic candidate query that returns entity IDs and titles.""" + return select( + AiSearchIndex.entity_id.label("entity_id"), AiSearchIndex.entity_title.label("entity_title") + ).distinct() + + +@pytest.fixture +def query_id(): + """Fixed query_id for pagination tests.""" + return uuid.uuid4() class TestStructuredRetriever: @@ -43,18 +53,17 @@ class TestStructuredRetriever: def test_basic_query_structure(self, candidate_query, request): """Test basic structured retrieval query structure.""" - pagination_params = PaginationParams() - retriever = StructuredRetriever(pagination_params) + retriever = StructuredRetriever(cursor=None) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) assert_sql_matches_snapshot("StructuredRetriever.test_basic_query_structure", sql, request) - def test_pagination_structure(self, candidate_query, request): + def test_pagination_structure(self, candidate_query, query_id, request): """Test pagination adds WHERE clause with correct comparison operator.""" - pagination_params = PaginationParams(page_after_id="test-id-123") - retriever = StructuredRetriever(pagination_params) + cursor = PageCursor(score=1.0, id="test-id-123", query_id=query_id) + retriever = StructuredRetriever(cursor=cursor) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) @@ -63,8 +72,7 @@ def test_pagination_structure(self, candidate_query, request): def test_metadata(self): """Test metadata returns correct search type.""" - pagination_params = PaginationParams() - retriever = StructuredRetriever(pagination_params) + retriever = StructuredRetriever(cursor=None) metadata = retriever.metadata @@ -76,18 +84,17 @@ class TestFuzzyRetriever: def test_basic_query_structure(self, candidate_query, request): """Test fuzzy retrieval query structure with all components.""" - pagination_params = PaginationParams() - retriever = FuzzyRetriever("test query", pagination_params) + retriever = FuzzyRetriever("test query", cursor=None) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) assert_sql_matches_snapshot("FuzzyRetriever.test_basic_query_structure", sql, request) - def test_pagination_structure(self, candidate_query, request): + def test_pagination_structure(self, candidate_query, query_id, request): """Test pagination with score and id adds correct WHERE clause.""" - pagination_params = PaginationParams(page_after_score=0.85, page_after_id="entity-123") - retriever = FuzzyRetriever("test", pagination_params) + cursor = PageCursor(score=0.85, id="entity-123", query_id=query_id) + retriever = FuzzyRetriever("test", cursor=cursor) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) @@ -96,8 +103,7 @@ def test_pagination_structure(self, candidate_query, request): def test_metadata(self): """Test metadata returns correct search type.""" - pagination_params = PaginationParams() - retriever = FuzzyRetriever("test", pagination_params) + retriever = FuzzyRetriever("test", cursor=None) metadata = retriever.metadata @@ -109,20 +115,19 @@ class TestSemanticRetriever: def test_basic_query_structure(self, candidate_query, request): """Test semantic retrieval query structure with all components.""" - pagination_params = PaginationParams() query_vector = [0.1, 0.2, 0.3] - retriever = SemanticRetriever(query_vector, pagination_params) + retriever = SemanticRetriever(query_vector, cursor=None) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) assert_sql_matches_snapshot("SemanticRetriever.test_basic_query_structure", sql, request) - def test_pagination_structure(self, candidate_query, request): + def test_pagination_structure(self, candidate_query, query_id, request): """Test pagination with score and id adds correct WHERE clause.""" - pagination_params = PaginationParams(page_after_score=0.92, page_after_id="entity-456") + cursor = PageCursor(score=0.92, id="entity-456", query_id=query_id) query_vector = [0.1, 0.2, 0.3] - retriever = SemanticRetriever(query_vector, pagination_params) + retriever = SemanticRetriever(query_vector, cursor=cursor) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) @@ -131,9 +136,8 @@ def test_pagination_structure(self, candidate_query, request): def test_metadata(self): """Test metadata returns correct search type.""" - pagination_params = PaginationParams() query_vector = [0.1, 0.2, 0.3] - retriever = SemanticRetriever(query_vector, pagination_params) + retriever = SemanticRetriever(query_vector, cursor=None) metadata = retriever.metadata @@ -145,20 +149,19 @@ class TestRrfHybridRetriever: def test_basic_query_structure(self, candidate_query, request): """Test hybrid RRF query structure with all CTEs.""" - pagination_params = PaginationParams() query_vector = [0.1, 0.2, 0.3] - retriever = RrfHybridRetriever(query_vector, "test", pagination_params) + retriever = RrfHybridRetriever(query_vector, "test", cursor=None) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) assert_sql_matches_snapshot("RrfHybridRetriever.test_basic_query_structure", sql, request) - def test_pagination_structure(self, candidate_query, request): + def test_pagination_structure(self, candidate_query, query_id, request): """Test that pagination adds score and entity_id comparison logic.""" - pagination_params = PaginationParams(page_after_score=0.95, page_after_id="entity-789") + cursor = PageCursor(score=0.95, id="entity-789", query_id=query_id) query_vector = [0.1, 0.2, 0.3] - retriever = RrfHybridRetriever(query_vector, "test", pagination_params) + retriever = RrfHybridRetriever(query_vector, "test", cursor=cursor) query = retriever.apply(candidate_query) sql = compile_query_to_sql(query) @@ -167,9 +170,8 @@ def test_pagination_structure(self, candidate_query, request): def test_metadata(self): """Test metadata returns correct search type.""" - pagination_params = PaginationParams() query_vector = [0.1, 0.2, 0.3] - retriever = RrfHybridRetriever(query_vector, "test", pagination_params) + retriever = RrfHybridRetriever(query_vector, "test", cursor=None) metadata = retriever.metadata