@@ -10,6 +10,7 @@ import (
10
10
"go.mongodb.org/mongo-driver/mongo"
11
11
"go.mongodb.org/mongo-driver/mongo/options"
12
12
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
13
+ "go.uber.org/atomic"
13
14
"io"
14
15
"io/ioutil"
15
16
"net/url"
40
41
)
41
42
42
43
type Mongo struct {
43
- client * mongo.Client
44
- db * mongo.Database
45
- config * Config
44
+ client * mongo.Client
45
+ db * mongo.Database
46
+ config * Config
47
+ isLocked atomic.Bool
46
48
}
47
49
48
50
type Locking struct {
@@ -327,55 +329,59 @@ func (m *Mongo) ensureVersionTable() (err error) {
327
329
// Utilizes advisory locking on the config.LockingCollection collection
328
330
// This uses a unique index on the `locking_key` field.
329
331
func (m * Mongo ) Lock () error {
330
- if ! m .config .Locking .Enabled {
331
- return nil
332
- }
333
- pid := os .Getpid ()
334
- hostname , err := os .Hostname ()
335
- if err != nil {
336
- hostname = fmt .Sprintf ("Could not determine hostname. Error: %s" , err .Error ())
337
- }
338
-
339
- newLockObj := lockObj {
340
- Key : lockKeyUniqueValue ,
341
- Pid : pid ,
342
- Hostname : hostname ,
343
- CreatedAt : time .Now (),
344
- }
345
- operation := func () error {
346
- timeout , cancelFunc := context .WithTimeout (context .Background (), contextWaitTimeout )
347
- _ , err := m .db .Collection (m .config .Locking .CollectionName ).InsertOne (timeout , newLockObj )
348
- defer cancelFunc ()
349
- return err
350
- }
351
- exponentialBackOff := backoff .NewExponentialBackOff ()
352
- duration := time .Duration (m .config .Locking .Timeout ) * time .Second
353
- exponentialBackOff .MaxElapsedTime = duration
354
- exponentialBackOff .MaxInterval = time .Duration (m .config .Locking .Interval ) * time .Second
332
+ return database .CasRestoreOnErr (& m .isLocked , false , true , database .ErrLocked , func () error {
333
+ if ! m .config .Locking .Enabled {
334
+ return nil
335
+ }
336
+ pid := os .Getpid ()
337
+ hostname , err := os .Hostname ()
338
+ if err != nil {
339
+ hostname = fmt .Sprintf ("Could not determine hostname. Error: %s" , err .Error ())
340
+ }
355
341
356
- err = backoff .Retry (operation , exponentialBackOff )
357
- if err != nil {
358
- return database .ErrLocked
359
- }
342
+ newLockObj := lockObj {
343
+ Key : lockKeyUniqueValue ,
344
+ Pid : pid ,
345
+ Hostname : hostname ,
346
+ CreatedAt : time .Now (),
347
+ }
348
+ operation := func () error {
349
+ timeout , cancelFunc := context .WithTimeout (context .Background (), contextWaitTimeout )
350
+ _ , err := m .db .Collection (m .config .Locking .CollectionName ).InsertOne (timeout , newLockObj )
351
+ defer cancelFunc ()
352
+ return err
353
+ }
354
+ exponentialBackOff := backoff .NewExponentialBackOff ()
355
+ duration := time .Duration (m .config .Locking .Timeout ) * time .Second
356
+ exponentialBackOff .MaxElapsedTime = duration
357
+ exponentialBackOff .MaxInterval = time .Duration (m .config .Locking .Interval ) * time .Second
360
358
361
- return nil
359
+ err = backoff .Retry (operation , exponentialBackOff )
360
+ if err != nil {
361
+ return database .ErrLocked
362
+ }
362
363
364
+ return nil
365
+ })
363
366
}
367
+
364
368
func (m * Mongo ) Unlock () error {
365
- if ! m .config .Locking .Enabled {
366
- return nil
367
- }
369
+ return database .CasRestoreOnErr (& m .isLocked , true , false , database .ErrNotLocked , func () error {
370
+ if ! m .config .Locking .Enabled {
371
+ return nil
372
+ }
368
373
369
- filter := findFilter {
370
- Key : lockKeyUniqueValue ,
371
- }
374
+ filter := findFilter {
375
+ Key : lockKeyUniqueValue ,
376
+ }
372
377
373
- ctx , cancel := context .WithTimeout (context .Background (), contextWaitTimeout )
374
- _ , err := m .db .Collection (m .config .Locking .CollectionName ).DeleteMany (ctx , filter )
375
- defer cancel ()
378
+ ctx , cancel := context .WithTimeout (context .Background (), contextWaitTimeout )
379
+ _ , err := m .db .Collection (m .config .Locking .CollectionName ).DeleteMany (ctx , filter )
380
+ defer cancel ()
376
381
377
- if err != nil {
378
- return err
379
- }
380
- return nil
382
+ if err != nil {
383
+ return err
384
+ }
385
+ return nil
386
+ })
381
387
}
0 commit comments