Skip to content

Commit bf7e4d6

Browse files
Bugfix: PGVector/RAG - Calculate the Vector Size based on Model Dimensions (#2865)
* Calculate the dimension size based off model chosen. * Added example docstring. * Validated working notebook with sentence models of different dimensions. * Validated removal of model_name working. * Second example uses conn object. * embedding_function no longer directly references .encode * Fixed pre-commit issue. * Use try/except to raise error when shape is not found in embedding function. * Re-ran notebook. * Update autogen/agentchat/contrib/vectordb/pgvectordb.py Co-authored-by: Li Jiang <[email protected]> * Update autogen/agentchat/contrib/vectordb/pgvectordb.py Co-authored-by: Li Jiang <[email protected]> * Added .encode * Removed example comment. * Fix overwrite doesn't work with existing collection when custom embedding function has different dimension from default one --------- Co-authored-by: Li Jiang <[email protected]>
1 parent 2d6c8c0 commit bf7e4d6

File tree

3 files changed

+447
-90
lines changed

3 files changed

+447
-90
lines changed

autogen/agentchat/contrib/vectordb/pgvectordb.py

+38-46
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ class Collection:
3232
client: The PGVector client.
3333
collection_name (str): The name of the collection. Default is "documents".
3434
embedding_function (Callable): The embedding function used to generate the vector representation.
35+
Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None.
36+
Models can be chosen from:
37+
https://huggingface.co/models?library=sentence-transformers
3538
metadata (Optional[dict]): The metadata of the collection.
3639
get_or_create (Optional): The flag indicating whether to get or create the collection.
37-
model_name: (Optional str) | Sentence embedding model to use. Models can be chosen from:
38-
https://huggingface.co/models?library=sentence-transformers
3940
"""
4041

4142
def __init__(
@@ -45,7 +46,6 @@ def __init__(
4546
embedding_function: Callable = None,
4647
metadata=None,
4748
get_or_create=None,
48-
model_name="all-MiniLM-L6-v2",
4949
):
5050
"""
5151
Initialize the Collection object.
@@ -56,30 +56,26 @@ def __init__(
5656
embedding_function: The embedding function used to generate the vector representation.
5757
metadata: The metadata of the collection.
5858
get_or_create: The flag indicating whether to get or create the collection.
59-
model_name: | Sentence embedding model to use. Models can be chosen from:
60-
https://huggingface.co/models?library=sentence-transformers
6159
Returns:
6260
None
6361
"""
6462
self.client = client
65-
self.embedding_function = embedding_function
66-
self.model_name = model_name
6763
self.name = self.set_collection_name(collection_name)
6864
self.require_embeddings_or_documents = False
6965
self.ids = []
70-
try:
71-
self.embedding_function = (
72-
SentenceTransformer(self.model_name) if embedding_function is None else embedding_function
73-
)
74-
except Exception as e:
75-
logger.error(
76-
f"Validate the model name entered: {self.model_name} "
77-
f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}"
78-
)
79-
raise e
66+
if embedding_function:
67+
self.embedding_function = embedding_function
68+
else:
69+
self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode
8070
self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
8171
self.documents = ""
8272
self.get_or_create = get_or_create
73+
# This will get the model dimension size by computing the embeddings dimensions
74+
sentences = [
75+
"The weather is lovely today in paradise.",
76+
]
77+
embeddings = self.embedding_function(sentences)
78+
self.dimension = len(embeddings[0])
8379

8480
def set_collection_name(self, collection_name) -> str:
8581
name = re.sub("-", "_", collection_name)
@@ -115,14 +111,14 @@ def add(self, ids: List[ItemID], documents: List, embeddings: List = None, metad
115111
elif metadatas is not None:
116112
for doc_id, metadata, document in zip(ids, metadatas, documents):
117113
metadata = re.sub("'", '"', str(metadata))
118-
embedding = self.embedding_function.encode(document)
114+
embedding = self.embedding_function(document)
119115
sql_values.append((doc_id, metadata, embedding, document))
120116
sql_string = (
121117
f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n" f"VALUES (%s, %s, %s, %s);\n"
122118
)
123119
else:
124120
for doc_id, document in zip(ids, documents):
125-
embedding = self.embedding_function.encode(document)
121+
embedding = self.embedding_function(document)
126122
sql_values.append((doc_id, document, embedding))
127123
sql_string = f"INSERT INTO {self.name} (id, documents, embedding)\n" f"VALUES (%s, %s, %s);\n"
128124
logger.debug(f"Add SQL String:\n{sql_string}\n{sql_values}")
@@ -166,7 +162,7 @@ def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, me
166162
elif metadatas is not None:
167163
for doc_id, metadata, document in zip(ids, metadatas, documents):
168164
metadata = re.sub("'", '"', str(metadata))
169-
embedding = self.embedding_function.encode(document)
165+
embedding = self.embedding_function(document)
170166
sql_values.append((doc_id, metadata, embedding, document, metadata, document, embedding))
171167
sql_string = (
172168
f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n"
@@ -176,7 +172,7 @@ def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, me
176172
)
177173
else:
178174
for doc_id, document in zip(ids, documents):
179-
embedding = self.embedding_function.encode(document)
175+
embedding = self.embedding_function(document)
180176
sql_values.append((doc_id, document, embedding, document))
181177
sql_string = (
182178
f"INSERT INTO {self.name} (id, documents, embedding)\n"
@@ -304,7 +300,7 @@ def get(
304300
)
305301
except (psycopg.errors.UndefinedTable, psycopg.errors.UndefinedColumn) as e:
306302
logger.info(f"Error executing select on non-existent table: {self.name}. Creating it instead. Error: {e}")
307-
self.create_collection(collection_name=self.name)
303+
self.create_collection(collection_name=self.name, dimension=self.dimension)
308304
logger.info(f"Created table {self.name}")
309305

310306
cursor.close()
@@ -419,7 +415,7 @@ def query(
419415
cursor = self.client.cursor()
420416
results = []
421417
for query_text in query_texts:
422-
vector = self.embedding_function.encode(query_text, convert_to_tensor=False).tolist()
418+
vector = self.embedding_function(query_text, convert_to_tensor=False).tolist()
423419
if distance_type.lower() == "cosine":
424420
index_function = "<=>"
425421
elif distance_type.lower() == "euclidean":
@@ -526,22 +522,31 @@ def delete_collection(self, collection_name: Optional[str] = None) -> None:
526522
cursor.execute(f"DROP TABLE IF EXISTS {self.name}")
527523
cursor.close()
528524

529-
def create_collection(self, collection_name: Optional[str] = None) -> None:
525+
def create_collection(
526+
self, collection_name: Optional[str] = None, dimension: Optional[Union[str, int]] = None
527+
) -> None:
530528
"""
531529
Create a new collection.
532530
533531
Args:
534532
collection_name (Optional[str]): The name of the new collection.
533+
dimension (Optional[Union[str, int]]): The dimension size of the sentence embedding model
535534
536535
Returns:
537536
None
538537
"""
539538
if collection_name:
540539
self.name = collection_name
540+
541+
if dimension:
542+
self.dimension = dimension
543+
elif self.dimension is None:
544+
self.dimension = 384
545+
541546
cursor = self.client.cursor()
542547
cursor.execute(
543548
f"CREATE TABLE {self.name} ("
544-
f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector(384));"
549+
f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector({self.dimension}));"
545550
f"CREATE INDEX "
546551
f'ON {self.name} USING hnsw (embedding vector_l2_ops) WITH (m = {self.metadata["hnsw:M"]}, '
547552
f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
@@ -573,7 +578,6 @@ def __init__(
573578
connect_timeout: Optional[int] = 10,
574579
embedding_function: Callable = None,
575580
metadata: Optional[dict] = None,
576-
model_name: Optional[str] = "all-MiniLM-L6-v2",
577581
) -> None:
578582
"""
579583
Initialize the vector database.
@@ -591,15 +595,14 @@ def __init__(
591595
username: str | The database username to use. Default is None.
592596
password: str | The database user password to use. Default is None.
593597
connect_timeout: int | The timeout to set for the connection. Default is 10.
594-
embedding_function: Callable | The embedding function used to generate the vector representation
595-
of the documents. Default is None.
598+
embedding_function: Callable | The embedding function used to generate the vector representation.
599+
Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None.
600+
Models can be chosen from:
601+
https://huggingface.co/models?library=sentence-transformers
596602
metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
597603
setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}. Creates Index on table
598604
using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef".
599605
For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
600-
model_name: str | Sentence embedding model to use. Models can be chosen from:
601-
https://huggingface.co/models?library=sentence-transformers
602-
603606
Returns:
604607
None
605608
"""
@@ -613,17 +616,10 @@ def __init__(
613616
password=password,
614617
connect_timeout=connect_timeout,
615618
)
616-
self.model_name = model_name
617-
try:
618-
self.embedding_function = (
619-
SentenceTransformer(self.model_name) if embedding_function is None else embedding_function
620-
)
621-
except Exception as e:
622-
logger.error(
623-
f"Validate the model name entered: {self.model_name} "
624-
f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}"
625-
)
626-
raise e
619+
if embedding_function:
620+
self.embedding_function = embedding_function
621+
else:
622+
self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode
627623
self.metadata = metadata
628624
register_vector(self.client)
629625
self.active_collection = None
@@ -738,7 +734,6 @@ def create_collection(
738734
embedding_function=self.embedding_function,
739735
get_or_create=get_or_create,
740736
metadata=self.metadata,
741-
model_name=self.model_name,
742737
)
743738
collection.set_collection_name(collection_name=collection_name)
744739
collection.create_collection(collection_name=collection_name)
@@ -751,7 +746,6 @@ def create_collection(
751746
embedding_function=self.embedding_function,
752747
get_or_create=get_or_create,
753748
metadata=self.metadata,
754-
model_name=self.model_name,
755749
)
756750
collection.set_collection_name(collection_name=collection_name)
757751
collection.create_collection(collection_name=collection_name)
@@ -765,7 +759,6 @@ def create_collection(
765759
embedding_function=self.embedding_function,
766760
get_or_create=get_or_create,
767761
metadata=self.metadata,
768-
model_name=self.model_name,
769762
)
770763
collection.set_collection_name(collection_name=collection_name)
771764
collection.create_collection(collection_name=collection_name)
@@ -797,7 +790,6 @@ def get_collection(self, collection_name: str = None) -> Collection:
797790
client=self.client,
798791
collection_name=collection_name,
799792
embedding_function=self.embedding_function,
800-
model_name=self.model_name,
801793
)
802794
return self.active_collection
803795

0 commit comments

Comments
 (0)