Skip to content

Commit

Permalink
Merge pull request #30 from philippgille/add-validations
Browse files Browse the repository at this point in the history
Add more validations
  • Loading branch information
philippgille authored Mar 2, 2024
2 parents ee92d94 + 20c2f9a commit 6ceec95
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 17 deletions.
11 changes: 7 additions & 4 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,19 @@ func (c *Collection) Count() int {
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
if queryText == "" {
return nil, errors.New("queryText is empty")
}
if nResults <= 0 {
return nil, errors.New("nResults must be > 0")
}

c.documentsLock.RLock()
defer c.documentsLock.RUnlock()
if len(c.documents) == 0 {
return nil, nil
}

if nResults <= 0 {
return nil, errors.New("nResults must be > 0")
}

// Validate whereDocument operators
for k := range whereDocument {
if !slices.Contains(supportedFilters, k) {
Expand Down
3 changes: 3 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ func NewPersistentDB(path string) (*DB, error) {
// - embeddingFunc: Optional function to use to embed documents.
// Uses the default embedding function if not provided.
func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
if name == "" {
return nil, errors.New("collection name is empty")
}
if embeddingFunc == nil {
embeddingFunc = NewEmbeddingFuncDefault()
}
Expand Down
35 changes: 22 additions & 13 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,30 @@ func TestDB_CreateCollection(t *testing.T) {
return []float32{-0.1, 0.1, 0.2}, nil
}

// Create collection
db := chromem.NewDB()
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Error("expected no error, got", err)
}
if c == nil {
t.Error("expected collection, got nil")
}

// Check expectations
if c.Name != name {
t.Error("expected name", name, "got", c.Name)
}
// TODO: Check metadata etc when they become accessible
t.Run("OK", func(t *testing.T) {
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Error("expected no error, got", err)
}
if c == nil {
t.Error("expected collection, got nil")
}

// Check expectations
if c.Name != name {
t.Error("expected name", name, "got", c.Name)
}
// TODO: Check metadata etc when they become accessible
})

t.Run("NOK - Empty name", func(t *testing.T) {
_, err := db.CreateCollection("", metadata, embeddingFunc)
if err == nil {
t.Error("expected error, got nil")
}
})
}

func TestDB_ListCollections(t *testing.T) {
Expand Down

0 comments on commit 6ceec95

Please sign in to comment.