Skip to content

Commit

Permalink
community: fixed bug in GraphVectorStoreRetriever (langchain-ai#27846)
Browse files Browse the repository at this point in the history
Description:

This fixes an issue that mistakenly created in
langchain-ai#27253. The issue
currently exists only in `langchain-community==0.3.4`.

Test cases were added to prevent this issue in the future.

Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
2 people authored and yanomaly committed Nov 8, 2024
1 parent 0dd871e commit 1bbbf4f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
21 changes: 14 additions & 7 deletions libs/community/langchain_community/graph_vectorstores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ClassVar,
Optional,
Sequence,
cast,
)

from langchain_core._api import beta
Expand Down Expand Up @@ -701,7 +702,7 @@ def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever:
docsearch.as_retriever(search_kwargs={'k': 1})
"""
return GraphVectorStoreRetriever(vector_store=self, **kwargs)
return GraphVectorStoreRetriever(vectorstore=self, **kwargs)


@beta(message="Added in version 0.3.1 of langchain_community. API subject to change.")
Expand Down Expand Up @@ -837,8 +838,8 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
retriever = graph_vectorstore.as_retriever(search_kwargs={"score_threshold": 0.5})
""" # noqa: E501

vector_store: GraphVectorStore
"""GraphVectorStore to use for retrieval."""
vectorstore: VectorStore
"""VectorStore to use for retrieval."""
search_type: str = "traversal"
"""Type of search to perform. Defaults to "traversal"."""
allowed_search_types: ClassVar[Collection[str]] = (
Expand All @@ -849,14 +850,20 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
"mmr_traversal",
)

@property
def graph_vectorstore(self) -> GraphVectorStore:
return cast(GraphVectorStore, self.vectorstore)

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
if self.search_type == "traversal":
return list(self.vector_store.traversal_search(query, **self.search_kwargs))
return list(
self.graph_vectorstore.traversal_search(query, **self.search_kwargs)
)
elif self.search_type == "mmr_traversal":
return list(
self.vector_store.mmr_traversal_search(query, **self.search_kwargs)
self.graph_vectorstore.mmr_traversal_search(query, **self.search_kwargs)
)
else:
return super()._get_relevant_documents(query, run_manager=run_manager)
Expand All @@ -867,14 +874,14 @@ async def _aget_relevant_documents(
if self.search_type == "traversal":
return [
doc
async for doc in self.vector_store.atraversal_search(
async for doc in self.graph_vectorstore.atraversal_search(
query, **self.search_kwargs
)
]
elif self.search_type == "mmr_traversal":
return [
doc
async for doc in self.vector_store.ammr_traversal_search(
async for doc in self.graph_vectorstore.ammr_traversal_search(
query, **self.search_kwargs
)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,17 @@ def test_gvs_traversal_search_sync(
ts_labels = {doc.metadata["label"] for doc in ts_response}
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}

# verify the same works as a retriever
retriever = g_store.as_retriever(
search_type="traversal", search_kwargs={"k": 2, "depth": 2}
)

ts_labels = {
doc.metadata["label"]
for doc in retriever.get_relevant_documents(query="[2, 10]")
}
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}

async def test_gvs_traversal_search_async(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
Expand All @@ -453,6 +464,17 @@ async def test_gvs_traversal_search_async(
# so ordering is not deterministic:
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}

# verify the same works as a retriever
retriever = g_store.as_retriever(
search_type="traversal", search_kwargs={"k": 2, "depth": 2}
)

ts_labels = {
doc.metadata["label"]
for doc in await retriever.aget_relevant_documents(query="[2, 10]")
}
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}

def test_gvs_mmr_traversal_search_sync(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
Expand Down

0 comments on commit 1bbbf4f

Please sign in to comment.