Skip to content

Commit 6d7bf78

Browse files
namkwangwooolgamurraft
authored andcommitted
core[minor]: Support asynchronous in InMemoryVectorStore (langchain-ai#24472)
### Description * support asynchronous in InMemoryVectorStore * since embeddings might be possible to call asynchronously, ensure that both asynchronous and synchronous functions operate correctly.
1 parent c5b12fd commit 6d7bf78

File tree

2 files changed

+164
-21
lines changed

2 files changed

+164
-21
lines changed

libs/core/langchain_core/vectorstores/in_memory.py

+54-12
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
Any,
99
Callable,
1010
Dict,
11-
Iterable,
1211
List,
1312
Optional,
1413
Sequence,
@@ -74,6 +73,27 @@ def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
7473
"failed": [],
7574
}
7675

76+
async def aupsert(
77+
self, items: Sequence[Document], /, **kwargs: Any
78+
) -> UpsertResponse:
79+
vectors = await self.embedding.aembed_documents(
80+
[item.page_content for item in items]
81+
)
82+
ids = []
83+
for item, vector in zip(items, vectors):
84+
doc_id = item.id if item.id else str(uuid.uuid4())
85+
ids.append(doc_id)
86+
self.store[doc_id] = {
87+
"id": doc_id,
88+
"vector": vector,
89+
"text": item.page_content,
90+
"metadata": item.metadata,
91+
}
92+
return {
93+
"succeeded": ids,
94+
"failed": [],
95+
}
96+
7797
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
7898
"""Get documents by their ids.
7999
@@ -108,14 +128,6 @@ async def aget_by_ids(self, ids: Sequence[str], /) -> List[Document]:
108128
"""
109129
return self.get_by_ids(ids)
110130

