diff --git a/docker-compose.yaml b/docker-compose.yaml index 3162143..c450078 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,4 +1,4 @@ -version: '3.8' +version: '3.9' # Requires version 3.9+ for profiles services: postgres: @@ -13,6 +13,16 @@ services: volumes: - pgdata:/var/lib/postgresql/data - ./db/init.sql:/docker-entrypoint-initdb.d/init.sql + profiles: + - postgres # Enable this service only when the 'postgres' profile is active + + qdrant: + image: qdrant/qdrant + container_name: qdrant-db + ports: + - "6334:6334" + profiles: + - qdrant # Enable this service only when the 'qdrant' profile is active volumes: pgdata: diff --git a/examples/ollama/main.go b/examples/ollama/main.go index 377fc32..e5eb4d2 100644 --- a/examples/ollama/main.go +++ b/examples/ollama/main.go @@ -33,6 +33,9 @@ func main() { } log.Println("Vector database initialized") + // Make sure to close the connection when done + defer vectorDB.Close() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -49,7 +52,7 @@ func main() { log.Println("Embedding generated") // Insert the document into the vector store - err = db.InsertDocument(ctx, vectorDB, ragContent, embedding) + err = vectorDB.InsertDocument(ctx, ragContent, embedding) if err != nil { log.Fatalf("Error inserting document: %v", err) } diff --git a/examples/openai/main.go b/examples/openai/main.go index 173cc28..6df7bb3 100644 --- a/examples/openai/main.go +++ b/examples/openai/main.go @@ -49,6 +49,9 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() + // Close the connection when done + defer vectorDB.Close() + // We insert contextual information into the vector store so that the RAG system // can use it to answer the query about the moon landing, effectively replacing 1969 with 2023 ragContent := "According to the Space Exploration Organization's official records, the moon landing occurred on July 20, 2023, during the Artemis Program. This mission marked the first successful crewed lunar landing since the Apollo program." @@ -62,7 +65,8 @@ func main() { log.Println("Embedding generated") // Insert the document into the vector store - err = db.InsertDocument(ctx, vectorDB, ragContent, embedding) + err = vectorDB.InsertDocument(ctx, ragContent, embedding) + if err != nil { log.Fatalf("Error inserting document: %v", err) } diff --git a/examples/qdrant/main.go b/examples/qdrant/main.go new file mode 100644 index 0000000..bd263ba --- /dev/null +++ b/examples/qdrant/main.go @@ -0,0 +1,134 @@ +package main + +import ( + "context" + "fmt" + "github.com/google/uuid" + "github.com/stackloklabs/gollm/pkg/backend" + "github.com/stackloklabs/gollm/pkg/db" + "log" + "time" +) + +var ( + ollamaHost = "http://localhost:11434" + ollamaEmbModel = "mxbai-embed-large" + ollamaGenModel = "llama3" + // databaseURL = "postgres://user:password@localhost:5432/dbname?sslmode=disable" +) + +func main() { + // Initialize Qdrant vector connection + + // Configure the Ollama backend for both embedding and generation + embeddingBackend := backend.NewOllamaBackend(ollamaHost, ollamaEmbModel, time.Duration(10*time.Second)) + log.Printf("Embedding backend LLM: %s", ollamaEmbModel) + + generationBackend := backend.NewOllamaBackend(ollamaHost, ollamaGenModel, time.Duration(10*time.Second)) + log.Printf("Generation backend: %s", ollamaGenModel) + + vectorDB, err := db.NewQdrantVector("localhost", 6334) + if err != nil { + log.Fatalf("Failed to connect to Qdrant: %v", err) + } + // Defer the Close() method to ensure the connection is properly closed after use + defer vectorDB.Close() + + // Set up a context with a timeout for the Qdrant operations + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + // Create the collection in Qdrant + collection_name := uuid.New().String() + err = CreateCollection(ctx, vectorDB, collection_name) + if err != nil { + log.Fatalf("Failed to create collection: %v", err) + } + + // We insert contextual information into the vector store so that the RAG system + // can use it to answer the query about the moon landing, effectively replacing 1969 with 2023 + ragContent := "According to the Space Exploration Organization's official records, the moon landing occurred on July 20, 2023, during the Artemis Program. This mission marked the first successful crewed lunar landing since the Apollo program." + userQuery := "When was the moon landing?." + + // Embed the query using Ollama Embedding backend + embedding, err := embeddingBackend.Embed(ctx, ragContent) + if err != nil { + log.Fatalf("Error generating embedding: %v", err) + } + log.Println("Embedding generated") + + // Insert the document into the Qdrant vector store + err = vectorDB.InsertDocument(ctx, ragContent, embedding, collection_name) + if err != nil { + log.Fatalf("Failed to insert document: %v", err) + } + log.Println("Document inserted successfully.") + + // Embed the query using the specified embedding backend + queryEmbedding, err := embeddingBackend.Embed(ctx, userQuery) + if err != nil { + log.Fatalf("Error generating query embedding: %v", err) + } + + // Query the most relevant documents based on a given embedding + retrievedDocs, err := vectorDB.QueryRelevantDocuments(ctx, queryEmbedding, 5, collection_name) + if err != nil { + log.Fatalf("Failed to query documents: %v", err) + } + + // Print out the retrieved documents + for _, doc := range retrievedDocs { + log.Printf("Document ID: %s, Content: %v\n", doc.ID, doc.Metadata["content"]) + } + + // Augment the query with retrieved context + augmentedQuery := db.CombineQueryWithContext(userQuery, retrievedDocs) + + prompt := backend.NewPrompt(). + AddMessage("system", "You are an AI assistant. Use the provided context to answer the user's question. Prioritize the provided context over any prior knowledge."). + AddMessage("user", augmentedQuery). + SetParameters(backend.Parameters{ + MaxTokens: 150, + Temperature: 0.7, + TopP: 0.9, + }) + + // Generate response with the specified generation backend + response, err := generationBackend.Generate(ctx, prompt) + if err != nil { + log.Fatalf("Failed to generate response: %v", err) + } + + log.Printf("Retrieval-Augmented Generation influenced output from LLM model: %s", response) +} + +// CreateCollection creates a new collection in Qdrant +func CreateCollection(ctx context.Context, vectorDB *db.QdrantVector, collectionName string) error { + vectorSize := uint64(1024) // Size of the embedding vectors + distance := "Cosine" // Distance metric (Cosine, Euclidean, etc.) + + // Call Qdrant's API to create the collection + err := vectorDB.CreateCollection(ctx, collectionName, vectorSize, distance) + if err != nil { + return fmt.Errorf("error creating collection: %v", err) + } + return nil +} + +// QDrantInsertDocument inserts a document into the Qdrant vector store. +func QDrantInsertDocument(ctx context.Context, vectorDB db.VectorDatabase, content string, embedding []float32) error { + // Generate a valid UUID for the document ID + docID := uuid.New().String() // Use pure UUID without the 'doc-' prefix + + // Create metadata for the document + metadata := map[string]interface{}{ + "content": content, + } + + // Save the document and its embedding + err := vectorDB.SaveEmbeddings(ctx, docID, embedding, metadata) + if err != nil { + return fmt.Errorf("error saving embedding: %v", err) + } + return nil +} diff --git a/go.mod b/go.mod index 15948f8..a8e74c3 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,15 @@ module github.com/stackloklabs/gollm -go 1.22.1 +go 1.22.2 + +toolchain go1.22.8 require ( github.com/google/uuid v1.6.0 github.com/jackc/pgx/v4 v4.18.3 github.com/pgvector/pgvector-go v0.2.2 + github.com/qdrant/go-client v1.12.0 + github.com/stretchr/testify v1.9.0 ) require ( @@ -18,9 +22,15 @@ require ( github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgtype v1.14.0 // indirect github.com/jackc/puddle v1.3.0 // indirect - github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/stretchr/testify v1.9.0 // indirect - golang.org/x/crypto v0.25.0 // indirect - golang.org/x/text v0.16.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect + github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect + golang.org/x/crypto v0.27.0 // indirect + golang.org/x/net v0.29.0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/text v0.18.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect + google.golang.org/grpc v1.66.2 // indirect + google.golang.org/protobuf v1.34.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/pkg/db/pgvector.go b/pkg/db/pgvector.go index 445da21..ff5e62c 100644 --- a/pkg/db/pgvector.go +++ b/pkg/db/pgvector.go @@ -16,6 +16,7 @@ // It includes implementations for vector storage and retrieval using PostgreSQL // with the pgvector extension, enabling efficient similarity search operations // on high-dimensional vector data. + package db import ( @@ -34,11 +35,9 @@ type PGVector struct { conn *pgxpool.Pool } -// Document represents a single document in the vector database. -// It contains a unique identifier and associated metadata. -type Document struct { - ID string - Metadata map[string]interface{} +// Close closes the PostgreSQL connection pool. +func (pg *PGVector) Close() { + pg.conn.Close() } // NewPGVector creates a new PGVector instance with a connection to the PostgreSQL database. @@ -67,26 +66,25 @@ func NewPGVector(connString string) (*PGVector, error) { // // Returns: // - An error if the saving operation fails, nil otherwise. -func (pg *PGVector) SaveEmbedding(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}) error { - // Create a pgvector.Vector type from the float32 slice - var query string +// +// SaveEmbeddings stores a document embedding and associated metadata in the PostgreSQL database, implementing the VectorDatabase interface. +func (pg *PGVector) SaveEmbeddings(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}) error { vector := pgvector.NewVector(embedding) + // Determine the table based on the embedding length + var query string switch len(embedding) { case 1536: query = `INSERT INTO openai_embeddings (doc_id, embedding, metadata) VALUES ($1, $2, $3)` - case 1024: query = `INSERT INTO ollama_embeddings (doc_id, embedding, metadata) VALUES ($1, $2, $3)` default: return fmt.Errorf("unsupported embedding length: %d", len(embedding)) } - // Construct the query to insert the vector into the database - // Execute the query with the pgvector.Vector type + // Execute the query to insert the vector into the database _, err := pg.conn.Exec(ctx, query, docID, vector, metadata) if err != nil { - // Log the error for debugging purposes return fmt.Errorf("failed to insert document: %w", err) } return nil @@ -167,21 +165,9 @@ func ConvertEmbeddingToPGVector(embedding []float32) string { return fmt.Sprintf("{%s}", strings.Join(strValues, ",")) } -// CombineQueryWithContext combines the query and retrieved documents' content to provide context for generation. -func CombineQueryWithContext(query string, docs []Document) string { - var contextStr string - for _, doc := range docs { - // Cast doc.Metadata["content"] to a string - if content, ok := doc.Metadata["content"].(string); ok { - contextStr += content + "\n" - } - } - return fmt.Sprintf("Context: %s\nQuery: %s", contextStr, query) -} - -// InsertDocument insert a document into the vector store -func InsertDocument(ctx context.Context, vectorDB *PGVector, content string, embedding []float32) error { - // Generate a unique document ID (for simplicity, using a static value for testing) +// InsertDocument inserts a document into the PGVector store, implementing the VectorDatabase interface. +func (pg *PGVector) InsertDocument(ctx context.Context, content string, embedding []float32) error { + // Generate a unique document ID (for simplicity, using UUID) docID := fmt.Sprintf("doc-%s", uuid.New().String()) // Create metadata @@ -190,7 +176,7 @@ func InsertDocument(ctx context.Context, vectorDB *PGVector, content string, emb } // Save the document and its embedding into the vector store - err := vectorDB.SaveEmbedding(ctx, docID, embedding, metadata) + err := pg.SaveEmbeddings(ctx, docID, embedding, metadata) if err != nil { return fmt.Errorf("error saving embedding: %v", err) } diff --git a/pkg/db/qdrant.go b/pkg/db/qdrant.go new file mode 100644 index 0000000..cfaf9ec --- /dev/null +++ b/pkg/db/qdrant.go @@ -0,0 +1,206 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package db provides database-related functionality for the application. +// It includes implementations for vector storage and retrieval using PostgreSQL +// with the pgvector extension, enabling efficient similarity search operations +// on high-dimensional vector data. + +package db + +import ( + "context" + "fmt" + + "github.com/google/uuid" + "github.com/qdrant/go-client/qdrant" +) + +// QdrantVector represents a connection to Qdrant. +type QdrantVector struct { + client *qdrant.Client +} + +// Close closes the Qdrant client connection. +func (qv *QdrantVector) Close() { + qv.client.Close() +} + +// NewQdrantVector initializes a connection to Qdrant. +// +// Parameters: +// - address: The Qdrant server address (e.g., "localhost"). +// - port: The port Qdrant is running on (e.g., 6333). +// +// Returns: +// - A pointer to a new QdrantVector instance. +// - An error if the connection fails, nil otherwise. +func NewQdrantVector(address string, port int) (*QdrantVector, error) { + client, err := qdrant.NewClient(&qdrant.Config{ + Host: address, + Port: port, + }) + if err != nil { + return nil, fmt.Errorf("failed to connect to Qdrant: %w", err) + } + return &QdrantVector{client: client}, nil +} + +// SaveEmbedding stores an embedding and metadata in Qdrant. +// +// Parameters: +// - ctx: Context for the operation. +// - docID: A unique identifier for the document. +// - embedding: A slice of float32 values representing the document's embedding. +// - metadata: A map of additional information associated with the document. +// +// Returns: +// - An error if the saving operation fails, nil otherwise. +// +// SaveEmbeddings stores an embedding and metadata in Qdrant, implementing the VectorDatabase interface. +func (qv *QdrantVector) SaveEmbeddings(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}, collection string) error { + point := &qdrant.PointStruct{ + Id: qdrant.NewID(docID), + Vectors: qdrant.NewVectors(embedding...), + Payload: qdrant.NewValueMap(metadata), + } + + waitUpsert := true + _, err := qv.client.Upsert(ctx, &qdrant.UpsertPoints{ + CollectionName: collection, // Replace with actual collection name + Wait: &waitUpsert, + Points: []*qdrant.PointStruct{point}, + }) + if err != nil { + return fmt.Errorf("failed to insert point: %w", err) + } + return nil +} + +// QueryRelevantDocuments retrieves the most relevant documents based on a given embedding. +// +// Parameters: +// - ctx: The context for the query. +// - embedding: The query embedding. +// - limit: The number of documents to return. +// +// Returns: +// - A slice of QDrantDocument structs containing the most relevant documents. +// - An error if the query fails. +func (qv *QdrantVector) QueryRelevantDocuments(ctx context.Context, embedding []float32, limit int, colllection string) ([]Document, error) { + limitUint := uint64(limit) // Convert limit to uint64 + query := &qdrant.QueryPoints{ + CollectionName: colllection, // Replace with actual collection name + Query: qdrant.NewQuery(embedding...), + Limit: &limitUint, + WithPayload: qdrant.NewWithPayloadInclude("content"), + } + + response, err := qv.client.Query(ctx, query) + if err != nil { + return nil, fmt.Errorf("failed to search points: %w", err) + } + + var docs []Document + for _, point := range response { + var docID string + if numericID := point.Id.GetNum(); numericID != 0 { + docID = fmt.Sprintf("%d", numericID) // Numeric ID + } else { + docID = point.Id.GetUuid() // UUID + } + metadata := convertPayloadToMap(point.Payload) + doc := Document{ + ID: docID, + Metadata: metadata, + } + docs = append(docs, doc) + } + return docs, nil +} + +// convertPayloadToMap converts a Qdrant Payload (map[string]*qdrant.Value) into a map[string]interface{}. +func convertPayloadToMap(payload map[string]*qdrant.Value) map[string]interface{} { + result := make(map[string]interface{}) + for key, value := range payload { + switch v := value.Kind.(type) { + case *qdrant.Value_StringValue: + result[key] = v.StringValue + case *qdrant.Value_BoolValue: + result[key] = v.BoolValue + case *qdrant.Value_DoubleValue: + result[key] = v.DoubleValue + case *qdrant.Value_ListValue: + var list []interface{} + for _, item := range v.ListValue.Values { + switch itemVal := item.Kind.(type) { + case *qdrant.Value_StringValue: + list = append(list, itemVal.StringValue) + case *qdrant.Value_BoolValue: + list = append(list, itemVal.BoolValue) + case *qdrant.Value_DoubleValue: + list = append(list, itemVal.DoubleValue) + } + } + result[key] = list + default: + result[key] = nil + } + } + return result +} + +// InsertDocument inserts a document into the Qdrant vector store. +// +// Parameters: +// - ctx: Context for the operation. +// - vectorDB: A QdrantVector instance. +// - content: The document content to be inserted. +// - embedding: The embedding vector for the document. +// +// Returns: +// - An error if the operation fails, nil otherwise. +// +// QdrantVector should implement the InsertDocument method as defined in VectorDatabase +func (qv *QdrantVector) InsertDocument(ctx context.Context, content string, embedding []float32, collection string) error { + // Generate a valid UUID for the document ID + docID := uuid.New().String() // Properly generate a UUID + + metadata := map[string]interface{}{ + "content": content, + } + + // Call the SaveEmbeddings method to save the document and its embedding + err := qv.SaveEmbeddings(ctx, docID, embedding, metadata, collection) + if err != nil { + return fmt.Errorf("error saving embedding: %v", err) + } + return nil +} + +// CreateCollection creates a new collection in Qdrant +func (qv *QdrantVector) CreateCollection(ctx context.Context, collectionName string, vectorSize uint64, distance string) error { + // Create the collection + err := qv.client.CreateCollection(ctx, &qdrant.CreateCollection{ + CollectionName: collectionName, + VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ + Size: vectorSize, + Distance: qdrant.Distance_Cosine, // Example: Cosine distance + }), + }) + if err != nil { + return fmt.Errorf("failed to create collection: %w", err) + } + return nil +} diff --git a/pkg/db/qdrant_test.go b/pkg/db/qdrant_test.go new file mode 100644 index 0000000..4b16074 --- /dev/null +++ b/pkg/db/qdrant_test.go @@ -0,0 +1,282 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package db provides database-related functionality for the application. +// It includes implementations for vector storage and retrieval using PostgreSQL +// with the pgvector extension, enabling efficient similarity search operations +// on high-dimensional vector data. +package db + +import ( + "context" + "reflect" + "testing" + + "github.com/google/uuid" + "github.com/qdrant/go-client/qdrant" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// Create a test wrapper for QdrantVector that uses interface instead of concrete client +type testQdrantVector struct { + QdrantVector + mockClient *mockClient +} + +// mockClient implements the necessary methods we use from qdrant.Client +type mockClient struct { + mock.Mock +} + +func newTestQdrantVector() *testQdrantVector { + mc := &mockClient{} + return &testQdrantVector{ + QdrantVector: QdrantVector{}, + mockClient: mc, + } +} + +func (m *mockClient) Close() { + m.Called() +} + +func (m *mockClient) Upsert(ctx context.Context, req *qdrant.UpsertPoints) (*qdrant.PointsOperationResponse, error) { + args := m.Called(ctx, req) + return args.Get(0).(*qdrant.PointsOperationResponse), args.Error(1) +} + +func (m *mockClient) Query(ctx context.Context, req *qdrant.QueryPoints) ([]*qdrant.ScoredPoint, error) { + args := m.Called(ctx, req) + return args.Get(0).([]*qdrant.ScoredPoint), args.Error(1) +} + +func (m *mockClient) CreateCollection(ctx context.Context, req *qdrant.CreateCollection) error { + args := m.Called(ctx, req) + return args.Error(0) +} + +// Modified QdrantVector struct for testing +func (t *testQdrantVector) SaveEmbeddings(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}, collection string) error { + point := &qdrant.PointStruct{ + Id: qdrant.NewID(docID), + Vectors: qdrant.NewVectors(embedding...), + Payload: qdrant.NewValueMap(metadata), + } + + waitUpsert := true + _, err := t.mockClient.Upsert(ctx, &qdrant.UpsertPoints{ + CollectionName: collection, + Wait: &waitUpsert, + Points: []*qdrant.PointStruct{point}, + }) + return err +} + +func (t *testQdrantVector) QueryRelevantDocuments(ctx context.Context, embedding []float32, limit int, collection string) ([]Document, error) { + limitUint := uint64(limit) + query := &qdrant.QueryPoints{ + CollectionName: collection, + Query: qdrant.NewQuery(embedding...), + Limit: &limitUint, + WithPayload: qdrant.NewWithPayloadInclude("content"), + } + + response, err := t.mockClient.Query(ctx, query) + if err != nil { + return nil, err + } + + var docs []Document + for _, point := range response { + var docID string + if numericID := point.Id.GetNum(); numericID != 0 { + docID = point.Id.GetUuid() + } else { + docID = point.Id.GetUuid() + } + metadata := convertPayloadToMap(point.Payload) + doc := Document{ + ID: docID, + Metadata: metadata, + } + docs = append(docs, doc) + } + return docs, nil +} + +func (t *testQdrantVector) CreateCollection(ctx context.Context, collectionName string, vectorSize uint64, distance string) error { + return t.mockClient.CreateCollection(ctx, &qdrant.CreateCollection{ + CollectionName: collectionName, + VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ + Size: vectorSize, + Distance: qdrant.Distance_Cosine, + }), + }) +} + +func TestSaveEmbeddings(t *testing.T) { + qv := newTestQdrantVector() + + ctx := context.Background() + docID := "test-doc" + embedding := []float32{0.1, 0.2, 0.3} + metadata := map[string]interface{}{ + "content": "test content", + } + collection := "test-collection" + + // Set up expectations + qv.mockClient.On("Upsert", mock.Anything, mock.MatchedBy(func(req *qdrant.UpsertPoints) bool { + return req.CollectionName == collection && + len(req.Points) == 1 && + req.Points[0].Id.GetUuid() == docID + })).Return(&qdrant.PointsOperationResponse{}, nil) + + // Test the SaveEmbeddings function + err := qv.SaveEmbeddings(ctx, docID, embedding, metadata, collection) + assert.NoError(t, err) + + // Verify expectations + qv.mockClient.AssertExpectations(t) +} + +func TestQueryRelevantDocuments(t *testing.T) { + qv := newTestQdrantVector() + + ctx := context.Background() + embedding := []float32{0.1, 0.2, 0.3} + limit := 5 + collection := "test-collection" + + // Create mock response + testUUID := uuid.New().String() + mockResponse := []*qdrant.ScoredPoint{ + { + Id: &qdrant.PointId{ + PointIdOptions: &qdrant.PointId_Uuid{ + Uuid: testUUID, + }, + }, + Payload: map[string]*qdrant.Value{ + "content": { + Kind: &qdrant.Value_StringValue{ + StringValue: "test content", + }, + }, + }, + }, + } + + // Set up expectations + qv.mockClient.On("Query", mock.Anything, mock.MatchedBy(func(req *qdrant.QueryPoints) bool { + return req.CollectionName == collection && + *req.Limit == uint64(limit) + })).Return(mockResponse, nil) + + // Test the QueryRelevantDocuments function + docs, err := qv.QueryRelevantDocuments(ctx, embedding, limit, collection) + assert.NoError(t, err) + assert.Len(t, docs, 1) + assert.Equal(t, "test content", docs[0].Metadata["content"]) + + // Verify expectations + qv.mockClient.AssertExpectations(t) +} + +func TestCreateCollection(t *testing.T) { + qv := newTestQdrantVector() + + ctx := context.Background() + collectionName := "test-collection" + vectorSize := uint64(3) + distance := "cosine" + + // Set up expectations + qv.mockClient.On("CreateCollection", mock.Anything, mock.MatchedBy(func(req *qdrant.CreateCollection) bool { + return req.CollectionName == collectionName && + req.VectorsConfig.GetParams().Size == vectorSize + })).Return(nil) + + // Test the CreateCollection function + err := qv.CreateCollection(ctx, collectionName, vectorSize, distance) + assert.NoError(t, err) + + // Verify expectations + qv.mockClient.AssertExpectations(t) +} + +// Add InsertDocument method to testQdrantVector +func (t *testQdrantVector) InsertDocument(ctx context.Context, content string, embedding []float32, collection string) error { + // Create metadata map with content + metadata := map[string]interface{}{ + "content": content, + } + + // Generate a new UUID for the document + docID := uuid.New().String() + + // Use our mock-aware SaveEmbeddings method + return t.SaveEmbeddings(ctx, docID, embedding, metadata, collection) +} + +func TestInsertDocument(t *testing.T) { + qv := newTestQdrantVector() + + ctx := context.Background() + content := "test content" + embedding := []float32{0.1, 0.2, 0.3} + collection := "test-collection" + + // Set up expectations for the mock client + qv.mockClient.On("Upsert", mock.Anything, mock.MatchedBy(func(req *qdrant.UpsertPoints) bool { + if len(req.Points) != 1 { + return false + } + point := req.Points[0] + + // Check collection name + if req.CollectionName != collection { + return false + } + + // Check payload contains correct content + payload := point.Payload + contentValue, exists := payload["content"] + if !exists { + return false + } + stringValue, ok := contentValue.Kind.(*qdrant.Value_StringValue) + if !ok { + return false + } + if stringValue.StringValue != content { + return false + } + + // Check vectors + if !reflect.DeepEqual(point.Vectors.GetVector().Data, embedding) { + return false + } + + return true + })).Return(&qdrant.PointsOperationResponse{}, nil) + + // Test the InsertDocument function + err := qv.InsertDocument(ctx, content, embedding, collection) + assert.NoError(t, err) + + // Verify expectations + qv.mockClient.AssertExpectations(t) +} diff --git a/pkg/db/vectordb.go b/pkg/db/vectordb.go new file mode 100644 index 0000000..1efeb5a --- /dev/null +++ b/pkg/db/vectordb.go @@ -0,0 +1,50 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package db provides database-related functionality for the application. +// It includes implementations for vector storage and retrieval using PostgreSQL +// with the pgvector extension, enabling efficient similarity search operations +// on high-dimensional vector data. + +package db + +import ( + "context" + "fmt" +) + +// Document represents a single document in the vector database. +// It contains a unique identifier and associated metadata. +type Document struct { + ID string + Metadata map[string]interface{} +} + +// VectorDatabase is the interface that both QdrantVector and PGVector implement +type VectorDatabase interface { + InsertDocument(ctx context.Context, content string, embedding []float32) error + QueryRelevantDocuments(ctx context.Context, embedding []float32, backend string) ([]Document, error) + SaveEmbeddings(ctx context.Context, docID string, embedding []float32, metadata map[string]interface{}) error +} + +// CombineQueryWithContext combines the user's query with the relevant retrieved documents' content +func CombineQueryWithContext(query string, retrievedDocs []Document) string { + var context string + for _, doc := range retrievedDocs { + // Include the content of each retrieved document in the context + context += fmt.Sprintf("%s\n", doc.Metadata["content"]) + } + // Construct the augmented query with the retrieved context and the user's query + return fmt.Sprintf("Context: %s\n\nQuery: %s", context, query) +}