Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenAPI chain with OpenAI functions #97

Merged
merged 9 commits into from
Jul 8, 2023
28 changes: 25 additions & 3 deletions ix/chains/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Optional
from typing import Optional, Any, Dict
from asgiref.sync import sync_to_async
from langchain.callbacks.manager import AsyncCallbackManagerForToolRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForChainRun,
)


class SyncToAsync:
class SyncToAsyncRun:
"""
Mixin to convert a chain or tool to run asynchronously by using
sync_to_async to convert the run method to a coroutine.
Expand All @@ -20,3 +23,22 @@ async def _arun(
"""Use the tool asynchronously."""
result = await sync_to_async(self._run)(query, run_manager=run_manager)
return result


class SyncToAsyncCall:
"""
Mixin to convert a chain or tool to run asynchronously by using
sync_to_async to convert the _call method to a coroutine.

This doesn't provide full async support, but it does allow for
the chain/tool to work.
"""

async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> str:
"""Use the tool asynchronously."""
result = await sync_to_async(self._call)(inputs, run_manager=run_manager)
return result
22 changes: 22 additions & 0 deletions ix/chains/fixture_src/openai_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from ix.chains.fixture_src.common import VERBOSE
from ix.chains.fixture_src.targets import LLM_TARGET, MEMORY_TARGET, PROMPT_TARGET

FUNCTION_SCHEMA = {
"class_path": "ix.chains.functions.FunctionSchema",
"type": "tool",
Expand Down Expand Up @@ -40,3 +43,22 @@
"name": "function_call",
"type": "string",
}

OPENAPI_CHAIN = {
"class_path": "ix.chains.openapi.get_openapi_chain_async",
"type": "chain",
"name": "OpenAPI with OpenAI Functions",
"description": "Chain that uses OpenAI Functions to call OpenAPI endpoints.",
"connectors": [LLM_TARGET, MEMORY_TARGET, PROMPT_TARGET],
"fields": [
VERBOSE,
{
"name": "spec",
"type": "string",
"label": "OpenAPI URL",
"style": {
"width": "500px",
},
},
],
}
8 changes: 7 additions & 1 deletion ix/chains/management/commands/import_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from ix.chains.fixture_src.agents import AGENTS
from ix.chains.fixture_src.artifacts import ARTIFACT_MEMORY, SAVE_ARTIFACT
from ix.chains.fixture_src.chains import LLM_CHAIN, LLM_TOOL_CHAIN, LLM_REPLY
from ix.chains.fixture_src.chains import (
LLM_CHAIN,
LLM_TOOL_CHAIN,
LLM_REPLY,
)
from ix.chains.fixture_src.chat_memory_backend import (
FILESYSTEM_MEMORY_BACKEND,
REDIS_MEMORY_BACKEND,
Expand All @@ -26,6 +30,7 @@
from ix.chains.fixture_src.openai_functions import (
FUNCTION_SCHEMA,
FUNCTION_OUTPUT_PARSER,
OPENAPI_CHAIN,
)
from ix.chains.fixture_src.prompts import CHAT_PROMPT_TEMPLATE
from ix.chains.fixture_src.routing import SEQUENCE, MAP_SUBCHAIN
Expand Down Expand Up @@ -77,6 +82,7 @@
[
FUNCTION_SCHEMA,
FUNCTION_OUTPUT_PARSER,
OPENAPI_CHAIN,
]
)

Expand Down
34 changes: 34 additions & 0 deletions ix/chains/openapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging
from unittest import mock
from langchain.chains.openai_functions.openapi import (
get_openapi_chain,
SimpleRequestChain,
)
from langchain.prompts import ChatPromptTemplate

from ix.chains.asyncio import SyncToAsyncCall

logger = logging.getLogger(__name__)


class AsyncSimpleRequestChainRun(SyncToAsyncCall, SimpleRequestChain):
pass


def get_openapi_chain_async(**kwargs):
"""
Extremely hacky way of injecting asyncio support into LangChain's function.
Done within this wrapper function to limit the scope of the patch.
"""
with mock.patch(
"langchain.chains.openai_functions.openapi.SimpleRequestChain",
new=AsyncSimpleRequestChainRun,
):
# modified to use `user_input` for consistency with other chains
if "prompt" not in kwargs:
kwargs["prompt"] = ChatPromptTemplate.from_template(
"Use the provided API's to respond to this user query:\n\n{user_input}"
)

chain = get_openapi_chain(**kwargs)
return chain
4 changes: 2 additions & 2 deletions ix/tools/arxiv.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from langchain import ArxivAPIWrapper
from langchain.tools import BaseTool, ArxivQueryRun

from ix.chains.asyncio import SyncToAsync
from ix.chains.asyncio import SyncToAsyncRun
from ix.chains.loaders.tools import extract_tool_kwargs
from typing import Any


class AsyncArxivQueryRun(SyncToAsync, ArxivQueryRun):
class AsyncArxivQueryRun(SyncToAsyncRun, ArxivQueryRun):
pass


Expand Down
4 changes: 2 additions & 2 deletions ix/tools/bing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from langchain.tools import BaseTool, BingSearchRun
from langchain.utilities import BingSearchAPIWrapper

from ix.chains.asyncio import SyncToAsync
from ix.chains.asyncio import SyncToAsyncRun
from ix.chains.loaders.tools import extract_tool_kwargs


class AsyncBingSearchRun(SyncToAsync, BingSearchRun):
class AsyncBingSearchRun(SyncToAsyncRun, BingSearchRun):
pass


Expand Down
4 changes: 2 additions & 2 deletions ix/tools/duckduckgo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from langchain.tools import DuckDuckGoSearchRun, BaseTool
from langchain.utilities import DuckDuckGoSearchAPIWrapper

from ix.chains.asyncio import SyncToAsync
from ix.chains.asyncio import SyncToAsyncRun
from ix.chains.loaders.tools import extract_tool_kwargs


class AsyncDuckDuckGoSearchRun(SyncToAsync, DuckDuckGoSearchRun):
class AsyncDuckDuckGoSearchRun(SyncToAsyncRun, DuckDuckGoSearchRun):
pass


Expand Down
4 changes: 2 additions & 2 deletions ix/tools/google.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ix.chains.asyncio import SyncToAsync
from ix.chains.asyncio import SyncToAsyncRun
from ix.chains.loaders.tools import extract_tool_kwargs
from typing import Any

Expand All @@ -23,7 +23,7 @@ def get_google_serper_results_json(**kwargs: Any) -> BaseTool:
return GoogleSerperResults(api_wrapper=wrapper, **tool_kwargs)


class AsyncGoogleSearchResults(SyncToAsync, GoogleSearchResults):
class AsyncGoogleSearchResults(SyncToAsyncRun, GoogleSearchResults):
pass


Expand Down
4 changes: 2 additions & 2 deletions ix/tools/graphql.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from langchain.tools import BaseTool, BaseGraphQLTool
from langchain.utilities import GraphQLAPIWrapper

from ix.chains.asyncio import SyncToAsync
from ix.chains.asyncio import SyncToAsyncRun
from ix.chains.loaders.tools import extract_tool_kwargs


class AsyncGraphQLTool(SyncToAsync, BaseGraphQLTool):
class AsyncGraphQLTool(SyncToAsyncRun, BaseGraphQLTool):
pass


Expand Down
4 changes: 2 additions & 2 deletions ix/tools/pubmed.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from langchain.tools import PubmedQueryRun, BaseTool
from langchain.utilities import PubMedAPIWrapper

from ix.chains.asyncio import SyncToAsync
from ix.chains.asyncio import SyncToAsyncRun
from ix.chains.loaders.tools import extract_tool_kwargs
from typing import Any


class AsyncPubmedQueryRun(SyncToAsync, PubmedQueryRun):
class AsyncPubmedQueryRun(SyncToAsyncRun, PubmedQueryRun):
pass


Expand Down
4 changes: 2 additions & 2 deletions ix/tools/wikipedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from langchain import WikipediaAPIWrapper
from langchain.tools import WikipediaQueryRun, BaseTool

from ix.chains.asyncio import SyncToAsync
from ix.chains.asyncio import SyncToAsyncRun
from ix.chains.loaders.tools import extract_tool_kwargs


class AsyncWikipediaQueryRun(SyncToAsync, WikipediaQueryRun):
class AsyncWikipediaQueryRun(SyncToAsyncRun, WikipediaQueryRun):
pass


Expand Down
4 changes: 2 additions & 2 deletions ix/tools/wolfram_alpha.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from langchain import WolframAlphaAPIWrapper
from langchain.tools import BaseTool, WolframAlphaQueryRun

from ix.chains.asyncio import SyncToAsync
from ix.chains.asyncio import SyncToAsyncRun
from ix.chains.loaders.tools import extract_tool_kwargs
from typing import Any


class AsyncWolframAlphaQueryRun(SyncToAsync, WolframAlphaQueryRun):
class AsyncWolframAlphaQueryRun(SyncToAsyncRun, WolframAlphaQueryRun):
pass


Expand Down