From d7838163ba6b03744a4a5160cc3292b643262137 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 4 Mar 2024 13:18:18 -0800 Subject: [PATCH] wip: query analysis --- README.md | 6 +- backend/db/models.py | 68 ++++++- backend/extraction/utils.py | 6 +- backend/poetry.lock | 70 ++++--- backend/pyproject.toml | 2 +- backend/server/main.py | 13 ++ backend/server/query_analysis.py | 171 +++++++++++++++++ backend/server/settings.py | 2 +- docs/source/notebooks/query_analysis.ipynb | 211 +++++++++++++++++++++ 9 files changed, 513 insertions(+), 36 deletions(-) create mode 100644 backend/server/query_analysis.py create mode 100644 docs/source/notebooks/query_analysis.ipynb diff --git a/README.md b/README.md index 6e70129..1189880 100644 --- a/README.md +++ b/README.md @@ -173,13 +173,13 @@ poetry install --with lint,dev,test Run the following script to create a database and schema: ```sh -python -m scripts.run_migrations create +poetry run python -m scripts.run_migrations create ``` From `/backend`: ```sh -OPENAI_API_KEY=[YOUR API KEY] python -m server.main +OPENAI_API_KEY=[YOUR API KEY] poetry run python -m server.main ``` ### Testing @@ -189,7 +189,7 @@ separate from the main database. It will have the same schema as the main database. ```sh -python -m scripts.run_migrations create-test-db +poetry run python -m scripts.run_migrations create-test-db ``` Run the tests diff --git a/backend/db/models.py b/backend/db/models.py index 9025668..e67ac08 100644 --- a/backend/db/models.py +++ b/backend/db/models.py @@ -6,7 +6,6 @@ from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, relationship, sessionmaker -from sqlalchemy.sql import func from server.settings import get_postgres_url @@ -123,3 +122,70 @@ class Example(TimestampedModel): def __repr__(self) -> str: return f"" + + +class QueryAnalyzer(TimestampedModel): + __tablename__ = "query_analyzers" + + name = Column( + String(100), + nullable=False, + server_default="", + comment="The name of the query analyser.", + ) + schema = Column( + JSONB, + nullable=False, + comment="JSON Schema that describes schema of query", + ) + description = Column( + String(100), + nullable=False, + server_default="", + comment="Surfaced via UI to the users.", + ) + instruction = Column( + Text, nullable=False, comment="The prompt to the language model." + ) # TODO: This will need to evolve + + examples = relationship("QueryAnalysisExample", backref="query_analyzer") + + def __repr__(self) -> str: + return f"" + + +class QueryAnalysisExample(TimestampedModel): + """A representation of an example. + + Examples consist of content together with the expected output. + + The output is a JSON object that is expected to be extracted from the content. + + The JSON object should be valid according to the schema of the associated extractor. + + The JSON object is defined by the schema of the associated extractor, so + it's perfectly fine for a given example to represent the extraction + of multiple instances of some object from the content since + the JSON schema can represent a list of objects. + """ + + __tablename__ = "query_analysis_examples" + + content = Column( + JSONB, + nullable=False, + comment="The input portion of the example.", + ) + output = Column( + JSONB, + comment="The output associated with the example.", + ) + query_analyzer_id = Column( + UUID(as_uuid=True), + ForeignKey("query_analyzers.uuid", ondelete="CASCADE"), + nullable=False, + comment="Foreign key referencing the associated query analyzer.", + ) + + def __repr__(self) -> str: + return f"" diff --git a/backend/extraction/utils.py b/backend/extraction/utils.py index f6835cd..cdf38ba 100644 --- a/backend/extraction/utils.py +++ b/backend/extraction/utils.py @@ -43,10 +43,8 @@ def convert_json_schema_to_openai_schema( else: raise NotImplementedError("Only multi is supported for now.") - schema_.pop("definitions", None) - return { - "name": "extractor", - "description": "Extract information matching the given schema.", + "name": "query_analyzer", + "description": "Generate optimized queries matching the given schema.", "parameters": _rm_titles(schema_) if rm_titles else schema_, } diff --git a/backend/poetry.lock b/backend/poetry.lock index ec9f429..2c2c8a7 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -1737,13 +1737,13 @@ test = ["hatch", "ipykernel", "openapi-core (>=0.18.0,<0.19.0)", "openapi-spec-v [[package]] name = "langchain" -version = "0.1.5" +version = "0.1.10" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langchain-0.1.5-py3-none-any.whl", hash = "sha256:4614118d4a95b2e7ba3611a0b6b21707a259a21652a04fbe3c31205bcf3fcd50"}, - {file = "langchain-0.1.5.tar.gz", hash = "sha256:69603a5bb21b044ddea69d38131dbbf47475afdf79728644faa67d1ad325d652"}, + {file = "langchain-0.1.10-py3-none-any.whl", hash = "sha256:dcc1c0968b8d946a812155584ecbbeda690c930c3ee27bb5ecc113d954f6cf1a"}, + {file = "langchain-0.1.10.tar.gz", hash = "sha256:17951bcd6d74adc74aa081f260ef5514c449488815314420b7e0f8349f15d932"}, ] [package.dependencies] @@ -1751,9 +1751,10 @@ aiohttp = ">=3.8.3,<4.0.0" async-timeout = {version = ">=4.0.0,<5.0.0", markers = "python_version < \"3.11\""} dataclasses-json = ">=0.5.7,<0.7" jsonpatch = ">=1.33,<2.0" -langchain-community = ">=0.0.17,<0.1" -langchain-core = ">=0.1.16,<0.2" -langsmith = ">=0.0.83,<0.1" +langchain-community = ">=0.0.25,<0.1" +langchain-core = ">=0.1.28,<0.2" +langchain-text-splitters = ">=0.0.1,<0.1" +langsmith = ">=0.1.0,<0.2.0" numpy = ">=1,<2" pydantic = ">=1,<3" PyYAML = ">=5.3" @@ -1762,7 +1763,7 @@ SQLAlchemy = ">=1.4,<3" tenacity = ">=8.1.0,<9.0.0" [package.extras] -azure = ["azure-ai-formrecognizer (>=3.2.1,<4.0.0)", "azure-ai-textanalytics (>=5.3.0,<6.0.0)", "azure-ai-vision (>=0.11.1b1,<0.12.0)", "azure-cognitiveservices-speech (>=1.28.0,<2.0.0)", "azure-core (>=1.26.4,<2.0.0)", "azure-cosmos (>=4.4.0b1,<5.0.0)", "azure-identity (>=1.12.0,<2.0.0)", "azure-search-documents (==11.4.0b8)", "openai (<2)"] +azure = ["azure-ai-formrecognizer (>=3.2.1,<4.0.0)", "azure-ai-textanalytics (>=5.3.0,<6.0.0)", "azure-cognitiveservices-speech (>=1.28.0,<2.0.0)", "azure-core (>=1.26.4,<2.0.0)", "azure-cosmos (>=4.4.0b1,<5.0.0)", "azure-identity (>=1.12.0,<2.0.0)", "azure-search-documents (==11.4.0b8)", "openai (<2)"] clarifai = ["clarifai (>=9.1.0)"] cli = ["typer (>=0.9.0,<0.10.0)"] cohere = ["cohere (>=4,<5)"] @@ -1777,20 +1778,20 @@ text-helpers = ["chardet (>=5.1.0,<6.0.0)"] [[package]] name = "langchain-community" -version = "0.0.18" +version = "0.0.25" description = "Community contributed LangChain integrations." optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langchain_community-0.0.18-py3-none-any.whl", hash = "sha256:b87e20c1fa3f37e9608d7ccc08b4d8ed86f875b8c1e735d0464ae986e41c5a71"}, - {file = "langchain_community-0.0.18.tar.gz", hash = "sha256:f044f331b418f16148b76929f27cc2107fce2d190ea3fae0cdaf155ceda9892f"}, + {file = "langchain_community-0.0.25-py3-none-any.whl", hash = "sha256:09b931ba710b1a10e449396d59f38575e0554acd527287937c33a2c4abdc6d83"}, + {file = "langchain_community-0.0.25.tar.gz", hash = "sha256:b6c8c14cd6ec2635e51e3974bf78a8de3b959bbedb4af55aad164f8cf392f0c5"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" dataclasses-json = ">=0.5.7,<0.7" -langchain-core = ">=0.1.19,<0.2" -langsmith = ">=0.0.83,<0.1" +langchain-core = ">=0.1.28,<0.2.0" +langsmith = ">=0.1.0,<0.2.0" numpy = ">=1,<2" PyYAML = ">=5.3" requests = ">=2,<3" @@ -1799,23 +1800,23 @@ tenacity = ">=8.1.0,<9.0.0" [package.extras] cli = ["typer (>=0.9.0,<0.10.0)"] -extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<5)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "httpx (>=0.24.1,<0.25.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "nvidia-riva-client (>=2.14.0,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)", "zhipuai (>=1.0.7,<2.0.0)"] +extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<5)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "httpx (>=0.24.1,<0.25.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "nvidia-riva-client (>=2.14.0,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "tree-sitter (>=0.20.2,<0.21.0)", "tree-sitter-languages (>=1.8.0,<2.0.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)", "zhipuai (>=1.0.7,<2.0.0)"] [[package]] name = "langchain-core" -version = "0.1.19" +version = "0.1.28" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langchain_core-0.1.19-py3-none-any.whl", hash = "sha256:46b5fd54181df5aa6d3041d61beb2b91e5437b6742274e7924a97734ed62cf43"}, - {file = "langchain_core-0.1.19.tar.gz", hash = "sha256:30539190a63dff53e995f10aefb943b4f7e01aba4bf28fd1e13016b040c0e9da"}, + {file = "langchain_core-0.1.28-py3-none-any.whl", hash = "sha256:f40ca31257e003eb404e6275345d13a1b3839df147153684bab56bb8d80162c6"}, + {file = "langchain_core-0.1.28.tar.gz", hash = "sha256:04e761a513200b6e5b5818613821945799c07bc5349087d7692e50823107c9d6"}, ] [package.dependencies] anyio = ">=3,<5" jsonpatch = ">=1.33,<2.0" -langsmith = ">=0.0.83,<0.1" +langsmith = ">=0.1.0,<0.2.0" packaging = ">=23.2,<24.0" pydantic = ">=1,<3" PyYAML = ">=5.3" @@ -1827,21 +1828,37 @@ extended-testing = ["jinja2 (>=3,<4)"] [[package]] name = "langchain-openai" -version = "0.0.6" +version = "0.0.8" description = "An integration package connecting OpenAI and LangChain" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langchain_openai-0.0.6-py3-none-any.whl", hash = "sha256:2ef040e4447a26a9d3bd45dfac9cefa00797ea58555a3d91ab4f88699eb3a005"}, - {file = "langchain_openai-0.0.6.tar.gz", hash = "sha256:f5c4ebe46f2c8635c8f0c26cc8df27700aacafea025410e418d5a080039974dd"}, + {file = "langchain_openai-0.0.8-py3-none-any.whl", hash = "sha256:4862fc72cecbee0240aaa6df0234d5893dd30cd33ca23ac5cfdd86c11d2c44df"}, + {file = "langchain_openai-0.0.8.tar.gz", hash = "sha256:b7aba7fcc52305e78b08197ebc54fc45cc06dbc40ba5b913bc48a22b30a4f5c9"}, ] [package.dependencies] -langchain-core = ">=0.1.16,<0.2" -numpy = ">=1,<2" +langchain-core = ">=0.1.27,<0.2.0" openai = ">=1.10.0,<2.0.0" tiktoken = ">=0.5.2,<1" +[[package]] +name = "langchain-text-splitters" +version = "0.0.1" +description = "LangChain text splitting utilities" +optional = false +python-versions = ">=3.8.1,<4.0" +files = [ + {file = "langchain_text_splitters-0.0.1-py3-none-any.whl", hash = "sha256:f5b802f873f5ff6a8b9259ff34d53ed989666ef4e1582e6d1adb3b5520e3839a"}, + {file = "langchain_text_splitters-0.0.1.tar.gz", hash = "sha256:ac459fa98799f5117ad5425a9330b21961321e30bc19a2a2f9f761ddadd62aa1"}, +] + +[package.dependencies] +langchain-core = ">=0.1.28,<0.2.0" + +[package.extras] +extended-testing = ["lxml (>=5.1.0,<6.0.0)"] + [[package]] name = "langserve" version = "0.0.45" @@ -1866,16 +1883,17 @@ server = ["fastapi (>=0.90.1,<1)", "sse-starlette (>=1.3.0,<2.0.0)"] [[package]] name = "langsmith" -version = "0.0.92" +version = "0.1.16" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langsmith-0.0.92-py3-none-any.whl", hash = "sha256:ddcf65e3b5ca11893ae8ef9816ce2a11a089d051be491886e43a2c4556b88fd0"}, - {file = "langsmith-0.0.92.tar.gz", hash = "sha256:61a3a502222bdd221b7f592b6fc14756d74c4fc088aa6bd8834b92adfe9ee583"}, + {file = "langsmith-0.1.16-py3-none-any.whl", hash = "sha256:48dd9472b656561af520892fff50cc7789a8a1b96b97799d18d9ad2046c296c1"}, + {file = "langsmith-0.1.16.tar.gz", hash = "sha256:7db99c209091d75cd1d32c1bcebfe476656d9d662cd7faa61425635f7fe6533e"}, ] [package.dependencies] +orjson = ">=3.9.14,<4.0.0" pydantic = ">=1,<3" requests = ">=2,<3" @@ -4700,4 +4718,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "adb5aa5abbe85ad9a450118e1c91fdd578c187076b102e586315705fe0a241ae" +content-hash = "97f9563925d6adad46c479d89608f654f02042464536f3693bcc4adf0d5b137e" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 9d11f37..1b1a70c 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -14,7 +14,7 @@ fastapi = "^0.109.2" langserve = "^0.0.45" uvicorn = "^0.27.1" pydantic = "^1.10" -langchain-openai = "^0.0.6" +langchain-openai = "^0.0.8" jsonschema = "^4.21.1" sse-starlette = "^2.0.0" alembic = "^1.13.1" diff --git a/backend/server/main.py b/backend/server/main.py index 917533f..582940e 100644 --- a/backend/server/main.py +++ b/backend/server/main.py @@ -10,6 +10,11 @@ ExtractResponse, extraction_runnable, ) +from server.query_analysis import ( + QueryAnalysisRequest, + QueryAnalysisResponse, + query_analyzer, +) app = FastAPI( title="Extraction Powered by LangChain", @@ -56,6 +61,14 @@ def ready(): enabled_endpoints=["invoke", "batch"], ) +add_routes( + app, + query_analyzer.with_types( + input_type=QueryAnalysisRequest, output_type=QueryAnalysisResponse + ), + path="/query_analysis", + enabled_endpoints=["invoke", "batch"], +) if __name__ == "__main__": import uvicorn diff --git a/backend/server/query_analysis.py b/backend/server/query_analysis.py new file mode 100644 index 0000000..d8cbc91 --- /dev/null +++ b/backend/server/query_analysis.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import json +from typing import Any, Dict, List, Optional, Sequence + +from fastapi import HTTPException +from jsonschema import Draft202012Validator, exceptions +from langchain_core.messages import AIMessage, AnyMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.prompts.chat import MessageLikeRepresentation +from langchain_core.runnables import chain +from langserve import CustomUserType +from pydantic import BaseModel, Field, validator +from typing_extensions import TypedDict + +from db.models import QueryAnalysisExample as DBQueryAnalysisExample +from db.models import QueryAnalyzer +from extraction.utils import convert_json_schema_to_openai_schema +from server.settings import get_model +from server.validators import validate_json_schema + +# Instantiate the model +model = get_model() + + +class QueryAnalysisExample(BaseModel): + """An example query analysis. + + This example consists of input messages and the expected queries. + """ + + messages: List[AnyMessage] = Field(..., description="The input messages") + output: List[Dict[str, Any]] = Field( + ..., description="The expected output of the example. A list of objects." + ) + + +class QueryAnalysisRequest(CustomUserType): + """Request body for the query analyzer endpoint.""" + + messages: List[AnyMessage] = Field( + ..., description="The messages to generates queries from." + ) + json_schema: Dict[str, Any] = Field( + ..., + description="JSON schema that describes what a query looks like", + alias="schema", + ) + instructions: Optional[str] = Field( + None, description="Supplemental system instructions." + ) + examples: Optional[List[QueryAnalysisExample]] = Field( + None, description="Examples of optimized queries." + ) + + @validator("json_schema") + def validate_schema(cls, v: Any) -> Dict[str, Any]: + """Validate the schema.""" + validate_json_schema(v) + return v + + +class QueryAnalysisResponse(TypedDict): + """Response body for the query analysis endpoint.""" + + data: List[Any] + + +def _deduplicate( + responses: Sequence[QueryAnalysisResponse], +) -> QueryAnalysisResponse: + """Deduplicate the results. + + The deduplication is done by comparing the serialized JSON of each of the results + and only keeping the unique ones. + """ + unique = [] + seen = set() + for response in responses: + for data_item in response["data"]: + # Serialize the data item for comparison purposes + serialized = json.dumps(data_item, sort_keys=True) + if serialized not in seen: + seen.add(serialized) + unique.append(data_item) + + return { + "data": unique, + } + + +def _cast_example_to_dict(example: DBQueryAnalysisExample) -> Dict[str, Any]: + """Cast example record to dictionary.""" + return { + "messages": example.content, + "output": example.output, + } + + +def _make_prompt_template( + instructions: Optional[str], + examples: Optional[Sequence[QueryAnalysisExample]], + function_name: str, +) -> ChatPromptTemplate: + """Make a system message from instructions and examples.""" + prefix = ( + "You are a world class expert at converting user questions into database " + "queries. Given a question, return a list of database queries optimized to " + "retrieve the most relevant results." + ) + if instructions: + system_message = ("system", f"{prefix}\n\n{instructions}") + else: + system_message = ("system", prefix) + prompt_components: List[MessageLikeRepresentation] = [system_message] + if examples is not None: + for example in examples: + # TODO: We'll need to refactor this at some point to + # support other encoding strategies. The function calling logic here + # has some hard-coded assumptions (e.g., name of parameters like `data`). + function_call = { + "arguments": json.dumps( + { + "data": example.output, + } + ), + "name": function_name, + } + prompt_components.extend( + [ + *example.messages, + AIMessage( + content="", additional_kwargs={"function_call": function_call} + ), + ] + ) + + prompt_components.append(MessagesPlaceholder("input")) + return ChatPromptTemplate.from_messages(prompt_components) + + +# PUBLIC API + + +def get_examples_from_query_analyzer( + query_analyzer: QueryAnalyzer, +) -> List[Dict[str, Any]]: + """Get examples from an query_analyzer.""" + return [_cast_example_to_dict(example) for example in query_analyzer.examples] + + +@chain +async def query_analyzer(request: QueryAnalysisRequest) -> QueryAnalysisResponse: + """An end point to generate queries from a list of messages.""" + # TODO: Add validation for model context window size + schema = request.json_schema + try: + Draft202012Validator.check_schema(schema) + except exceptions.ValidationError as e: + raise HTTPException(status_code=422, detail=f"Invalid schema: {e.message}") + + openai_function = convert_json_schema_to_openai_schema(schema) + function_name = openai_function["name"] + prompt = _make_prompt_template( + request.instructions, + request.examples, + function_name, + ) + runnable = prompt | model.with_structured_output(openai_function) + + return await runnable.ainvoke({"input": request.messages}) diff --git a/backend/server/settings.py b/backend/server/settings.py index 86dca67..87f9ca1 100644 --- a/backend/server/settings.py +++ b/backend/server/settings.py @@ -5,7 +5,7 @@ from langchain_openai import ChatOpenAI from sqlalchemy.engine import URL -MODEL_NAME = "gpt-3.5-turbo" +MODEL_NAME = "gpt-3.5-turbo-0125" CHUNK_SIZE = int(4_096 * 0.8) # Max concurrency for the model. MAX_CONCURRENCY = 1 diff --git a/docs/source/notebooks/query_analysis.ipynb b/docs/source/notebooks/query_analysis.ipynb new file mode 100644 index 0000000..66c7a1f --- /dev/null +++ b/docs/source/notebooks/query_analysis.ipynb @@ -0,0 +1,211 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7e43ed67-9fbb-4d6c-9a5d-8c4addeb2ed5", + "metadata": {}, + "source": [ + "# Query analysis" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "b123c960-a0b4-4d5e-b15f-729de23974f5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langserve import RemoteRunnable" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "19dafdeb-63c5-4218-b0f9-fc20754369be", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from typing import Optional\n", + "\n", + "from langchain_core.pydantic_v1 import BaseModel, Field\n", + "\n", + "\n", + "class Search(BaseModel):\n", + " \"\"\"Search over a database of tutorial videos about a software library.\"\"\" # noqa\n", + "\n", + " query: str = Field(\n", + " ...,\n", + " description=\"Similarity search query applied to video transcripts.\", # noqa\n", + " )\n", + " publish_year: Optional[int] = Field(\n", + " None, description=\"Year video was published\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "bf79ef88-b816-46aa-addf-9366b7ebdcaf", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "runnable = RemoteRunnable(\"http://localhost:8000/query_analysis/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "553d7dbc-9117-4834-83b1-11e28a513170", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'data': [{'query': 'RAG agent tutorial', 'publish_year': 2023}]}" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_core.messages import HumanMessage\n", + "\n", + "messages = [HumanMessage(\"RAG agent tutorial from 2023\")]\n", + "response = runnable.invoke(\n", + " {\"messages\": messages, \"schema\": Search.schema()}\n", + ")\n", + "response" + ] + }, + { + "cell_type": "markdown", + "id": "c70d8d7c-5f0b-4757-92b7-cdd40f351275", + "metadata": {}, + "source": [ + "Add instructions:" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "97294409-6daf-418d-9cbe-f44946245e35", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'data': [{'query': 'RAG agent tutorial', 'publish_year': 2023}]}" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "instructions = (\n", + " \"Always expand acronym RAG to Retrieval Augmented Generation. \"\n", + " \"NEVER INCLUDE RAG IN THE SEARCH\"\n", + ")\n", + "\n", + "response = runnable.invoke(\n", + " {\n", + " \"messages\": messages,\n", + " \"schema\": Search.schema(),\n", + " \"instructions\": instructions,\n", + " }\n", + ")\n", + "response" + ] + }, + { + "cell_type": "markdown", + "id": "24b4a123-7841-465b-b43b-1db439c45fa7", + "metadata": {}, + "source": [ + "Add few-shot examples:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "bae9416d-abd4-4b41-90c2-3144c8566483", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'data': [{'query': 'RAG agent tutorial', 'publish_year': 2023}]}" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "examples = [\n", + " {\n", + " \"messages\": [HumanMessage(\"RAG from scratch series\")],\n", + " \"output\": [ \n", + " {\"search\": \"Retrieval Augmented Generation from scratch\"} \n", + " ],\n", + " }\n", + "]\n", + "\n", + "response = runnable.invoke(\n", + " {\n", + " \"messages\": messages,\n", + " \"schema\": Search.schema(),\n", + " \"instructions\": instructions,\n", + " \"examples\": examples,\n", + " }\n", + ")\n", + "response" + ] + }, + { + "cell_type": "markdown", + "id": "d8619550", + "metadata": {}, + "source": [ + "Trace: https://smith.langchain.com/public/7a95a5af-7f89-4312-90b8-88cc408de1a7/r" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "extract-venv", + "language": "python", + "name": "extract-venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}