Skip to content

Commit

Permalink
[R4R] - {develop}: fix consensys audit issue: cs-6.14 (#1204)
Browse files Browse the repository at this point in the history
* add partyIDtoP2PIDLock and optimize all mutex usage

* add SetPartyIDtoP2PID

* unify concurrent map

---------

Co-authored-by: Raymond <[email protected]>
  • Loading branch information
HaoyangLiu and wukongcheng authored Jul 2, 2023
1 parent cff85de commit 548f9b2
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 90 deletions.
148 changes: 81 additions & 67 deletions tss/node/tsslib/common/tss.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,63 +11,69 @@ import (

"github.com/binance-chain/tss-lib/tss"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"

"github.com/mantlenetworkio/mantle/l2geth/crypto"
abnormal2 "github.com/mantlenetworkio/mantle/tss/node/tsslib/abnormal"
"github.com/mantlenetworkio/mantle/tss/node/tsslib/conversion"
"github.com/mantlenetworkio/mantle/tss/node/tsslib/messages"
"github.com/mantlenetworkio/mantle/tss/node/tsslib/p2p"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)

type TssCommon struct {
conf TssConfig
logger zerolog.Logger
partyLock *sync.Mutex
partyInfo *abnormal2.PartyInfo
PartyIDtoP2PID map[string]peer.ID
unConfirmedMsgLock *sync.Mutex
unConfirmedMessages map[string]*LocalCacheItem
localPeerID string
broadcastChannel chan *messages.BroadcastMsgChan
TssMsg chan *p2p.Message
P2PPeersLock *sync.RWMutex
P2PPeers []peer.ID // most of tss message are broadcast, we store the peers ID to avoid iterating
msgID string
privateKey *ecdsa.PrivateKey
taskDone chan struct{}
abnormalMgr *abnormal2.Manager
finishedPeers map[string]bool
culprits []*tss.PartyID
culpritsLock *sync.RWMutex
conf TssConfig
logger zerolog.Logger

partyLock sync.RWMutex
partyInfo *abnormal2.PartyInfo

partyIDtoP2PIDMap *sync.Map // map[string]peer.ID
unConfirmedMessagesMap *sync.Map // map[string]*LocalCacheItem

localPeerID string
broadcastChannel chan *messages.BroadcastMsgChan
TssMsg chan *p2p.Message

P2PPeersLock sync.RWMutex
P2PPeers []peer.ID

msgID string
privateKey *ecdsa.PrivateKey
taskDone chan struct{}
abnormalMgr *abnormal2.Manager
finishedPeers map[string]bool

culpritsLock sync.RWMutex
culprits []*tss.PartyID

cachedWireBroadcastMsgLists *sync.Map
cachedWireUnicastMsgLists *sync.Map
thresHold int
threshHold int
}

func NewTssCommon(peerID string, broadcastChannel chan *messages.BroadcastMsgChan, conf TssConfig, msgID string, privKey *ecdsa.PrivateKey, thresHold int) *TssCommon {
return &TssCommon{
conf: conf,
logger: log.With().Str("module", "tsscommon").Logger(),
partyLock: &sync.Mutex{},
partyLock: sync.RWMutex{},
partyInfo: nil,
PartyIDtoP2PID: make(map[string]peer.ID),
unConfirmedMsgLock: &sync.Mutex{},
unConfirmedMessages: make(map[string]*LocalCacheItem),
partyIDtoP2PIDMap: &sync.Map{},
unConfirmedMessagesMap: &sync.Map{},
broadcastChannel: broadcastChannel,
TssMsg: make(chan *p2p.Message),
P2PPeersLock: &sync.RWMutex{},
P2PPeersLock: sync.RWMutex{},
P2PPeers: nil,
msgID: msgID,
localPeerID: peerID,
privateKey: privKey,
taskDone: make(chan struct{}),
abnormalMgr: abnormal2.NewAbnormalManager(),
finishedPeers: make(map[string]bool),
culpritsLock: &sync.RWMutex{},
culpritsLock: sync.RWMutex{},
cachedWireBroadcastMsgLists: &sync.Map{},
cachedWireUnicastMsgLists: &sync.Map{},
thresHold: thresHold,
threshHold: thresHold,
}
}

Expand Down Expand Up @@ -147,8 +153,23 @@ func (t *TssCommon) GetTaskDone() chan struct{} {
return t.taskDone
}

func (t *TssCommon) GetThresHold() int {
return t.thresHold
func (t *TssCommon) GetThreshHold() int {
return t.threshHold
}

func (t *TssCommon) InsertPartyIDtoP2PID(newMap map[string]peer.ID) {
for k, v := range newMap {
t.partyIDtoP2PIDMap.Store(k, v)
}
}

func (t *TssCommon) GetPartyIDtoP2PID() map[string]peer.ID {
result := make(map[string]peer.ID)
t.partyIDtoP2PIDMap.Range(func(key, value any) bool {
result[key.(string)] = value.(peer.ID)
return true
})
return result
}

func (t *TssCommon) GetAbnormalMgr() *abnormal2.Manager {
Expand All @@ -162,8 +183,8 @@ func (t *TssCommon) SetPartyInfo(partyInfo *abnormal2.PartyInfo) {
}

func (t *TssCommon) getPartyInfo() *abnormal2.PartyInfo {
t.partyLock.Lock()
defer t.partyLock.Unlock()
t.partyLock.RLock()
defer t.partyLock.RUnlock()
return t.partyInfo
}

Expand Down Expand Up @@ -248,14 +269,14 @@ func (t *TssCommon) updateLocal(wireMsg *messages.WireMessage) error {
return fmt.Errorf("get message from unknown party %s", partyID.Id)
}

dataOwnerPeerID, ok := t.PartyIDtoP2PID[wireMsg.Routing.From.Id]
dataOwnerPeerID, ok := t.partyIDtoP2PIDMap.Load(wireMsg.Routing.From.Id)
if !ok {
t.logger.Error().Msg("fail to find the peer ID of this party")
return errors.New("fail to find the peer")
}
// here we log down this peer as the latest unicast peer
if !wireMsg.Routing.IsBroadcast {
t.abnormalMgr.SetLastUnicastPeer(dataOwnerPeerID, wireMsg.RoundInfo)
t.abnormalMgr.SetLastUnicastPeer(dataOwnerPeerID.(peer.ID), wireMsg.RoundInfo)
}

var bulkMsg BulkWireMsg
Expand Down Expand Up @@ -306,13 +327,20 @@ func (t *TssCommon) updateLocal(wireMsg *messages.WireMessage) error {
}
return false
}
t.culpritsLock.RLock()
if len(t.culprits) != 0 && partyInlist(partyID, t.culprits) {
t.logger.Error().Msgf("the malicious party (party ID:%s) try to send incorrect message to me (party ID:%s)", partyID.Id, localMsgParty.PartyID().Id)
t.culpritsLock.RUnlock()
return errors.New("tss share verification failed")
err = func() error {
t.culpritsLock.RLock()
defer t.culpritsLock.RUnlock()
if len(t.culprits) != 0 && partyInlist(partyID, t.culprits) {
return errors.New("tss share verification failed")
}
return nil
}()
if err != nil {
t.logger.Err(err).Msgf("the malicious party (party ID:%s) try to send incorrect message to me (party ID:%s)",
partyID.Id, localMsgParty.PartyID().Id)
return err
}
t.culpritsLock.RUnlock()

job := newJob(localMsgParty, bulkMsg.WiredBulkMsg, round.MsgIdentifier, partyID, bulkMsg.Routing.IsBroadcast)
tssJobChan <- job
}
Expand Down Expand Up @@ -361,12 +389,12 @@ func (t *TssCommon) sendBulkMsg(wiredMsgType string, tssMsgType messages.TSSMess
t.P2PPeersLock.RUnlock()
} else {
for _, each := range r.To {
peerID, ok := t.PartyIDtoP2PID[each.Id]
peerID, ok := t.partyIDtoP2PIDMap.Load(each.Id)
if !ok {
t.logger.Error().Msg("error in find the P2P ID")
continue
}
peerIDs = append(peerIDs, peerID)
peerIDs = append(peerIDs, peerID.(peer.ID))
}
}
t.renderToP2P(&messages.BroadcastMsgChan{
Expand Down Expand Up @@ -437,19 +465,13 @@ func (t *TssCommon) applyShare(localCacheItem *LocalCacheItem, key string, msgTy
}
t.logger.Debug().Msgf("remove key: %s", key)
// the information had been confirmed by all party , we don't need it anymore
t.removeKey(key)
t.unConfirmedMessagesMap.Delete(key)
return nil
}

func (t *TssCommon) removeKey(key string) {
t.unConfirmedMsgLock.Lock()
defer t.unConfirmedMsgLock.Unlock()
delete(t.unConfirmedMessages, key)
}

func (t *TssCommon) hashCheck(localCacheItem *LocalCacheItem, threshold int) error {
dataOwner := localCacheItem.Msg.Routing.From
dataOwnerP2PID, ok := t.PartyIDtoP2PID[dataOwner.Id]
dataOwnerP2PID, ok := t.partyIDtoP2PIDMap.Load(dataOwner.Id)
if !ok {
t.logger.Warn().Msgf("error in find the data Owner P2PID\n")
return errors.New("error in find the data Owner P2PID")
Expand All @@ -464,7 +486,7 @@ func (t *TssCommon) hashCheck(localCacheItem *LocalCacheItem, threshold int) err

targetHashValue := localCacheItem.Hash
for P2PID := range localCacheItem.ConfirmedList {
if P2PID == dataOwnerP2PID.String() {
if P2PID == dataOwnerP2PID.(peer.ID).String() {
t.logger.Warn().Msgf("we detect that the data owner try to send the hash for his own message\n")
delete(localCacheItem.ConfirmedList, P2PID)
return abnormal2.ErrHashFromOwner
Expand Down Expand Up @@ -602,15 +624,15 @@ func (t *TssCommon) processTSSMsg(wireMsg *messages.WireMessage, msgType message
return fmt.Errorf("fail to calculate hash of the wire message: %w", err)
}
localCacheItem := t.TryGetLocalCacheItem(key)
if nil == localCacheItem {
t.logger.Debug().Msgf("++%s doesn't exist yet,add a new one", key)
if localCacheItem == nil {
t.logger.Debug().Msgf("%s doesn't exist yet,add a new one", key)
localCacheItem = NewLocalCacheItem(wireMsg, msgHash)
t.updateLocalUnconfirmedMessages(key, localCacheItem)
t.unConfirmedMessagesMap.Store(key, localCacheItem)
} else {
// this means we received the broadcast confirm message from other party first
t.logger.Debug().Msgf("==%s exist", key)
t.logger.Debug().Msgf("%s exist", key)
if localCacheItem.Msg == nil {
t.logger.Debug().Msgf("==%s exist, set message", key)
t.logger.Debug().Msgf("%s exist, set message", key)
localCacheItem.Msg = wireMsg
localCacheItem.Hash = msgHash
}
Expand All @@ -624,17 +646,9 @@ func (t *TssCommon) processTSSMsg(wireMsg *messages.WireMessage, msgType message
}

func (t *TssCommon) TryGetLocalCacheItem(key string) *LocalCacheItem {
t.unConfirmedMsgLock.Lock()
defer t.unConfirmedMsgLock.Unlock()
localCacheItem, ok := t.unConfirmedMessages[key]
localCacheItem, ok := t.unConfirmedMessagesMap.Load(key)
if !ok {
return nil
}
return localCacheItem
}

func (t *TssCommon) updateLocalUnconfirmedMessages(key string, cacheItem *LocalCacheItem) {
t.unConfirmedMsgLock.Lock()
defer t.unConfirmedMsgLock.Unlock()
t.unConfirmedMessages[key] = cacheItem
return localCacheItem.(*LocalCacheItem)
}
20 changes: 11 additions & 9 deletions tss/node/tsslib/conversion/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ import (
"encoding/hex"
"errors"
"fmt"
"math/big"
"sort"
"strconv"

"github.com/binance-chain/tss-lib/crypto"
"github.com/binance-chain/tss-lib/tss"
"github.com/btcsuite/btcd/btcec"
ethcrypto "github.com/ethereum/go-ethereum/crypto"
crypto2 "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"math/big"
"sort"
"strconv"
)

func GetParties(keys []string, localPartyKey string) ([]*tss.PartyID, *tss.PartyID, error) {
Expand Down Expand Up @@ -52,15 +53,16 @@ func SetupPartyIDMap(partiesID []*tss.PartyID) map[string]*tss.PartyID {
return partyIDMap
}

func SetupIDMaps(parties map[string]*tss.PartyID, partyIDtoP2PID map[string]peer.ID) error {
func GeneratePartyIDtoP2PIDMaps(parties map[string]*tss.PartyID) (map[string]peer.ID, error) {
partyIDtoP2PID := make(map[string]peer.ID)
for id, party := range parties {
peerID, err := GetPeerIDFromPartyID(party)
if err != nil {
return err
return nil, err
}
partyIDtoP2PID[id] = peerID
}
return nil
return partyIDtoP2PID, nil
}

func GetPeerIDFromPartyID(partyID *tss.PartyID) (peer.ID, error) {
Expand Down Expand Up @@ -107,10 +109,10 @@ func BytesToHashString(msg []byte) (string, error) {
return hex.EncodeToString(h.Sum(nil)), nil
}

func GetTssPubKey(pubKeyPoint *crypto.ECPoint) (string, []byte,[]byte, error) {
func GetTssPubKey(pubKeyPoint *crypto.ECPoint) (string, []byte, []byte, error) {
// we check whether the point is on curve according to Kudelski report
if pubKeyPoint == nil || !isOnCurve(pubKeyPoint.X(), pubKeyPoint.Y()) {
return "", nil,nil, errors.New("invalid points")
return "", nil, nil, errors.New("invalid points")
}
tssPubKey := btcec.PublicKey{
Curve: btcec.S256(),
Expand All @@ -123,7 +125,7 @@ func GetTssPubKey(pubKeyPoint *crypto.ECPoint) (string, []byte,[]byte, error) {

pubKeyBytes := tssPubKey.SerializeUncompressed()
pubKeyBytes = pubKeyBytes[1:]
return pubKeyStr, address,pubKeyBytes, nil
return pubKeyStr, address, pubKeyBytes, nil
}

func PartyIDtoPubKey(party *tss.PartyID) (string, error) {
Expand Down
11 changes: 6 additions & 5 deletions tss/node/tsslib/keygen/tss_keygen.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (tKeyGen *TssKeyGen) GenerateNewKey(keygenReq Request) (*bcrypto.ECPoint, e
keyGenLocalStateItem := storage2.KeygenLocalState{
ParticipantKeys: tKeyGen.ParticipantKeys,
LocalPartyKey: tKeyGen.localNodePubKey,
Threshold: tKeyGen.tssCommonStruct.GetThresHold(),
Threshold: tKeyGen.tssCommonStruct.GetThreshHold(),
}

if err != nil {
Expand All @@ -108,11 +108,12 @@ func (tKeyGen *TssKeyGen) GenerateNewKey(keygenReq Request) (*bcrypto.ECPoint, e
abnormalMgr := tKeyGen.tssCommonStruct.GetAbnormalMgr()
keyGenParty := keygen.NewLocalParty(params, outCh, endCh, *tKeyGen.preParams)
partyIDMap := conversion.SetupPartyIDMap(partiesID)
err1 := conversion.SetupIDMaps(partyIDMap, tKeyGen.tssCommonStruct.PartyIDtoP2PID)
if err1 != nil {
tKeyGen.logger.Error().Msgf("error in creating mapping between partyID and P2P ID")
partyIDtoP2PIDMaps, err := conversion.GeneratePartyIDtoP2PIDMaps(partyIDMap)
if err != nil {
tKeyGen.logger.Err(err).Msgf("error in creating mapping between partyID and P2P ID")
return nil, err
}
tKeyGen.tssCommonStruct.InsertPartyIDtoP2PID(partyIDtoP2PIDMaps)
// we never run multi keygen, so the moniker is set to default empty value
partyInfo := &abnormal.PartyInfo{
Party: keyGenParty,
Expand All @@ -122,7 +123,7 @@ func (tKeyGen *TssKeyGen) GenerateNewKey(keygenReq Request) (*bcrypto.ECPoint, e
tKeyGen.tssCommonStruct.SetPartyInfo(partyInfo)
abnormalMgr.SetPartyInfo(keyGenParty, partyIDMap)
tKeyGen.tssCommonStruct.P2PPeersLock.Lock()
tKeyGen.tssCommonStruct.P2PPeers = conversion.GetPeersID(tKeyGen.tssCommonStruct.PartyIDtoP2PID, tKeyGen.tssCommonStruct.GetLocalPeerID())
tKeyGen.tssCommonStruct.P2PPeers = conversion.GetPeersID(tKeyGen.tssCommonStruct.GetPartyIDtoP2PID(), tKeyGen.tssCommonStruct.GetLocalPeerID())
tKeyGen.tssCommonStruct.P2PPeersLock.Unlock()
var keyGenWg sync.WaitGroup
keyGenWg.Add(2)
Expand Down
15 changes: 6 additions & 9 deletions tss/node/tsslib/keysign/tss_keysign.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,18 @@ func (tKeySign *TssKeySign) SignMessage(msgToSign []byte, localStateItem storage
tKeySign.logger.Info().Msgf("message: (%s) keysign parties: %+v", m.String(), parties)
eachLocalPartyID.Moniker = moniker
tKeySign.localParties = nil
params := tss.NewParameters(btcec.S256(), ctx, eachLocalPartyID, len(partiesID), tKeySign.GetTssCommonStruct().GetThresHold())
params := tss.NewParameters(btcec.S256(), ctx, eachLocalPartyID, len(partiesID), tKeySign.GetTssCommonStruct().GetThreshHold())
keySignParty := signing.NewLocalParty(m, params, localStateItem.LocalData, outCh, endCh)

abnormalMgr := tKeySign.tssCommonStruct.GetAbnormalMgr()
partyIDMap := conversion.SetupPartyIDMap(partiesID)
err = conversion.SetupIDMaps(partyIDMap, tKeySign.tssCommonStruct.PartyIDtoP2PID)
partyIDtoP2PIDMap, err := conversion.GeneratePartyIDtoP2PIDMaps(partyIDMap)
if err != nil {
tKeySign.logger.Error().Err(err).Msgf("error in creating mapping between partyID and P2P ID")
return nil, err
}
err = conversion.SetupIDMaps(partyIDMap, abnormalMgr.PartyIDtoP2PID)
if err != nil {
tKeySign.logger.Error().Err(err).Msgf("error in creating mapping between partyID and P2P ID")
tKeySign.logger.Err(err).Msgf("error in creating mapping between partyID and P2P ID")
return nil, err
}
tKeySign.tssCommonStruct.InsertPartyIDtoP2PID(partyIDtoP2PIDMap)
abnormalMgr.PartyIDtoP2PID = partyIDtoP2PIDMap

tKeySign.tssCommonStruct.SetPartyInfo(&abnormal.PartyInfo{
Party: keySignParty,
Expand All @@ -114,7 +111,7 @@ func (tKeySign *TssKeySign) SignMessage(msgToSign []byte, localStateItem storage
abnormalMgr.SetPartyInfo(keySignParty, partyIDMap)

tKeySign.tssCommonStruct.P2PPeersLock.Lock()
tKeySign.tssCommonStruct.P2PPeers = conversion.GetPeersID(tKeySign.tssCommonStruct.PartyIDtoP2PID, tKeySign.tssCommonStruct.GetLocalPeerID())
tKeySign.tssCommonStruct.P2PPeers = conversion.GetPeersID(tKeySign.tssCommonStruct.GetPartyIDtoP2PID(), tKeySign.tssCommonStruct.GetLocalPeerID())
tKeySign.tssCommonStruct.P2PPeersLock.Unlock()
var keySignWg sync.WaitGroup
keySignWg.Add(2)
Expand Down

0 comments on commit 548f9b2

Please sign in to comment.