Skip to content

Commit

Permalink
feat: add IDs to Insert result
Browse files Browse the repository at this point in the history
  • Loading branch information
maxbolgarin committed Dec 16, 2024
1 parent 1f593e0 commit e3d9298
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 38 deletions.
3 changes: 2 additions & 1 deletion async.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ type AsyncCollection struct {
// Tasks in different queues will be executed in parallel.
func (ac *AsyncCollection) Insert(queueKey, taskName string, records ...any) {
ac.push(queueKey, taskName, "insert", func(ctx context.Context) error {
return ac.coll.Insert(ctx, records...)
_, err := ac.coll.Insert(ctx, records...)
return err
})
}

Expand Down
26 changes: 20 additions & 6 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,31 @@ func (m *Collection) Distinct(ctx context.Context, dest any, field string, filte
}

// Insert inserts a document or many documents into the collection.
// It returns IDs of the inserted documents.
// Internally InsertMany uses bulk write.
func (m *Collection) Insert(ctx context.Context, records ...any) (err error) {
func (m *Collection) Insert(ctx context.Context, records ...any) (ids []bson.ObjectID, err error) {
if len(records) == 0 {
return nil
return nil, nil
}

ids = make([]bson.ObjectID, len(records))
if len(records) == 1 {
_, err = m.coll.InsertOne(ctx, records[0])
res, err := m.coll.InsertOne(ctx, records[0])
if err != nil {
return nil, HandleMongoError(err)
}
ids[0], _ = res.InsertedID.(bson.ObjectID)

} else {
_, err = m.coll.InsertMany(ctx, records)
}
return HandleMongoError(err)
res, err := m.coll.InsertMany(ctx, records)
if err != nil {
return nil, HandleMongoError(err)
}
for i, id := range res.InsertedIDs {
ids[i], _ = id.(bson.ObjectID)
}
}
return ids, nil
}

// Upsert replaces a document in the collection or inserts it if it doesn't exist.
Expand Down
4 changes: 3 additions & 1 deletion generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mongox
import (
"context"

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
)

Expand Down Expand Up @@ -67,8 +68,9 @@ func Distinct[T any](ctx context.Context, coll *Collection, field string, filter
}

// Insert inserts a document(s) into the collection.
// It returns IDs of the inserted documents.
// Internally InsertMany uses bulk write.
func Insert(ctx context.Context, coll *Collection, record ...any) error {
func Insert(ctx context.Context, coll *Collection, record ...any) ([]bson.ObjectID, error) {
return coll.Insert(ctx, record...)
}

Expand Down
58 changes: 28 additions & 30 deletions mongox_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import (
"go.mongodb.org/mongo-driver/v2/mongo"
)

// TODO: async

var client *mongox.Client

const (
Expand Down Expand Up @@ -61,11 +59,11 @@ func TestIndexAndText(t *testing.T) {
t.Error(err)
}
entity1 := newTestEntity("1")
err = db.Collection(indexSingleCollection).Insert(ctx, entity1)
_, err = db.Collection(indexSingleCollection).Insert(ctx, entity1)
if err != nil {
t.Error(err)
}
err = db.Collection(indexSingleCollection).Insert(ctx, entity1)
_, err = db.Collection(indexSingleCollection).Insert(ctx, entity1)
if err == nil {
t.Error("expected error, got nil")
}
Expand All @@ -82,19 +80,19 @@ func TestIndexAndText(t *testing.T) {
t.Error(err)
}
entity1 := newTestEntity("1")
err = db.Collection(indexManyCollection).Insert(ctx, entity1)
_, err = db.Collection(indexManyCollection).Insert(ctx, entity1)
if err != nil {
t.Error(err)
}
err = db.Collection(indexManyCollection).Insert(ctx, entity1)
_, err = db.Collection(indexManyCollection).Insert(ctx, entity1)
if err == nil {
t.Error("expected error, got nil")
}
err = db.Collection(indexManyCollection).Insert(ctx, newTestEntity("1"))
_, err = db.Collection(indexManyCollection).Insert(ctx, newTestEntity("1"))
if err != nil {
t.Error(err)
}
err = db.Collection(indexManyCollection).Insert(ctx, newTestEntity("1"), newTestEntity("1"), newTestEntity("1"))
_, err = db.Collection(indexManyCollection).Insert(ctx, newTestEntity("1"), newTestEntity("1"), newTestEntity("1"))
if err != nil {
t.Error(err)
}
Expand All @@ -103,14 +101,14 @@ func TestIndexAndText(t *testing.T) {
t.Run("Text", func(t *testing.T) {
entity1 := newTestEntity("1")
entity1.Name = "Running tool: /usr/local/go/bin/go test -timeout 45s -run ^TestFind$ github.com/maxbolgarin/mongox"
err := db.Collection(textCollection).Insert(ctx, entity1)
_, err := db.Collection(textCollection).Insert(ctx, entity1)
if err != nil {
t.Error(err)
}

entity2 := newTestEntity("2")
entity2.Name = "Pairs in tool must be in the form NewF(key1, value1, key2, value2, ...)"
err = db.Collection(textCollection).Insert(ctx, entity2)
_, err = db.Collection(textCollection).Insert(ctx, entity2)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -170,7 +168,7 @@ func TestInsertFindDelete(t *testing.T) {
t.Run("FindOne_Replace_Upsert_DeleteOne", func(t *testing.T) {
entity1 := newTestEntity("1")
_, err := db.WithTransaction(ctx, func(ctx context.Context) (any, error) {
err := db.Collection(findOneCollection).Insert(ctx, entity1)
_, err := db.Collection(findOneCollection).Insert(ctx, entity1)
if err != nil {
return nil, err
}
Expand All @@ -181,16 +179,16 @@ func TestInsertFindDelete(t *testing.T) {
t.Errorf("expected error %v, got %v", mongox.ErrIllegalOperation, err)
}

err = db.Collection(findOneCollection).Insert(ctx, entity1)
_, err = db.Collection(findOneCollection).Insert(ctx, entity1)
if err != nil {
t.Error(err)
}
entity2 := newTestEntity("2")
err = mongox.Insert(ctx, db.Collection(findOneCollection), entity2)
_, err = mongox.Insert(ctx, db.Collection(findOneCollection), entity2)
if err != nil {
t.Error(err)
}
err = db.Collection(findOneCollection).Insert(ctx, newTestEntity("3"), newTestEntity("4"))
_, err = db.Collection(findOneCollection).Insert(ctx, newTestEntity("3"), newTestEntity("4"))
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -254,7 +252,7 @@ func TestInsertFindDelete(t *testing.T) {

t.Run("Find_DeleteMany", func(t *testing.T) {
entity2, entity3, entity4 := newTestEntity("2"), newTestEntity("3"), newTestEntity("4")
err := db.Collection(findCollection).Insert(ctx, entity2, entity3, entity4)
_, err := db.Collection(findCollection).Insert(ctx, entity2, entity3, entity4)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -295,7 +293,7 @@ func TestInsertFindDelete(t *testing.T) {
}

for i := 0; i < 10; i++ {
err = db.Collection(findAllCollection).Insert(ctx, newTestEntity(strconv.Itoa(i)), newTestEntity(strconv.Itoa(i)),
_, err = db.Collection(findAllCollection).Insert(ctx, newTestEntity(strconv.Itoa(i)), newTestEntity(strconv.Itoa(i)),
newTestEntity(strconv.Itoa(i)), newTestEntity(strconv.Itoa(i)), newTestEntity(strconv.Itoa(i)), newTestEntity(strconv.Itoa(i)),
newTestEntity(strconv.Itoa(i)), newTestEntity(strconv.Itoa(i)), newTestEntity(strconv.Itoa(i)), newTestEntity(strconv.Itoa(i)))
if err != nil {
Expand Down Expand Up @@ -400,7 +398,7 @@ func TestUpdate(t *testing.T) {
f = mongox.M{"id": "1"}
)

err := coll.Insert(ctx, entity)
_, err := coll.Insert(ctx, entity)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -439,7 +437,7 @@ func TestUpdate(t *testing.T) {
entity.Name = ""
testUpdate(t, ctx, db, entity, mongox.M{"name": nil})

err = coll.Insert(ctx, entity, entity, entity, entity, entity, entity, entity, entity, entity)
_, err = coll.Insert(ctx, entity, entity, entity, entity, entity, entity, entity, entity, entity)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -477,7 +475,7 @@ func TestUpdate(t *testing.T) {
}

newEntity := newTestEntity("1")
err = coll.Insert(ctx, newEntity)
_, err = coll.Insert(ctx, newEntity)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -630,7 +628,7 @@ func TestError(t *testing.T) {
t.Run("Error_NilArguments", func(t *testing.T) {
coll := db.Collection(errorNilArgCollection)

err := coll.Insert(ctx, newTestEntity("1"))
_, err := coll.Insert(ctx, newTestEntity("1"))
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -677,12 +675,12 @@ func TestError(t *testing.T) {
t.Error(err)
}

err = coll.Insert(ctx, nil)
_, err = coll.Insert(ctx, nil)
if !errors.Is(err, mongox.ErrInvalidArgument) {
t.Errorf("expected error %v, got %v", mongox.ErrInvalidArgument, err)
}

err = coll.Insert(ctx, []any{newTestEntity("2"), nil})
_, err = coll.Insert(ctx, []any{newTestEntity("2"), nil})
if !errors.Is(err, mongox.ErrInvalidArgument) {
t.Errorf("expected error %v, got %v", mongox.ErrInvalidArgument, err)
}
Expand Down Expand Up @@ -733,7 +731,7 @@ func TestError(t *testing.T) {
t.Errorf("expected error %v, got %v", mongox.ErrNotFound, err)
}

err = coll.Insert(ctx, newTestEntity("1"))
_, err = coll.Insert(ctx, newTestEntity("1"))
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -763,7 +761,7 @@ func TestError(t *testing.T) {
t.Run("Error_InvalidArguments", func(t *testing.T) {
coll := db.Collection(errorInvalidArgCollection)

err := coll.Insert(ctx, newTestEntity("1"))
_, err := coll.Insert(ctx, newTestEntity("1"))
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -811,12 +809,12 @@ func TestError(t *testing.T) {
t.Error(err)
}

err = coll.Insert(ctx, 1)
_, err = coll.Insert(ctx, 1)
if !errors.Is(err, mongox.ErrInvalidArgument) {
t.Errorf("expected error %v, got %v", mongox.ErrInvalidArgument, err)
}

err = coll.Insert(ctx, []any{newTestEntity("2"), 1})
_, err = coll.Insert(ctx, []any{newTestEntity("2"), 1})
if !errors.Is(err, mongox.ErrInvalidArgument) {
t.Errorf("expected error %v, got %v", mongox.ErrInvalidArgument, err)
}
Expand Down Expand Up @@ -904,7 +902,7 @@ func TestError(t *testing.T) {
t.Run("Error_InvalidFilter", func(t *testing.T) {
coll := db.Collection(errorInvalidFilterCollection)

err := coll.Insert(ctx, newTestEntity("1"))
_, err := coll.Insert(ctx, newTestEntity("1"))
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -959,7 +957,7 @@ func TestError(t *testing.T) {
t.Run("Error_InvalidUpdate", func(t *testing.T) {
coll := db.Collection(errorInvalidUpdCollection)

err := coll.Insert(ctx, newTestEntity("1"))
_, err := coll.Insert(ctx, newTestEntity("1"))
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -1237,11 +1235,11 @@ func TestError(t *testing.T) {
if err != nil {
t.Error(err)
}
err = coll.Insert(ctx, newTestEntity("1"))
_, err = coll.Insert(ctx, newTestEntity("1"))
if err != nil {
t.Error(err)
}
err = coll.Insert(ctx, newTestEntity("1"))
_, err = coll.Insert(ctx, newTestEntity("1"))
if !errors.Is(err, mongox.ErrDuplicate) {
t.Errorf("expected error %v, got %v", mongox.ErrDuplicateKey, err)
}
Expand Down

0 comments on commit e3d9298

Please sign in to comment.