Skip to content

Commit c4fd509

Browse files
committed
added atomic lock to sqlite & mongo
1 parent 312d98d commit c4fd509

File tree

4 files changed

+86
-82
lines changed

4 files changed

+86
-82
lines changed

database/mongodb/mongodb.go

+52-46
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"go.mongodb.org/mongo-driver/mongo"
1111
"go.mongodb.org/mongo-driver/mongo/options"
1212
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
13+
"go.uber.org/atomic"
1314
"io"
1415
"io/ioutil"
1516
"net/url"
@@ -40,9 +41,10 @@ var (
4041
)
4142

4243
type Mongo struct {
43-
client *mongo.Client
44-
db *mongo.Database
45-
config *Config
44+
client *mongo.Client
45+
db *mongo.Database
46+
config *Config
47+
isLocked atomic.Bool
4648
}
4749

4850
type Locking struct {
@@ -327,55 +329,59 @@ func (m *Mongo) ensureVersionTable() (err error) {
327329
// Utilizes advisory locking on the config.LockingCollection collection
328330
// This uses a unique index on the `locking_key` field.
329331
func (m *Mongo) Lock() error {
330-
if !m.config.Locking.Enabled {
331-
return nil
332-
}
333-
pid := os.Getpid()
334-
hostname, err := os.Hostname()
335-
if err != nil {
336-
hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error())
337-
}
338-
339-
newLockObj := lockObj{
340-
Key: lockKeyUniqueValue,
341-
Pid: pid,
342-
Hostname: hostname,
343-
CreatedAt: time.Now(),
344-
}
345-
operation := func() error {
346-
timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout)
347-
_, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj)
348-
defer cancelFunc()
349-
return err
350-
}
351-
exponentialBackOff := backoff.NewExponentialBackOff()
352-
duration := time.Duration(m.config.Locking.Timeout) * time.Second
353-
exponentialBackOff.MaxElapsedTime = duration
354-
exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second
332+
return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
333+
if !m.config.Locking.Enabled {
334+
return nil
335+
}
336+
pid := os.Getpid()
337+
hostname, err := os.Hostname()
338+
if err != nil {
339+
hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error())
340+
}
355341

356-
err = backoff.Retry(operation, exponentialBackOff)
357-
if err != nil {
358-
return database.ErrLocked
359-
}
342+
newLockObj := lockObj{
343+
Key: lockKeyUniqueValue,
344+
Pid: pid,
345+
Hostname: hostname,
346+
CreatedAt: time.Now(),
347+
}
348+
operation := func() error {
349+
timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout)
350+
_, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj)
351+
defer cancelFunc()
352+
return err
353+
}
354+
exponentialBackOff := backoff.NewExponentialBackOff()
355+
duration := time.Duration(m.config.Locking.Timeout) * time.Second
356+
exponentialBackOff.MaxElapsedTime = duration
357+
exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second
360358

361-
return nil
359+
err = backoff.Retry(operation, exponentialBackOff)
360+
if err != nil {
361+
return database.ErrLocked
362+
}
362363

364+
return nil
365+
})
363366
}
367+
364368
func (m *Mongo) Unlock() error {
365-
if !m.config.Locking.Enabled {
366-
return nil
367-
}
369+
return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
370+
if !m.config.Locking.Enabled {
371+
return nil
372+
}
368373

369-
filter := findFilter{
370-
Key: lockKeyUniqueValue,
371-
}
374+
filter := findFilter{
375+
Key: lockKeyUniqueValue,
376+
}
372377

373-
ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout)
374-
_, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter)
375-
defer cancel()
378+
ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout)
379+
_, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter)
380+
defer cancel()
376381

377-
if err != nil {
378-
return err
379-
}
380-
return nil
382+
if err != nil {
383+
return err
384+
}
385+
return nil
386+
})
381387
}

database/sqlite/sqlite.go

+5-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package sqlite
33
import (
44
"database/sql"
55
"fmt"
6+
"go.uber.org/atomic"
67
"io"
78
"io/ioutil"
89
nurl "net/url"
@@ -34,7 +35,7 @@ type Config struct {
3435

3536
type Sqlite struct {
3637
db *sql.DB
37-
isLocked bool
38+
isLocked atomic.Bool
3839

3940
config *Config
4041
}
@@ -177,18 +178,16 @@ func (m *Sqlite) Drop() (err error) {
177178
}
178179

179180
func (m *Sqlite) Lock() error {
180-
if m.isLocked {
181+
if !m.isLocked.CAS(false, true) {
181182
return database.ErrLocked
182183
}
183-
m.isLocked = true
184184
return nil
185185
}
186186

187187
func (m *Sqlite) Unlock() error {
188-
if !m.isLocked {
189-
return nil
188+
if !m.isLocked.CAS(true, false) {
189+
return database.ErrNotLocked
190190
}
191-
m.isLocked = false
192191
return nil
193192
}
194193

database/util.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ func CasRestoreOnErr(lock *atomic.Bool, o, n bool, casErr error, f func() error)
3030
return err
3131
}
3232
return nil
33-
}
33+
}

database/util_test.go

+28-29
Original file line numberDiff line numberDiff line change
@@ -49,44 +49,43 @@ func TestGenerateAdvisoryLockId(t *testing.T) {
4949

5050
func TestCasRestoreOnErr(t *testing.T) {
5151
testcases := []struct {
52-
name string
53-
lock *atomic.Bool
54-
from bool
55-
to bool
56-
casErr error
57-
fErr error
58-
expectLock bool
52+
name string
53+
lock *atomic.Bool
54+
from bool
55+
to bool
56+
casErr error
57+
fErr error
58+
expectLock bool
5959
expectError error
6060
}{
6161
{
62-
name: "Test positive CAS lock",
63-
lock: atomic.NewBool(false),
64-
from: false,
65-
to: true,
66-
casErr: ErrLocked,
67-
fErr: nil,
62+
name: "Test positive CAS lock",
63+
lock: atomic.NewBool(false),
64+
from: false,
65+
to: true,
66+
casErr: ErrLocked,
67+
fErr: nil,
6868
expectError: nil,
69-
expectLock: true,
70-
69+
expectLock: true,
7170
},
7271
{
73-
name: "Test negative CAS lock",
74-
lock: atomic.NewBool(true),
75-
from: false,
76-
to: true,
77-
casErr: ErrLocked,
78-
fErr: nil,
79-
expectLock: true,
72+
name: "Test negative CAS lock",
73+
lock: atomic.NewBool(true),
74+
from: false,
75+
to: true,
76+
casErr: ErrLocked,
77+
fErr: nil,
78+
expectLock: true,
8079
expectError: ErrLocked,
8180
},
8281
{
83-
name: "Test negative with callback lock",
84-
lock: atomic.NewBool(false),
85-
from: false,
86-
to: true,
87-
casErr: ErrLocked,
88-
fErr: ErrNotLocked,
89-
expectLock: false,
82+
name: "Test negative with callback lock",
83+
lock: atomic.NewBool(false),
84+
from: false,
85+
to: true,
86+
casErr: ErrLocked,
87+
fErr: ErrNotLocked,
88+
expectLock: false,
9089
expectError: ErrNotLocked,
9190
},
9291
}

0 commit comments

Comments
 (0)