Skip to content

Commit

Permalink
Merge pull request #38 from aiplanethub/vectordb
Browse files Browse the repository at this point in the history
Added vectordb config,chromadb
  • Loading branch information
tarun-aiplanet committed Apr 30, 2024
2 parents b6f00ac + b644c9e commit 29c598a
Show file tree
Hide file tree
Showing 10 changed files with 234 additions and 34 deletions.
4 changes: 2 additions & 2 deletions src/beyondllm/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from pydantic import BaseModel

class EmbeddingConfig(BaseModel):
"""Base configuration model for all LLMs.
"""Base configuration model for all Embeddings.
This class can be extended to include more fields specific to certain LLMs.
This class can be extended to include more fields specific to certain Embeddings.
"""
pass

Expand Down
14 changes: 8 additions & 6 deletions src/beyondllm/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from beyondllm.retrievers.utils import generate_qa_dataset, evaluate_from_dataset
import pandas as pd

def auto_retriever(data,embed_model=None,type="normal",top_k=4,**kwargs):
def auto_retriever(data=None,embed_model=None,type="normal",top_k=4,vectordb=None,**kwargs):
"""
Automatically selects and initializes a retriever based on the specified type.
Parameters:
Expand All @@ -16,6 +16,7 @@ def auto_retriever(data,embed_model=None,type="normal",top_k=4,**kwargs):
type (str): The type of retriever to use. Options include 'normal', 'flag-rerank',
'cross-rerank', and 'hybrid'. Defaults to 'normal'.
top_k (int): The number of top results to retrieve. Defaults to 4.
vectordb (VectorDb): The vectordb to use for retrieval
Additional parameters:
reranker: Name of the reranking model to be used. To be specified only for type = 'flag-rerank' and 'cross-rerank'
mode: Possible options are 'AND' or 'OR'. To be specified only for type = 'hybrid. 'AND' mode will retrieve nodes in common between
Expand All @@ -27,22 +28,23 @@ def auto_retriever(data,embed_model=None,type="normal",top_k=4,**kwargs):
data = <your dataset here>
embed_model = <pass your embed model here>
vector_store = <pass your vector-store object here>
retriever = auto_retriever(data=data, embed_model=embed_model, type="normal", top_k=5)
retriever = auto_retriever(data=data, embed_model=embed_model, type="normal", top_k=5, vectordb=vector_store)
"""
if embed_model is None:
embed_model = GeminiEmbeddings()
if type == 'normal':
retriever = NormalRetriever(data,embed_model,top_k,**kwargs)
retriever = NormalRetriever(data,embed_model,top_k,vectordb,**kwargs)
elif type == 'flag-rerank':
from .retrievers.flagReranker import FlagEmbeddingRerankRetriever
retriever = FlagEmbeddingRerankRetriever(data,embed_model,top_k,**kwargs)
retriever = FlagEmbeddingRerankRetriever(data,embed_model,top_k,vectordb,**kwargs)
elif type == 'cross-rerank':
from .retrievers.crossEncoderReranker import CrossEncoderRerankRetriever
retriever = CrossEncoderRerankRetriever(data,embed_model,top_k,**kwargs)
retriever = CrossEncoderRerankRetriever(data,embed_model,top_k,vectordb,**kwargs)
elif type == 'hybrid':
from .retrievers.hybridRetriever import HybridRetriever
retriever = HybridRetriever(data,embed_model,top_k,**kwargs)
retriever = HybridRetriever(data,embed_model,top_k,vectordb,**kwargs)
else:
raise NotImplementedError(f"Retriever for the type '{type}' is not implemented.")

Expand Down
4 changes: 3 additions & 1 deletion src/beyondllm/retrievers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ class BaseRetriever:
data: The dataset to be indexed or retrieved from.
embed_model: The embedding model used to generate embeddings for the data.
top_k: The top k similarity search results to be retrieved
vectordb: The vectordb to be used for retrieval
"""
def __init__(self, data, embed_model,**kwargs):
def __init__(self, data, embed_model, vectordb, **kwargs):
self.data = data
self.embed_model = embed_model
self.vectordb = vectordb

def load_index(self):
raise NotImplementedError("This method should be implemented by subclasses.")
Expand Down
34 changes: 29 additions & 5 deletions src/beyondllm/retrievers/crossEncoderReranker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from beyondllm.retrievers.base import BaseRetriever
from llama_index.core import VectorStoreIndex, ServiceContext
from llama_index.core import VectorStoreIndex, ServiceContext,StorageContext
from llama_index.core.schema import QueryBundle
import sys
import subprocess
Expand Down Expand Up @@ -45,10 +45,34 @@ def __init__(self, data, embed_model, top_k,*args, **kwargs):
self.reranker = kwargs.get('reranker',"cross-encoder/ms-marco-MiniLM-L-2-v2")

