From 0bfc015389f9d24d8244910417a352228849aa56 Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Sat, 12 Oct 2024 22:54:45 +0100 Subject: [PATCH 1/5] Add QDrant --- docker-compose.yaml | 12 ++- examples/qdrant/main.go | 81 +++++++++++++++++ go.mod | 17 ++-- pkg/db/qdrant.go | 188 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 292 insertions(+), 6 deletions(-) create mode 100644 examples/qdrant/main.go create mode 100644 pkg/db/qdrant.go diff --git a/docker-compose.yaml b/docker-compose.yaml index 3162143..c450078 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,4 +1,4 @@ -version: '3.8' +version: '3.9' # Requires version 3.9+ for profiles services: postgres: @@ -13,6 +13,16 @@ services: volumes: - pgdata:/var/lib/postgresql/data - ./db/init.sql:/docker-entrypoint-initdb.d/init.sql + profiles: + - postgres # Enable this service only when the 'postgres' profile is active + + qdrant: + image: qdrant/qdrant + container_name: qdrant-db + ports: + - "6334:6334" + profiles: + - qdrant # Enable this service only when the 'qdrant' profile is active volumes: pgdata: diff --git a/examples/qdrant/main.go b/examples/qdrant/main.go new file mode 100644 index 0000000..e11b498 --- /dev/null +++ b/examples/qdrant/main.go @@ -0,0 +1,81 @@ +package main + +import ( + "context" + "fmt" + "github.com/google/uuid" + "github.com/stackloklabs/gollm/pkg/db" + "log" + "time" +) + +func main() { + qdrantVector, err := db.NewQdrantVector("localhost", 6334) + if err != nil { + log.Fatalf("Failed to connect to Qdrant: %v", err) + } + defer qdrantVector.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + // Create the collection + err = CreateCollection(ctx, qdrantVector) + if err != nil { + log.Fatalf("Failed to create collection: %v", err) + } + + // Example embedding and content + embedding := []float32{0.05, 0.61, 0.76, 0.74} + content := "This is a test document." + + // Insert the document into the Qdrant vector store + err = QDrantInsertDocument(ctx, qdrantVector, content, embedding) + if err != nil { + log.Fatalf("Failed to insert document: %v", err) + } + log.Println("Document inserted successfully.") + + // Query the most relevant documents based on a given embedding + docs, err := qdrantVector.QueryRelevantDocuments(ctx, embedding, 5) + if err != nil { + log.Fatalf("Failed to query documents: %v", err) + } + + // Print out the results + for _, doc := range docs { + log.Printf("Document ID: %s, Content: %v\n", doc.ID, doc.Metadata["content"]) + } +} + +// CreateCollection creates a new collection in Qdrant +func CreateCollection(ctx context.Context, vectorDB *db.QdrantVector) error { + collectionName := "gollm" // Replace with your collection name + vectorSize := uint64(4) // Size of the embedding vectors + distance := "Cosine" // Distance metric (Cosine, Euclidean, etc.) + + // Call Qdrant's API to create the collection + err := vectorDB.CreateCollection(ctx, collectionName, vectorSize, distance) + if err != nil { + return fmt.Errorf("error creating collection: %v", err) + } + return nil +} + +// QDrantInsertDocument inserts a document into the Qdrant vector store. +func QDrantInsertDocument(ctx context.Context, vectorDB *db.QdrantVector, content string, embedding []float32) error { + // Generate a valid UUID for the document ID + docID := uuid.New().String() // Use pure UUID without the 'doc-' prefix + + // Create metadata for the document + metadata := map[string]interface{}{ + "content": content, + } + + // Save the document and its embedding + err := vectorDB.SaveEmbedding(ctx, docID, embedding, metadata) + if err != nil { + return fmt.Errorf("error saving embedding: %v", err) + } + return nil +} diff --git a/go.mod b/go.mod index 15948f8..1149756 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,14 @@ module github.com/stackloklabs/gollm -go 1.22.1 +go 1.22.2 + +toolchain go1.22.8 require ( github.com/google/uuid v1.6.0 github.com/jackc/pgx/v4 v4.18.3 github.com/pgvector/pgvector-go v0.2.2 + github.com/qdrant/go-client v1.12.0 ) require ( @@ -18,9 +21,13 @@ require ( github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgtype v1.14.0 // indirect github.com/jackc/puddle v1.3.0 // indirect - github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/stretchr/testify v1.9.0 // indirect - golang.org/x/crypto v0.25.0 // indirect - golang.org/x/text v0.16.0 // indirect + github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect + golang.org/x/crypto v0.27.0 // indirect + golang.org/x/net v0.29.0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/text v0.18.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect + google.golang.org/grpc v1.66.2 // indirect + google.golang.org/protobuf v1.34.2 // indirect ) diff --git a/pkg/db/qdrant.go b/pkg/db/qdrant.go new file mode 100644 index 0000000..4105875 --- /dev/null +++ b/pkg/db/qdrant.go @@ -0,0 +1,188 @@ +package db + +import ( + "context" + "fmt" + "github.com/qdrant/go-client/qdrant" +) + +// QdrantVector represents a connection to Qdrant. +type QdrantVector struct { + client *qdrant.Client +} + +// QDrantDocument represents a document stored in the Qdrant vector database. +type QDrantDocument struct { + ID string + Metadata map[string]interface{} +} + +// NewQdrantVector initializes a connection to Qdrant. +// +// Parameters: +// - address: The Qdrant server address (e.g., "localhost"). +// - port: The port Qdrant is running on (e.g., 6333). +// +// Returns: +// - A pointer to a new QdrantVector instance. +// - An error if the connection fails, nil otherwise. +func NewQdrantVector(address string, port int) (*QdrantVector, error) { + client, err := qdrant.NewClient(&qdrant.Config{ + Host: address, + Port: port, + }) + if err != nil { + return nil, fmt.Errorf("failed to connect to Qdrant: %w", err) + } + return &QdrantVector{client: client}, nil +} + +// Close closes the Qdrant client connection. +func (qv *QdrantVector) Close() { + qv.client.Close() +} + +// SaveEmbedding stores an embedding and metadata in Qdrant. +// +// Parameters: +// - ctx: Context for the operation. +// - docID: A unique identifier for the document. +// - embedding: A slice of float32 values representing the document's embedding. +// - metadata: A map of additional information associated with the document. +// +// Returns: +// - An error if the saving operation fails, nil otherwise. +func (qv *QdrantVector) SaveEmbedding(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}) error { + point := &qdrant.PointStruct{ + Id: qdrant.NewID(docID), + Vectors: qdrant.NewVectors(embedding...), + Payload: qdrant.NewValueMap(metadata), + } + + waitUpsert := true + _, err := qv.client.Upsert(ctx, &qdrant.UpsertPoints{ + CollectionName: "gollm", // Replace with actual collection name + Wait: &waitUpsert, + Points: []*qdrant.PointStruct{point}, + }) + if err != nil { + return fmt.Errorf("failed to insert point: %w", err) + } + return nil +} + +// QueryRelevantDocuments retrieves the most relevant documents based on a given embedding. +// +// Parameters: +// - ctx: The context for the query. +// - embedding: The query embedding. +// - limit: The number of documents to return. +// +// Returns: +// - A slice of QDrantDocument structs containing the most relevant documents. +// - An error if the query fails. +func (qv *QdrantVector) QueryRelevantDocuments(ctx context.Context, embedding []float32, limit int) ([]QDrantDocument, error) { + limitUint := uint64(limit) // Convert limit to uint64 + query := &qdrant.QueryPoints{ + CollectionName: "gollm", // Replace with actual collection name + Query: qdrant.NewQuery(embedding...), + Limit: &limitUint, + WithPayload: qdrant.NewWithPayloadInclude("content"), + } + + response, err := qv.client.Query(ctx, query) + if err != nil { + return nil, fmt.Errorf("failed to search points: %w", err) + } + + var docs []QDrantDocument + for _, point := range response { + var docID string + if numericID := point.Id.GetNum(); numericID != 0 { + docID = fmt.Sprintf("%d", numericID) // Numeric ID + } else { + docID = point.Id.GetUuid() // UUID + } + metadata := convertPayloadToMap(point.Payload) + doc := QDrantDocument{ + ID: docID, + Metadata: metadata, + } + docs = append(docs, doc) + } + return docs, nil +} + +// convertPayloadToMap converts a Qdrant Payload (map[string]*qdrant.Value) into a map[string]interface{}. +func convertPayloadToMap(payload map[string]*qdrant.Value) map[string]interface{} { + result := make(map[string]interface{}) + for key, value := range payload { + switch v := value.Kind.(type) { + case *qdrant.Value_StringValue: + result[key] = v.StringValue + case *qdrant.Value_BoolValue: + result[key] = v.BoolValue + case *qdrant.Value_DoubleValue: + result[key] = v.DoubleValue + case *qdrant.Value_ListValue: + var list []interface{} + for _, item := range v.ListValue.Values { + switch itemVal := item.Kind.(type) { + case *qdrant.Value_StringValue: + list = append(list, itemVal.StringValue) + case *qdrant.Value_BoolValue: + list = append(list, itemVal.BoolValue) + case *qdrant.Value_DoubleValue: + list = append(list, itemVal.DoubleValue) + } + } + result[key] = list + default: + result[key] = nil + } + } + return result +} + +// QDrantInsertDocument inserts a document into the Qdrant vector store. +// +// Parameters: +// - ctx: Context for the operation. +// - vectorDB: A QdrantVector instance. +// - content: The document content to be inserted. +// - embedding: The embedding vector for the document. +// +// Returns: +// - An error if the operation fails, nil otherwise. +func QDrantInsertDocument(ctx context.Context, vectorDB *QdrantVector, content string, embedding []float32) error { + // Generate a unique document ID (e.g., using UUID) + docID := fmt.Sprintf("doc-%s", qdrant.NewID("")) + + // Create metadata for the document + metadata := map[string]interface{}{ + "content": content, + } + + // Save the document and its embedding + err := vectorDB.SaveEmbedding(ctx, docID, embedding, metadata) + if err != nil { + return fmt.Errorf("error saving embedding: %v", err) + } + return nil +} + +// CreateCollection creates a new collection in Qdrant +func (qv *QdrantVector) CreateCollection(ctx context.Context, collectionName string, vectorSize uint64, distance string) error { + // Create the collection + err := qv.client.CreateCollection(ctx, &qdrant.CreateCollection{ + CollectionName: collectionName, + VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ + Size: vectorSize, + Distance: qdrant.Distance_Cosine, // Example: Cosine distance + }), + }) + if err != nil { + return fmt.Errorf("failed to create collection: %w", err) + } + return nil +} From 41388d2c54eb699a1739893649471170d9fc859e Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Sun, 13 Oct 2024 09:08:09 +0100 Subject: [PATCH 2/5] Add QDrant --- examples/qdrant/main.go | 13 ++++++++----- pkg/db/pgvector.go | 8 +------- pkg/db/qdrant.go | 39 ++++++++++++++++++++++++++------------- pkg/db/vectordb.go | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 25 deletions(-) create mode 100644 pkg/db/vectordb.go diff --git a/examples/qdrant/main.go b/examples/qdrant/main.go index e11b498..16016e7 100644 --- a/examples/qdrant/main.go +++ b/examples/qdrant/main.go @@ -10,16 +10,19 @@ import ( ) func main() { + // Initialize Qdrant vector connection qdrantVector, err := db.NewQdrantVector("localhost", 6334) if err != nil { log.Fatalf("Failed to connect to Qdrant: %v", err) } + // Defer the Close() method to ensure the connection is properly closed after use defer qdrantVector.Close() + // Set up a context with a timeout for the Qdrant operations ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - // Create the collection + // Create the collection in Qdrant err = CreateCollection(ctx, qdrantVector) if err != nil { log.Fatalf("Failed to create collection: %v", err) @@ -42,7 +45,7 @@ func main() { log.Fatalf("Failed to query documents: %v", err) } - // Print out the results + // Print out the retrieved documents for _, doc := range docs { log.Printf("Document ID: %s, Content: %v\n", doc.ID, doc.Metadata["content"]) } @@ -50,9 +53,9 @@ func main() { // CreateCollection creates a new collection in Qdrant func CreateCollection(ctx context.Context, vectorDB *db.QdrantVector) error { - collectionName := "gollm" // Replace with your collection name - vectorSize := uint64(4) // Size of the embedding vectors - distance := "Cosine" // Distance metric (Cosine, Euclidean, etc.) + collectionName := "gollm2" // Replace with your collection name + vectorSize := uint64(4) // Size of the embedding vectors + distance := "Cosine" // Distance metric (Cosine, Euclidean, etc.) // Call Qdrant's API to create the collection err := vectorDB.CreateCollection(ctx, collectionName, vectorSize, distance) diff --git a/pkg/db/pgvector.go b/pkg/db/pgvector.go index 445da21..26d1177 100644 --- a/pkg/db/pgvector.go +++ b/pkg/db/pgvector.go @@ -16,6 +16,7 @@ // It includes implementations for vector storage and retrieval using PostgreSQL // with the pgvector extension, enabling efficient similarity search operations // on high-dimensional vector data. + package db import ( @@ -34,13 +35,6 @@ type PGVector struct { conn *pgxpool.Pool } -// Document represents a single document in the vector database. -// It contains a unique identifier and associated metadata. -type Document struct { - ID string - Metadata map[string]interface{} -} - // NewPGVector creates a new PGVector instance with a connection to the PostgreSQL database. // // Parameters: diff --git a/pkg/db/qdrant.go b/pkg/db/qdrant.go index 4105875..ef5fc56 100644 --- a/pkg/db/qdrant.go +++ b/pkg/db/qdrant.go @@ -1,3 +1,22 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package db provides database-related functionality for the application. +// It includes implementations for vector storage and retrieval using PostgreSQL +// with the pgvector extension, enabling efficient similarity search operations +// on high-dimensional vector data. + package db import ( @@ -11,10 +30,9 @@ type QdrantVector struct { client *qdrant.Client } -// QDrantDocument represents a document stored in the Qdrant vector database. -type QDrantDocument struct { - ID string - Metadata map[string]interface{} +// Close closes the Qdrant client connection. +func (qv *QdrantVector) Close() { + qv.client.Close() } // NewQdrantVector initializes a connection to Qdrant. @@ -37,11 +55,6 @@ func NewQdrantVector(address string, port int) (*QdrantVector, error) { return &QdrantVector{client: client}, nil } -// Close closes the Qdrant client connection. -func (qv *QdrantVector) Close() { - qv.client.Close() -} - // SaveEmbedding stores an embedding and metadata in Qdrant. // // Parameters: @@ -81,7 +94,7 @@ func (qv *QdrantVector) SaveEmbedding(ctx context.Context, docID string, embeddi // Returns: // - A slice of QDrantDocument structs containing the most relevant documents. // - An error if the query fails. -func (qv *QdrantVector) QueryRelevantDocuments(ctx context.Context, embedding []float32, limit int) ([]QDrantDocument, error) { +func (qv *QdrantVector) QueryRelevantDocuments(ctx context.Context, embedding []float32, limit int) ([]Document, error) { limitUint := uint64(limit) // Convert limit to uint64 query := &qdrant.QueryPoints{ CollectionName: "gollm", // Replace with actual collection name @@ -95,7 +108,7 @@ func (qv *QdrantVector) QueryRelevantDocuments(ctx context.Context, embedding [] return nil, fmt.Errorf("failed to search points: %w", err) } - var docs []QDrantDocument + var docs []Document for _, point := range response { var docID string if numericID := point.Id.GetNum(); numericID != 0 { @@ -104,7 +117,7 @@ func (qv *QdrantVector) QueryRelevantDocuments(ctx context.Context, embedding [] docID = point.Id.GetUuid() // UUID } metadata := convertPayloadToMap(point.Payload) - doc := QDrantDocument{ + doc := Document{ ID: docID, Metadata: metadata, } @@ -154,7 +167,7 @@ func convertPayloadToMap(payload map[string]*qdrant.Value) map[string]interface{ // // Returns: // - An error if the operation fails, nil otherwise. -func QDrantInsertDocument(ctx context.Context, vectorDB *QdrantVector, content string, embedding []float32) error { +func (qv *QdrantVector) InsertDocument(ctx context.Context, vectorDB *QdrantVector, content string, embedding []float32) error { // Generate a unique document ID (e.g., using UUID) docID := fmt.Sprintf("doc-%s", qdrant.NewID("")) diff --git a/pkg/db/vectordb.go b/pkg/db/vectordb.go new file mode 100644 index 0000000..f640814 --- /dev/null +++ b/pkg/db/vectordb.go @@ -0,0 +1,37 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package db provides database-related functionality for the application. +// It includes implementations for vector storage and retrieval using PostgreSQL +// with the pgvector extension, enabling efficient similarity search operations +// on high-dimensional vector data. + +package db + +import ( + "context" +) + +// Document represents a single document in the vector database. +// It contains a unique identifier and associated metadata. +type Document struct { + ID string + Metadata map[string]interface{} +} + +// VectorDatabase is the interface that both QdrantVector and PGVector implement +type VectorDatabase interface { + InsertDocument(ctx context.Context, content string, embedding []float32) error + QueryRelevantDocuments(ctx context.Context, embedding []float32, backend string) ([]Document, error) +} From 21ad96cffcfa1cc5fa56f6779e61c715b10ac11e Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Sun, 13 Oct 2024 10:59:35 +0100 Subject: [PATCH 3/5] Add QDrant --- examples/ollama/main.go | 8 ++++---- examples/openai/main.go | 6 +++++- examples/qdrant/main.go | 12 ++++++------ pkg/db/pgvector.go | 26 +++++++++++++++----------- pkg/db/qdrant.go | 19 ++++++++++++------- pkg/db/qdrant_test.go | 1 + pkg/db/vectordb.go | 1 + 7 files changed, 44 insertions(+), 29 deletions(-) create mode 100644 pkg/db/qdrant_test.go diff --git a/examples/ollama/main.go b/examples/ollama/main.go index 0190683..822dbdb 100644 --- a/examples/ollama/main.go +++ b/examples/ollama/main.go @@ -25,14 +25,11 @@ func main() { var generationBackend backend.Backend // Choose the backend for embeddings based on the config - embeddingBackend = backend.NewOllamaBackend(ollamaHost, ollamaEmbModel) - log.Printf("Embedding backend LLM: %s", ollamaEmbModel) // Choose the backend for generation based on the config generationBackend = backend.NewOllamaBackend(ollamaHost, ollamaGenModel) - log.Printf("Generation backend: %s", ollamaGenModel) // Initialize the vector database @@ -42,6 +39,9 @@ func main() { } log.Println("Vector database initialized") + // Make sure to close the connection when done + defer vectorDB.Close() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -58,7 +58,7 @@ func main() { log.Println("Embedding generated") // Insert the document into the vector store - err = db.InsertDocument(ctx, vectorDB, ragContent, embedding) + err = vectorDB.InsertDocument(ctx, ragContent, embedding) if err != nil { log.Fatalf("Error inserting document: %v", err) } diff --git a/examples/openai/main.go b/examples/openai/main.go index 9d41bcc..c437c44 100644 --- a/examples/openai/main.go +++ b/examples/openai/main.go @@ -49,6 +49,9 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() + // Close the connection when done + defer vectorDB.Close() + // We insert contextual information into the vector store so that the RAG system // can use it to answer the query about the moon landing, effectively replacing 1969 with 2023 ragContent := "According to the Space Exploration Organization's official records, the moon landing occurred on July 20, 2023, during the Artemis Program. This mission marked the first successful crewed lunar landing since the Apollo program." @@ -62,7 +65,8 @@ func main() { log.Println("Embedding generated") // Insert the document into the vector store - err = db.InsertDocument(ctx, vectorDB, ragContent, embedding) + err = vectorDB.InsertDocument(ctx, ragContent, embedding) + if err != nil { log.Fatalf("Error inserting document: %v", err) } diff --git a/examples/qdrant/main.go b/examples/qdrant/main.go index 16016e7..1acf3a4 100644 --- a/examples/qdrant/main.go +++ b/examples/qdrant/main.go @@ -33,7 +33,7 @@ func main() { content := "This is a test document." // Insert the document into the Qdrant vector store - err = QDrantInsertDocument(ctx, qdrantVector, content, embedding) + err = qdrantVector.InsertDocument(ctx, content, embedding) if err != nil { log.Fatalf("Failed to insert document: %v", err) } @@ -53,9 +53,9 @@ func main() { // CreateCollection creates a new collection in Qdrant func CreateCollection(ctx context.Context, vectorDB *db.QdrantVector) error { - collectionName := "gollm2" // Replace with your collection name - vectorSize := uint64(4) // Size of the embedding vectors - distance := "Cosine" // Distance metric (Cosine, Euclidean, etc.) + collectionName := "sddd" // Replace with your collection name + vectorSize := uint64(4) // Size of the embedding vectors + distance := "Cosine" // Distance metric (Cosine, Euclidean, etc.) // Call Qdrant's API to create the collection err := vectorDB.CreateCollection(ctx, collectionName, vectorSize, distance) @@ -66,7 +66,7 @@ func CreateCollection(ctx context.Context, vectorDB *db.QdrantVector) error { } // QDrantInsertDocument inserts a document into the Qdrant vector store. -func QDrantInsertDocument(ctx context.Context, vectorDB *db.QdrantVector, content string, embedding []float32) error { +func QDrantInsertDocument(ctx context.Context, vectorDB db.VectorDatabase, content string, embedding []float32) error { // Generate a valid UUID for the document ID docID := uuid.New().String() // Use pure UUID without the 'doc-' prefix @@ -76,7 +76,7 @@ func QDrantInsertDocument(ctx context.Context, vectorDB *db.QdrantVector, conten } // Save the document and its embedding - err := vectorDB.SaveEmbedding(ctx, docID, embedding, metadata) + err := vectorDB.SaveEmbeddings(ctx, docID, embedding, metadata) if err != nil { return fmt.Errorf("error saving embedding: %v", err) } diff --git a/pkg/db/pgvector.go b/pkg/db/pgvector.go index 26d1177..1ea8392 100644 --- a/pkg/db/pgvector.go +++ b/pkg/db/pgvector.go @@ -35,6 +35,11 @@ type PGVector struct { conn *pgxpool.Pool } +// Close closes the PostgreSQL connection pool. +func (pg *PGVector) Close() { + pg.conn.Close() +} + // NewPGVector creates a new PGVector instance with a connection to the PostgreSQL database. // // Parameters: @@ -61,26 +66,25 @@ func NewPGVector(connString string) (*PGVector, error) { // // Returns: // - An error if the saving operation fails, nil otherwise. -func (pg *PGVector) SaveEmbedding(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}) error { - // Create a pgvector.Vector type from the float32 slice - var query string +// +// SaveEmbeddings stores a document embedding and associated metadata in the PostgreSQL database, implementing the VectorDatabase interface. +func (pg *PGVector) SaveEmbeddings(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}) error { vector := pgvector.NewVector(embedding) + // Determine the table based on the embedding length + var query string switch len(embedding) { case 1536: query = `INSERT INTO openai_embeddings (doc_id, embedding, metadata) VALUES ($1, $2, $3)` - case 1024: query = `INSERT INTO ollama_embeddings (doc_id, embedding, metadata) VALUES ($1, $2, $3)` default: return fmt.Errorf("unsupported embedding length: %d", len(embedding)) } - // Construct the query to insert the vector into the database - // Execute the query with the pgvector.Vector type + // Execute the query to insert the vector into the database _, err := pg.conn.Exec(ctx, query, docID, vector, metadata) if err != nil { - // Log the error for debugging purposes return fmt.Errorf("failed to insert document: %w", err) } return nil @@ -173,9 +177,9 @@ func CombineQueryWithContext(query string, docs []Document) string { return fmt.Sprintf("Context: %s\nQuery: %s", contextStr, query) } -// InsertDocument insert a document into the vector store -func InsertDocument(ctx context.Context, vectorDB *PGVector, content string, embedding []float32) error { - // Generate a unique document ID (for simplicity, using a static value for testing) +// InsertDocument inserts a document into the PGVector store, implementing the VectorDatabase interface. +func (pg *PGVector) InsertDocument(ctx context.Context, content string, embedding []float32) error { + // Generate a unique document ID (for simplicity, using UUID) docID := fmt.Sprintf("doc-%s", uuid.New().String()) // Create metadata @@ -184,7 +188,7 @@ func InsertDocument(ctx context.Context, vectorDB *PGVector, content string, emb } // Save the document and its embedding into the vector store - err := vectorDB.SaveEmbedding(ctx, docID, embedding, metadata) + err := pg.SaveEmbeddings(ctx, docID, embedding, metadata) if err != nil { return fmt.Errorf("error saving embedding: %v", err) } diff --git a/pkg/db/qdrant.go b/pkg/db/qdrant.go index ef5fc56..4587f91 100644 --- a/pkg/db/qdrant.go +++ b/pkg/db/qdrant.go @@ -22,6 +22,8 @@ package db import ( "context" "fmt" + + "github.com/google/uuid" "github.com/qdrant/go-client/qdrant" ) @@ -65,7 +67,9 @@ func NewQdrantVector(address string, port int) (*QdrantVector, error) { // // Returns: // - An error if the saving operation fails, nil otherwise. -func (qv *QdrantVector) SaveEmbedding(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}) error { +// +// SaveEmbeddings stores an embedding and metadata in Qdrant, implementing the VectorDatabase interface. +func (qv *QdrantVector) SaveEmbeddings(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}) error { point := &qdrant.PointStruct{ Id: qdrant.NewID(docID), Vectors: qdrant.NewVectors(embedding...), @@ -167,17 +171,18 @@ func convertPayloadToMap(payload map[string]*qdrant.Value) map[string]interface{ // // Returns: // - An error if the operation fails, nil otherwise. -func (qv *QdrantVector) InsertDocument(ctx context.Context, vectorDB *QdrantVector, content string, embedding []float32) error { - // Generate a unique document ID (e.g., using UUID) - docID := fmt.Sprintf("doc-%s", qdrant.NewID("")) +// +// QdrantVector should implement the InsertDocument method as defined in VectorDatabase +func (qv *QdrantVector) InsertDocument(ctx context.Context, content string, embedding []float32) error { + // Generate a valid UUID for the document ID + docID := uuid.New().String() // Properly generate a UUID - // Create metadata for the document metadata := map[string]interface{}{ "content": content, } - // Save the document and its embedding - err := vectorDB.SaveEmbedding(ctx, docID, embedding, metadata) + // Call the SaveEmbeddings method to save the document and its embedding + err := qv.SaveEmbeddings(ctx, docID, embedding, metadata) if err != nil { return fmt.Errorf("error saving embedding: %v", err) } diff --git a/pkg/db/qdrant_test.go b/pkg/db/qdrant_test.go new file mode 100644 index 0000000..3a49c63 --- /dev/null +++ b/pkg/db/qdrant_test.go @@ -0,0 +1 @@ +package db diff --git a/pkg/db/vectordb.go b/pkg/db/vectordb.go index f640814..2b79504 100644 --- a/pkg/db/vectordb.go +++ b/pkg/db/vectordb.go @@ -34,4 +34,5 @@ type Document struct { type VectorDatabase interface { InsertDocument(ctx context.Context, content string, embedding []float32) error QueryRelevantDocuments(ctx context.Context, embedding []float32, backend string) ([]Document, error) + SaveEmbeddings(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}) error } From 4af7246252d3a35e5d9481173ce1375bad05087c Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Sun, 13 Oct 2024 22:12:33 +0100 Subject: [PATCH 4/5] Add QDrant --- examples/qdrant/main.go | 10 +++++----- pkg/db/qdrant.go | 2 +- pkg/db/qdrant_test.go | 1 - 3 files changed, 6 insertions(+), 7 deletions(-) delete mode 100644 pkg/db/qdrant_test.go diff --git a/examples/qdrant/main.go b/examples/qdrant/main.go index 1acf3a4..280c611 100644 --- a/examples/qdrant/main.go +++ b/examples/qdrant/main.go @@ -23,7 +23,8 @@ func main() { defer cancel() // Create the collection in Qdrant - err = CreateCollection(ctx, qdrantVector) + collection_name := uuid.New().String() + err = CreateCollection(ctx, qdrantVector, collection_name) if err != nil { log.Fatalf("Failed to create collection: %v", err) } @@ -52,10 +53,9 @@ func main() { } // CreateCollection creates a new collection in Qdrant -func CreateCollection(ctx context.Context, vectorDB *db.QdrantVector) error { - collectionName := "sddd" // Replace with your collection name - vectorSize := uint64(4) // Size of the embedding vectors - distance := "Cosine" // Distance metric (Cosine, Euclidean, etc.) +func CreateCollection(ctx context.Context, vectorDB *db.QdrantVector, collectionName string) error { + vectorSize := uint64(4) // Size of the embedding vectors + distance := "Cosine" // Distance metric (Cosine, Euclidean, etc.) // Call Qdrant's API to create the collection err := vectorDB.CreateCollection(ctx, collectionName, vectorSize, distance) diff --git a/pkg/db/qdrant.go b/pkg/db/qdrant.go index 4587f91..71971fb 100644 --- a/pkg/db/qdrant.go +++ b/pkg/db/qdrant.go @@ -161,7 +161,7 @@ func convertPayloadToMap(payload map[string]*qdrant.Value) map[string]interface{ return result } -// QDrantInsertDocument inserts a document into the Qdrant vector store. +// InsertDocument inserts a document into the Qdrant vector store. // // Parameters: // - ctx: Context for the operation. diff --git a/pkg/db/qdrant_test.go b/pkg/db/qdrant_test.go deleted file mode 100644 index 3a49c63..0000000 --- a/pkg/db/qdrant_test.go +++ /dev/null @@ -1 +0,0 @@ -package db From 186e1e6cef804644a9b0cee746bbe12b595fe789 Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Sat, 19 Oct 2024 16:48:36 +0100 Subject: [PATCH 5/5] Implements support for qdrant as a backend Example code included, this extends upon the generic backend interface --- examples/qdrant/main.go | 72 ++++++++-- go.mod | 3 + pkg/db/pgvector.go | 12 -- pkg/db/qdrant.go | 12 +- pkg/db/qdrant_test.go | 282 ++++++++++++++++++++++++++++++++++++++++ pkg/db/vectordb.go | 12 ++ 6 files changed, 364 insertions(+), 29 deletions(-) create mode 100644 pkg/db/qdrant_test.go diff --git a/examples/qdrant/main.go b/examples/qdrant/main.go index 280c611..bd263ba 100644 --- a/examples/qdrant/main.go +++ b/examples/qdrant/main.go @@ -4,19 +4,35 @@ import ( "context" "fmt" "github.com/google/uuid" + "github.com/stackloklabs/gollm/pkg/backend" "github.com/stackloklabs/gollm/pkg/db" "log" "time" ) +var ( + ollamaHost = "http://localhost:11434" + ollamaEmbModel = "mxbai-embed-large" + ollamaGenModel = "llama3" + // databaseURL = "postgres://user:password@localhost:5432/dbname?sslmode=disable" +) + func main() { // Initialize Qdrant vector connection - qdrantVector, err := db.NewQdrantVector("localhost", 6334) + + // Configure the Ollama backend for both embedding and generation + embeddingBackend := backend.NewOllamaBackend(ollamaHost, ollamaEmbModel, time.Duration(10*time.Second)) + log.Printf("Embedding backend LLM: %s", ollamaEmbModel) + + generationBackend := backend.NewOllamaBackend(ollamaHost, ollamaGenModel, time.Duration(10*time.Second)) + log.Printf("Generation backend: %s", ollamaGenModel) + + vectorDB, err := db.NewQdrantVector("localhost", 6334) if err != nil { log.Fatalf("Failed to connect to Qdrant: %v", err) } // Defer the Close() method to ensure the connection is properly closed after use - defer qdrantVector.Close() + defer vectorDB.Close() // Set up a context with a timeout for the Qdrant operations ctx, cancel := context.WithTimeout(context.Background(), time.Minute) @@ -24,38 +40,72 @@ func main() { // Create the collection in Qdrant collection_name := uuid.New().String() - err = CreateCollection(ctx, qdrantVector, collection_name) + err = CreateCollection(ctx, vectorDB, collection_name) if err != nil { log.Fatalf("Failed to create collection: %v", err) } - // Example embedding and content - embedding := []float32{0.05, 0.61, 0.76, 0.74} - content := "This is a test document." + // We insert contextual information into the vector store so that the RAG system + // can use it to answer the query about the moon landing, effectively replacing 1969 with 2023 + ragContent := "According to the Space Exploration Organization's official records, the moon landing occurred on July 20, 2023, during the Artemis Program. This mission marked the first successful crewed lunar landing since the Apollo program." + userQuery := "When was the moon landing?." + + // Embed the query using Ollama Embedding backend + embedding, err := embeddingBackend.Embed(ctx, ragContent) + if err != nil { + log.Fatalf("Error generating embedding: %v", err) + } + log.Println("Embedding generated") // Insert the document into the Qdrant vector store - err = qdrantVector.InsertDocument(ctx, content, embedding) + err = vectorDB.InsertDocument(ctx, ragContent, embedding, collection_name) if err != nil { log.Fatalf("Failed to insert document: %v", err) } log.Println("Document inserted successfully.") + // Embed the query using the specified embedding backend + queryEmbedding, err := embeddingBackend.Embed(ctx, userQuery) + if err != nil { + log.Fatalf("Error generating query embedding: %v", err) + } + // Query the most relevant documents based on a given embedding - docs, err := qdrantVector.QueryRelevantDocuments(ctx, embedding, 5) + retrievedDocs, err := vectorDB.QueryRelevantDocuments(ctx, queryEmbedding, 5, collection_name) if err != nil { log.Fatalf("Failed to query documents: %v", err) } // Print out the retrieved documents - for _, doc := range docs { + for _, doc := range retrievedDocs { log.Printf("Document ID: %s, Content: %v\n", doc.ID, doc.Metadata["content"]) } + + // Augment the query with retrieved context + augmentedQuery := db.CombineQueryWithContext(userQuery, retrievedDocs) + + prompt := backend.NewPrompt(). + AddMessage("system", "You are an AI assistant. Use the provided context to answer the user's question. Prioritize the provided context over any prior knowledge."). + AddMessage("user", augmentedQuery). + SetParameters(backend.Parameters{ + MaxTokens: 150, + Temperature: 0.7, + TopP: 0.9, + }) + + // Generate response with the specified generation backend + response, err := generationBackend.Generate(ctx, prompt) + if err != nil { + log.Fatalf("Failed to generate response: %v", err) + } + + log.Printf("Retrieval-Augmented Generation influenced output from LLM model: %s", response) } // CreateCollection creates a new collection in Qdrant func CreateCollection(ctx context.Context, vectorDB *db.QdrantVector, collectionName string) error { - vectorSize := uint64(4) // Size of the embedding vectors - distance := "Cosine" // Distance metric (Cosine, Euclidean, etc.) + vectorSize := uint64(1024) // Size of the embedding vectors + distance := "Cosine" // Distance metric (Cosine, Euclidean, etc.) // Call Qdrant's API to create the collection err := vectorDB.CreateCollection(ctx, collectionName, vectorSize, distance) diff --git a/go.mod b/go.mod index 1149756..a8e74c3 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/jackc/pgx/v4 v4.18.3 github.com/pgvector/pgvector-go v0.2.2 github.com/qdrant/go-client v1.12.0 + github.com/stretchr/testify v1.9.0 ) require ( @@ -22,6 +23,7 @@ require ( github.com/jackc/pgtype v1.14.0 // indirect github.com/jackc/puddle v1.3.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect golang.org/x/crypto v0.27.0 // indirect golang.org/x/net v0.29.0 // indirect @@ -30,4 +32,5 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect google.golang.org/grpc v1.66.2 // indirect google.golang.org/protobuf v1.34.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/pkg/db/pgvector.go b/pkg/db/pgvector.go index 1ea8392..ff5e62c 100644 --- a/pkg/db/pgvector.go +++ b/pkg/db/pgvector.go @@ -165,18 +165,6 @@ func ConvertEmbeddingToPGVector(embedding []float32) string { return fmt.Sprintf("{%s}", strings.Join(strValues, ",")) } -// CombineQueryWithContext combines the query and retrieved documents' content to provide context for generation. -func CombineQueryWithContext(query string, docs []Document) string { - var contextStr string - for _, doc := range docs { - // Cast doc.Metadata["content"] to a string - if content, ok := doc.Metadata["content"].(string); ok { - contextStr += content + "\n" - } - } - return fmt.Sprintf("Context: %s\nQuery: %s", contextStr, query) -} - // InsertDocument inserts a document into the PGVector store, implementing the VectorDatabase interface. func (pg *PGVector) InsertDocument(ctx context.Context, content string, embedding []float32) error { // Generate a unique document ID (for simplicity, using UUID) diff --git a/pkg/db/qdrant.go b/pkg/db/qdrant.go index 71971fb..cfaf9ec 100644 --- a/pkg/db/qdrant.go +++ b/pkg/db/qdrant.go @@ -69,7 +69,7 @@ func NewQdrantVector(address string, port int) (*QdrantVector, error) { // - An error if the saving operation fails, nil otherwise. // // SaveEmbeddings stores an embedding and metadata in Qdrant, implementing the VectorDatabase interface. -func (qv *QdrantVector) SaveEmbeddings(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}) error { +func (qv *QdrantVector) SaveEmbeddings(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}, collection string) error { point := &qdrant.PointStruct{ Id: qdrant.NewID(docID), Vectors: qdrant.NewVectors(embedding...), @@ -78,7 +78,7 @@ func (qv *QdrantVector) SaveEmbeddings(ctx context.Context, docID string, embedd waitUpsert := true _, err := qv.client.Upsert(ctx, &qdrant.UpsertPoints{ - CollectionName: "gollm", // Replace with actual collection name + CollectionName: collection, // Replace with actual collection name Wait: &waitUpsert, Points: []*qdrant.PointStruct{point}, }) @@ -98,10 +98,10 @@ func (qv *QdrantVector) SaveEmbeddings(ctx context.Context, docID string, embedd // Returns: // - A slice of QDrantDocument structs containing the most relevant documents. // - An error if the query fails. -func (qv *QdrantVector) QueryRelevantDocuments(ctx context.Context, embedding []float32, limit int) ([]Document, error) { +func (qv *QdrantVector) QueryRelevantDocuments(ctx context.Context, embedding []float32, limit int, colllection string) ([]Document, error) { limitUint := uint64(limit) // Convert limit to uint64 query := &qdrant.QueryPoints{ - CollectionName: "gollm", // Replace with actual collection name + CollectionName: colllection, // Replace with actual collection name Query: qdrant.NewQuery(embedding...), Limit: &limitUint, WithPayload: qdrant.NewWithPayloadInclude("content"), @@ -173,7 +173,7 @@ func convertPayloadToMap(payload map[string]*qdrant.Value) map[string]interface{ // - An error if the operation fails, nil otherwise. // // QdrantVector should implement the InsertDocument method as defined in VectorDatabase -func (qv *QdrantVector) InsertDocument(ctx context.Context, content string, embedding []float32) error { +func (qv *QdrantVector) InsertDocument(ctx context.Context, content string, embedding []float32, collection string) error { // Generate a valid UUID for the document ID docID := uuid.New().String() // Properly generate a UUID @@ -182,7 +182,7 @@ func (qv *QdrantVector) InsertDocument(ctx context.Context, content string, embe } // Call the SaveEmbeddings method to save the document and its embedding - err := qv.SaveEmbeddings(ctx, docID, embedding, metadata) + err := qv.SaveEmbeddings(ctx, docID, embedding, metadata, collection) if err != nil { return fmt.Errorf("error saving embedding: %v", err) } diff --git a/pkg/db/qdrant_test.go b/pkg/db/qdrant_test.go new file mode 100644 index 0000000..4b16074 --- /dev/null +++ b/pkg/db/qdrant_test.go @@ -0,0 +1,282 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package db provides database-related functionality for the application. +// It includes implementations for vector storage and retrieval using PostgreSQL +// with the pgvector extension, enabling efficient similarity search operations +// on high-dimensional vector data. +package db + +import ( + "context" + "reflect" + "testing" + + "github.com/google/uuid" + "github.com/qdrant/go-client/qdrant" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// Create a test wrapper for QdrantVector that uses interface instead of concrete client +type testQdrantVector struct { + QdrantVector + mockClient *mockClient +} + +// mockClient implements the necessary methods we use from qdrant.Client +type mockClient struct { + mock.Mock +} + +func newTestQdrantVector() *testQdrantVector { + mc := &mockClient{} + return &testQdrantVector{ + QdrantVector: QdrantVector{}, + mockClient: mc, + } +} + +func (m *mockClient) Close() { + m.Called() +} + +func (m *mockClient) Upsert(ctx context.Context, req *qdrant.UpsertPoints) (*qdrant.PointsOperationResponse, error) { + args := m.Called(ctx, req) + return args.Get(0).(*qdrant.PointsOperationResponse), args.Error(1) +} + +func (m *mockClient) Query(ctx context.Context, req *qdrant.QueryPoints) ([]*qdrant.ScoredPoint, error) { + args := m.Called(ctx, req) + return args.Get(0).([]*qdrant.ScoredPoint), args.Error(1) +} + +func (m *mockClient) CreateCollection(ctx context.Context, req *qdrant.CreateCollection) error { + args := m.Called(ctx, req) + return args.Error(0) +} + +// Modified QdrantVector struct for testing +func (t *testQdrantVector) SaveEmbeddings(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}, collection string) error { + point := &qdrant.PointStruct{ + Id: qdrant.NewID(docID), + Vectors: qdrant.NewVectors(embedding...), + Payload: qdrant.NewValueMap(metadata), + } + + waitUpsert := true + _, err := t.mockClient.Upsert(ctx, &qdrant.UpsertPoints{ + CollectionName: collection, + Wait: &waitUpsert, + Points: []*qdrant.PointStruct{point}, + }) + return err +} + +func (t *testQdrantVector) QueryRelevantDocuments(ctx context.Context, embedding []float32, limit int, collection string) ([]Document, error) { + limitUint := uint64(limit) + query := &qdrant.QueryPoints{ + CollectionName: collection, + Query: qdrant.NewQuery(embedding...), + Limit: &limitUint, + WithPayload: qdrant.NewWithPayloadInclude("content"), + } + + response, err := t.mockClient.Query(ctx, query) + if err != nil { + return nil, err + } + + var docs []Document + for _, point := range response { + var docID string + if numericID := point.Id.GetNum(); numericID != 0 { + docID = point.Id.GetUuid() + } else { + docID = point.Id.GetUuid() + } + metadata := convertPayloadToMap(point.Payload) + doc := Document{ + ID: docID, + Metadata: metadata, + } + docs = append(docs, doc) + } + return docs, nil +} + +func (t *testQdrantVector) CreateCollection(ctx context.Context, collectionName string, vectorSize uint64, distance string) error { + return t.mockClient.CreateCollection(ctx, &qdrant.CreateCollection{ + CollectionName: collectionName, + VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ + Size: vectorSize, + Distance: qdrant.Distance_Cosine, + }), + }) +} + +func TestSaveEmbeddings(t *testing.T) { + qv := newTestQdrantVector() + + ctx := context.Background() + docID := "test-doc" + embedding := []float32{0.1, 0.2, 0.3} + metadata := map[string]interface{}{ + "content": "test content", + } + collection := "test-collection" + + // Set up expectations + qv.mockClient.On("Upsert", mock.Anything, mock.MatchedBy(func(req *qdrant.UpsertPoints) bool { + return req.CollectionName == collection && + len(req.Points) == 1 && + req.Points[0].Id.GetUuid() == docID + })).Return(&qdrant.PointsOperationResponse{}, nil) + + // Test the SaveEmbeddings function + err := qv.SaveEmbeddings(ctx, docID, embedding, metadata, collection) + assert.NoError(t, err) + + // Verify expectations + qv.mockClient.AssertExpectations(t) +} + +func TestQueryRelevantDocuments(t *testing.T) { + qv := newTestQdrantVector() + + ctx := context.Background() + embedding := []float32{0.1, 0.2, 0.3} + limit := 5 + collection := "test-collection" + + // Create mock response + testUUID := uuid.New().String() + mockResponse := []*qdrant.ScoredPoint{ + { + Id: &qdrant.PointId{ + PointIdOptions: &qdrant.PointId_Uuid{ + Uuid: testUUID, + }, + }, + Payload: map[string]*qdrant.Value{ + "content": { + Kind: &qdrant.Value_StringValue{ + StringValue: "test content", + }, + }, + }, + }, + } + + // Set up expectations + qv.mockClient.On("Query", mock.Anything, mock.MatchedBy(func(req *qdrant.QueryPoints) bool { + return req.CollectionName == collection && + *req.Limit == uint64(limit) + })).Return(mockResponse, nil) + + // Test the QueryRelevantDocuments function + docs, err := qv.QueryRelevantDocuments(ctx, embedding, limit, collection) + assert.NoError(t, err) + assert.Len(t, docs, 1) + assert.Equal(t, "test content", docs[0].Metadata["content"]) + + // Verify expectations + qv.mockClient.AssertExpectations(t) +} + +func TestCreateCollection(t *testing.T) { + qv := newTestQdrantVector() + + ctx := context.Background() + collectionName := "test-collection" + vectorSize := uint64(3) + distance := "cosine" + + // Set up expectations + qv.mockClient.On("CreateCollection", mock.Anything, mock.MatchedBy(func(req *qdrant.CreateCollection) bool { + return req.CollectionName == collectionName && + req.VectorsConfig.GetParams().Size == vectorSize + })).Return(nil) + + // Test the CreateCollection function + err := qv.CreateCollection(ctx, collectionName, vectorSize, distance) + assert.NoError(t, err) + + // Verify expectations + qv.mockClient.AssertExpectations(t) +} + +// Add InsertDocument method to testQdrantVector +func (t *testQdrantVector) InsertDocument(ctx context.Context, content string, embedding []float32, collection string) error { + // Create metadata map with content + metadata := map[string]interface{}{ + "content": content, + } + + // Generate a new UUID for the document + docID := uuid.New().String() + + // Use our mock-aware SaveEmbeddings method + return t.SaveEmbeddings(ctx, docID, embedding, metadata, collection) +} + +func TestInsertDocument(t *testing.T) { + qv := newTestQdrantVector() + + ctx := context.Background() + content := "test content" + embedding := []float32{0.1, 0.2, 0.3} + collection := "test-collection" + + // Set up expectations for the mock client + qv.mockClient.On("Upsert", mock.Anything, mock.MatchedBy(func(req *qdrant.UpsertPoints) bool { + if len(req.Points) != 1 { + return false + } + point := req.Points[0] + + // Check collection name + if req.CollectionName != collection { + return false + } + + // Check payload contains correct content + payload := point.Payload + contentValue, exists := payload["content"] + if !exists { + return false + } + stringValue, ok := contentValue.Kind.(*qdrant.Value_StringValue) + if !ok { + return false + } + if stringValue.StringValue != content { + return false + } + + // Check vectors + if !reflect.DeepEqual(point.Vectors.GetVector().Data, embedding) { + return false + } + + return true + })).Return(&qdrant.PointsOperationResponse{}, nil) + + // Test the InsertDocument function + err := qv.InsertDocument(ctx, content, embedding, collection) + assert.NoError(t, err) + + // Verify expectations + qv.mockClient.AssertExpectations(t) +} diff --git a/pkg/db/vectordb.go b/pkg/db/vectordb.go index 2b79504..1efeb5a 100644 --- a/pkg/db/vectordb.go +++ b/pkg/db/vectordb.go @@ -21,6 +21,7 @@ package db import ( "context" + "fmt" ) // Document represents a single document in the vector database. @@ -36,3 +37,14 @@ type VectorDatabase interface { QueryRelevantDocuments(ctx context.Context, embedding []float32, backend string) ([]Document, error) SaveEmbeddings(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}) error } + +// CombineQueryWithContext combines the user's query with the relevant retrieved documents' content +func CombineQueryWithContext(query string, retrievedDocs []Document) string { + var context string + for _, doc := range retrievedDocs { + // Include the content of each retrieved document in the context + context += fmt.Sprintf("%s\n", doc.Metadata["content"]) + } + // Construct the augmented query with the retrieved context and the user's query + return fmt.Sprintf("Context: %s\n\nQuery: %s", context, query) +}