Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloboschi committed Jul 2, 2024
1 parent fd2d154 commit f339531
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 21 deletions.
14 changes: 10 additions & 4 deletions src/backend/base/langflow/components/vectorstores/AstraDB.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from loguru import logger

from langchain_core.vectorstores import VectorStore
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.helpers import docs_to_data
from langflow.inputs import FloatInput, DictInput
Expand All @@ -22,6 +23,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
documentation: str = "https://python.langchain.com/docs/integrations/vectorstores/astradb"
icon: str = "AstraDB"

_cached_vectorstore: VectorStore = None

inputs = [
StrInput(
name="collection_name",
Expand Down Expand Up @@ -158,6 +161,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
]

def _build_vector_store_no_ingest(self):
if self._cached_vectorstore:
return self._cached_vectorstore
try:
from langchain_astradb import AstraDBVectorStore
from langchain_astradb.utils.astradb import SetupMode
Expand Down Expand Up @@ -216,13 +221,13 @@ def _build_vector_store_no_ingest(self):
except Exception as e:
raise ValueError(f"Error initializing AstraDBVectorStore: {str(e)}") from e

self._cached_vectorstore = vector_store

return vector_store

def build_vector_store(self):
vector_store = self._build_vector_store_no_ingest()
if hasattr(self, "ingest_data") and self.ingest_data:
logger.debug("Ingesting data into the Vector Store.")
self._add_documents_to_vector_store(vector_store)
self._add_documents_to_vector_store(vector_store)
return vector_store

def _add_documents_to_vector_store(self, vector_store):
Expand All @@ -233,7 +238,7 @@ def _add_documents_to_vector_store(self, vector_store):
else:
raise ValueError("Vector Store Inputs must be Data objects.")

if documents and self.embedding is not None:
if documents:
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")
try:
vector_store.add_documents(documents)
Expand All @@ -252,6 +257,7 @@ def _map_search_type(self):

def search_documents(self) -> list[Data]:
vector_store = self._build_vector_store_no_ingest()
self._add_documents_to_vector_store(vector_store)

logger.debug(f"Search input: {self.search_input}")
logger.debug(f"Search type: {self.search_type}")
Expand Down
53 changes: 36 additions & 17 deletions src/backend/base/langflow/components/vectorstores/Cassandra.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List

from langchain_community.vectorstores import Cassandra

Expand All @@ -15,6 +15,7 @@
SecretStrInput,
)
from langflow.schema import Data
from loguru import logger


class CassandraVectorStoreComponent(LCVectorStoreComponent):
Expand All @@ -23,6 +24,8 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
documentation = "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/cassandra"
icon = "Cassandra"

_cached_vectorstore: Cassandra = None

inputs = [
MessageTextInput(
name="database_ref",
Expand Down Expand Up @@ -131,11 +134,14 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
]

def build_vector_store(self) -> Cassandra:
return self._build_cassandra(ingest=True)
return self._build_cassandra()

def _build_cassandra(self, ingest: bool) -> Cassandra:
def _build_cassandra(self) -> Cassandra:
if self._cached_vectorstore:
return self._cached_vectorstore
try:
import cassio
from langchain_community.utilities.cassandra import SetupMode
except ImportError:
raise ImportError(
"Could not import cassio integration package. " "Please install it with `pip install cassio`."
Expand Down Expand Up @@ -167,43 +173,48 @@ def _build_cassandra(self, ingest: bool) -> Cassandra:
password=self.token,
cluster_kwargs=self.cluster_kwargs,
)
ttl_seconds: Optional[int] = self.ttl_seconds

documents = []

if ingest:
for _input in self.ingest_data or []:
if isinstance(_input, Data):
documents.append(_input.to_lc_document())
else:
documents.append(_input)
for _input in self.ingest_data or []:
if isinstance(_input, Data):
documents.append(_input.to_lc_document())
else:
documents.append(_input)

if self.enable_body_search:
body_index_options = [("index_analyzer", "STANDARD")]
else:
body_index_options = None

if self.setup_mode == "Off":
setup_mode = SetupMode.OFF
elif self.setup_mode == "Sync":
setup_mode = SetupMode.SYNC
else:
setup_mode = SetupMode.ASYNC

if documents:
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")
table = Cassandra.from_documents(
documents=documents,
embedding=self.embedding,
table_name=self.table_name,
keyspace=self.keyspace,
ttl_seconds=ttl_seconds,
ttl_seconds=self.ttl_seconds or None,
batch_size=self.batch_size,
body_index_options=body_index_options,
)

else:
logger.debug("No documents to add to the Vector Store.")
table = Cassandra(
embedding=self.embedding,
table_name=self.table_name,
keyspace=self.keyspace,
ttl_seconds=ttl_seconds,
ttl_seconds=self.ttl_seconds or None,
body_index_options=body_index_options,
setup_mode=self.setup_mode,
setup_mode=setup_mode,
)

self._cached_vectorstore = table
return table

def _map_search_type(self):
Expand All @@ -215,13 +226,19 @@ def _map_search_type(self):
return "similarity"

def search_documents(self) -> List[Data]:
vector_store = self._build_cassandra(ingest=False)
vector_store = self._build_cassandra()

logger.debug(f"Search input: {self.search_query}")
logger.debug(f"Search type: {self.search_type}")
logger.debug(f"Number of results: {self.number_of_results}")

if self.search_query and isinstance(self.search_query, str) and self.search_query.strip():
try:
search_type = self._map_search_type()
search_args = self._build_search_args()

logger.debug(f"Search args: {str(search_args)}")

docs = vector_store.search(query=self.search_query, search_type=search_type, **search_args)
except KeyError as e:
if "content" in str(e):
Expand All @@ -231,6 +248,8 @@ def search_documents(self) -> List[Data]:
else:
raise e

logger.debug(f"Retrieved documents: {len(docs)}")

data = docs_to_data(docs)
self.status = data
return data
Expand Down

0 comments on commit f339531

Please sign in to comment.