|
| 1 | +import os |
| 2 | + |
| 3 | +import orjson |
| 4 | +from astrapy.admin import parse_api_endpoint |
| 5 | +from loguru import logger |
| 6 | + |
| 7 | +from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store |
| 8 | +from langflow.helpers import docs_to_data |
| 9 | +from langflow.inputs import DictInput, FloatInput |
| 10 | +from langflow.io import ( |
| 11 | + BoolInput, |
| 12 | + DataInput, |
| 13 | + DropdownInput, |
| 14 | + HandleInput, |
| 15 | + IntInput, |
| 16 | + MultilineInput, |
| 17 | + SecretStrInput, |
| 18 | + StrInput, |
| 19 | +) |
| 20 | +from langflow.schema import Data |
| 21 | + |
| 22 | + |
| 23 | +class AstraGraphVectorStoreComponent(LCVectorStoreComponent): |
| 24 | + display_name: str = "Astra DB Graph" |
| 25 | + description: str = "Implementation of Graph Vector Store using Astra DB" |
| 26 | + documentation: str = "https://python.langchain.com/api_reference/astradb/graph_vectorstores/langchain_astradb.graph_vectorstores.AstraDBGraphVectorStore.html" |
| 27 | + name = "AstraDBGraph" |
| 28 | + icon: str = "AstraDB" |
| 29 | + |
| 30 | + inputs = [ |
| 31 | + SecretStrInput( |
| 32 | + name="token", |
| 33 | + display_name="Astra DB Application Token", |
| 34 | + info="Authentication token for accessing Astra DB.", |
| 35 | + value="ASTRA_DB_APPLICATION_TOKEN", |
| 36 | + required=True, |
| 37 | + advanced=os.getenv("ASTRA_ENHANCED", "false").lower() == "true", |
| 38 | + ), |
| 39 | + SecretStrInput( |
| 40 | + name="api_endpoint", |
| 41 | + display_name="Database" if os.getenv("ASTRA_ENHANCED", "false").lower() == "true" else "API Endpoint", |
| 42 | + info="API endpoint URL for the Astra DB service.", |
| 43 | + value="ASTRA_DB_API_ENDPOINT", |
| 44 | + required=True, |
| 45 | + ), |
| 46 | + StrInput( |
| 47 | + name="collection_name", |
| 48 | + display_name="Collection Name", |
| 49 | + info="The name of the collection within Astra DB where the vectors will be stored.", |
| 50 | + required=True, |
| 51 | + ), |
| 52 | + StrInput( |
| 53 | + name="link_to_metadata_key", |
| 54 | + display_name="Outgoing links metadata key", |
| 55 | + info="Metadata key used for outgoing links.", |
| 56 | + advanced=True, |
| 57 | + ), |
| 58 | + StrInput( |
| 59 | + name="link_from_metadata_key", |
| 60 | + display_name="Incoming links metadata key", |
| 61 | + info="Metadata key used for incoming links.", |
| 62 | + advanced=True, |
| 63 | + ), |
| 64 | + StrInput( |
| 65 | + name="namespace", |
| 66 | + display_name="Namespace", |
| 67 | + info="Optional namespace within Astra DB to use for the collection.", |
| 68 | + advanced=True, |
| 69 | + ), |
| 70 | + MultilineInput( |
| 71 | + name="search_input", |
| 72 | + display_name="Search Input", |
| 73 | + ), |
| 74 | + DataInput( |
| 75 | + name="ingest_data", |
| 76 | + display_name="Ingest Data", |
| 77 | + is_list=True, |
| 78 | + ), |
| 79 | + StrInput( |
| 80 | + name="namespace", |
| 81 | + display_name="Namespace", |
| 82 | + info="Optional namespace within Astra DB to use for the collection.", |
| 83 | + advanced=True, |
| 84 | + ), |
| 85 | + HandleInput( |
| 86 | + name="embedding", |
| 87 | + display_name="Embedding Model", |
| 88 | + input_types=["Embeddings"], |
| 89 | + info="Embedding model.", |
| 90 | + required=True, |
| 91 | + ), |
| 92 | + DropdownInput( |
| 93 | + name="metric", |
| 94 | + display_name="Metric", |
| 95 | + info="Optional distance metric for vector comparisons in the vector store.", |
| 96 | + options=["cosine", "dot_product", "euclidean"], |
| 97 | + value="cosine", |
| 98 | + advanced=True, |
| 99 | + ), |
| 100 | + IntInput( |
| 101 | + name="batch_size", |
| 102 | + display_name="Batch Size", |
| 103 | + info="Optional number of data to process in a single batch.", |
| 104 | + advanced=True, |
| 105 | + ), |
| 106 | + IntInput( |
| 107 | + name="bulk_insert_batch_concurrency", |
| 108 | + display_name="Bulk Insert Batch Concurrency", |
| 109 | + info="Optional concurrency level for bulk insert operations.", |
| 110 | + advanced=True, |
| 111 | + ), |
| 112 | + IntInput( |
| 113 | + name="bulk_insert_overwrite_concurrency", |
| 114 | + display_name="Bulk Insert Overwrite Concurrency", |
| 115 | + info="Optional concurrency level for bulk insert operations that overwrite existing data.", |
| 116 | + advanced=True, |
| 117 | + ), |
| 118 | + IntInput( |
| 119 | + name="bulk_delete_concurrency", |
| 120 | + display_name="Bulk Delete Concurrency", |
| 121 | + info="Optional concurrency level for bulk delete operations.", |
| 122 | + advanced=True, |
| 123 | + ), |
| 124 | + DropdownInput( |
| 125 | + name="setup_mode", |
| 126 | + display_name="Setup Mode", |
| 127 | + info="Configuration mode for setting up the vector store, with options like 'Sync', or 'Off'.", |
| 128 | + options=["Sync", "Off"], |
| 129 | + advanced=True, |
| 130 | + value="Sync", |
| 131 | + ), |
| 132 | + BoolInput( |
| 133 | + name="pre_delete_collection", |
| 134 | + display_name="Pre Delete Collection", |
| 135 | + info="Boolean flag to determine whether to delete the collection before creating a new one.", |
| 136 | + advanced=True, |
| 137 | + value=False, |
| 138 | + ), |
| 139 | + StrInput( |
| 140 | + name="metadata_indexing_include", |
| 141 | + display_name="Metadata Indexing Include", |
| 142 | + info="Optional list of metadata fields to include in the indexing.", |
| 143 | + advanced=True, |
| 144 | + is_list=True, |
| 145 | + ), |
| 146 | + StrInput( |
| 147 | + name="metadata_indexing_exclude", |
| 148 | + display_name="Metadata Indexing Exclude", |
| 149 | + info="Optional list of metadata fields to exclude from the indexing.", |
| 150 | + advanced=True, |
| 151 | + is_list=True, |
| 152 | + ), |
| 153 | + StrInput( |
| 154 | + name="collection_indexing_policy", |
| 155 | + display_name="Collection Indexing Policy", |
| 156 | + info='Optional JSON string for the "indexing" field of the collection. ' |
| 157 | + "See https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#the-indexing-option", |
| 158 | + advanced=True, |
| 159 | + ), |
| 160 | + IntInput( |
| 161 | + name="number_of_results", |
| 162 | + display_name="Number of Results", |
| 163 | + info="Number of results to return.", |
| 164 | + advanced=True, |
| 165 | + value=4, |
| 166 | + ), |
| 167 | + DropdownInput( |
| 168 | + name="search_type", |
| 169 | + display_name="Search Type", |
| 170 | + info="Search type to use", |
| 171 | + options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"], |
| 172 | + value="Similarity", |
| 173 | + advanced=True, |
| 174 | + ), |
| 175 | + FloatInput( |
| 176 | + name="search_score_threshold", |
| 177 | + display_name="Search Score Threshold", |
| 178 | + info="Minimum similarity score threshold for search results. " |
| 179 | + "(when using 'Similarity with score threshold')", |
| 180 | + value=0, |
| 181 | + advanced=True, |
| 182 | + ), |
| 183 | + DictInput( |
| 184 | + name="search_filter", |
| 185 | + display_name="Search Metadata Filter", |
| 186 | + info="Optional dictionary of filters to apply to the search query.", |
| 187 | + advanced=True, |
| 188 | + is_list=True, |
| 189 | + ), |
| 190 | + ] |
| 191 | + |
| 192 | + @check_cached_vector_store |
| 193 | + def build_vector_store(self): |
| 194 | + try: |
| 195 | + from langchain_astradb import AstraDBGraphVectorStore |
| 196 | + from langchain_astradb.utils.astradb import SetupMode |
| 197 | + except ImportError as e: |
| 198 | + msg = ( |
| 199 | + "Could not import langchain Astra DB integration package. " |
| 200 | + "Please install it with `pip install langchain-astradb`." |
| 201 | + ) |
| 202 | + raise ImportError(msg) from e |
| 203 | + |
| 204 | + try: |
| 205 | + vector_store = AstraDBGraphVectorStore( |
| 206 | + embedding=self.embedding, |
| 207 | + collection_name=self.collection_name, |
| 208 | + link_to_metadata_key=self.link_to_metadata_key or "links_to", |
| 209 | + link_from_metadata_key=self.link_from_metadata_key or "links_from", |
| 210 | + token=self.token, |
| 211 | + api_endpoint=self.api_endpoint, |
| 212 | + namespace=self.namespace or None, |
| 213 | + environment=parse_api_endpoint(self.api_endpoint).environment, |
| 214 | + metric=self.metric, |
| 215 | + batch_size=self.batch_size or None, |
| 216 | + bulk_insert_batch_concurrency=self.bulk_insert_batch_concurrency or None, |
| 217 | + bulk_insert_overwrite_concurrency=self.bulk_insert_overwrite_concurrency or None, |
| 218 | + bulk_delete_concurrency=self.bulk_delete_concurrency or None, |
| 219 | + setup_mode=SetupMode[self.setup_mode.upper()], |
| 220 | + pre_delete_collection=self.pre_delete_collection, |
| 221 | + metadata_indexing_include=[s for s in self.metadata_indexing_include if s], |
| 222 | + metadata_indexing_exclude=[s for s in self.metadata_indexing_exclude if s], |
| 223 | + collection_indexing_policy=orjson.dumps(self.collection_indexing_policy) |
| 224 | + if self.collection_indexing_policy |
| 225 | + else None, |
| 226 | + ) |
| 227 | + except Exception as e: |
| 228 | + msg = f"Error initializing AstraDBGraphVectorStore: {e}" |
| 229 | + raise ValueError(msg) from e |
| 230 | + |
| 231 | + self._add_documents_to_vector_store(vector_store) |
| 232 | + |
| 233 | + return vector_store |
| 234 | + |
| 235 | + def _add_documents_to_vector_store(self, vector_store) -> None: |
| 236 | + documents = [] |
| 237 | + for _input in self.ingest_data or []: |
| 238 | + if isinstance(_input, Data): |
| 239 | + documents.append(_input.to_lc_document()) |
| 240 | + else: |
| 241 | + msg = "Vector Store Inputs must be Data objects." |
| 242 | + raise TypeError(msg) |
| 243 | + |
| 244 | + if documents: |
| 245 | + logger.debug(f"Adding {len(documents)} documents to the Vector Store.") |
| 246 | + try: |
| 247 | + vector_store.add_documents(documents) |
| 248 | + except Exception as e: |
| 249 | + msg = f"Error adding documents to AstraDBGraphVectorStore: {e}" |
| 250 | + raise ValueError(msg) from e |
| 251 | + else: |
| 252 | + logger.debug("No documents to add to the Vector Store.") |
| 253 | + |
| 254 | + def _map_search_type(self) -> str: |
| 255 | + if self.search_type == "Similarity with score threshold": |
| 256 | + return "similarity_score_threshold" |
| 257 | + if self.search_type == "MMR (Max Marginal Relevance)": |
| 258 | + return "mmr" |
| 259 | + return "similarity" |
| 260 | + |
| 261 | + def _build_search_args(self): |
| 262 | + args = { |
| 263 | + "k": self.number_of_results, |
| 264 | + "score_threshold": self.search_score_threshold, |
| 265 | + } |
| 266 | + |
| 267 | + if self.search_filter: |
| 268 | + clean_filter = {k: v for k, v in self.search_filter.items() if k and v} |
| 269 | + if len(clean_filter) > 0: |
| 270 | + args["filter"] = clean_filter |
| 271 | + return args |
| 272 | + |
| 273 | + def search_documents(self, vector_store=None) -> list[Data]: |
| 274 | + if not vector_store: |
| 275 | + vector_store = self.build_vector_store() |
| 276 | + |
| 277 | + logger.debug(f"Search input: {self.search_input}") |
| 278 | + logger.debug(f"Search type: {self.search_type}") |
| 279 | + logger.debug(f"Number of results: {self.number_of_results}") |
| 280 | + |
| 281 | + if self.search_input and isinstance(self.search_input, str) and self.search_input.strip(): |
| 282 | + try: |
| 283 | + search_type = self._map_search_type() |
| 284 | + search_args = self._build_search_args() |
| 285 | + |
| 286 | + docs = vector_store.search(query=self.search_input, search_type=search_type, **search_args) |
| 287 | + except Exception as e: |
| 288 | + msg = f"Error performing search in AstraDBGraphVectorStore: {e}" |
| 289 | + raise ValueError(msg) from e |
| 290 | + |
| 291 | + logger.debug(f"Retrieved documents: {len(docs)}") |
| 292 | + |
| 293 | + data = docs_to_data(docs) |
| 294 | + logger.debug(f"Converted documents to data: {len(data)}") |
| 295 | + self.status = data |
| 296 | + return data |
| 297 | + logger.debug("No search input provided. Skipping search.") |
| 298 | + return [] |
| 299 | + |
| 300 | + def get_retriever_kwargs(self): |
| 301 | + search_args = self._build_search_args() |
| 302 | + return { |
| 303 | + "search_type": self._map_search_type(), |
| 304 | + "search_kwargs": search_args, |
| 305 | + } |
0 commit comments