def load_index(self):
service_context = ServiceContext.from_defaults(llm=None, embed_model=self.embed_model)
index = VectorStoreIndex(
self.data, service_context= service_context,
)
if self.data is None:
index = self.initialize_from_vector_store()
else:
index = self.initialize_from_data()

return index

def initialize_from_vector_store(self):
if self.vectordb is None:
raise ValueError("Vector store must be provided if no data is passed")
else:
index = VectorStoreIndex.from_vector_store(
self.vectordb,
embed_model=self.embed_model,
)
return index


def initialize_from_data(self):
if self.vectordb==None:
index = VectorStoreIndex(
self.data, embed_model=self.embed_model
)
else:
storage_context = StorageContext.from_defaults(vector_store=self.vectordb)
index = VectorStoreIndex(
self.data, storage_context=storage_context, embed_model=self.embed_model
)
return index

def retrieve(self, query):
Expand Down
34 changes: 29 additions & 5 deletions src/beyondllm/retrievers/flagReranker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from beyondllm.retrievers.base import BaseRetriever
from llama_index.core import VectorStoreIndex, ServiceContext
from llama_index.core import VectorStoreIndex, ServiceContext, StorageContext
import sys
import subprocess
try:
Expand Down Expand Up @@ -44,10 +44,34 @@ def __init__(self, data, embed_model, top_k,*args, **kwargs):
self.reranker = kwargs.get('reranker',"BAAI/bge-reranker-large")

def load_index(self):
service_context = ServiceContext.from_defaults(llm=None, embed_model=self.embed_model)
index = VectorStoreIndex(
self.data, service_context= service_context,
)
if self.data is None:
index = self.initialize_from_vector_store()
else:
index = self.initialize_from_data()

return index

def initialize_from_vector_store(self):
if self.vectordb is None:
raise ValueError("Vector store must be provided if no data is passed")
else:
index = VectorStoreIndex.from_vector_store(
self.vectordb,
embed_model=self.embed_model,
)
return index


def initialize_from_data(self):
if self.vectordb==None:
index = VectorStoreIndex(
self.data, embed_model=self.embed_model
)
else:
storage_context = StorageContext.from_defaults(vector_store=self.vectordb)
index = VectorStoreIndex(
self.data, storage_context=storage_context, embed_model=self.embed_model
)
return index

def retrieve(self, query):
Expand Down
40 changes: 32 additions & 8 deletions src/beyondllm/retrievers/hybridRetriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,39 @@ def __init__(self, data, embed_model, top_k,*args, **kwargs):
raise ValueError("Invalid mode. Mode must be 'AND' or 'OR'.")

def load_index(self):
service_context = ServiceContext.from_defaults(llm=None, embed_model=self.embed_model)
storage_context = StorageContext.from_defaults()
vector_index = VectorStoreIndex(
self.data, service_context= service_context, storage_context=storage_context
)
keyword_index = SimpleKeywordTableIndex(
self.data,service_context=service_context,storage_context=storage_context
)
if self.data is None:
raise ValueError("Data needs to be passed for keyword retrieval.")
else:
vector_index, keyword_index = self.initialize_from_data()

return vector_index, keyword_index

def initialize_from_data(self):
if self.vectordb==None:
vector_index = VectorStoreIndex(
self.data, embed_model=self.embed_model
)
keyword_index = SimpleKeywordTableIndex(
self.data, service_context=ServiceContext.from_defaults(llm=None,embed_model=None)
)
else:
storage_context = StorageContext.from_defaults(vector_store=self.vectordb)
vector_index = VectorStoreIndex(
self.data, storage_context=storage_context, embed_model=self.embed_model
)
keyword_index = SimpleKeywordTableIndex(
self.data, service_context=ServiceContext.from_defaults(llm=None,embed_model=None)
)
return vector_index, keyword_index

# def load_index(self):
# vector_index = VectorStoreIndex(
# self.data, embed_model=self.embed_model
# )
# keyword_index = SimpleKeywordTableIndex(
# self.data, service_context=ServiceContext.from_defaults(llm=None,embed_model=None)
# )
# return vector_index, keyword_index

def as_retriever(self):
vector_index, keyword_index = self.load_index()
Expand Down
40 changes: 33 additions & 7 deletions src/beyondllm/retrievers/normalRetriever.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from beyondllm.retrievers.base import BaseRetriever
from llama_index.core import VectorStoreIndex, ServiceContext
from llama_index.core import VectorStoreIndex, ServiceContext, StorageContext

