Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Mar 4, 2024
1 parent d783816 commit b8bc1c2
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 34 deletions.
37 changes: 21 additions & 16 deletions backend/extraction/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Adapters to convert between different formats."""
from __future__ import annotations

from typing import Union

from langchain_core.utils.json_schema import dereference_refs


Expand All @@ -21,30 +23,33 @@ def _rm_titles(kv: dict) -> dict:


def convert_json_schema_to_openai_schema(
schema: dict,
schema: Union[dict, list],
*,
rm_titles: bool = True,
multi: bool = True,
) -> dict:
"""Convert JSON schema to a corresponding OpenAI function call."""
if multi:
# Wrap the schema in an object called "Root" with a property called: "data"
# which will be a json array of the original schema.
schema_ = {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": dereference_refs(schema),
},
},
"required": ["data"],
}
else:
if not multi:
raise NotImplementedError("Only multi is supported for now.")
# Wrap the schema in an object called "Root" with a property called: "data"
# which will be a json array of the original schema.
if isinstance(schema, dict):
schema_ = dereference_refs(schema)
else:
schema_ = {"anyOf": [dereference_refs(s) for s in schema]}
params = {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": schema_,
},
},
"required": ["data"],
}

return {
"name": "query_analyzer",
"description": "Generate optimized queries matching the given schema.",
"parameters": _rm_titles(schema_) if rm_titles else schema_,
"parameters": _rm_titles(params) if rm_titles else params,
}
13 changes: 4 additions & 9 deletions backend/server/query_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from typing import Any, Dict, List, Optional, Sequence
from typing import Any, Dict, List, Optional, Sequence, Union

from fastapi import HTTPException
from jsonschema import Draft202012Validator, exceptions
Expand Down Expand Up @@ -41,9 +41,9 @@ class QueryAnalysisRequest(CustomUserType):
messages: List[AnyMessage] = Field(
..., description="The messages to generates queries from."
)
json_schema: Dict[str, Any] = Field(
json_schema: Union[Dict[str, Any], List[Dict[str, Any]]] = Field(
...,
description="JSON schema that describes what a query looks like",
description="JSON schema(s) that describes what a query looks like",
alias="schema",
)
instructions: Optional[str] = Field(
Expand All @@ -56,7 +56,7 @@ class QueryAnalysisRequest(CustomUserType):
@validator("json_schema")
def validate_schema(cls, v: Any) -> Dict[str, Any]:
"""Validate the schema."""
validate_json_schema(v)
# validate_json_schema(v)
return v


Expand Down Expand Up @@ -154,11 +154,6 @@ async def query_analyzer(request: QueryAnalysisRequest) -> QueryAnalysisResponse
"""An end point to generate queries from a list of messages."""
# TODO: Add validation for model context window size
schema = request.json_schema
try:
Draft202012Validator.check_schema(schema)
except exceptions.ValidationError as e:
raise HTTPException(status_code=422, detail=f"Invalid schema: {e.message}")

openai_function = convert_json_schema_to_openai_schema(schema)
function_name = openai_function["name"]
prompt = _make_prompt_template(
Expand Down
51 changes: 42 additions & 9 deletions docs/source/notebooks/query_analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 13,
"id": "b123c960-a0b4-4d5e-b15f-729de23974f5",
"metadata": {
"tags": []
Expand All @@ -22,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 14,
"id": "19dafdeb-63c5-4218-b0f9-fc20754369be",
"metadata": {
"tags": []
Expand All @@ -34,21 +34,29 @@
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"\n",
"class Search(BaseModel):\n",
" \"\"\"Search over a database of tutorial videos about a software library.\"\"\" # noqa\n",
"class YouTubeSearch(BaseModel):\n",
" \"\"\"Search over a database of tutorial videos about a software library.\"\"\"\n",
"\n",
" query: str = Field(\n",
" ...,\n",
" description=\"Similarity search query applied to video transcripts.\", # noqa\n",
" description=\"Similarity search query applied to video transcripts.\",\n",
" )\n",
" publish_year: Optional[int] = Field(\n",
" None, description=\"Year video was published\"\n",
" )\n",
" \n",
"class APISearch(BaseModel):\n",
" \"\"\"Search over an API reference for software library.\"\"\"\n",
"\n",
" query: str = Field(\n",
" ...,\n",
" description=\"Similarity search query applied to reference documentation.\",\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 15,
"id": "bf79ef88-b816-46aa-addf-9366b7ebdcaf",
"metadata": {
"tags": []
Expand All @@ -60,7 +68,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 16,
"id": "553d7dbc-9117-4834-83b1-11e28a513170",
"metadata": {
"tags": []
Expand All @@ -72,7 +80,7 @@
"{'data': [{'query': 'RAG agent tutorial', 'publish_year': 2023}]}"
]
},
"execution_count": 20,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -82,7 +90,32 @@
"\n",
"messages = [HumanMessage(\"RAG agent tutorial from 2023\")]\n",
"response = runnable.invoke(\n",
" {\"messages\": messages, \"schema\": Search.schema()}\n",
" {\"messages\": messages, \"schema\": [YouTubeSearch.schema(), APISearch.schema()]}\n",
")\n",
"response"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "62ee5772",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'data': [{'query': 'arguments RunnablePassthrough.assign accept'}]}"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [HumanMessage(\"what arguments does RunnablePassthrough.assign accept\")]\n",
"response = runnable.invoke(\n",
" {\"messages\": messages, \"schema\": [YouTubeSearch.schema(), APISearch.schema()]}\n",
")\n",
"response"
]
Expand Down

0 comments on commit b8bc1c2

Please sign in to comment.