Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retriever: MultiQueryRetriever #177

Merged
merged 8 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions ix/chains/fixture_src/retriever.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from langchain.retrievers import MultiQueryRetriever
from langchain.vectorstores.base import VectorStoreRetriever

from ix.chains.fixture_src.targets import (
RETRIEVER_TARGET,
LLM_TARGET,
PROMPT_TARGET,
)
from ix.api.components.types import NodeTypeField
from ix.chains.fixture_src.targets import VECTORSTORE_TARGET

Expand All @@ -26,8 +32,24 @@
),
}

RETRIEVERS = [
VECTORSTORE_RETRIEVER,
]

__all__ = ["RETRIEVERS", "VECTORSTORE_RETRIEVER_CLASS_PATH"]
MULTI_QUERY_RETRIEVER_CLASS_PATH = (
"langchain.retrievers.multi_query.MultiQueryRetriever.from_llm"
)
MULTI_QUERY_RETRIEVER = {
"class_path": MULTI_QUERY_RETRIEVER_CLASS_PATH,
"type": "retriever",
"name": "MultiQueryRetriever",
"description": "MultiQueryRetriever",
"connectors": [RETRIEVER_TARGET, LLM_TARGET, PROMPT_TARGET],
"fields": [] + NodeTypeField.get_fields(MultiQueryRetriever, include=["parse_key"]),
}


RETRIEVERS = [VECTORSTORE_RETRIEVER, MULTI_QUERY_RETRIEVER]

__all__ = [
"RETRIEVERS",
"VECTORSTORE_RETRIEVER_CLASS_PATH",
"MULTI_QUERY_RETRIEVER_CLASS_PATH",
]
41 changes: 41 additions & 0 deletions ix/chains/fixtures/node_types.json
Original file line number Diff line number Diff line change
Expand Up @@ -2378,6 +2378,47 @@
"config_schema": {}
}
},
{
"model": "chains.nodetype",
"pk": "7b954f7c-4449-46b4-9013-382ce3c94c2a",
"fields": {
"name": "MultiQueryRetriever",
"description": "MultiQueryRetriever",
"class_path": "langchain.retrievers.multi_query.MultiQueryRetriever.from_llm",
"type": "retriever",
"display_type": "node",
"connectors": [
{
"key": "retriever",
"type": "target",
"as_type": "retriever",
"required": true,
"source_type": [
"retriever",
"vectorstore"
]
},
{
"key": "llm",
"type": "target",
"required": true,
"source_type": "llm"
},
{
"key": "prompt",
"type": "target",
"source_type": "prompt"
}
],
"fields": [],
"child_field": null,
"config_schema": {
"type": "object",
"required": [],
"properties": {}
}
}
},
{
"model": "chains.nodetype",
"pk": "7d49f24a-18e2-4db9-8be3-bbbe6aa25440",
Expand Down
14 changes: 13 additions & 1 deletion ix/chains/loaders/retriever.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import deepcopy
from typing import List

from asgiref.sync import sync_to_async
from langchain.schema import BaseRetriever
from langchain.vectorstores import VectorStore

Expand All @@ -11,6 +12,17 @@
from ix.utils.importlib import import_class


async def async_aget_relevant_documents(self, *args, **kwargs):
"""Async wrapper for BaseRetriever._get_relevant_documents"""
return await sync_to_async(self._get_relevant_documents)(*args, **kwargs)


# HAX: monkeypatch asyncio support into BaseRetriever. This is a gigantic hack, but
# it's the easiest way to get support for all retrievers without having to
# modify langchain or implement lots of custom wrappers.
setattr(BaseRetriever, "_aget_relevant_documents", async_aget_relevant_documents)


def load_retriever_property(
node_group: List[ChainNode], context: IxContext
) -> BaseRetriever:
Expand All @@ -25,7 +37,7 @@ def load_retriever_property(
node = node_group[0]
component_class = import_class(node.class_path)

if issubclass(component_class, VectorStore):
if isinstance(component_class, type) and issubclass(component_class, VectorStore):
# unpack retriever fields from vectorstore config
config = deepcopy(node.config)
retriever_fields = get_vectorstore_retriever_fieldnames(node.class_path)
Expand Down
2 changes: 1 addition & 1 deletion ix/chains/management/commands/import_langchain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from django.core.management.base import BaseCommand

from ix.api.components.types import NodeTypeField
from ix.api.chains.types import NodeType as NodeTypePydantic
from ix.api.components.types import NodeType as NodeTypePydantic
from ix.chains.fixture_src.agent_interaction import AGENT_INTERACTION_CHAINS

from ix.chains.fixture_src.agents import AGENTS
Expand Down
Loading