From b8bc1c2d6323282657a96787cee7633067f0289f Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 4 Mar 2024 13:45:37 -0800 Subject: [PATCH] fmt --- backend/extraction/utils.py | 37 +++++++++------- backend/server/query_analysis.py | 13 ++---- docs/source/notebooks/query_analysis.ipynb | 51 ++++++++++++++++++---- 3 files changed, 67 insertions(+), 34 deletions(-) diff --git a/backend/extraction/utils.py b/backend/extraction/utils.py index cdf38ba..b59c4e2 100644 --- a/backend/extraction/utils.py +++ b/backend/extraction/utils.py @@ -1,6 +1,8 @@ """Adapters to convert between different formats.""" from __future__ import annotations +from typing import Union + from langchain_core.utils.json_schema import dereference_refs @@ -21,30 +23,33 @@ def _rm_titles(kv: dict) -> dict: def convert_json_schema_to_openai_schema( - schema: dict, + schema: Union[dict, list], *, rm_titles: bool = True, multi: bool = True, ) -> dict: """Convert JSON schema to a corresponding OpenAI function call.""" - if multi: - # Wrap the schema in an object called "Root" with a property called: "data" - # which will be a json array of the original schema. - schema_ = { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": dereference_refs(schema), - }, - }, - "required": ["data"], - } - else: + if not multi: raise NotImplementedError("Only multi is supported for now.") + # Wrap the schema in an object called "Root" with a property called: "data" + # which will be a json array of the original schema. + if isinstance(schema, dict): + schema_ = dereference_refs(schema) + else: + schema_ = {"anyOf": [dereference_refs(s) for s in schema]} + params = { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": schema_, + }, + }, + "required": ["data"], + } return { "name": "query_analyzer", "description": "Generate optimized queries matching the given schema.", - "parameters": _rm_titles(schema_) if rm_titles else schema_, + "parameters": _rm_titles(params) if rm_titles else params, } diff --git a/backend/server/query_analysis.py b/backend/server/query_analysis.py index d8cbc91..2f56f6f 100644 --- a/backend/server/query_analysis.py +++ b/backend/server/query_analysis.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import Any, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence, Union from fastapi import HTTPException from jsonschema import Draft202012Validator, exceptions @@ -41,9 +41,9 @@ class QueryAnalysisRequest(CustomUserType): messages: List[AnyMessage] = Field( ..., description="The messages to generates queries from." ) - json_schema: Dict[str, Any] = Field( + json_schema: Union[Dict[str, Any], List[Dict[str, Any]]] = Field( ..., - description="JSON schema that describes what a query looks like", + description="JSON schema(s) that describes what a query looks like", alias="schema", ) instructions: Optional[str] = Field( @@ -56,7 +56,7 @@ class QueryAnalysisRequest(CustomUserType): @validator("json_schema") def validate_schema(cls, v: Any) -> Dict[str, Any]: """Validate the schema.""" - validate_json_schema(v) + # validate_json_schema(v) return v @@ -154,11 +154,6 @@ async def query_analyzer(request: QueryAnalysisRequest) -> QueryAnalysisResponse """An end point to generate queries from a list of messages.""" # TODO: Add validation for model context window size schema = request.json_schema - try: - Draft202012Validator.check_schema(schema) - except exceptions.ValidationError as e: - raise HTTPException(status_code=422, detail=f"Invalid schema: {e.message}") - openai_function = convert_json_schema_to_openai_schema(schema) function_name = openai_function["name"] prompt = _make_prompt_template( diff --git a/docs/source/notebooks/query_analysis.ipynb b/docs/source/notebooks/query_analysis.ipynb index 66c7a1f..9ebe727 100644 --- a/docs/source/notebooks/query_analysis.ipynb +++ b/docs/source/notebooks/query_analysis.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 13, "id": "b123c960-a0b4-4d5e-b15f-729de23974f5", "metadata": { "tags": [] @@ -22,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 14, "id": "19dafdeb-63c5-4218-b0f9-fc20754369be", "metadata": { "tags": [] @@ -34,21 +34,29 @@ "from langchain_core.pydantic_v1 import BaseModel, Field\n", "\n", "\n", - "class Search(BaseModel):\n", - " \"\"\"Search over a database of tutorial videos about a software library.\"\"\" # noqa\n", + "class YouTubeSearch(BaseModel):\n", + " \"\"\"Search over a database of tutorial videos about a software library.\"\"\"\n", "\n", " query: str = Field(\n", " ...,\n", - " description=\"Similarity search query applied to video transcripts.\", # noqa\n", + " description=\"Similarity search query applied to video transcripts.\",\n", " )\n", " publish_year: Optional[int] = Field(\n", " None, description=\"Year video was published\"\n", + " )\n", + " \n", + "class APISearch(BaseModel):\n", + " \"\"\"Search over an API reference for software library.\"\"\"\n", + "\n", + " query: str = Field(\n", + " ...,\n", + " description=\"Similarity search query applied to reference documentation.\",\n", " )" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 15, "id": "bf79ef88-b816-46aa-addf-9366b7ebdcaf", "metadata": { "tags": [] @@ -60,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 16, "id": "553d7dbc-9117-4834-83b1-11e28a513170", "metadata": { "tags": [] @@ -72,7 +80,7 @@ "{'data': [{'query': 'RAG agent tutorial', 'publish_year': 2023}]}" ] }, - "execution_count": 20, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -82,7 +90,32 @@ "\n", "messages = [HumanMessage(\"RAG agent tutorial from 2023\")]\n", "response = runnable.invoke(\n", - " {\"messages\": messages, \"schema\": Search.schema()}\n", + " {\"messages\": messages, \"schema\": [YouTubeSearch.schema(), APISearch.schema()]}\n", + ")\n", + "response" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "62ee5772", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'data': [{'query': 'arguments RunnablePassthrough.assign accept'}]}" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = [HumanMessage(\"what arguments does RunnablePassthrough.assign accept\")]\n", + "response = runnable.invoke(\n", + " {\"messages\": messages, \"schema\": [YouTubeSearch.schema(), APISearch.schema()]}\n", ")\n", "response" ]