111-
async def aadd_texts(
112-
self,
113-
texts: Iterable[str],
114-
metadatas: Optional[List[dict]] = None,
115-
**kwargs: Any,
116-
) -> List[str]:
117-
return self.add_texts(texts, metadatas, **kwargs)
118-
119131
def _similarity_search_with_score_by_vector(
120132
self,
121133
embedding: List[float],
@@ -172,7 +184,13 @@ def similarity_search_with_score(
172184
async def asimilarity_search_with_score(
173185
self, query: str, k: int = 4, **kwargs: Any
174186
) -> List[Tuple[Document, float]]:
175-
return self.similarity_search_with_score(query, k, **kwargs)
187+
embedding = await self.embedding.aembed_query(query)
188+
docs = self.similarity_search_with_score_by_vector(
189+
embedding,
190+
k,
191+
**kwargs,
192+
)
193+
return docs
176194

177195
def similarity_search_by_vector(
178196
self,
@@ -200,7 +218,10 @@ def similarity_search(
200218
async def asimilarity_search(
201219
self, query: str, k: int = 4, **kwargs: Any
202220
) -> List[Document]:
203-
return self.similarity_search(query, k, **kwargs)
221+
return [
222+
doc
223+
for doc, _ in await self.asimilarity_search_with_score(query, k, **kwargs)
224+
]
204225

205226
def max_marginal_relevance_search_by_vector(
206227
self,
@@ -249,6 +270,23 @@ def max_marginal_relevance_search(
249270
**kwargs,
250271
)
251272

273+
async def amax_marginal_relevance_search(
274+
self,
275+
query: str,
276+
k: int = 4,
277+
fetch_k: int = 20,
278+
lambda_mult: float = 0.5,
279+
**kwargs: Any,
280+
) -> List[Document]:
281+
embedding_vector = await self.embedding.aembed_query(query)
282+
return self.max_marginal_relevance_search_by_vector(
283+
embedding_vector,
284+
k,
285+
fetch_k,
286+
lambda_mult=lambda_mult,
287+
**kwargs,
288+
)
289+
252290
@classmethod
253291
def from_texts(
254292
cls,
@@ -271,7 +309,11 @@ async def afrom_texts(
271309
metadatas: Optional[List[dict]] = None,
272310
**kwargs: Any,
273311
) -> "InMemoryVectorStore":
274-
return cls.from_texts(texts, embedding, metadatas, **kwargs)
312+
store = cls(
313+
embedding=embedding,
314+
)
315+
await store.aadd_texts(texts=texts, metadatas=metadatas, **kwargs)
316+
return store
275317

276318
@classmethod
277319
def load(

libs/core/tests/unit_tests/vectorstores/test_in_memory.py

+110-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pathlib import Path
2+
from unittest.mock import AsyncMock, Mock
23

34
import pytest
45
from langchain_standard_tests.integration_tests.vectorstores import (
@@ -24,43 +25,65 @@ async def vectorstore(self) -> InMemoryVectorStore:
2425
return InMemoryVectorStore(embedding=self.get_embeddings())
2526

2627

27-
async def test_inmemory() -> None:
28-
"""Test end to end construction and search."""
28+
async def test_inmemory_similarity_search() -> None:
29+
"""Test end to end similarity search."""
2930
store = await InMemoryVectorStore.afrom_texts(
30-
["foo", "bar", "baz"], DeterministicFakeEmbedding(size=6)
31+
["foo", "bar", "baz"], DeterministicFakeEmbedding(size=3)
3132
)
32-
output = await store.asimilarity_search("foo", k=1)
33+
34+
# Check sync version
35+
output = store.similarity_search("foo", k=1)
3336
assert output == [Document(page_content="foo", id=AnyStr())]
3437

38+
# Check async version
3539
output = await store.asimilarity_search("bar", k=2)
3640
assert output == [
3741
Document(page_content="bar", id=AnyStr()),
3842
Document(page_content="baz", id=AnyStr()),
3943
]
4044

41-
output2 = await store.asimilarity_search_with_score("bar", k=2)
42-
assert output2[0][1] > output2[1][1]
45+
46+
async def test_inmemory_similarity_search_with_score() -> None:
47+
"""Test end to end similarity search with score"""
48+
store = await InMemoryVectorStore.afrom_texts(
49+
["foo", "bar", "baz"], DeterministicFakeEmbedding(size=3)
50+
)
51+
52+
output = store.similarity_search_with_score("foo", k=1)
53+
assert output[0][0].page_content == "foo"
54+
55+
output = await store.asimilarity_search_with_score("bar", k=2)
56+
assert output[0][1] > output[1][1]
4357

4458

4559
async def test_add_by_ids() -> None:
60+
"""Test add texts with ids."""
4661
vectorstore = InMemoryVectorStore(embedding=DeterministicFakeEmbedding(size=6))
4762

4863
# Check sync version
4964
ids1 = vectorstore.add_texts(["foo", "bar", "baz"], ids=["1", "2", "3"])
5065
assert ids1 == ["1", "2", "3"]
5166
assert sorted(vectorstore.store.keys()) == ["1", "2", "3"]
5267

68+
# Check async version
5369
ids2 = await vectorstore.aadd_texts(["foo", "bar", "baz"], ids=["4", "5", "6"])
5470
assert ids2 == ["4", "5", "6"]
5571
assert sorted(vectorstore.store.keys()) == ["1", "2", "3", "4", "5", "6"]
5672

5773

5874
async def test_inmemory_mmr() -> None:
75+
"""Test MMR search"""
5976
texts = ["foo", "foo", "fou", "foy"]
6077
docsearch = await InMemoryVectorStore.afrom_texts(
6178
texts, DeterministicFakeEmbedding(size=6)
6279
)
6380
# make sure we can k > docstore size
81+
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
82+
assert len(output) == len(texts)
83+
assert output[0] == Document(page_content="foo", id=AnyStr())
84+
assert output[1] == Document(page_content="foy", id=AnyStr())
85+
86+
# Check async version
6487
output = await docsearch.amax_marginal_relevance_search(
6588
"foo", k=10, lambda_mult=0.1
6689
)
@@ -85,13 +108,91 @@ async def test_inmemory_dump_load(tmp_path: Path) -> None:
85108

86109

87110
async def test_inmemory_filter() -> None:
88-
"""Test end to end construction and search."""
111+
"""Test end to end construction and search with filter."""
89112
store = await InMemoryVectorStore.afrom_texts(
90113
["foo", "bar"],
91114
DeterministicFakeEmbedding(size=6),
92115
[{"id": 1}, {"id": 2}],
93116
)
117+
118+
# Check sync version
119+
output = store.similarity_search("fee", filter=lambda doc: doc.metadata["id"] == 1)
120+
assert output == [Document(page_content="foo", metadata={"id": 1}, id=AnyStr())]
121+
122+
# filter with not stored document id
94123
output = await store.asimilarity_search(
95-
"baz", filter=lambda doc: doc.metadata["id"] == 1
124+
"baz", filter=lambda doc: doc.metadata["id"] == 3
96125
)
97-
assert output == [Document(page_content="foo", metadata={"id": 1}, id=AnyStr())]
126+
assert output == []
127+
128+
129+
async def test_inmemory_upsert() -> None:
130+
"""Test upsert documents."""
131+
embedding = DeterministicFakeEmbedding(size=2)
132+
store = InMemoryVectorStore(embedding=embedding)
133+
134+
# Check sync version
135+
store.upsert([Document(page_content="foo", id="1")])
136+
assert sorted(store.store.keys()) == ["1"]
137+
138+
# Check async version
139+
await store.aupsert([Document(page_content="bar", id="2")])
140+
assert sorted(store.store.keys()) == ["1", "2"]
141+
142+
# update existing document
143+
await store.aupsert(
144+
[Document(page_content="baz", id="2", metadata={"metadata": "value"})]
145+
)
146+
item = store.store["2"]
147+
148+
baz_vector = embedding.embed_query("baz")
149+
assert item == {
150+
"id": "2",
151+
"text": "baz",
152+
"vector": baz_vector,
153+
"metadata": {"metadata": "value"},
154+
}
155+
156+
157+
async def test_inmemory_get_by_ids() -> None:
158+
"""Test get by ids."""
159+
160+
store = InMemoryVectorStore(embedding=DeterministicFakeEmbedding(size=3))
161+
162+
store.upsert(
163+
[
164+
Document(page_content="foo", id="1", metadata={"metadata": "value"}),
165+
Document(page_content="bar", id="2"),
166+
Document(page_content="baz", id="3"),
167+
],
168+
)
169+
170+
# Check sync version
171+
output = store.get_by_ids(["1", "2"])
172+
assert output == [
173+
Document(page_content="foo", id="1", metadata={"metadata": "value"}),
174+
Document(page_content="bar", id="2"),
175+
]
176+
177+
# Check async version
178+
output = await store.aget_by_ids(["1", "3", "5"])
179+
assert output == [
180+
Document(page_content="foo", id="1", metadata={"metadata": "value"}),
181+
Document(page_content="baz", id="3"),
182+
]
183+
184+
185+
async def test_inmemory_call_embeddings_async() -> None:
186+
embeddings_mock = Mock(
187+
wraps=DeterministicFakeEmbedding(size=3),
188+
aembed_documents=AsyncMock(),
189+
aembed_query=AsyncMock(),
190+
)
191+
store = InMemoryVectorStore(embedding=embeddings_mock)
192+
193+
await store.aadd_texts("foo")
194+
await store.asimilarity_search("foo", k=1)
195+
196+
# Ensure the async embedding function is called
197+
assert embeddings_mock.aembed_documents.await_count == 1
198+
assert embeddings_mock.aembed_query.await_count == 1

0 commit comments

Comments
 (0)