Skip to content

Commit

Permalink
Replace mutex with atomic value
Browse files Browse the repository at this point in the history
  • Loading branch information
asamuj committed Aug 6, 2024
1 parent b3cad7b commit bf4562d
Showing 1 changed file with 14 additions and 21 deletions.
35 changes: 14 additions & 21 deletions pkg/wallet/dsbackend.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"reflect"
"strings"
"sync"
"sync/atomic"

"github.com/awnumar/memguard"
"github.com/filecoin-project/go-address"
Expand Down Expand Up @@ -50,7 +51,7 @@ type DSBackend struct {
password *memguard.Enclave
unLocked map[address.Address]*key.KeyInfo

state int
state atomic.Int64
}

var _ Backend = (*DSBackend)(nil)
Expand Down Expand Up @@ -198,11 +199,8 @@ func (backend *DSBackend) putKeyInfo(ctx context.Context, ki *key.KeyInfo) error
if err := backend.ds.Put(ctx, ds.NewKey(key.Address.String()), keyJSON); err != nil {
return errors.Wrapf(err, "failed to store new address: %s", key.Address.String())
}

backend.lk.Lock()
backend.cache[addr] = struct{}{}
backend.unLocked[addr] = ki
backend.lk.Unlock()
return nil
}

Expand Down Expand Up @@ -285,9 +283,7 @@ func (backend *DSBackend) getKey(ctx context.Context, addr address.Address, pass
}

func (backend *DSBackend) LockWallet(ctx context.Context) error {
backend.lk.Lock()
defer backend.lk.Unlock()
if backend.state == Lock {
if backend.state.Load() == Lock {
return fmt.Errorf("already locked")
}

Expand All @@ -296,24 +292,24 @@ func (backend *DSBackend) LockWallet(ctx context.Context) error {
}

for _, addr := range backend.Addresses(ctx) {
backend.lk.Lock()
delete(backend.unLocked, addr)
backend.lk.Unlock()
}
backend.cleanPassword()
backend.state = Lock
backend.state.Store(Lock)

return nil
}

// UnLockWallet unlock wallet with password, decrypt local key in db and save to protected memory
func (backend *DSBackend) UnLockWallet(ctx context.Context, password []byte) error {
backend.lk.Lock()
defer func() {
backend.lk.Unlock()
for i := range password {
password[i] = 0
}
}()
if backend.state == Unlock {
if backend.state.Load() == Unlock {
return fmt.Errorf("already unlocked")
}

Expand All @@ -327,17 +323,17 @@ func (backend *DSBackend) UnLockWallet(ctx context.Context, password []byte) err
return err
}

backend.lk.Lock()
backend.unLocked[addr] = ki
backend.lk.Unlock()
}
backend.state = Unlock
backend.state.Store(Unlock)

return nil
}

// SetPassword set password for wallet , and wallet used this password to encrypt private key
func (backend *DSBackend) SetPassword(ctx context.Context, password []byte) error {
backend.lk.Lock()
defer backend.lk.Unlock()
if backend.password != nil {
return ErrRepeatPassword
}
Expand All @@ -347,13 +343,12 @@ func (backend *DSBackend) SetPassword(ctx context.Context, password []byte) erro
if err != nil {
return err
}

backend.lk.Lock()
backend.unLocked[addr] = ki
}
if backend.state == undetermined {
backend.state = Unlock
backend.lk.Unlock()
}

backend.state.CompareAndSwap(undetermined, Unlock)
backend.setPassword(password)

return nil
Expand All @@ -366,9 +361,7 @@ func (backend *DSBackend) HasPassword() bool {

// WalletState return wallet state(lock/unlock)
func (backend *DSBackend) WalletState(ctx context.Context) int {
backend.lk.Lock()
defer backend.lk.Unlock()
return backend.state
return int(backend.state.Load())
}

func (backend *DSBackend) setPassword(password []byte) {
Expand Down

0 comments on commit bf4562d

Please sign in to comment.