class NormalRetriever(BaseRetriever):
"""
Expand All @@ -14,27 +14,53 @@ class NormalRetriever(BaseRetriever):
results = retriever.retrieve("<your query>")
"""
def __init__(self, data, embed_model, top_k,*args, **kwargs):
def __init__(self, data, embed_model, top_k, vectordb,*args, **kwargs):
"""
Initializes a NormalRetriever instance.
Args:
data: The dataset to be indexed.
embed_model: The embedding model to use.
top_k: The number of top results to retrieve.
vectordb: The vectordb to use for retrieval
"""
super().__init__(data, embed_model,*args, **kwargs)
super().__init__(data, embed_model, vectordb,*args, **kwargs)
self.embed_model = embed_model
self.data = data
self.top_k = top_k
self.vectordb = vectordb

def load_index(self):
service_context = ServiceContext.from_defaults(llm=None, embed_model=self.embed_model)
index = VectorStoreIndex(
self.data, service_context= service_context
)
if self.data is None:
index = self.initialize_from_vector_store()
else:
index = self.initialize_from_data()

return index

def initialize_from_vector_store(self):
if self.vectordb is None:
raise ValueError("Vector store must be provided if no data is passed")
else:
index = VectorStoreIndex.from_vector_store(
self.vectordb,
embed_model=self.embed_model,
)
return index


def initialize_from_data(self):
if self.vectordb==None:
index = VectorStoreIndex(
self.data, embed_model=self.embed_model
)
else:
storage_context = StorageContext.from_defaults(vector_store=self.vectordb)
index = VectorStoreIndex(
self.data, storage_context=storage_context, embed_model=self.embed_model
)
return index

def retrieve(self, query):
retriever = self.as_retriever()
return retriever.retrieve(query)
Expand Down
1 change: 1 addition & 0 deletions src/beyondllm/vectordb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .chroma import ChromaVectorDb
24 changes: 24 additions & 0 deletions src/beyondllm/vectordb/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from pydantic import BaseModel

class VectorDbConfig(BaseModel):
"""Base configuration model for all LLMs.
This class can be extended to include more fields specific to certain LLMs.
"""
pass

class VectorDb(BaseModel):
def load(self):
raise NotImplementedError("This method should be implemented by subclasses.")

def add(self,*args, **kwargs):
raise NotImplementedError("This method should be implemented by subclasses.")

def stores_text(self,*args, **kwargs):
raise NotImplementedError("This method should be implemented by subclasses.")

def is_embedding_query(self,*args, **kwargs):
raise NotImplementedError("This method should be implemented by subclasses.")

def query(self,*args, **kwargs):
raise NotImplementedError("This method should be implemented by subclasses.")
73 changes: 73 additions & 0 deletions src/beyondllm/vectordb/chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from beyondllm.vectordb.base import VectorDb, VectorDbConfig
from dataclasses import dataclass, field
import warnings
warnings.filterwarnings("ignore")
import subprocess,sys
try:
from llama_index.vector_stores.chroma import ChromaVectorStore
except ImportError:
user_agree = input("The feature you're trying to use requires an additional library(s):llama_index.vector_stores.chroma. Would you like to install it now? [y/N]: ")
if user_agree.lower() == 'y':
subprocess.check_call([sys.executable, "-m", "pip", "install", "llama_index.vector_stores.chroma"])
from llama_index.vector_stores.chroma import ChromaVectorStore
else:
raise ImportError("The required 'llama_index.vector_stores.chroma' is not installed.")
import chromadb

@dataclass
class ChromaVectorDb:
"""
from beyondllm.vectordb import ChromaVectorDb
vectordb = ChromaVectorDb(collection_name="quickstart",persist_directory="./db/chroma/")
"""
collection_name: str
persist_directory: str = ""

def __post_init__(self):
if self.persist_directory=="" or self.persist_directory==None:
self.chroma_client = chromadb.EphemeralClient()
else:
self.chroma_client = chromadb.PersistentClient(self.persist_directory)
self.load()

def load(self):
try:
from llama_index.vector_stores.chroma import ChromaVectorStore
except:
raise ImportError("ChromaVectorStore library is not installed. Please install it with ``pip install llama_index.vector_stores.chroma``.")

# More clarity and specificity required for try error statements
try:
try:
chroma_collection = self.chroma_client.get_collection(self.collection_name)
except Exception:
chroma_collection = self.chroma_client.create_collection(self.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
self.client = vector_store
except Exception as e:
raise Exception(f"Failed to load the Chroma Vectorstore: {e}")

return self.client

def add(self,*args, **kwargs):
client = self.client
return client.add(*args, **kwargs)

def stores_text(self,*args, **kwargs):
client = self.client
return client.stores_text(*args, **kwargs)

def is_embedding_query(self,*args, **kwargs):
client = self.client
return client.is_embedding_query(*args, **kwargs)

def query(self,*args, **kwargs):
client = self.client
return client.query(*args, **kwargs)


@staticmethod
def load_from_kwargs(self,kwargs):
embed_config = VectorDbConfig(**kwargs)
self.config = embed_config
self.load()

0 comments on commit 29c598a

Please sign in to comment.