Skip to content

Commit a844cdf

Browse files
committed
added atomic lock to sqlite & mongo
1 parent 312d98d commit a844cdf

File tree

2 files changed

+57
-52
lines changed

2 files changed

+57
-52
lines changed

database/mongodb/mongodb.go

+52-46
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"go.mongodb.org/mongo-driver/mongo"
1111
"go.mongodb.org/mongo-driver/mongo/options"
1212
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
13+
"go.uber.org/atomic"
1314
"io"
1415
"io/ioutil"
1516
"net/url"
@@ -40,9 +41,10 @@ var (
4041
)
4142

4243
type Mongo struct {
43-
client *mongo.Client
44-
db *mongo.Database
45-
config *Config
44+
client *mongo.Client
45+
db *mongo.Database
46+
config *Config
47+
isLocked atomic.Bool
4648
}
4749

4850
type Locking struct {
@@ -327,55 +329,59 @@ func (m *Mongo) ensureVersionTable() (err error) {
327329
// Utilizes advisory locking on the config.LockingCollection collection
328330
// This uses a unique index on the `locking_key` field.
329331
func (m *Mongo) Lock() error {
330-
if !m.config.Locking.Enabled {
331-
return nil
332-
}
333-
pid := os.Getpid()
334-
hostname, err := os.Hostname()
335-
if err != nil {
336-
hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error())
337-
}
338-
339-
newLockObj := lockObj{
340-
Key: lockKeyUniqueValue,
341-
Pid: pid,
342-
Hostname: hostname,
343-
CreatedAt: time.Now(),
344-
}
345-
operation := func() error {
346-
timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout)
347-
_, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj)
348-
defer cancelFunc()
349-
return err
350-
}
351-
exponentialBackOff := backoff.NewExponentialBackOff()
352-
duration := time.Duration(m.config.Locking.Timeout) * time.Second
353-
exponentialBackOff.MaxElapsedTime = duration
354-
exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second
332+
return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
333+
if !m.config.Locking.Enabled {
334+
return nil
335+
}
336+
pid := os.Getpid()
337+
hostname, err := os.Hostname()
338+
if err != nil {
339+
hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error())
340+
}
355341

356-
err = backoff.Retry(operation, exponentialBackOff)
357-
if err != nil {
358-
return database.ErrLocked
359-
}
342+
newLockObj := lockObj{
343+
Key: lockKeyUniqueValue,
344+
Pid: pid,
345+
Hostname: hostname,
346+
CreatedAt: time.Now(),
347+
}
348+
operation := func() error {
349+
timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout)
350+
_, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj)
351+
defer cancelFunc()
352+
return err
353+
}
354+
exponentialBackOff := backoff.NewExponentialBackOff()
355+
duration := time.Duration(m.config.Locking.Timeout) * time.Second
356+
exponentialBackOff.MaxElapsedTime = duration
357+
exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second
360358

361-
return nil
359+
err = backoff.Retry(operation, exponentialBackOff)
360+
if err != nil {
361+
return database.ErrLocked
362+
}
362363

364+
return nil
365+
})
363366
}
367+
364368
func (m *Mongo) Unlock() error {
365-
if !m.config.Locking.Enabled {
366-
return nil
367-
}
369+
return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
370+
if !m.config.Locking.Enabled {
371+
return nil
372+
}
368373

369-
filter := findFilter{
370-
Key: lockKeyUniqueValue,
371-
}
374+
filter := findFilter{
375+
Key: lockKeyUniqueValue,
376+
}
372377

373-
ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout)
374-
_, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter)
375-
defer cancel()
378+
ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout)
379+
_, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter)
380+
defer cancel()
376381

377-
if err != nil {
378-
return err
379-
}
380-
return nil
382+
if err != nil {
383+
return err
384+
}
385+
return nil
386+
})
381387
}

database/sqlite/sqlite.go

+5-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package sqlite
33
import (
44
"database/sql"
55
"fmt"
6+
"go.uber.org/atomic"
67
"io"
78
"io/ioutil"
89
nurl "net/url"
@@ -34,7 +35,7 @@ type Config struct {
3435

3536
type Sqlite struct {
3637
db *sql.DB
37-
isLocked bool
38+
isLocked atomic.Bool
3839

3940
config *Config
4041
}
@@ -177,18 +178,16 @@ func (m *Sqlite) Drop() (err error) {
177178
}
178179

179180
func (m *Sqlite) Lock() error {
180-
if m.isLocked {
181+
if !m.isLocked.CAS(false, true) {
181182
return database.ErrLocked
182183
}
183-
m.isLocked = true
184184
return nil
185185
}
186186

187187
func (m *Sqlite) Unlock() error {
188-
if !m.isLocked {
189-
return nil
188+
if !m.isLocked.CAS(true, false) {
189+
return database.ErrNotLocked
190190
}
191-
m.isLocked = false
192191
return nil
193192
}
194193

0 commit comments

Comments
 (0)