Skip to content

Commit

Permalink
cassandra/astradb: hybrid search support
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloboschi committed Jun 27, 2024
1 parent 7af8b6b commit e868454
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 33 deletions.
10 changes: 9 additions & 1 deletion src/backend/base/langflow/base/vectorstores/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,15 @@ def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
"""
vector_store = self.build_vector_store()
if hasattr(vector_store, "as_retriever"):
retriever = vector_store.as_retriever()
retriever = vector_store.as_retriever(**self.get_retriever_kwargs())
if self.status is None:
self.status = "Retriever built successfully."
return retriever
else:
raise ValueError(f"Vector Store {vector_store.__class__.__name__} does not have an as_retriever method.")

def search_documents(self) -> List[Data]:
logger.info("here we go ", self.search_query + " " + self.search_type + " " + self.number_of_results)
"""
Search for documents in the Chroma vector store.
"""
Expand All @@ -106,3 +107,10 @@ def search_documents(self) -> List[Data]:
)
self.status = search_results
return search_results

def get_retriever_kwargs(self):
"""
Get the retriever kwargs. Implementations can override this method to provide custom retriever kwargs.
"""
return {}

78 changes: 59 additions & 19 deletions src/backend/base/langflow/components/vectorstores/AstraDB.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from loguru import logger

from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.helpers import docs_to_data
from langflow.inputs import FloatInput, DictInput
from langflow.io import (
BoolInput,
DataInput,
Expand Down Expand Up @@ -121,19 +123,34 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
info="Optional dictionary defining the indexing policy for the collection.",
advanced=True,
),
IntInput(
name="number_of_results",
display_name="Number of Results",
info="Number of results to return.",
advanced=True,
value=4,
),
DropdownInput(
name="search_type",
display_name="Search Type",
options=["Similarity", "MMR"],
info="Search type to use",
options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"],
value="Similarity",
advanced=True,
),
IntInput(
name="number_of_results",
display_name="Number of Results",
info="Number of results to return.",
FloatInput(
name="search_score_threshold",
display_name="Search Score Threshold",
info="Minimum similarity score threshold for search results. (when using 'Similarity with score threshold')",
value=0,
advanced=True,
value=4,
),
DictInput(
name="search_filter",
display_name="Search Metadata Filter",
info="Optional dictionary of filters to apply to the search query.",
advanced=True,
is_list=True
),
]

Expand Down Expand Up @@ -220,6 +237,14 @@ def _add_documents_to_vector_store(self, vector_store):
else:
logger.debug("No documents to add to the Vector Store.")


def _map_search_type(self):
if self.search_type == "Similarity with score threshold":
return "similarity_score_threshold"
elif self.search_type == "MMR (Max Marginal Relevance)":
return "mmr"
else:
return "similarity"
def search_documents(self) -> list[Data]:
vector_store = self.build_vector_store()

Expand All @@ -229,24 +254,18 @@ def search_documents(self) -> list[Data]:

if self.search_input and isinstance(self.search_input, str) and self.search_input.strip():
try:
if self.search_type == "Similarity":
docs = vector_store.similarity_search(
query=self.search_input,
k=self.number_of_results,
)
elif self.search_type == "MMR":
docs = vector_store.max_marginal_relevance_search(
query=self.search_input,
k=self.number_of_results,
)
else:
raise ValueError(f"Invalid search type: {self.search_type}")
search_type = self._map_search_type()
search_args = self._build_search_args()

docs = vector_store.search(query=self.search_input,
search_type=search_type,
**search_args)
except Exception as e:
raise ValueError(f"Error performing search in AstraDBVectorStore: {str(e)}") from e

logger.debug(f"Retrieved documents: {len(docs)}")

data = [Data.from_document(doc) for doc in docs]
data = docs_to_data(docs)
logger.debug(f"Converted documents to data: {len(data)}")
self.status = data
return data
Expand All @@ -263,3 +282,24 @@ def _astradb_collection_to_data(self, collection):
for item in data_dict:
data.append(Data(content=item["content"]))
return data

def _build_search_args(self):
args = {
"k": self.number_of_results,
"score_threshold": self.search_score_threshold,
}

if self.search_filter:
clean_filter = {
k: v for k, v in self.search_filter.items() if k and v
}
if len(clean_filter) > 0:
args["filter"] = clean_filter
return args

def get_retriever_kwargs(self):
search_args = self._build_search_args()
return {
"search_type": self._map_search_type(),
"search_kwargs": search_args,
}
99 changes: 86 additions & 13 deletions src/backend/base/langflow/components/vectorstores/Cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.helpers.data import docs_to_data
from langflow.inputs import DictInput
from langflow.inputs import DictInput, FloatInput, BoolInput
from langflow.io import (
DataInput,
DropdownInput,
Expand Down Expand Up @@ -62,12 +62,7 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
value=16,
advanced=True,
),
MessageTextInput(
name="body_index_options",
display_name="Body Index Options",
info="Optional options used to create the body index.",
advanced=True,
),

DropdownInput(
name="setup_mode",
display_name="Setup Mode",
Expand Down Expand Up @@ -97,6 +92,41 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
value=4,
advanced=True,
),
DropdownInput(
name="search_type",
display_name="Search Type",
info="Search type to use",
options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"],
value="Similarity",
advanced=True,
),
FloatInput(
name="search_score_threshold",
display_name="Search Score Threshold",
info="Minimum similarity score threshold for search results. (when using 'Similarity with score threshold')",
value=0,
advanced=True,
),
DictInput(
name="search_filter",
display_name="Search Metadata Filter",
info="Optional dictionary of filters to apply to the search query.",
advanced=True,
is_list=True
),
MessageTextInput(
name="body_search",
display_name="Search Body",
info="Document textual search terms to apply to the search query.",
advanced=True,
),
BoolInput(
name="enable_body_search",
display_name="Enable Body Search",
info="Flag to enable body search. This must be enabled BEFORE the table is created.",
value=False,
advanced=True,
),
]

def build_vector_store(self) -> Cassandra:
Expand Down Expand Up @@ -148,6 +178,11 @@ def _build_cassandra(self) -> Cassandra:
else:
documents.append(_input)

if self.enable_body_search:
body_index_options = [("index_analyzer", "STANDARD")]
else:
body_index_options = None

if documents:
table = Cassandra.from_documents(
documents=documents,
Expand All @@ -156,7 +191,7 @@ def _build_cassandra(self) -> Cassandra:
keyspace=self.keyspace,
ttl_seconds=self.ttl_seconds,
batch_size=self.batch_size,
body_index_options=self.body_index_options,
body_index_options=body_index_options,
)

else:
Expand All @@ -165,21 +200,32 @@ def _build_cassandra(self) -> Cassandra:
table_name=self.table_name,
keyspace=self.keyspace,
ttl_seconds=self.ttl_seconds,
body_index_options=self.body_index_options,
body_index_options=body_index_options,
setup_mode=self.setup_mode,
)

return table

def _map_search_type(self):
if self.search_type == "Similarity with score threshold":
return "similarity_score_threshold"
elif self.search_type == "MMR (Max Marginal Relevance)":
return "mmr"
else:
return "similarity"

def search_documents(self) -> List[Data]:
vector_store = self._build_cassandra()


if self.search_query and isinstance(self.search_query, str) and self.search_query.strip():
try:
docs = vector_store.similarity_search(
query=self.search_query,
k=self.number_of_results,
)
search_type = self._map_search_type()
search_args = self._build_search_args()

docs = vector_store.search(query=self.search_query,
search_type=search_type,
**search_args)
except KeyError as e:
if "content" in str(e):
raise ValueError(
Expand All @@ -193,3 +239,30 @@ def search_documents(self) -> List[Data]:
return data
else:
return []

def _build_search_args(self):
args = {
"k": self.number_of_results,
"score_threshold": self.search_score_threshold,
}

if self.search_filter:
clean_filter = {
k: v for k, v in self.search_filter.items() if k and v
}
if len(clean_filter) > 0:
args["filter"] = clean_filter
if self.body_search:
if not self.enable_body_search:
raise ValueError(
"You should enable body search when creating the table to search the body field."
)
args["body_search"] = self.body_search
return args

def get_retriever_kwargs(self):
search_args = self._build_search_args()
return {
"search_type": self._map_search_type(),
"search_kwargs": search_args,
}

0 comments on commit e868454

Please sign in to comment.