@@ -32,10 +32,11 @@ class Collection:
32
32
client: The PGVector client.
33
33
collection_name (str): The name of the collection. Default is "documents".
34
34
embedding_function (Callable): The embedding function used to generate the vector representation.
35
+ Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None.
36
+ Models can be chosen from:
37
+ https://huggingface.co/models?library=sentence-transformers
35
38
metadata (Optional[dict]): The metadata of the collection.
36
39
get_or_create (Optional): The flag indicating whether to get or create the collection.
37
- model_name: (Optional str) | Sentence embedding model to use. Models can be chosen from:
38
- https://huggingface.co/models?library=sentence-transformers
39
40
"""
40
41
41
42
def __init__ (
@@ -45,7 +46,6 @@ def __init__(
45
46
embedding_function : Callable = None ,
46
47
metadata = None ,
47
48
get_or_create = None ,
48
- model_name = "all-MiniLM-L6-v2" ,
49
49
):
50
50
"""
51
51
Initialize the Collection object.
@@ -56,30 +56,26 @@ def __init__(
56
56
embedding_function: The embedding function used to generate the vector representation.
57
57
metadata: The metadata of the collection.
58
58
get_or_create: The flag indicating whether to get or create the collection.
59
- model_name: | Sentence embedding model to use. Models can be chosen from:
60
- https://huggingface.co/models?library=sentence-transformers
61
59
Returns:
62
60
None
63
61
"""
64
62
self .client = client
65
- self .embedding_function = embedding_function
66
- self .model_name = model_name
67
63
self .name = self .set_collection_name (collection_name )
68
64
self .require_embeddings_or_documents = False
69
65
self .ids = []
70
- try :
71
- self .embedding_function = (
72
- SentenceTransformer (self .model_name ) if embedding_function is None else embedding_function
73
- )
74
- except Exception as e :
75
- logger .error (
76
- f"Validate the model name entered: { self .model_name } "
77
- f"from https://huggingface.co/models?library=sentence-transformers\n Error: { e } "
78
- )
79
- raise e
66
+ if embedding_function :
67
+ self .embedding_function = embedding_function
68
+ else :
69
+ self .embedding_function = SentenceTransformer ("all-MiniLM-L6-v2" ).encode
80
70
self .metadata = metadata if metadata else {"hnsw:space" : "ip" , "hnsw:construction_ef" : 32 , "hnsw:M" : 16 }
81
71
self .documents = ""
82
72
self .get_or_create = get_or_create
73
+ # This will get the model dimension size by computing the embeddings dimensions
74
+ sentences = [
75
+ "The weather is lovely today in paradise." ,
76
+ ]
77
+ embeddings = self .embedding_function (sentences )
78
+ self .dimension = len (embeddings [0 ])
83
79
84
80
def set_collection_name (self , collection_name ) -> str :
85
81
name = re .sub ("-" , "_" , collection_name )
@@ -115,14 +111,14 @@ def add(self, ids: List[ItemID], documents: List, embeddings: List = None, metad
115
111
elif metadatas is not None :
116
112
for doc_id , metadata , document in zip (ids , metadatas , documents ):
117
113
metadata = re .sub ("'" , '"' , str (metadata ))
118
- embedding = self .embedding_function . encode (document )
114
+ embedding = self .embedding_function (document )
119
115
sql_values .append ((doc_id , metadata , embedding , document ))
120
116
sql_string = (
121
117
f"INSERT INTO { self .name } (id, metadatas, embedding, documents)\n " f"VALUES (%s, %s, %s, %s);\n "
122
118
)
123
119
else :
124
120
for doc_id , document in zip (ids , documents ):
125
- embedding = self .embedding_function . encode (document )
121
+ embedding = self .embedding_function (document )
126
122
sql_values .append ((doc_id , document , embedding ))
127
123
sql_string = f"INSERT INTO { self .name } (id, documents, embedding)\n " f"VALUES (%s, %s, %s);\n "
128
124
logger .debug (f"Add SQL String:\n { sql_string } \n { sql_values } " )
@@ -166,7 +162,7 @@ def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, me
166
162
elif metadatas is not None :
167
163
for doc_id , metadata , document in zip (ids , metadatas , documents ):
168
164
metadata = re .sub ("'" , '"' , str (metadata ))
169
- embedding = self .embedding_function . encode (document )
165
+ embedding = self .embedding_function (document )
170
166
sql_values .append ((doc_id , metadata , embedding , document , metadata , document , embedding ))
171
167
sql_string = (
172
168
f"INSERT INTO { self .name } (id, metadatas, embedding, documents)\n "
@@ -176,7 +172,7 @@ def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, me
176
172
)
177
173
else :
178
174
for doc_id , document in zip (ids , documents ):
179
- embedding = self .embedding_function . encode (document )
175
+ embedding = self .embedding_function (document )
180
176
sql_values .append ((doc_id , document , embedding , document ))
181
177
sql_string = (
182
178
f"INSERT INTO { self .name } (id, documents, embedding)\n "
@@ -304,7 +300,7 @@ def get(
304
300
)
305
301
except (psycopg .errors .UndefinedTable , psycopg .errors .UndefinedColumn ) as e :
306
302
logger .info (f"Error executing select on non-existent table: { self .name } . Creating it instead. Error: { e } " )
307
- self .create_collection (collection_name = self .name )
303
+ self .create_collection (collection_name = self .name , dimension = self . dimension )
308
304
logger .info (f"Created table { self .name } " )
309
305
310
306
cursor .close ()
@@ -419,7 +415,7 @@ def query(
419
415
cursor = self .client .cursor ()
420
416
results = []
421
417
for query_text in query_texts :
422
- vector = self .embedding_function . encode (query_text , convert_to_tensor = False ).tolist ()
418
+ vector = self .embedding_function (query_text , convert_to_tensor = False ).tolist ()
423
419
if distance_type .lower () == "cosine" :
424
420
index_function = "<=>"
425
421
elif distance_type .lower () == "euclidean" :
@@ -526,22 +522,31 @@ def delete_collection(self, collection_name: Optional[str] = None) -> None:
526
522
cursor .execute (f"DROP TABLE IF EXISTS { self .name } " )
527
523
cursor .close ()
528
524
529
- def create_collection (self , collection_name : Optional [str ] = None ) -> None :
525
+ def create_collection (
526
+ self , collection_name : Optional [str ] = None , dimension : Optional [Union [str , int ]] = None
527
+ ) -> None :
530
528
"""
531
529
Create a new collection.
532
530
533
531
Args:
534
532
collection_name (Optional[str]): The name of the new collection.
533
+ dimension (Optional[Union[str, int]]): The dimension size of the sentence embedding model
535
534
536
535
Returns:
537
536
None
538
537
"""
539
538
if collection_name :
540
539
self .name = collection_name
540
+
541
+ if dimension :
542
+ self .dimension = dimension
543
+ elif self .dimension is None :
544
+ self .dimension = 384
545
+
541
546
cursor = self .client .cursor ()
542
547
cursor .execute (
543
548
f"CREATE TABLE { self .name } ("
544
- f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector(384 ));"
549
+ f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector({ self . dimension } ));"
545
550
f"CREATE INDEX "
546
551
f'ON { self .name } USING hnsw (embedding vector_l2_ops) WITH (m = { self .metadata ["hnsw:M" ]} , '
547
552
f'ef_construction = { self .metadata ["hnsw:construction_ef" ]} );'
@@ -573,7 +578,6 @@ def __init__(
573
578
connect_timeout : Optional [int ] = 10 ,
574
579
embedding_function : Callable = None ,
575
580
metadata : Optional [dict ] = None ,
576
- model_name : Optional [str ] = "all-MiniLM-L6-v2" ,
577
581
) -> None :
578
582
"""
579
583
Initialize the vector database.
@@ -591,15 +595,14 @@ def __init__(
591
595
username: str | The database username to use. Default is None.
592
596
password: str | The database user password to use. Default is None.
593
597
connect_timeout: int | The timeout to set for the connection. Default is 10.
594
- embedding_function: Callable | The embedding function used to generate the vector representation
595
- of the documents. Default is None.
598
+ embedding_function: Callable | The embedding function used to generate the vector representation.
599
+ Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None.
600
+ Models can be chosen from:
601
+ https://huggingface.co/models?library=sentence-transformers
596
602
metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
597
603
setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}. Creates Index on table
598
604
using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef".
599
605
For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
600
- model_name: str | Sentence embedding model to use. Models can be chosen from:
601
- https://huggingface.co/models?library=sentence-transformers
602
-
603
606
Returns:
604
607
None
605
608
"""
@@ -613,17 +616,10 @@ def __init__(
613
616
password = password ,
614
617
connect_timeout = connect_timeout ,
615
618
)
616
- self .model_name = model_name
617
- try :
618
- self .embedding_function = (
619
- SentenceTransformer (self .model_name ) if embedding_function is None else embedding_function
620
- )
621
- except Exception as e :
622
- logger .error (
623
- f"Validate the model name entered: { self .model_name } "
624
- f"from https://huggingface.co/models?library=sentence-transformers\n Error: { e } "
625
- )
626
- raise e
619
+ if embedding_function :
620
+ self .embedding_function = embedding_function
621
+ else :
622
+ self .embedding_function = SentenceTransformer ("all-MiniLM-L6-v2" ).encode
627
623
self .metadata = metadata
628
624
register_vector (self .client )
629
625
self .active_collection = None
@@ -738,7 +734,6 @@ def create_collection(
738
734
embedding_function = self .embedding_function ,
739
735
get_or_create = get_or_create ,
740
736
metadata = self .metadata ,
741
- model_name = self .model_name ,
742
737
)
743
738
collection .set_collection_name (collection_name = collection_name )
744
739
collection .create_collection (collection_name = collection_name )
@@ -751,7 +746,6 @@ def create_collection(
751
746
embedding_function = self .embedding_function ,
752
747
get_or_create = get_or_create ,
753
748
metadata = self .metadata ,
754
- model_name = self .model_name ,
755
749
)
756
750
collection .set_collection_name (collection_name = collection_name )
757
751
collection .create_collection (collection_name = collection_name )
@@ -765,7 +759,6 @@ def create_collection(
765
759
embedding_function = self .embedding_function ,
766
760
get_or_create = get_or_create ,
767
761
metadata = self .metadata ,
768
- model_name = self .model_name ,
769
762
)
770
763
collection .set_collection_name (collection_name = collection_name )
771
764
collection .create_collection (collection_name = collection_name )
@@ -797,7 +790,6 @@ def get_collection(self, collection_name: str = None) -> Collection:
797
790
client = self .client ,
798
791
collection_name = collection_name ,
799
792
embedding_function = self .embedding_function ,
800
- model_name = self .model_name ,
801
793
)
802
794
return self .active_collection
803
795
0 commit comments