Skip to content

Commit dc9b75e

Browse files
committed
Add AstraGraphVectorStoreComponent
1 parent e93397d commit dc9b75e

File tree

1 file changed

+305
-0
lines changed

1 file changed

+305
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
import os
2+
3+
import orjson
4+
from astrapy.admin import parse_api_endpoint
5+
from loguru import logger
6+
7+
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
8+
from langflow.helpers import docs_to_data
9+
from langflow.inputs import DictInput, FloatInput
10+
from langflow.io import (
11+
BoolInput,
12+
DataInput,
13+
DropdownInput,
14+
HandleInput,
15+
IntInput,
16+
MultilineInput,
17+
SecretStrInput,
18+
StrInput,
19+
)
20+
from langflow.schema import Data
21+
22+
23+
class AstraGraphVectorStoreComponent(LCVectorStoreComponent):
24+
display_name: str = "Astra DB Graph"
25+
description: str = "Implementation of Graph Vector Store using Astra DB"
26+
documentation: str = "https://python.langchain.com/api_reference/astradb/graph_vectorstores/langchain_astradb.graph_vectorstores.AstraDBGraphVectorStore.html"
27+
name = "AstraDBGraph"
28+
icon: str = "AstraDB"
29+
30+
inputs = [
31+
SecretStrInput(
32+
name="token",
33+
display_name="Astra DB Application Token",
34+
info="Authentication token for accessing Astra DB.",
35+
value="ASTRA_DB_APPLICATION_TOKEN",
36+
required=True,
37+
advanced=os.getenv("ASTRA_ENHANCED", "false").lower() == "true",
38+
),
39+
SecretStrInput(
40+
name="api_endpoint",
41+
display_name="Database" if os.getenv("ASTRA_ENHANCED", "false").lower() == "true" else "API Endpoint",
42+
info="API endpoint URL for the Astra DB service.",
43+
value="ASTRA_DB_API_ENDPOINT",
44+
required=True,
45+
),
46+
StrInput(
47+
name="collection_name",
48+
display_name="Collection Name",
49+
info="The name of the collection within Astra DB where the vectors will be stored.",
50+
required=True,
51+
),
52+
StrInput(
53+
name="link_to_metadata_key",
54+
display_name="Outgoing links metadata key",
55+
info="Metadata key used for outgoing links.",
56+
advanced=True,
57+
),
58+
StrInput(
59+
name="link_from_metadata_key",
60+
display_name="Incoming links metadata key",
61+
info="Metadata key used for incoming links.",
62+
advanced=True,
63+
),
64+
StrInput(
65+
name="namespace",
66+
display_name="Namespace",
67+
info="Optional namespace within Astra DB to use for the collection.",
68+
advanced=True,
69+
),
70+
MultilineInput(
71+
name="search_input",
72+
display_name="Search Input",
73+
),
74+
DataInput(
75+
name="ingest_data",
76+
display_name="Ingest Data",
77+
is_list=True,
78+
),
79+
StrInput(
80+
name="namespace",
81+
display_name="Namespace",
82+
info="Optional namespace within Astra DB to use for the collection.",
83+
advanced=True,
84+
),
85+
HandleInput(
86+
name="embedding",
87+
display_name="Embedding Model",
88+
input_types=["Embeddings"],
89+
info="Embedding model.",
90+
required=True,
91+
),
92+
DropdownInput(
93+
name="metric",
94+
display_name="Metric",
95+
info="Optional distance metric for vector comparisons in the vector store.",
96+
options=["cosine", "dot_product", "euclidean"],
97+
value="cosine",
98+
advanced=True,
99+
),
100+
IntInput(
101+
name="batch_size",
102+
display_name="Batch Size",
103+
info="Optional number of data to process in a single batch.",
104+
advanced=True,
105+
),
106+
IntInput(
107+
name="bulk_insert_batch_concurrency",
108+
display_name="Bulk Insert Batch Concurrency",
109+
info="Optional concurrency level for bulk insert operations.",
110+
advanced=True,
111+
),
112+
IntInput(
113+
name="bulk_insert_overwrite_concurrency",
114+
display_name="Bulk Insert Overwrite Concurrency",
115+
info="Optional concurrency level for bulk insert operations that overwrite existing data.",
116+
advanced=True,
117+
),
118+
IntInput(
119+
name="bulk_delete_concurrency",
120+
display_name="Bulk Delete Concurrency",
121+
info="Optional concurrency level for bulk delete operations.",
122+
advanced=True,
123+
),
124+
DropdownInput(
125+
name="setup_mode",
126+
display_name="Setup Mode",
127+
info="Configuration mode for setting up the vector store, with options like 'Sync', or 'Off'.",
128+
options=["Sync", "Off"],
129+
advanced=True,
130+
value="Sync",
131+
),
132+
BoolInput(
133+
name="pre_delete_collection",
134+
display_name="Pre Delete Collection",
135+
info="Boolean flag to determine whether to delete the collection before creating a new one.",
136+
advanced=True,
137+
value=False,
138+
),
139+
StrInput(
140+
name="metadata_indexing_include",
141+
display_name="Metadata Indexing Include",
142+
info="Optional list of metadata fields to include in the indexing.",
143+
advanced=True,
144+
is_list=True,
145+
),
146+
StrInput(
147+
name="metadata_indexing_exclude",
148+
display_name="Metadata Indexing Exclude",
149+
info="Optional list of metadata fields to exclude from the indexing.",
150+
advanced=True,
151+
is_list=True,
152+
),
153+
StrInput(
154+
name="collection_indexing_policy",
155+
display_name="Collection Indexing Policy",
156+
info='Optional JSON string for the "indexing" field of the collection. '
157+
"See https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#the-indexing-option",
158+
advanced=True,
159+
),
160+
IntInput(
161+
name="number_of_results",
162+
display_name="Number of Results",
163+
info="Number of results to return.",
164+
advanced=True,
165+
value=4,
166+
),
167+
DropdownInput(
168+
name="search_type",
169+
display_name="Search Type",
170+
info="Search type to use",
171+
options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"],
172+
value="Similarity",
173+
advanced=True,
174+
),
175+
FloatInput(
176+
name="search_score_threshold",
177+
display_name="Search Score Threshold",
178+
info="Minimum similarity score threshold for search results. "
179+
"(when using 'Similarity with score threshold')",
180+
value=0,
181+
advanced=True,
182+
),
183+
DictInput(
184+
name="search_filter",
185+
display_name="Search Metadata Filter",
186+
info="Optional dictionary of filters to apply to the search query.",
187+
advanced=True,
188+
is_list=True,
189+
),
190+
]
191+
192+
@check_cached_vector_store
193+
def build_vector_store(self):
194+
try:
195+
from langchain_astradb import AstraDBGraphVectorStore
196+
from langchain_astradb.utils.astradb import SetupMode
197+
except ImportError as e:
198+
msg = (
199+
"Could not import langchain Astra DB integration package. "
200+
"Please install it with `pip install langchain-astradb`."
201+
)
202+
raise ImportError(msg) from e
203+
204+
try:
205+
vector_store = AstraDBGraphVectorStore(
206+
embedding=self.embedding,
207+
collection_name=self.collection_name,
208+
link_to_metadata_key=self.link_to_metadata_key or "links_to",
209+
link_from_metadata_key=self.link_from_metadata_key or "links_from",
210+
token=self.token,
211+
api_endpoint=self.api_endpoint,
212+
namespace=self.namespace or None,
213+
environment=parse_api_endpoint(self.api_endpoint).environment,
214+
metric=self.metric,
215+
batch_size=self.batch_size or None,
216+
bulk_insert_batch_concurrency=self.bulk_insert_batch_concurrency or None,
217+
bulk_insert_overwrite_concurrency=self.bulk_insert_overwrite_concurrency or None,
218+
bulk_delete_concurrency=self.bulk_delete_concurrency or None,
219+
setup_mode=SetupMode[self.setup_mode.upper()],
220+
pre_delete_collection=self.pre_delete_collection,
221+
metadata_indexing_include=[s for s in self.metadata_indexing_include if s],
222+
metadata_indexing_exclude=[s for s in self.metadata_indexing_exclude if s],
223+
collection_indexing_policy=orjson.dumps(self.collection_indexing_policy)
224+
if self.collection_indexing_policy
225+
else None,
226+
)
227+
except Exception as e:
228+
msg = f"Error initializing AstraDBGraphVectorStore: {e}"
229+
raise ValueError(msg) from e
230+
231+
self._add_documents_to_vector_store(vector_store)
232+
233+
return vector_store
234+
235+
def _add_documents_to_vector_store(self, vector_store) -> None:
236+
documents = []
237+
for _input in self.ingest_data or []:
238+
if isinstance(_input, Data):
239+
documents.append(_input.to_lc_document())
240+
else:
241+
msg = "Vector Store Inputs must be Data objects."
242+
raise TypeError(msg)
243+
244+
if documents:
245+
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")
246+
try:
247+
vector_store.add_documents(documents)
248+
except Exception as e:
249+
msg = f"Error adding documents to AstraDBGraphVectorStore: {e}"
250+
raise ValueError(msg) from e
251+
else:
252+
logger.debug("No documents to add to the Vector Store.")
253+
254+
def _map_search_type(self) -> str:
255+
if self.search_type == "Similarity with score threshold":
256+
return "similarity_score_threshold"
257+
if self.search_type == "MMR (Max Marginal Relevance)":
258+
return "mmr"
259+
return "similarity"
260+
261+
def _build_search_args(self):
262+
args = {
263+
"k": self.number_of_results,
264+
"score_threshold": self.search_score_threshold,
265+
}
266+
267+
if self.search_filter:
268+
clean_filter = {k: v for k, v in self.search_filter.items() if k and v}
269+
if len(clean_filter) > 0:
270+
args["filter"] = clean_filter
271+
return args
272+
273+
def search_documents(self, vector_store=None) -> list[Data]:
274+
if not vector_store:
275+
vector_store = self.build_vector_store()
276+
277+
logger.debug(f"Search input: {self.search_input}")
278+
logger.debug(f"Search type: {self.search_type}")
279+
logger.debug(f"Number of results: {self.number_of_results}")
280+
281+
if self.search_input and isinstance(self.search_input, str) and self.search_input.strip():
282+
try:
283+
search_type = self._map_search_type()
284+
search_args = self._build_search_args()
285+
286+
docs = vector_store.search(query=self.search_input, search_type=search_type, **search_args)
287+
except Exception as e:
288+
msg = f"Error performing search in AstraDBGraphVectorStore: {e}"
289+
raise ValueError(msg) from e
290+
291+
logger.debug(f"Retrieved documents: {len(docs)}")
292+
293+
data = docs_to_data(docs)
294+
logger.debug(f"Converted documents to data: {len(data)}")
295+
self.status = data
296+
return data
297+
logger.debug("No search input provided. Skipping search.")
298+
return []
299+
300+
def get_retriever_kwargs(self):
301+
search_args = self._build_search_args()
302+
return {
303+
"search_type": self._map_search_type(),
304+
"search_kwargs": search_args,
305+
}

0 commit comments

Comments
 (0)