diff --git a/go.sum b/go.sum index edbb5ea2e090..4d18c8c20ece 100644 --- a/go.sum +++ b/go.sum @@ -101,6 +101,7 @@ github.com/influxdata/influxdb v1.2.3-0.20180221223340-01288bdb0883 h1:FSeK4fZCo github.com/influxdata/influxdb v1.2.3-0.20180221223340-01288bdb0883/go.mod h1:qZna6X/4elxqT3yI9iZYdZrWWdeFOOprn86kgg4+IzY= github.com/jackpal/go-nat-pmp v1.0.2-0.20160603034137-1fa385a6f458 h1:6OvNmYgJyexcZ3pYbTI9jWx5tHo1Dee/tWbLMfPe2TA= github.com/jackpal/go-nat-pmp v1.0.2-0.20160603034137-1fa385a6f458/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= +github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/julienschmidt/httprouter v1.1.1-0.20170430222011-975b5c4c7c21 h1:F/iKcka0K2LgnKy/fgSBf235AETtm1n1TvBzqu40LE0= github.com/julienschmidt/httprouter v1.1.1-0.20170430222011-975b5c4c7c21/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go index bc105ef37c28..7ac7762e4522 100644 --- a/internal/web3ext/web3ext.go +++ b/internal/web3ext/web3ext.go @@ -33,6 +33,7 @@ var Modules = map[string]string{ "swarmfs": SwarmfsJs, "txpool": TxpoolJs, "les": LESJs, + "lespay": LESPAYJs, } const ChequebookJs = ` @@ -856,3 +857,45 @@ web3._extend({ ] }); ` + +const LESPAYJs = ` +web3._extend({ + property: 'lespay', + methods: + [ + new web3._extend.Method({ + name: 'connection', + call: 'lespay_connection', + params: 6 + }), + new web3._extend.Method({ + name: 'deposit', + call: 'lespay_deposit', + params: 4 + }), + new web3._extend.Method({ + name: 'buyTokens', + call: 'lespay_buyTokens', + params: 6 + }), + new web3._extend.Method({ + name: 'getBalance', + call: 'lespay_getBalance', + params: 2 + }), + new web3._extend.Method({ + name: 'info', + call: 'lespay_info', + params: 2 + }), + new web3._extend.Method({ + name: 'remoteInfo', + call: 'lespay_remoteInfo', + params: 3 + }), + ], + properties: + [ + ] +}); +` diff --git a/les/api.go b/les/api.go index ad511c9d6b59..0fb5b5c096f3 100644 --- a/les/api.go +++ b/les/api.go @@ -17,14 +17,17 @@ package les import ( + "context" "errors" "fmt" - "math" + "reflect" "time" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/rlp" ) var ( @@ -35,8 +38,6 @@ var ( errNoPriority = errors.New("priority too low to raise capacity") ) -const maxBalance = math.MaxInt64 - // PrivateLightServerAPI provides an API to access the LES light server. type PrivateLightServerAPI struct { server *LesServer @@ -150,7 +151,7 @@ func (api *PrivateLightServerAPI) setParams(params map[string]interface{}, clien setFactor(&negFactors.requestFactor) case !defParams && name == "capacity": if capacity, ok := value.(float64); ok && uint64(capacity) >= api.server.minCapacity { - err = api.server.clientPool.setCapacity(client, uint64(capacity)) + _, _, err = api.server.clientPool.setCapacity(client.id, client.freeID, uint64(capacity), 0, true) // Don't have to call factor update explicitly. It's already done // in setCapacity function. } else { @@ -184,7 +185,7 @@ func (api *PrivateLightServerAPI) SetClientParams(ids []enode.ID, params map[str if client != nil { update, err := api.setParams(params, client, nil, nil) if update { - client.updatePriceFactors() + updatePriceFactors(&client.balanceTracker, client.posFactors, client.negFactors, client.capacity) } return err } else { @@ -352,3 +353,166 @@ func (api *PrivateLightAPI) GetCheckpointContractAddress() (string, error) { } return api.backend.oracle.config.Address.Hex(), nil } + +type PrivateLespayAPI struct { + peerSet *peerSet + clientHandler *clientHandler + dht *discv5.Network + tokenSale *tokenSale +} + +// NewPrivateLespayAPI creates a new LESPAY API. +func NewPrivateLespayAPI(peerSet *peerSet, clientHandler *clientHandler, dht *discv5.Network, tokenSale *tokenSale) *PrivateLespayAPI { + return &PrivateLespayAPI{ + peerSet: peerSet, + clientHandler: clientHandler, + dht: dht, + tokenSale: tokenSale, + } +} + +func (api *PrivateLespayAPI) makeCall(ctx context.Context, remote bool, nodeStr string, cmd []byte) ([]byte, error) { + var ( + id enode.ID + freeID string + peer *peer + node *enode.Node + err error + ) + if nodeStr != "" { + if id, err = enode.ParseID(nodeStr); err == nil { + if peer = api.peerSet.Peer(peerIdToString(id)); peer == nil { + return nil, errors.New("peer not connected") + } + freeID = peer.freeClientId() + } else { + var err error + if node, err = enode.Parse(enode.ValidSchemes, nodeStr); err == nil { + id = node.ID() + freeID = node.IP().String() + } else { + return nil, err + } + } + } + + if remote { + var ( + reply []byte + cancelFn func() bool + ) + delivered := make(chan struct{}) + if peer != nil { + // remote call to a connected peer through LES + if api.clientHandler == nil { + return nil, errors.New("client handler not available") + } + cancelFn = api.clientHandler.makeLespayCall(peer, cmd, func(r []byte) bool { + reply = r + close(delivered) + return reply != nil + }) + } else { + // remote call through UDP TALK + if api.dht == nil { + return nil, errors.New("UDP DHT not available") + } + cancelFn = api.dht.SendTalkRequest(node, "lespay", [][]byte{cmd}, func(payload interface{}) bool { + fmt.Println("dht delivered", payload, reflect.TypeOf(payload)) + if replies, ok := payload.([]interface{}); ok && len(replies) == 1 { + reply, ok = replies[0].([]byte) + } + close(delivered) + return reply != nil + }) + } + select { + case <-time.After(time.Second * 5): + cancelFn() + return nil, errors.New("timeout") + case <-ctx.Done(): + cancelFn() + return nil, ctx.Err() + case <-delivered: + if len(reply) == 0 { + return nil, errors.New("unknown command") + } + return reply, nil + } + } else { + if api.tokenSale == nil { + return nil, errors.New("token sale module not available") + } + // execute call locally + return api.tokenSale.runCommand(cmd, id, freeID), nil + } + +} + +func (api *PrivateLespayAPI) Connection(ctx context.Context, remote bool, node string, requestedCapacity, stayConnected uint64, paymentModule []string, setCap bool) (results tsConnectionResults, err error) { + params := tsConnectionParams{requestedCapacity, stayConnected, paymentModule, setCap} + enc, _ := rlp.EncodeToBytes(¶ms) + var resEnc []byte + resEnc, err = api.makeCall(ctx, remote, node, append([]byte{tsConnection}, enc...)) + if err != nil { + return + } + err = rlp.DecodeBytes(resEnc, &results) + return +} + +func (api *PrivateLespayAPI) Deposit(ctx context.Context, remote bool, node string, paymentModule string, proofOfPayment []byte) (results tsDepositResults, err error) { + params := tsDepositParams{paymentModule, proofOfPayment} + enc, _ := rlp.EncodeToBytes(¶ms) + var resEnc []byte + resEnc, err = api.makeCall(ctx, remote, node, append([]byte{tsDeposit}, enc...)) + if err != nil { + return + } + err = rlp.DecodeBytes(resEnc, &results) + return +} + +func (api *PrivateLespayAPI) BuyTokens(ctx context.Context, remote bool, node string, maxSpend, minReceive uint64, relative, spendAll bool) (results tsBuyTokensResults, err error) { + params := tsBuyTokensParams{maxSpend, minReceive, relative, spendAll} + enc, _ := rlp.EncodeToBytes(¶ms) + var resEnc []byte + resEnc, err = api.makeCall(ctx, remote, node, append([]byte{tsBuyTokens}, enc...)) + if err != nil { + return + } + err = rlp.DecodeBytes(resEnc, &results) + return +} + +func (api *PrivateLespayAPI) GetBalance(ctx context.Context, remote bool, node string) (results tsGetBalanceResults, err error) { + var resEnc []byte + resEnc, err = api.makeCall(ctx, remote, node, []byte{tsGetBalance}) + if err != nil { + return + } + err = rlp.DecodeBytes(resEnc, &results) + return +} + +func (api *PrivateLespayAPI) Info(ctx context.Context, remote bool, node string) (results tsInfoResults, err error) { + var resEnc []byte + resEnc, err = api.makeCall(ctx, remote, node, []byte{tsInfo}) + if err != nil { + return + } + err = rlp.DecodeBytes(resEnc, &results) + return +} + +func (api *PrivateLespayAPI) ReceiverInfo(ctx context.Context, remote bool, node string, receiverIDs []string) (results tsReceiverInfoResults, err error) { + params := tsReceiverInfoParams(receiverIDs) + enc, _ := rlp.EncodeToBytes(¶ms) + var resEnc []byte + resEnc, err = api.makeCall(ctx, remote, node, append([]byte{tsReceiverInfo}, enc...)) + if err != nil { + return + } + err = rlp.DecodeBytes(resEnc, &results) + return +} diff --git a/les/balance.go b/les/balance.go index 51cef15c803d..1d3cbd143ead 100644 --- a/les/balance.go +++ b/les/balance.go @@ -17,12 +17,15 @@ package les import ( + "math" "sync" "time" "github.com/ethereum/go-ethereum/common/mclock" ) +const maxBalance = math.MaxInt64 + const ( balanceCallbackQueue = iota balanceCallbackZero @@ -65,6 +68,7 @@ type balanceCallback struct { } // init initializes balanceTracker +// Note: capacity should never be zero func (bt *balanceTracker) init(clock mclock.Clock, capacity uint64) { bt.clock = clock bt.initTime, bt.lastUpdate = clock.Now(), clock.Now() // Init timestamps @@ -96,11 +100,36 @@ func (bt *balanceTracker) stop(now mclock.AbsTime) { // balance is zero then negative balance translates to a positive priority. func (bt *balanceTracker) balanceToPriority(b balance) int64 { if b.pos > 0 { - return ^int64(b.pos / bt.capacity) + return -int64(b.pos / bt.capacity) } return int64(b.neg) } +func (bt *balanceTracker) posBalanceMissing(targetPriority int64, targetCapacity uint64, after time.Duration) uint64 { + if targetPriority > 0 { + negPrice := uint64(float64(after) * bt.negTimeFactor) + if negPrice+bt.balance.neg < uint64(targetPriority) { + return 0 + } + if uint64(targetPriority) > bt.balance.neg && bt.negTimeFactor > 1e-100 { + if negTime := time.Duration(float64(uint64(targetPriority)-bt.balance.neg) / bt.negTimeFactor); negTime < after { + after -= negTime + } else { + after = 0 + } + } + targetPriority = 0 + } + posRequired := uint64(float64(-targetPriority)*float64(targetCapacity)+float64(after)*bt.timeFactor) + 1 + if posRequired >= maxBalance { + return math.MaxUint64 // target not reachable + } + if posRequired > bt.balance.pos { + return posRequired - bt.balance.pos + } + return 0 +} + // reducedBalance estimates the reduced balance at a given time in the fututre based // on the current balance, the time factor and an estimated average request cost per time ratio func (bt *balanceTracker) reducedBalance(at mclock.AbsTime, avgReqCost float64) balance { @@ -136,7 +165,7 @@ func (bt *balanceTracker) timeUntil(priority int64) (time.Duration, bool) { return 0, false } if priority < 0 { - newBalance := uint64(^priority) * bt.capacity + newBalance := uint64(-priority) * bt.capacity if newBalance > bt.balance.pos { return 0, false } @@ -161,6 +190,7 @@ func (bt *balanceTracker) timeUntil(priority int64) (time.Duration, bool) { } // setCapacity updates the capacity value used for priority calculation +// Note: capacity should never be zero func (bt *balanceTracker) setCapacity(capacity uint64) { bt.lock.Lock() defer bt.lock.Unlock() @@ -262,12 +292,12 @@ func (bt *balanceTracker) updateAfter(dt time.Duration) { } // requestCost should be called after serving a request for the given peer -func (bt *balanceTracker) requestCost(cost uint64) { +func (bt *balanceTracker) requestCost(cost uint64) uint64 { bt.lock.Lock() defer bt.lock.Unlock() if bt.stopped { - return + return 0 } now := bt.clock.Now() bt.addBalance(now) @@ -295,6 +325,7 @@ func (bt *balanceTracker) requestCost(cost uint64) { } } bt.sumReqCost += cost + return bt.balance.pos } // getBalance returns the current positive and negative balance diff --git a/les/balance_test.go b/les/balance_test.go index b571c2cc5c2d..2cf4fcc7307d 100644 --- a/les/balance_test.go +++ b/les/balance_test.go @@ -141,8 +141,8 @@ func TestBalanceToPriority(t *testing.T) { neg uint64 priority int64 }{ - {1000, 0, ^int64(1)}, - {2000, 0, ^int64(2)}, // Higher balance, lower priority value + {1000, 0, -1}, + {2000, 0, -2}, // Higher balance, lower priority value {0, 0, 0}, {0, 1000, 1000}, } @@ -172,16 +172,16 @@ func TestEstimatedPriority(t *testing.T) { reqCost uint64 // single request cost priority int64 // expected estimated priority }{ - {time.Second, time.Second, 0, ^int64(58)}, - {0, time.Second, 0, ^int64(58)}, + {time.Second, time.Second, 0, -58}, + {0, time.Second, 0, -58}, // 2 seconds time cost, 1 second estimated time cost, 10^9 request cost, // 10^9 estimated request cost per second. - {time.Second, time.Second, 1000000000, ^int64(55)}, + {time.Second, time.Second, 1000000000, -55}, // 3 seconds time cost, 3 second estimated time cost, 10^9*2 request cost, // 4*10^9 estimated request cost. - {time.Second, 3 * time.Second, 1000000000, ^int64(48)}, + {time.Second, 3 * time.Second, 1000000000, -48}, // All positive balance is used up {time.Second * 55, 0, 0, 0}, @@ -213,7 +213,7 @@ func TestCallbackChecking(t *testing.T) { priority int64 expDiff time.Duration }{ - {^int64(500), time.Millisecond * 500}, + {-500, time.Millisecond * 500}, {0, time.Second}, {int64(time.Second), 2 * time.Second}, } diff --git a/les/client.go b/les/client.go index c460f4c09d5d..26c5175bcfe2 100644 --- a/les/client.go +++ b/les/client.go @@ -48,6 +48,7 @@ import ( type LightEthereum struct { lesCommons + srvr *p2p.Server reqDist *requestDistributor retriever *retrieveManager odr *LesOdr @@ -206,6 +207,12 @@ func (s *LightEthereum) APIs() []rpc.API { Service: NewPrivateLightAPI(&s.lesCommons), Public: false, }, + { + Namespace: "lespay", + Version: "1.0", + Service: NewPrivateLespayAPI(s.lesCommons.peers, s.handler, s.srvr.DiscV5, nil), + Public: false, + }, }...) } @@ -235,6 +242,7 @@ func (s *LightEthereum) Protocols() []p2p.Protocol { // light ethereum protocol implementation. func (s *LightEthereum) Start(srvr *p2p.Server) error { log.Warn("Light client mode is an experimental feature") + s.srvr = srvr // Start bloom request workers. s.wg.Add(bloomServiceThreads) diff --git a/les/client_handler.go b/les/client_handler.go index 7fdb1657194c..485a1395b591 100644 --- a/les/client_handler.go +++ b/les/client_handler.go @@ -17,6 +17,7 @@ package les import ( + "fmt" "math/big" "sync" "time" @@ -40,6 +41,9 @@ type clientHandler struct { downloader *downloader.Downloader backend *LightEthereum + lespayReplyHandlers map[uint64]func([]byte) bool + lespayReplyLock sync.Mutex + closeCh chan struct{} wg sync.WaitGroup // WaitGroup used to track all connected peers. syncDone func() // Test hooks when syncing is done. @@ -47,9 +51,10 @@ type clientHandler struct { func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.TrustedCheckpoint, backend *LightEthereum) *clientHandler { handler := &clientHandler{ - checkpoint: checkpoint, - backend: backend, - closeCh: make(chan struct{}), + checkpoint: checkpoint, + backend: backend, + closeCh: make(chan struct{}), + lespayReplyHandlers: make(map[uint64]func([]byte) bool), } if ulcServers != nil { ulc, err := newULC(ulcServers, ulcFraction) @@ -111,28 +116,48 @@ func (h *clientHandler) handle(p *peer) error { p.Log().Debug("Light Ethereum handshake failed", "err", err) return err } - // Register the peer locally - if err := h.backend.peers.Register(p); err != nil { - p.Log().Error("Light Ethereum peer registration failed", "err", err) - return err - } - serverConnectionGauge.Update(int64(h.backend.peers.Len())) - connectedAt := mclock.Now() - defer func() { - h.backend.peers.Unregister(p.id) + var ( + connectedAt mclock.AbsTime + lastActive bool + ) + activate := func() { + // Register the peer locally + if err := h.backend.peers.Register(p); err != nil { + p.Log().Error("Light Ethereum peer registration failed", "err", err) + return + } + serverConnectionGauge.Update(int64(h.backend.peers.Len())) + connectedAt = mclock.Now() + h.fetcher.announce(p, p.headInfo) + lastActive = true + } + deactivate := func() { + h.backend.peers.Unregister(p) connectionTimer.Update(time.Duration(mclock.Now() - connectedAt)) serverConnectionGauge.Update(int64(h.backend.peers.Len())) + lastActive = false + } + defer func() { + if lastActive { + deactivate() + } + h.backend.peers.Disconnect(p.id) }() - h.fetcher.announce(p, p.headInfo) - // pool entry can be nil during the unit test. if p.poolEntry != nil { h.backend.serverPool.registered(p.poolEntry) } + // Spawn a main loop to handle all incoming messages. for { + if p.active && !lastActive { + activate() + } + if !p.active && lastActive { + deactivate() + } if err := h.handleMsg(p); err != nil { p.Log().Debug("Light Ethereum message handling failed", "err", err) p.fcServer.DumpLogs() @@ -156,7 +181,10 @@ func (h *clientHandler) handleMsg(p *peer) error { } defer msg.Discard() - var deliverMsg *Msg + var ( + deliverMsg *Msg + responseError bool + ) // Handle the message depending on its contents switch msg.Code { @@ -192,13 +220,15 @@ func (h *clientHandler) handleMsg(p *peer) error { case BlockHeadersMsg: p.Log().Trace("Received block header response message") var resp struct { - ReqID, BV uint64 - Headers []*types.Header + ReqID uint64 + SF stateFeedback + Headers []*types.Header } + resp.SF.protocolVersion = p.version if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV) if h.fetcher.requestedID(resp.ReqID) { h.fetcher.deliverHeaders(p, resp.ReqID, resp.Headers) } else { @@ -209,13 +239,15 @@ func (h *clientHandler) handleMsg(p *peer) error { case BlockBodiesMsg: p.Log().Trace("Received block bodies response") var resp struct { - ReqID, BV uint64 - Data []*types.Body + ReqID uint64 + SF stateFeedback + Data []*types.Body } + resp.SF.protocolVersion = p.version if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV) deliverMsg = &Msg{ MsgType: MsgBlockBodies, ReqID: resp.ReqID, @@ -224,13 +256,15 @@ func (h *clientHandler) handleMsg(p *peer) error { case CodeMsg: p.Log().Trace("Received code response") var resp struct { - ReqID, BV uint64 - Data [][]byte + ReqID uint64 + SF stateFeedback + Data [][]byte } + resp.SF.protocolVersion = p.version if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV) deliverMsg = &Msg{ MsgType: MsgCode, ReqID: resp.ReqID, @@ -239,13 +273,15 @@ func (h *clientHandler) handleMsg(p *peer) error { case ReceiptsMsg: p.Log().Trace("Received receipts response") var resp struct { - ReqID, BV uint64 - Receipts []types.Receipts + ReqID uint64 + SF stateFeedback + Receipts []types.Receipts } + resp.SF.protocolVersion = p.version if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV) deliverMsg = &Msg{ MsgType: MsgReceipts, ReqID: resp.ReqID, @@ -254,13 +290,15 @@ func (h *clientHandler) handleMsg(p *peer) error { case ProofsV2Msg: p.Log().Trace("Received les/2 proofs response") var resp struct { - ReqID, BV uint64 - Data light.NodeList + ReqID uint64 + SF stateFeedback + Data light.NodeList } + resp.SF.protocolVersion = p.version if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV) deliverMsg = &Msg{ MsgType: MsgProofsV2, ReqID: resp.ReqID, @@ -269,13 +307,15 @@ func (h *clientHandler) handleMsg(p *peer) error { case HelperTrieProofsMsg: p.Log().Trace("Received helper trie proof response") var resp struct { - ReqID, BV uint64 - Data HelperTrieResps + ReqID uint64 + SF stateFeedback + Data HelperTrieResps } + resp.SF.protocolVersion = p.version if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV) deliverMsg = &Msg{ MsgType: MsgHelperTrieProofs, ReqID: resp.ReqID, @@ -284,13 +324,15 @@ func (h *clientHandler) handleMsg(p *peer) error { case TxStatusMsg: p.Log().Trace("Received tx status response") var resp struct { - ReqID, BV uint64 - Status []light.TxStatus + ReqID uint64 + SF stateFeedback + Status []light.TxStatus } + resp.SF.protocolVersion = p.version if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV) deliverMsg = &Msg{ MsgType: MsgTxStatus, ReqID: resp.ReqID, @@ -301,13 +343,37 @@ func (h *clientHandler) handleMsg(p *peer) error { h.backend.retriever.frozen(p) p.Log().Debug("Service stopped") case ResumeMsg: - var bv uint64 - if err := msg.Decode(&bv); err != nil { + var sf stateFeedback + sf.protocolVersion = p.version + if err := msg.Decode(&sf); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ResumeFreeze(bv) + p.fcServer.ResumeFreeze(sf.BV) p.freezeServer(false) p.Log().Debug("Service resumed") + case LespayReplyMsg: + fmt.Println("LespayReply received") + p.Log().Trace("Received tx status response") + var resp struct { + ReqID uint64 + Reply []byte + } + if err := msg.Decode(&resp); err != nil { + fmt.Println("LespayReply decode err", err) + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + fmt.Println("LespayReply decoded", resp) + h.lespayReplyLock.Lock() + if handler := h.lespayReplyHandlers[resp.ReqID]; handler != nil { + fmt.Println("handler found") + delete(h.lespayReplyHandlers, resp.ReqID) + responseError = !handler(resp.Reply) + } else { + fmt.Println("handler not found") + responseError = true + } + h.lespayReplyLock.Unlock() + default: p.Log().Trace("Received invalid message", "code", msg.Code) return errResp(ErrInvalidMsgCode, "%v", msg.Code) @@ -315,17 +381,46 @@ func (h *clientHandler) handleMsg(p *peer) error { // Deliver the received response to retriever. if deliverMsg != nil { if err := h.backend.retriever.deliver(p, deliverMsg); err != nil { - p.responseErrors++ - if p.responseErrors > maxResponseErrors { - return err - } + responseError = true + } + } + if responseError { + p.responseErrors++ + if p.responseErrors > maxResponseErrors { + return err } } return nil } +func (h *clientHandler) makeLespayCall(p *peer, cmd []byte, handler func([]byte) bool) func() bool { + reqID := genReqID() + h.lespayReplyLock.Lock() + h.lespayReplyHandlers[reqID] = handler + h.lespayReplyLock.Unlock() + if p.SendLespay(reqID, cmd) != nil { + h.lespayReplyLock.Lock() + delete(h.lespayReplyHandlers, reqID) + h.lespayReplyLock.Unlock() + return nil + } + fmt.Println("Lespay sent") + return func() bool { + h.lespayReplyLock.Lock() + cancel := h.lespayReplyHandlers[reqID] != nil + if cancel { + delete(h.lespayReplyHandlers, reqID) + } + h.lespayReplyLock.Unlock() + if cancel { + handler(nil) + } + return cancel + } +} + func (h *clientHandler) removePeer(id string) { - h.backend.peers.Unregister(id) + h.backend.peers.Disconnect(id) } type peerConnection struct { diff --git a/les/clientpool.go b/les/clientpool.go index da76f08b91fd..42a7d06fc28e 100644 --- a/les/clientpool.go +++ b/les/clientpool.go @@ -39,18 +39,21 @@ const ( negBalanceExpTC = time.Hour // time constant for exponentially reducing negative balance fixedPointMultiplier = 0x1000000 // constant to convert logarithms to fixed point format lazyQueueRefresh = time.Second * 10 // refresh period of the connected queue + tryActivatePeriod = time.Second * 5 // periodically check whether inactive clients can be activated + dropInactiveCycles = 2 // number of activation check periods after non-priority inactive peers are dropped persistCumulativeTimeRefresh = time.Minute * 5 // refresh period of the cumulative running time persistence posBalanceCacheLimit = 8192 // the maximum number of cached items in positive balance queue negBalanceCacheLimit = 8192 // the maximum number of cached items in negative balance queue + fullRatioTC = time.Hour - // connectedBias is applied to already connected clients So that + // activeBias is applied to already connected clients So that // already connected client won't be kicked out very soon and we // can ensure all connected clients can have enough time to request // or sync some data. // // todo(rjl493456442) make it configurable. It can be the option of // free trial time! - connectedBias = time.Minute * 3 + activeBias = time.Minute * 3 ) // clientPool implements a client database that assigns a priority to each client @@ -60,7 +63,7 @@ const ( // then negative balance is accumulated. // // Balance tracking and priority calculation for connected clients is done by -// balanceTracker. connectedQueue ensures that clients with the lowest positive or +// balanceTracker. activeQueue ensures that clients with the lowest positive or // highest negative balance get evicted when the total capacity allowance is full // and new clients with a better balance want to connect. // @@ -82,19 +85,27 @@ type clientPool struct { closed bool removePeer func(enode.ID) - connectedMap map[enode.ID]*clientInfo - connectedQueue *prque.LazyQueue + connectedMap map[enode.ID]*clientInfo + activeQueue *prque.LazyQueue + inactiveQueue *prque.Prque + dropInactivePeers map[uint64][]*clientInfo + dropInactiveCounter uint64 + + activeBalances, inactiveBalances uint64 + lastConnectedBalanceUpdate, fullRatioLastUpdate mclock.AbsTime + fullRatio float64 defaultPosFactors, defaultNegFactors priceFactors - connLimit int // The maximum number of connections that clientpool can support - capLimit uint64 // The maximum cumulative capacity that clientpool can support - connectedCap uint64 // The sum of the capacity of the current clientpool connected - priorityConnected uint64 // The sum of the capacity of currently connected priority clients - freeClientCap uint64 // The capacity value of each free client - startTime mclock.AbsTime // The timestamp at which the clientpool started running - cumulativeTime int64 // The cumulative running time of clientpool at the start point. - disableBias bool // Disable connection bias(used in testing) + activeLimit int // The maximum number of connections that clientpool can support + capLimit uint64 // The maximum cumulative capacity that clientpool can support + activeCap uint64 // The sum of the capacity of the current clientpool connected + priorityActive uint64 // The sum of the capacity of currently connected priority clients + minCap uint64 // The minimal capacity value allowed for any client + freeClientCap uint64 // The capacity value of each free client + startTime mclock.AbsTime // The timestamp at which the clientpool started running + cumulativeTime int64 // The cumulative running time of clientpool at the start point. + disableBias bool // Disable connection bias(used in testing) } // clientPeer represents a client in the pool. @@ -113,36 +124,38 @@ type clientPeer interface { type clientInfo struct { address string id enode.ID + freeID string + active bool connectedAt mclock.AbsTime capacity uint64 priority bool pool *clientPool peer clientPeer - queueIndex int // position in connectedQueue + queueIndex int // position in activeQueue balanceTracker balanceTracker posFactors, negFactors priceFactors balanceMetaInfo string } -// connSetIndex callback updates clientInfo item index in connectedQueue +// connSetIndex callback updates clientInfo item index in activeQueue func connSetIndex(a interface{}, index int) { a.(*clientInfo).queueIndex = index } -// connPriority callback returns actual priority of clientInfo item in connectedQueue +// connPriority callback returns actual priority of clientInfo item in activeQueue func connPriority(a interface{}, now mclock.AbsTime) int64 { c := a.(*clientInfo) return c.balanceTracker.getPriority(now) } -// connMaxPriority callback returns estimated maximum priority of clientInfo item in connectedQueue +// connMaxPriority callback returns estimated maximum priority of clientInfo item in activeQueue func connMaxPriority(a interface{}, until mclock.AbsTime) int64 { c := a.(*clientInfo) pri := c.balanceTracker.estimatedPriority(until, true) c.balanceTracker.addCallback(balanceCallbackQueue, pri+1, func() { c.pool.lock.Lock() - if c.queueIndex != -1 { - c.pool.connectedQueue.Update(c.queueIndex) + if c.active && c.queueIndex != -1 { + c.pool.activeQueue.Update(c.queueIndex) } c.pool.lock.Unlock() }) @@ -159,18 +172,40 @@ type priceFactors struct { } // newClientPool creates a new client pool -func newClientPool(db ethdb.Database, freeClientCap uint64, clock mclock.Clock, removePeer func(enode.ID)) *clientPool { +func newClientPool(db ethdb.Database, minCap, freeClientCap uint64, clock mclock.Clock, removePeer func(enode.ID)) *clientPool { ndb := newNodeDB(db, clock) pool := &clientPool{ - ndb: ndb, - clock: clock, - connectedMap: make(map[enode.ID]*clientInfo), - connectedQueue: prque.NewLazyQueue(connSetIndex, connPriority, connMaxPriority, clock, lazyQueueRefresh), - freeClientCap: freeClientCap, - removePeer: removePeer, - startTime: clock.Now(), - cumulativeTime: ndb.getCumulativeTime(), - stopCh: make(chan struct{}), + ndb: ndb, + clock: clock, + connectedMap: make(map[enode.ID]*clientInfo), + activeQueue: prque.NewLazyQueue(connSetIndex, connPriority, connMaxPriority, clock, lazyQueueRefresh), + inactiveQueue: prque.New(connSetIndex), + dropInactivePeers: make(map[uint64][]*clientInfo), + minCap: minCap, + freeClientCap: freeClientCap, + removePeer: removePeer, + startTime: clock.Now(), + cumulativeTime: ndb.getCumulativeTime(), + stopCh: make(chan struct{}), + } + // calculate total token balance amount + var start enode.ID + for { + ids := pool.ndb.getPosBalanceIDs(start, enode.ID{}, 1000) + var stop bool + l := len(ids) + if l == 1000 { + l-- + start = ids[l] + } else { + stop = true + } + for i := 0; i < l; i++ { + pool.inactiveBalances += pool.ndb.getOrNewPB(ids[i]).value + } + if stop { + break + } } // If the negative balance of free client is even lower than 1, // delete this entry. @@ -183,8 +218,16 @@ func newClientPool(db ethdb.Database, freeClientCap uint64, clock mclock.Clock, select { case <-clock.After(lazyQueueRefresh): pool.lock.Lock() - pool.connectedQueue.Refresh() + pool.activeQueue.Refresh() pool.lock.Unlock() + case <-pool.stopCh: + return + } + } + }() + go func() { + for { + select { case <-clock.After(persistCumulativeTimeRefresh): pool.ndb.setCumulativeTime(pool.logOffset(clock.Now())) case <-pool.stopCh: @@ -192,6 +235,25 @@ func newClientPool(db ethdb.Database, freeClientCap uint64, clock mclock.Clock, } } }() + go func() { + for { + select { + case <-clock.After(tryActivatePeriod): + pool.lock.Lock() + pool.tryActivateClients() + for _, c := range pool.dropInactivePeers[pool.dropInactiveCounter] { + if _, ok := pool.connectedMap[c.id]; ok && !c.active && !c.priority { + pool.disconnectLocked(c.peer) + } + } + delete(pool.dropInactivePeers, pool.dropInactiveCounter) + pool.dropInactiveCounter++ + pool.lock.Unlock() + case <-pool.stopCh: + return + } + } + }() return pool } @@ -205,118 +267,136 @@ func (f *clientPool) stop() { f.ndb.close() } +func (f *clientPool) updateFullRatio() { + full := float64(1) + if f.priorityActive < f.capLimit { + freeCap := f.capLimit - f.priorityActive + if freeCap > f.freeClientCap { + freeCapThreshold := f.capLimit / 4 + if freeCap > freeCapThreshold { + full = 0 + } else { + full = float64(freeCapThreshold-freeCap) / float64(freeCapThreshold-f.freeClientCap) + } + } + } + now := f.clock.Now() + dt := now - f.fullRatioLastUpdate + f.fullRatioLastUpdate = now + if dt < 0 { + dt = 0 + } + d := math.Exp(-float64(dt) / float64(fullRatioTC)) + f.fullRatio = full - (full-f.fullRatio)*d +} + +func (f *clientPool) totalTokenLimit() uint64 { + f.lock.Lock() + defer f.lock.Unlock() + + f.updateFullRatio() + d := 1 - f.fullRatio + if d > 0.5 { + d = -math.Log(0.5/d) * float64(fullRatioTC) + } else { + d = 0 + } + return uint64(d * float64(f.capLimit) * f.defaultPosFactors.capacityFactor) +} + +func (f *clientPool) totalTokenAmount() uint64 { + f.lock.Lock() + defer f.lock.Unlock() + + now := f.clock.Now() + if now > f.lastConnectedBalanceUpdate+mclock.AbsTime(time.Second) { + f.activeBalances = 0 + for _, c := range f.connectedMap { + pos, _ := c.balanceTracker.getBalance(now) + f.activeBalances += pos + } + f.lastConnectedBalanceUpdate = now + } + return f.activeBalances + f.inactiveBalances +} + // connect should be called after a successful handshake. If the connection was // rejected, there is no need to call disconnect. -func (f *clientPool) connect(peer clientPeer, capacity uint64) bool { +func (f *clientPool) connect(peer clientPeer, reqCapacity uint64) (uint64, error) { f.lock.Lock() defer f.lock.Unlock() // Short circuit if clientPool is already closed. if f.closed { - return false + return 0, fmt.Errorf("Client pool is already closed") } // Dedup connected peers. id, freeID := peer.ID(), peer.freeClientId() if _, ok := f.connectedMap[id]; ok { clientRejectedMeter.Mark(1) log.Debug("Client already connected", "address", freeID, "id", peerIdToString(id)) - return false + return 0, fmt.Errorf("Client already connected address = %s id = %s", freeID, peerIdToString(id)) } - // Create a clientInfo but do not add it yet - var ( - posBalance uint64 - negBalance uint64 - now = f.clock.Now() - ) pb := f.ndb.getOrNewPB(id) - posBalance = pb.value - nb := f.ndb.getOrNewNB(freeID) - if nb.logValue != 0 { - negBalance = uint64(math.Exp(float64(nb.logValue-f.logOffset(now))/fixedPointMultiplier) * float64(time.Second)) - } e := &clientInfo{ + capacity: reqCapacity, pool: f, peer: peer, address: freeID, queueIndex: -1, id: id, - connectedAt: now, - priority: posBalance != 0, + freeID: freeID, + connectedAt: f.clock.Now(), + priority: pb.value != 0, posFactors: f.defaultPosFactors, negFactors: f.defaultNegFactors, balanceMetaInfo: pb.meta, } - // If the client is a free client, assign with a low free capacity, - // Otherwise assign with the given value(priority client) - if !e.priority || capacity == 0 { - capacity = f.freeClientCap + missing, capacity := f.capAvailable(id, freeID, reqCapacity, 0, true) + f.connectedMap[id] = e + if missing != 0 { + // capacity is not available, add client to inactive queue + f.initBalanceTracker(&e.balanceTracker, pb, nb, capacity, false) + f.inactiveQueue.Push(e, -connPriority(e, f.clock.Now())) + return 0, nil } + // capacity is available, add client + e.active = true e.capacity = capacity - - // Starts a balance tracker - e.balanceTracker.init(f.clock, capacity) - e.balanceTracker.setBalance(posBalance, negBalance) - e.updatePriceFactors() - - // If the number of clients already connected in the clientpool exceeds its - // capacity, evict some clients with lowest priority. - // - // If the priority of the newly added client is lower than the priority of - // all connected clients, the client is rejected. - newCapacity := f.connectedCap + capacity - newCount := f.connectedQueue.Size() + 1 - if newCapacity > f.capLimit || newCount > f.connLimit { - var ( - kickList []*clientInfo - kickPriority int64 - ) - f.connectedQueue.MultiPop(func(data interface{}, priority int64) bool { - c := data.(*clientInfo) - kickList = append(kickList, c) - kickPriority = priority - newCapacity -= c.capacity - newCount-- - return newCapacity > f.capLimit || newCount > f.connLimit - }) - bias := connectedBias - if f.disableBias { - bias = 0 - } - if newCapacity > f.capLimit || newCount > f.connLimit || (e.balanceTracker.estimatedPriority(now+mclock.AbsTime(bias), false)-kickPriority) > 0 { - for _, c := range kickList { - f.connectedQueue.Push(c) - } - clientRejectedMeter.Mark(1) - log.Debug("Client rejected", "address", freeID, "id", peerIdToString(id)) - return false - } - // accept new client, drop old ones - for _, c := range kickList { - f.dropClient(c, now, true) - } - } - + f.initBalanceTracker(&e.balanceTracker, pb, nb, capacity, true) // Register new client to connection queue. - f.connectedMap[id] = e - f.connectedQueue.Push(e) - f.connectedCap += e.capacity + f.inactiveBalances -= pb.value + f.activeBalances += pb.value + f.activeQueue.Push(e) + f.activeCap += e.capacity // If the current client is a paid client, monitor the status of client, // downgrade it to normal client if positive balance is used up. if e.priority { - f.priorityConnected += capacity + f.updateFullRatio() + f.priorityActive += capacity e.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) }) } - // If the capacity of client is not the default value(free capacity), notify - // it to update capacity. - if e.capacity != f.freeClientCap { - e.peer.updateCapacity(e.capacity) - } - totalConnectedGauge.Update(int64(f.connectedCap)) + totalConnectedGauge.Update(int64(f.activeCap)) clientConnectedMeter.Mark(1) log.Debug("Client accepted", "address", freeID) - return true + return e.capacity, nil +} + +func (f *clientPool) initBalanceTracker(bt *balanceTracker, pb posBalance, nb negBalance, capacity uint64, active bool) { + posBalance := pb.value + var negBalance uint64 + if nb.logValue != 0 { + negBalance = uint64(math.Exp(float64(nb.logValue-f.logOffset(f.clock.Now()))/fixedPointMultiplier) * float64(time.Second)) + } + bt.init(f.clock, capacity) + bt.setBalance(posBalance, negBalance) + if active { + updatePriceFactors(bt, f.defaultPosFactors, f.defaultNegFactors, capacity) + } else { + zeroPriceFactors(bt) + } } // disconnect should be called when a connection is terminated. If the disconnection @@ -326,17 +406,100 @@ func (f *clientPool) disconnect(p clientPeer) { f.lock.Lock() defer f.lock.Unlock() + f.disconnectLocked(p) +} + +func (f *clientPool) disconnectLocked(p clientPeer) { // Short circuit if client pool is already closed. if f.closed { return } - // Short circuit if the peer hasn't been registered. - e := f.connectedMap[p.ID()] - if e == nil { + e, ok := f.connectedMap[p.ID()] + if !ok { log.Debug("Client not connected", "address", p.freeClientId(), "id", peerIdToString(p.ID())) return } - f.dropClient(e, f.clock.Now(), false) + tryActivate := e.active + if e.active { + f.deactivateClient(e, false) + } + f.finalizeBalance(e, f.clock.Now()) + f.inactiveQueue.Remove(e.queueIndex) + delete(f.connectedMap, e.id) + clientDisconnectedMeter.Mark(1) + log.Debug("Client disconnected", "address", e.address) + if tryActivate { + f.tryActivateClients() + } +} + +// capAvailable checks whether the current priority level of the given client is enough to +// connect or change capacity to the requested level and then stay connected for at least +// the specified duration. If not then the additional required amount of positive balance is returned. +func (f *clientPool) capAvailable(id enode.ID, freeID string, capacity uint64, minConnTime time.Duration, kick bool) (uint64, uint64) { + var missing uint64 + if capacity == 0 { + capacity = f.freeClientCap + } + if capacity < f.minCap { + capacity = f.minCap + } + newCapacity := f.activeCap + capacity + newCount := f.activeQueue.Size() + 1 + client := f.connectedMap[id] + if client != nil && client.active { + newCapacity -= client.capacity + newCount-- + } + if newCapacity > f.capLimit || newCount > f.activeLimit { + var ( + popList []*clientInfo + targetPriority int64 + ) + f.activeQueue.MultiPop(func(data interface{}, priority int64) bool { + c := data.(*clientInfo) + popList = append(popList, c) + if c != client { + targetPriority = priority + newCapacity -= c.capacity + newCount-- + } + return newCapacity > f.capLimit || newCount > f.activeLimit + }) + if newCapacity > f.capLimit || newCount > f.activeLimit { + missing = math.MaxUint64 + } else { + var bt *balanceTracker + if client != nil { + bt = &client.balanceTracker + } else { + bt = &balanceTracker{} + f.initBalanceTracker(bt, f.ndb.getOrNewPB(id), f.ndb.getOrNewNB(freeID), capacity, true) + } + if capacity != f.freeClientCap && targetPriority >= 0 { + targetPriority = -1 + } + bias := activeBias + if f.disableBias { + bias = 0 + } + if bias < minConnTime { + bias = minConnTime + } + missing = bt.posBalanceMissing(targetPriority, capacity, bias) + } + if missing != 0 { + kick = false + } + for _, c := range popList { + if kick && c != client { + f.deactivateClient(c, true) + } else { + f.activeQueue.Push(c) + } + } + } + return missing, capacity } // forClients iterates through a list of clients, calling the callback for each one. @@ -371,27 +534,58 @@ func (f *clientPool) setDefaultFactors(posFactors, negFactors priceFactors) { f.defaultNegFactors = negFactors } -// dropClient removes a client from the connected queue and finalizes its balance. -// If kick is true then it also initiates the disconnection. -func (f *clientPool) dropClient(e *clientInfo, now mclock.AbsTime, kick bool) { - if _, ok := f.connectedMap[e.id]; !ok { +func (f *clientPool) deactivateClient(e *clientInfo, scheduleDrop bool) { + if _, ok := f.connectedMap[e.id]; !ok || !e.active { return } - f.finalizeBalance(e, now) - f.connectedQueue.Remove(e.queueIndex) - delete(f.connectedMap, e.id) - f.connectedCap -= e.capacity + f.activeQueue.Remove(e.queueIndex) + f.activeCap -= e.capacity if e.priority { - f.priorityConnected -= e.capacity + f.updateFullRatio() + f.priorityActive -= e.capacity } - totalConnectedGauge.Update(int64(f.connectedCap)) - if kick { - clientKickedMeter.Mark(1) - log.Debug("Client kicked out", "address", e.address) - f.removePeer(e.id) - } else { - clientDisconnectedMeter.Mark(1) - log.Debug("Client disconnected", "address", e.address) + e.active = false + e.peer.updateCapacity(0) + totalConnectedGauge.Update(int64(f.activeCap)) + f.inactiveQueue.Push(e, -connPriority(e, f.clock.Now())) + if scheduleDrop { + f.dropInactivePeers[f.dropInactiveCounter+dropInactiveCycles] = append(f.dropInactivePeers[f.dropInactiveCounter+dropInactiveCycles], e) + } +} + +func (f *clientPool) tryActivateClients() { + now := f.clock.Now() + for f.inactiveQueue.Size() != 0 { + e := f.inactiveQueue.PopItem().(*clientInfo) + missing, capacity := f.capAvailable(e.id, e.freeID, e.capacity, 0, true) + if missing != 0 { + f.inactiveQueue.Push(e, -connPriority(e, now)) + return + } + // capacity is available, activate client + e.active = true + e.capacity = capacity + e.peer.updateCapacity(capacity) + balance, _ := e.balanceTracker.getBalance(now) + e.balanceTracker.setCapacity(capacity) + updatePriceFactors(&e.balanceTracker, f.defaultPosFactors, f.defaultNegFactors, capacity) + // Register activated client to connection queue. + f.inactiveBalances -= balance + f.activeBalances += balance + f.activeQueue.Push(e) + f.activeCap += e.capacity + + // If the current client is a paid client, monitor the status of client, + // downgrade it to normal client if positive balance is used up. + if e.priority { + f.updateFullRatio() + f.priorityActive += capacity + e.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(e.id) }) + } + e.peer.updateCapacity(e.capacity) + totalConnectedGauge.Update(int64(f.activeCap)) + clientConnectedMeter.Mark(1) + log.Debug("Client activated", "address", e.freeID) } } @@ -401,7 +595,7 @@ func (f *clientPool) capacityInfo() (uint64, uint64, uint64) { f.lock.Lock() defer f.lock.Unlock() - return f.capLimit, f.connectedCap, f.priorityConnected + return f.capLimit, f.activeCap, f.priorityActive } // finalizeBalance stops the balance tracker, retrieves the final balances and @@ -409,6 +603,8 @@ func (f *clientPool) capacityInfo() (uint64, uint64, uint64) { func (f *clientPool) finalizeBalance(c *clientInfo, now mclock.AbsTime) { c.balanceTracker.stop(now) pos, neg := c.balanceTracker.getBalance(now) + f.inactiveBalances += pos + f.activeBalances -= pos pb, nb := f.ndb.getOrNewPB(c.id), f.ndb.getOrNewNB(c.address) pb.value = pos @@ -434,12 +630,13 @@ func (f *clientPool) balanceExhausted(id enode.ID) { return } if c.priority { - f.priorityConnected -= c.capacity + f.updateFullRatio() + f.priorityActive -= c.capacity } c.priority = false if c.capacity != f.freeClientCap { - f.connectedCap += f.freeClientCap - c.capacity - totalConnectedGauge.Update(int64(f.connectedCap)) + f.activeCap += f.freeClientCap - c.capacity + totalConnectedGauge.Update(int64(f.activeCap)) c.capacity = f.freeClientCap c.balanceTracker.setCapacity(c.capacity) c.peer.updateCapacity(c.capacity) @@ -449,82 +646,75 @@ func (f *clientPool) balanceExhausted(id enode.ID) { f.ndb.setPB(id, pb) } -// setConnLimit sets the maximum number and total capacity of connected clients, +// setactiveLimit sets the maximum number and total capacity of connected clients, // dropping some of them if necessary. func (f *clientPool) setLimits(totalConn int, totalCap uint64) { f.lock.Lock() defer f.lock.Unlock() - f.connLimit = totalConn + f.updateFullRatio() + f.activeLimit = totalConn f.capLimit = totalCap - if f.connectedCap > f.capLimit || f.connectedQueue.Size() > f.connLimit { - f.connectedQueue.MultiPop(func(data interface{}, priority int64) bool { - f.dropClient(data.(*clientInfo), mclock.Now(), true) - return f.connectedCap > f.capLimit || f.connectedQueue.Size() > f.connLimit + if f.activeCap > f.capLimit || f.activeQueue.Size() > f.activeLimit { + f.activeQueue.MultiPop(func(data interface{}, priority int64) bool { + f.deactivateClient(data.(*clientInfo), true) + return f.activeCap > f.capLimit || f.activeQueue.Size() > f.activeLimit }) + } else { + f.tryActivateClients() } } // setCapacity sets the assigned capacity of a connected client -func (f *clientPool) setCapacity(c *clientInfo, capacity uint64) error { - if f.connectedMap[c.id] != c { - return fmt.Errorf("client %064x is not connected", c.id[:]) - } - if c.capacity == capacity { - return nil - } - if !c.priority { - return errNoPriority - } - oldCapacity := c.capacity - c.capacity = capacity - f.connectedCap += capacity - oldCapacity - c.balanceTracker.setCapacity(capacity) - f.connectedQueue.Update(c.queueIndex) - if f.connectedCap > f.capLimit { - var kickList []*clientInfo - kick := true - f.connectedQueue.MultiPop(func(data interface{}, priority int64) bool { - client := data.(*clientInfo) - kickList = append(kickList, client) - f.connectedCap -= client.capacity - if client == c { - kick = false - } - return kick && (f.connectedCap > f.capLimit) - }) - if kick { - now := mclock.Now() - for _, c := range kickList { - f.dropClient(c, now, true) - } - } else { - c.capacity = oldCapacity - c.balanceTracker.setCapacity(oldCapacity) - for _, c := range kickList { - f.connectedCap += c.capacity - f.connectedQueue.Push(c) - } - return errNoPriority +func (f *clientPool) setCapacity(id enode.ID, freeID string, capacity uint64, minConnTime time.Duration, setCap bool) (uint64, uint64, error) { + c := f.connectedMap[id] + if c != nil { + if c.capacity == capacity { + return 0, capacity, nil } } - totalConnectedGauge.Update(int64(f.connectedCap)) - f.priorityConnected += capacity - oldCapacity - c.updatePriceFactors() - c.peer.updateCapacity(c.capacity) - return nil + var missing uint64 + missing, capacity = f.capAvailable(id, freeID, capacity, 0, setCap && c != nil) + if missing != 0 { + return missing, capacity, errNoPriority + } + // capacity update is possible + if setCap { + if c == nil { + return 0, capacity, fmt.Errorf("client %064x is not connected", c.id[:]) + } + f.activeCap += capacity - c.capacity + f.updateFullRatio() + f.priorityActive += capacity - c.capacity + c.capacity = capacity + c.balanceTracker.setCapacity(capacity) + f.activeQueue.Update(c.queueIndex) + totalConnectedGauge.Update(int64(f.activeCap)) + updatePriceFactors(&c.balanceTracker, c.posFactors, c.negFactors, c.capacity) + c.peer.updateCapacity(c.capacity) + f.tryActivateClients() + } + return 0, capacity, nil } -// requestCost feeds request cost after serving a request from the given peer. -func (f *clientPool) requestCost(p *peer, cost uint64) { +func (f *clientPool) setCapacityLocked(id enode.ID, freeID string, capacity uint64, minConnTime time.Duration, setCap bool) (uint64, uint64, error) { f.lock.Lock() defer f.lock.Unlock() - info, exist := f.connectedMap[p.ID()] - if !exist || f.closed { - return + return f.setCapacity(id, freeID, capacity, minConnTime, setCap) +} + +// requestCost feeds request cost after serving a request from the given peer and +// returns the remaining token balance +func (f *clientPool) requestCost(p *peer, cost uint64) uint64 { + f.lock.Lock() + defer f.lock.Unlock() + + c := f.connectedMap[p.ID()] + if c == nil || f.closed { + return 0 } - info.balanceTracker.requestCost(cost) + return c.balanceTracker.requestCost(cost) } // logOffset calculates the time-dependent offset for the logarithmic @@ -539,10 +729,15 @@ func (f *clientPool) logOffset(now mclock.AbsTime) int64 { return f.cumulativeTime + cumulativeTime } -// setClientPriceFactors sets the pricing factors for an individual connected client -func (c *clientInfo) updatePriceFactors() { - c.balanceTracker.setFactors(true, c.negFactors.timeFactor+float64(c.capacity)*c.negFactors.capacityFactor/1000000, c.negFactors.requestFactor) - c.balanceTracker.setFactors(false, c.posFactors.timeFactor+float64(c.capacity)*c.posFactors.capacityFactor/1000000, c.posFactors.requestFactor) +// updatePriceFactors sets the pricing factors for an individual connected client +func updatePriceFactors(bt *balanceTracker, posFactors, negFactors priceFactors, capacity uint64) { + bt.setFactors(true, negFactors.timeFactor+float64(capacity)*negFactors.capacityFactor/1000000, negFactors.requestFactor) + bt.setFactors(false, posFactors.timeFactor+float64(capacity)*posFactors.capacityFactor/1000000, posFactors.requestFactor) +} + +func zeroPriceFactors(bt *balanceTracker) { + bt.setFactors(true, 0, 0) + bt.setFactors(false, 0, 0) } // getPosBalance retrieves a single positive balance entry from cache or the database @@ -550,7 +745,12 @@ func (f *clientPool) getPosBalance(id enode.ID) posBalance { f.lock.Lock() defer f.lock.Unlock() - return f.ndb.getOrNewPB(id) + if c := f.connectedMap[id]; c != nil { + pb, _ := c.balanceTracker.getBalance(mclock.Now()) + return posBalance{value: pb, meta: c.balanceMetaInfo} + } else { + return f.ndb.getOrNewPB(id) + } } // addBalance updates the balance of a client (either overwrites it or adds to it). @@ -582,18 +782,32 @@ func (f *clientPool) addBalance(id enode.ID, amount int64, meta string) (uint64, f.ndb.setPB(id, pb) if c != nil { c.balanceTracker.setBalance(pb.value, negBalance) - if !c.priority && pb.value > 0 { - // The capacity should be adjusted based on the requirement, - // but we have no idea about the new capacity, need a second - // call to udpate it. + if c.active { + f.activeQueue.Update(c.queueIndex) + if !c.priority && pb.value > 0 { + // The capacity should be adjusted based on the requirement, + // but we have no idea about the new capacity, need a second + // call to udpate it. + f.updateFullRatio() + f.priorityActive += c.capacity + c.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) }) + } + c.balanceMetaInfo = meta + f.activeBalances += pb.value - oldBalance + } else { + f.inactiveQueue.Remove(c.queueIndex) + f.inactiveQueue.Push(c, -connPriority(c, f.clock.Now())) + f.inactiveBalances += pb.value - oldBalance + } + if pb.value > 0 { c.priority = true - f.priorityConnected += c.capacity - c.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) }) + // if balance is set to zero then reverting to non-priority status + // is handled by the balanceExhausted callback } - // if balance is set to zero then reverting to non-priority status - // is handled by the balanceExhausted callback - c.balanceMetaInfo = meta + } else { + f.inactiveBalances += pb.value - oldBalance } + f.tryActivateClients() return oldBalance, pb.value, nil } diff --git a/les/clientpool_test.go b/les/clientpool_test.go index 06f782ac96ef..5b4170ec001a 100644 --- a/les/clientpool_test.go +++ b/les/clientpool_test.go @@ -56,29 +56,34 @@ func TestClientPoolL100C300P20(t *testing.T) { const testClientPoolTicks = 100000 -type poolTestPeer int - -func (i poolTestPeer) ID() enode.ID { - return enode.ID{byte(i % 256), byte(i >> 8)} +type poolTestPeer struct { + index int + disconnCh chan int + cap uint64 } -func (i poolTestPeer) freeClientId() string { - return fmt.Sprintf("addr #%d", i) +func newPoolTestPeer(i int, disconnCh chan int) *poolTestPeer { + return &poolTestPeer{index: i, disconnCh: disconnCh} } -func (i poolTestPeer) updateCapacity(uint64) {} - -type poolTestPeerWithCap struct { - poolTestPeer +func (i *poolTestPeer) ID() enode.ID { + return enode.ID{byte(i.index % 256), byte(i.index >> 8)} +} - cap uint64 +func (i *poolTestPeer) freeClientId() string { + return fmt.Sprintf("addr #%d", i) } -func (i *poolTestPeerWithCap) updateCapacity(cap uint64) { i.cap = cap } +func (i *poolTestPeer) updateCapacity(cap uint64) { + i.cap = cap + if cap == 0 && i.disconnCh != nil { + i.disconnCh <- i.index + } +} -func (i poolTestPeer) freezeClient() {} +func (i *poolTestPeer) freezeClient() {} -func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomDisconnect bool) { +func testClientPool(t *testing.T, activeLimit, clientCount, paidCount int, randomDisconnect bool) { rand.Seed(time.Now().UnixNano()) var ( clock mclock.Simulated @@ -89,15 +94,16 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD disconnFn = func(id enode.ID) { disconnCh <- int(id[0]) + int(id[1])<<8 } - pool = newClientPool(db, 1, &clock, disconnFn) + pool = newClientPool(db, 1, 1, &clock, disconnFn) ) + pool.disableBias = true - pool.setLimits(connLimit, uint64(connLimit)) + pool.setLimits(activeLimit, uint64(activeLimit)) pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) // pool should accept new peers up to its connected limit - for i := 0; i < connLimit; i++ { - if pool.connect(poolTestPeer(i), 0) { + for i := 0; i < activeLimit; i++ { + if cap, _ := pool.connect(newPoolTestPeer(i, disconnCh), 0); cap != 0 { connected[i] = true } else { t.Fatalf("Test peer #%d rejected", i) @@ -111,28 +117,30 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD // give a positive balance to some of the peers amount := testClientPoolTicks / 2 * int64(time.Second) // enough for half of the simulation period for i := 0; i < paidCount; i++ { - pool.addBalance(poolTestPeer(i).ID(), amount, "") + pool.addBalance(newPoolTestPeer(i, disconnCh).ID(), amount, "") } } i := rand.Intn(clientCount) if connected[i] { if randomDisconnect { - pool.disconnect(poolTestPeer(i)) + pool.disconnect(newPoolTestPeer(i, disconnCh)) connected[i] = false connTicks[i] += tickCounter } } else { - if pool.connect(poolTestPeer(i), 0) { + if cap, _ := pool.connect(newPoolTestPeer(i, disconnCh), 0); cap != 0 { connected[i] = true connTicks[i] -= tickCounter + } else { + pool.disconnect(newPoolTestPeer(i, disconnCh)) } } pollDisconnects: for { select { case i := <-disconnCh: - pool.disconnect(poolTestPeer(i)) + pool.disconnect(newPoolTestPeer(i, disconnCh)) if connected[i] { connTicks[i] += tickCounter connected[i] = false @@ -143,10 +151,10 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD } } - expTicks := testClientPoolTicks/2*connLimit/clientCount + testClientPoolTicks/2*(connLimit-paidCount)/(clientCount-paidCount) + expTicks := testClientPoolTicks/2*activeLimit/clientCount + testClientPoolTicks/2*(activeLimit-paidCount)/(clientCount-paidCount) expMin := expTicks - expTicks/5 expMax := expTicks + expTicks/5 - paidTicks := testClientPoolTicks/2*connLimit/clientCount + testClientPoolTicks/2 + paidTicks := testClientPoolTicks/2*activeLimit/clientCount + testClientPoolTicks/2 paidMin := paidTicks - paidTicks/5 paidMax := paidTicks + paidTicks/5 @@ -172,15 +180,15 @@ func TestConnectPaidClient(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, &clock, nil) + pool := newClientPool(db, 1, 1, &clock, nil) defer pool.stop() pool.setLimits(10, uint64(10)) pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) // Add balance for an external client and mark it as paid client - pool.addBalance(poolTestPeer(0).ID(), 1000, "") + pool.addBalance(newPoolTestPeer(0, nil).ID(), 1000, "") - if !pool.connect(poolTestPeer(0), 10) { + if cap, _ := pool.connect(newPoolTestPeer(0, nil), 10); cap == 0 { t.Fatalf("Failed to connect paid client") } } @@ -190,16 +198,16 @@ func TestConnectPaidClientToSmallPool(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, &clock, nil) + pool := newClientPool(db, 1, 1, &clock, nil) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) // Add balance for an external client and mark it as paid client - pool.addBalance(poolTestPeer(0).ID(), 1000, "") + pool.addBalance(newPoolTestPeer(0, nil).ID(), 1000, "") // Connect a fat paid client to pool, should reject it. - if pool.connect(poolTestPeer(0), 100) { + if cap, _ := pool.connect(newPoolTestPeer(0, nil), 100); cap != 0 { t.Fatalf("Connected fat paid client, should reject it") } } @@ -210,23 +218,23 @@ func TestConnectPaidClientToFullPool(t *testing.T) { db = rawdb.NewMemoryDatabase() ) removeFn := func(enode.ID) {} // Noop - pool := newClientPool(db, 1, &clock, removeFn) + pool := newClientPool(db, 1, 1, &clock, removeFn) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) for i := 0; i < 10; i++ { - pool.addBalance(poolTestPeer(i).ID(), 1000000000, "") - pool.connect(poolTestPeer(i), 1) + pool.addBalance(newPoolTestPeer(i, nil).ID(), 1000000000, "") + pool.connect(newPoolTestPeer(i, nil), 1) } - pool.addBalance(poolTestPeer(11).ID(), 1000, "") // Add low balance to new paid client - if pool.connect(poolTestPeer(11), 1) { + pool.addBalance(newPoolTestPeer(11, nil).ID(), 1000, "") // Add low balance to new paid client + if cap, _ := pool.connect(newPoolTestPeer(11, nil), 1); cap != 0 { t.Fatalf("Low balance paid client should be rejected") } clock.Run(time.Second) - pool.addBalance(poolTestPeer(12).ID(), 1000000000*60*3, "") // Add high balance to new paid client - if !pool.connect(poolTestPeer(12), 1) { - t.Fatalf("High balance paid client should be accpected") + pool.addBalance(newPoolTestPeer(12, nil).ID(), 1000000000*60*3+1, "") // Add high balance to new paid client + if cap, _ := pool.connect(newPoolTestPeer(12, nil), 1); cap == 0 { + t.Fatalf("High balance paid client should be accepted") } } @@ -237,19 +245,19 @@ func TestPaidClientKickedOut(t *testing.T) { kickedCh = make(chan int, 1) ) removeFn := func(id enode.ID) { kickedCh <- int(id[0]) } - pool := newClientPool(db, 1, &clock, removeFn) + pool := newClientPool(db, 1, 1, &clock, removeFn) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) for i := 0; i < 10; i++ { - pool.addBalance(poolTestPeer(i).ID(), 1000000000, "") // 1 second allowance - pool.connect(poolTestPeer(i), 1) + pool.addBalance(newPoolTestPeer(i, kickedCh).ID(), 1000000000, "") // 1 second allowance + pool.connect(newPoolTestPeer(i, kickedCh), 1) clock.Run(time.Millisecond) } clock.Run(time.Second) - clock.Run(connectedBias) - if !pool.connect(poolTestPeer(11), 0) { + clock.Run(activeBias) + if cap, _ := pool.connect(newPoolTestPeer(11, kickedCh), 0); cap == 0 { t.Fatalf("Free client should be accectped") } select { @@ -267,11 +275,11 @@ func TestConnectFreeClient(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, &clock, nil) + pool := newClientPool(db, 1, 1, &clock, nil) defer pool.stop() pool.setLimits(10, uint64(10)) pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) - if !pool.connect(poolTestPeer(0), 10) { + if cap, _ := pool.connect(newPoolTestPeer(0, nil), 10); cap == 0 { t.Fatalf("Failed to connect free client") } } @@ -282,24 +290,24 @@ func TestConnectFreeClientToFullPool(t *testing.T) { db = rawdb.NewMemoryDatabase() ) removeFn := func(enode.ID) {} // Noop - pool := newClientPool(db, 1, &clock, removeFn) + pool := newClientPool(db, 1, 1, &clock, removeFn) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) for i := 0; i < 10; i++ { - pool.connect(poolTestPeer(i), 1) + pool.connect(newPoolTestPeer(i, nil), 1) } - if pool.connect(poolTestPeer(11), 1) { + if cap, _ := pool.connect(newPoolTestPeer(11, nil), 1); cap != 0 { t.Fatalf("New free client should be rejected") } clock.Run(time.Minute) - if pool.connect(poolTestPeer(12), 1) { + if cap, _ := pool.connect(newPoolTestPeer(12, nil), 1); cap != 0 { t.Fatalf("New free client should be rejected") } clock.Run(time.Millisecond) clock.Run(4 * time.Minute) - if !pool.connect(poolTestPeer(13), 1) { + if cap, _ := pool.connect(newPoolTestPeer(13, nil), 1); cap == 0 { t.Fatalf("Old client connects more than 5min should be kicked") } } @@ -311,21 +319,22 @@ func TestFreeClientKickedOut(t *testing.T) { kicked = make(chan int, 10) ) removeFn := func(id enode.ID) { kicked <- int(id[0]) } - pool := newClientPool(db, 1, &clock, removeFn) + pool := newClientPool(db, 1, 1, &clock, removeFn) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) for i := 0; i < 10; i++ { - pool.connect(poolTestPeer(i), 1) + pool.connect(newPoolTestPeer(i, kicked), 1) clock.Run(time.Millisecond) } - if pool.connect(poolTestPeer(10), 1) { + if cap, _ := pool.connect(newPoolTestPeer(10, kicked), 1); cap != 0 { t.Fatalf("New free client should be rejected") } + pool.disconnect(newPoolTestPeer(10, kicked)) clock.Run(5 * time.Minute) for i := 0; i < 10; i++ { - pool.connect(poolTestPeer(i+10), 1) + pool.connect(newPoolTestPeer(i+10, kicked), 1) } for i := 0; i < 10; i++ { select { @@ -346,17 +355,17 @@ func TestPositiveBalanceCalculation(t *testing.T) { kicked = make(chan int, 10) ) removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop - pool := newClientPool(db, 1, &clock, removeFn) + pool := newClientPool(db, 1, 1, &clock, removeFn) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) - pool.addBalance(poolTestPeer(0).ID(), int64(time.Minute*3), "") - pool.connect(poolTestPeer(0), 10) + pool.addBalance(newPoolTestPeer(0, kicked).ID(), int64(time.Minute*3), "") + pool.connect(newPoolTestPeer(0, kicked), 10) clock.Run(time.Minute) - pool.disconnect(poolTestPeer(0)) - pb := pool.ndb.getOrNewPB(poolTestPeer(0).ID()) + pool.disconnect(newPoolTestPeer(0, kicked)) + pb := pool.ndb.getOrNewPB(newPoolTestPeer(0, kicked).ID()) if pb.value != uint64(time.Minute*2) { t.Fatalf("Positive balance mismatch, want %v, got %v", uint64(time.Minute*2), pb.value) } @@ -369,16 +378,14 @@ func TestDowngradePriorityClient(t *testing.T) { kicked = make(chan int, 10) ) removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop - pool := newClientPool(db, 1, &clock, removeFn) + pool := newClientPool(db, 1, 1, &clock, removeFn) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) - p := &poolTestPeerWithCap{ - poolTestPeer: poolTestPeer(0), - } + p := newPoolTestPeer(0, kicked) pool.addBalance(p.ID(), int64(time.Minute), "") - pool.connect(p, 10) + p.cap, _ = pool.connect(p, 10) if p.cap != 10 { t.Fatalf("The capcacity of priority peer hasn't been updated, got: %d", p.cap) } @@ -388,13 +395,13 @@ func TestDowngradePriorityClient(t *testing.T) { if p.cap != 1 { t.Fatalf("The capcacity of peer should be downgraded, got: %d", p.cap) } - pb := pool.ndb.getOrNewPB(poolTestPeer(0).ID()) + pb := pool.ndb.getOrNewPB(newPoolTestPeer(0, kicked).ID()) if pb.value != 0 { t.Fatalf("Positive balance mismatch, want %v, got %v", 0, pb.value) } - pool.addBalance(poolTestPeer(0).ID(), int64(time.Minute), "") - pb = pool.ndb.getOrNewPB(poolTestPeer(0).ID()) + pool.addBalance(newPoolTestPeer(0, kicked).ID(), int64(time.Minute), "") + pb = pool.ndb.getOrNewPB(newPoolTestPeer(0, kicked).ID()) if pb.value != uint64(time.Minute) { t.Fatalf("Positive balance mismatch, want %v, got %v", uint64(time.Minute), pb.value) } @@ -402,36 +409,34 @@ func TestDowngradePriorityClient(t *testing.T) { func TestNegativeBalanceCalculation(t *testing.T) { var ( - clock mclock.Simulated - db = rawdb.NewMemoryDatabase() - kicked = make(chan int, 10) + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() ) - removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop - pool := newClientPool(db, 1, &clock, removeFn) + pool := newClientPool(db, 1, 1, &clock, nil) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) for i := 0; i < 10; i++ { - pool.connect(poolTestPeer(i), 1) + pool.connect(newPoolTestPeer(i, nil), 1) } clock.Run(time.Second) for i := 0; i < 10; i++ { - pool.disconnect(poolTestPeer(i)) - nb := pool.ndb.getOrNewNB(poolTestPeer(i).freeClientId()) + pool.disconnect(newPoolTestPeer(i, nil)) + nb := pool.ndb.getOrNewNB(newPoolTestPeer(i, nil).freeClientId()) if nb.logValue != 0 { t.Fatalf("Short connection shouldn't be recorded") } } for i := 0; i < 10; i++ { - pool.connect(poolTestPeer(i), 1) + pool.connect(newPoolTestPeer(i, nil), 1) } clock.Run(time.Minute) for i := 0; i < 10; i++ { - pool.disconnect(poolTestPeer(i)) - nb := pool.ndb.getOrNewNB(poolTestPeer(i).freeClientId()) + pool.disconnect(newPoolTestPeer(i, nil)) + nb := pool.ndb.getOrNewNB(newPoolTestPeer(i, nil).freeClientId()) nb.logValue -= pool.logOffset(clock.Now()) nb.logValue /= fixedPointMultiplier if nb.logValue != int64(math.Log(float64(time.Minute/time.Second))) { @@ -541,3 +546,83 @@ func TestNodeDBExpiration(t *testing.T) { t.Fatalf("Failed to evict useless negative balances, want %v, got %d", 4, iterated) } } + +func TestInactiveClient(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + ) + pool := newClientPool(db, 1, 1, &clock, nil) + defer pool.stop() + pool.setLimits(2, uint64(2)) // Total capacity limit is 10 + + p1 := newPoolTestPeer(1, nil) + p2 := newPoolTestPeer(2, nil) + p3 := newPoolTestPeer(3, nil) + pool.addBalance(p1.ID(), 1000, "") + pool.addBalance(p3.ID(), 2000, "") + // p1: 1000 p2: 0 p3: 2000 + p1.cap, _ = pool.connect(p1, 1) + if p1.cap != 1 { + t.Fatalf("Failed to connect peer #1") + } + p2.cap, _ = pool.connect(p2, 1) + if p2.cap != 1 { + t.Fatalf("Failed to connect peer #2") + } + p3.cap, _ = pool.connect(p3, 1) + if p3.cap != 1 { + t.Fatalf("Failed to connect peer #3") + } + if p2.cap != 0 { + t.Fatalf("Failed to deactivate peer #2") + } + pool.addBalance(p2.ID(), 3000, "") + // p1: 1000 p2: 3000 p3: 2000 + if p2.cap != 1 { + t.Fatalf("Failed to activate peer #2") + } + if p1.cap != 0 { + t.Fatalf("Failed to deactivate peer #1") + } + pool.addBalance(p2.ID(), -2500, "") + // p1: 1000 p2: 500 p3: 2000 + if p1.cap != 1 { + t.Fatalf("Failed to activate peer #1") + } + if p2.cap != 0 { + t.Fatalf("Failed to deactivate peer #2") + } + pool.setDefaultFactors(priceFactors{1e-9, 0, 0}, priceFactors{1e-9, 0, 0}) + p4 := newPoolTestPeer(4, nil) + pool.addBalance(p4.ID(), 1500, "") + // p1: 1000 p2: 500 p3: 2000 p4: 1500 + p4.cap, _ = pool.connect(p4, 1) + if p4.cap != 1 { + t.Fatalf("Failed to activate peer #4") + } + if p1.cap != 0 { + t.Fatalf("Failed to deactivate peer #1") + } + clock.Run(time.Second * 600) + // manually trigger a check to avoid a long real-time wait + pool.lock.Lock() + pool.tryActivateClients() + pool.lock.Unlock() + // p1: 1000 p2: 500 p3: 2000 p4: 900 + if p1.cap != 1 { + t.Fatalf("Failed to activate peer #1") + } + if p4.cap != 0 { + t.Fatalf("Failed to deactivate peer #4") + } + pool.disconnect(p2) + pool.disconnect(p4) + pool.addBalance(p1.ID(), -1000, "") + if p1.cap != 1 { + t.Fatalf("Should not deactivate peer #1") + } + if p2.cap != 0 { + t.Fatalf("Should not activate peer #2") + } +} diff --git a/les/costtracker.go b/les/costtracker.go index 81da04566007..abf8618ffde9 100644 --- a/les/costtracker.go +++ b/les/costtracker.go @@ -159,15 +159,9 @@ func newCostTracker(db ethdb.Database, config *eth.Config) (*costTracker, uint64 } ct.gfLoop() costList := ct.makeCostList(ct.globalFactor() * 1.25) - for _, c := range costList { - amount := minBufferReqAmount[c.MsgCode] - cost := c.BaseCost + amount*c.ReqCost - if cost > ct.minBufLimit { - ct.minBufLimit = cost - } - } - ct.minBufLimit *= uint64(minBufferMultiplier) - return ct, (ct.minBufLimit-1)/bufLimitRatio + 1 + var minRecharge uint64 + ct.minBufLimit, minRecharge = costList.decode(ProtocolLengths[ServerProtocolVersions[len(ServerProtocolVersions)-1]]).reqParams() + return ct, minRecharge } // stop stops the cost tracker and saves the cost factor statistics to the database @@ -480,6 +474,22 @@ func (table requestCostTable) getMaxCost(code, amount uint64) uint64 { return costs.baseCost + amount*costs.reqCost } +func (table requestCostTable) reqParams() (minRecharge, minBufLimit uint64) { + for code, c := range table { + amount := minBufferReqAmount[code] + cost := c.baseCost + amount*c.reqCost + if cost > minBufLimit { + minBufLimit = cost + } + } + minBufLimit *= uint64(minBufferMultiplier) + if minBufLimit < 1 { + minBufLimit = 1 + } + minRecharge = (minBufLimit-1)/bufLimitRatio + 1 + return +} + // decode converts a cost list to a cost table func (list RequestCostList) decode(protocolLength uint64) requestCostTable { table := make(requestCostTable) diff --git a/les/flowcontrol/control.go b/les/flowcontrol/control.go index 490013677c63..de636231ad08 100644 --- a/les/flowcontrol/control.go +++ b/les/flowcontrol/control.go @@ -185,6 +185,13 @@ func (node *ClientNode) UpdateParams(params ServerParams) { } } +func (node *ClientNode) Params() ServerParams { + node.lock.Lock() + defer node.lock.Unlock() + + return node.params +} + // updateParams updates the flow control parameters of the node func (node *ClientNode) updateParams(params ServerParams, now mclock.AbsTime) { diff := int64(params.BufLimit - node.params.BufLimit) diff --git a/les/handler_test.go b/les/handler_test.go index aad8d18e45c0..08796b49e803 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -38,20 +38,29 @@ import ( "github.com/ethereum/go-ethereum/trie" ) -func expectResponse(r p2p.MsgReader, msgcode, reqID, bv uint64, data interface{}) error { +func expectResponse(r p2p.MsgReader, protocol int, msgcode, reqID, bv, cost uint64, data interface{}) error { type resp struct { - ReqID, BV uint64 - Data interface{} + ReqID uint64 + SF stateFeedback + Data interface{} } - return p2p.ExpectMsg(r, msgcode, resp{reqID, bv, data}) + sf := stateFeedback{ + protocolVersion: protocol, + stateFeedbackV4: stateFeedbackV4{ + BV: bv, + RealCost: cost, + TokenBalance: 0, + }, + } + return p2p.ExpectMsg(r, msgcode, resp{reqID, sf, data}) } // Tests that block headers can be retrieved from a remote chain based on user queries. -func TestGetBlockHeadersLes2(t *testing.T) { testGetBlockHeaders(t, 2) } func TestGetBlockHeadersLes3(t *testing.T) { testGetBlockHeaders(t, 3) } +func TestGetBlockHeadersLes4(t *testing.T) { testGetBlockHeaders(t, 4) } func testGetBlockHeaders(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -170,18 +179,18 @@ func testGetBlockHeaders(t *testing.T, protocol int) { cost := server.peer.peer.GetRequestCost(GetBlockHeadersMsg, int(tt.query.Amount)) sendRequest(server.peer.app, GetBlockHeadersMsg, reqID, cost, tt.query) - if err := expectResponse(server.peer.app, BlockHeadersMsg, reqID, testBufLimit, headers); err != nil { + if err := expectResponse(server.peer.app, protocol, BlockHeadersMsg, reqID, testBufLimit, cost, headers); err != nil { t.Errorf("test %d: headers mismatch: %v", i, err) } } } // Tests that block contents can be retrieved from a remote chain based on their hashes. -func TestGetBlockBodiesLes2(t *testing.T) { testGetBlockBodies(t, 2) } func TestGetBlockBodiesLes3(t *testing.T) { testGetBlockBodies(t, 3) } +func TestGetBlockBodiesLes4(t *testing.T) { testGetBlockBodies(t, 4) } func testGetBlockBodies(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, downloader.MaxBlockFetch+15, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, downloader.MaxBlockFetch+15, protocol, nil, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -248,19 +257,19 @@ func testGetBlockBodies(t *testing.T, protocol int) { // Send the hash request and verify the response cost := server.peer.peer.GetRequestCost(GetBlockBodiesMsg, len(hashes)) sendRequest(server.peer.app, GetBlockBodiesMsg, reqID, cost, hashes) - if err := expectResponse(server.peer.app, BlockBodiesMsg, reqID, testBufLimit, bodies); err != nil { + if err := expectResponse(server.peer.app, protocol, BlockBodiesMsg, reqID, testBufLimit, cost, bodies); err != nil { t.Errorf("test %d: bodies mismatch: %v", i, err) } } } // Tests that the contract codes can be retrieved based on account addresses. -func TestGetCodeLes2(t *testing.T) { testGetCode(t, 2) } func TestGetCodeLes3(t *testing.T) { testGetCode(t, 3) } +func TestGetCodeLes4(t *testing.T) { testGetCode(t, 4) } func testGetCode(t *testing.T, protocol int) { // Assemble the test environment - server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -280,17 +289,17 @@ func testGetCode(t *testing.T, protocol int) { cost := server.peer.peer.GetRequestCost(GetCodeMsg, len(codereqs)) sendRequest(server.peer.app, GetCodeMsg, 42, cost, codereqs) - if err := expectResponse(server.peer.app, CodeMsg, 42, testBufLimit, codes); err != nil { + if err := expectResponse(server.peer.app, protocol, CodeMsg, 42, testBufLimit, cost, codes); err != nil { t.Errorf("codes mismatch: %v", err) } } // Tests that the stale contract codes can't be retrieved based on account addresses. -func TestGetStaleCodeLes2(t *testing.T) { testGetStaleCode(t, 2) } func TestGetStaleCodeLes3(t *testing.T) { testGetStaleCode(t, 3) } +func TestGetStaleCodeLes4(t *testing.T) { testGetStaleCode(t, 4) } func testGetStaleCode(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -301,7 +310,7 @@ func testGetStaleCode(t *testing.T, protocol int) { } cost := server.peer.peer.GetRequestCost(GetCodeMsg, 1) sendRequest(server.peer.app, GetCodeMsg, 42, cost, []*CodeReq{req}) - if err := expectResponse(server.peer.app, CodeMsg, 42, testBufLimit, expected); err != nil { + if err := expectResponse(server.peer.app, protocol, CodeMsg, 42, testBufLimit, cost, expected); err != nil { t.Errorf("codes mismatch: %v", err) } } @@ -311,12 +320,12 @@ func testGetStaleCode(t *testing.T, protocol int) { } // Tests that the transaction receipts can be retrieved based on hashes. -func TestGetReceiptLes2(t *testing.T) { testGetReceipt(t, 2) } func TestGetReceiptLes3(t *testing.T) { testGetReceipt(t, 3) } +func TestGetReceiptLes4(t *testing.T) { testGetReceipt(t, 4) } func testGetReceipt(t *testing.T, protocol int) { // Assemble the test environment - server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -333,18 +342,18 @@ func testGetReceipt(t *testing.T, protocol int) { // Send the hash request and verify the response cost := server.peer.peer.GetRequestCost(GetReceiptsMsg, len(hashes)) sendRequest(server.peer.app, GetReceiptsMsg, 42, cost, hashes) - if err := expectResponse(server.peer.app, ReceiptsMsg, 42, testBufLimit, receipts); err != nil { + if err := expectResponse(server.peer.app, protocol, ReceiptsMsg, 42, testBufLimit, cost, receipts); err != nil { t.Errorf("receipts mismatch: %v", err) } } // Tests that trie merkle proofs can be retrieved -func TestGetProofsLes2(t *testing.T) { testGetProofs(t, 2) } func TestGetProofsLes3(t *testing.T) { testGetProofs(t, 3) } +func TestGetProofsLes4(t *testing.T) { testGetProofs(t, 4) } func testGetProofs(t *testing.T, protocol int) { // Assemble the test environment - server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -369,17 +378,17 @@ func testGetProofs(t *testing.T, protocol int) { // Send the proof request and verify the response cost := server.peer.peer.GetRequestCost(GetProofsV2Msg, len(proofreqs)) sendRequest(server.peer.app, GetProofsV2Msg, 42, cost, proofreqs) - if err := expectResponse(server.peer.app, ProofsV2Msg, 42, testBufLimit, proofsV2.NodeList()); err != nil { + if err := expectResponse(server.peer.app, protocol, ProofsV2Msg, 42, testBufLimit, cost, proofsV2.NodeList()); err != nil { t.Errorf("proofs mismatch: %v", err) } } // Tests that the stale contract codes can't be retrieved based on account addresses. -func TestGetStaleProofLes2(t *testing.T) { testGetStaleProof(t, 2) } func TestGetStaleProofLes3(t *testing.T) { testGetStaleProof(t, 3) } +func TestGetStaleProofLes4(t *testing.T) { testGetStaleProof(t, 4) } func testGetStaleProof(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -402,7 +411,7 @@ func testGetStaleProof(t *testing.T, protocol int) { t.Prove(account, 0, proofsV2) expected = proofsV2.NodeList() } - if err := expectResponse(server.peer.app, ProofsV2Msg, 42, testBufLimit, expected); err != nil { + if err := expectResponse(server.peer.app, protocol, ProofsV2Msg, 42, testBufLimit, cost, expected); err != nil { t.Errorf("codes mismatch: %v", err) } } @@ -412,8 +421,8 @@ func testGetStaleProof(t *testing.T, protocol int) { } // Tests that CHT proofs can be correctly retrieved. -func TestGetCHTProofsLes2(t *testing.T) { testGetCHTProofs(t, 2) } func TestGetCHTProofsLes3(t *testing.T) { testGetCHTProofs(t, 3) } +func TestGetCHTProofsLes4(t *testing.T) { testGetCHTProofs(t, 4) } func testGetCHTProofs(t *testing.T, protocol int) { config := light.TestServerIndexerConfig @@ -427,7 +436,7 @@ func testGetCHTProofs(t *testing.T, protocol int) { time.Sleep(10 * time.Millisecond) } } - server, tearDown := newServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, false, true, 0) + server, tearDown := newServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -455,13 +464,13 @@ func testGetCHTProofs(t *testing.T, protocol int) { // Send the proof request and verify the response cost := server.peer.peer.GetRequestCost(GetHelperTrieProofsMsg, len(requestsV2)) sendRequest(server.peer.app, GetHelperTrieProofsMsg, 42, cost, requestsV2) - if err := expectResponse(server.peer.app, HelperTrieProofsMsg, 42, testBufLimit, proofsV2); err != nil { + if err := expectResponse(server.peer.app, protocol, HelperTrieProofsMsg, 42, testBufLimit, cost, proofsV2); err != nil { t.Errorf("proofs mismatch: %v", err) } } -func TestGetBloombitsProofsLes2(t *testing.T) { testGetBloombitsProofs(t, 2) } func TestGetBloombitsProofsLes3(t *testing.T) { testGetBloombitsProofs(t, 3) } +func TestGetBloombitsProofsLes4(t *testing.T) { testGetBloombitsProofs(t, 4) } // Tests that bloombits proofs can be correctly retrieved. func testGetBloombitsProofs(t *testing.T, protocol int) { @@ -476,7 +485,7 @@ func testGetBloombitsProofs(t *testing.T, protocol int) { time.Sleep(10 * time.Millisecond) } } - server, tearDown := newServerEnv(t, int(config.BloomTrieSize+config.BloomTrieConfirms), protocol, waitIndexers, false, true, 0) + server, tearDown := newServerEnv(t, int(config.BloomTrieSize+config.BloomTrieConfirms), protocol, waitIndexers, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -504,17 +513,17 @@ func testGetBloombitsProofs(t *testing.T, protocol int) { // Send the proof request and verify the response cost := server.peer.peer.GetRequestCost(GetHelperTrieProofsMsg, len(requests)) sendRequest(server.peer.app, GetHelperTrieProofsMsg, 42, cost, requests) - if err := expectResponse(server.peer.app, HelperTrieProofsMsg, 42, testBufLimit, proofs); err != nil { + if err := expectResponse(server.peer.app, protocol, HelperTrieProofsMsg, 42, testBufLimit, cost, proofs); err != nil { t.Errorf("bit %d: proofs mismatch: %v", bit, err) } } } -func TestTransactionStatusLes2(t *testing.T) { testTransactionStatus(t, 2) } func TestTransactionStatusLes3(t *testing.T) { testTransactionStatus(t, 3) } +func TestTransactionStatusLes4(t *testing.T) { testTransactionStatus(t, 4) } func testTransactionStatus(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, 0, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, 0, protocol, nil, false, true, 0, true) defer tearDown() server.handler.addTxsSync = true @@ -524,14 +533,15 @@ func testTransactionStatus(t *testing.T, protocol int) { test := func(tx *types.Transaction, send bool, expStatus light.TxStatus) { reqID++ + var cost uint64 if send { - cost := server.peer.peer.GetRequestCost(SendTxV2Msg, 1) + cost = server.peer.peer.GetRequestCost(SendTxV2Msg, 1) sendRequest(server.peer.app, SendTxV2Msg, reqID, cost, types.Transactions{tx}) } else { - cost := server.peer.peer.GetRequestCost(GetTxStatusMsg, 1) + cost = server.peer.peer.GetRequestCost(GetTxStatusMsg, 1) sendRequest(server.peer.app, GetTxStatusMsg, reqID, cost, []common.Hash{tx.Hash()}) } - if err := expectResponse(server.peer.app, TxStatusMsg, reqID, testBufLimit, []light.TxStatus{expStatus}); err != nil { + if err := expectResponse(server.peer.app, protocol, TxStatusMsg, reqID, testBufLimit, cost, []light.TxStatus{expStatus}); err != nil { t.Errorf("transaction status mismatch") } } @@ -606,8 +616,11 @@ func testTransactionStatus(t *testing.T, protocol int) { test(tx2, false, light.TxStatus{Status: core.TxStatusPending}) } -func TestStopResumeLes3(t *testing.T) { - server, tearDown := newServerEnv(t, 0, 3, nil, true, true, testBufLimit/10) +func TestStopResumeLes3(t *testing.T) { testStopResume(t, 3) } +func TestStopResumeLes4(t *testing.T) { testStopResume(t, 4) } + +func testStopResume(t *testing.T, protocol int) { + server, tearDown := newServerEnv(t, 0, protocol, nil, true, true, testBufLimit/10, true) defer tearDown() server.handler.server.costTracker.testing = true @@ -627,7 +640,7 @@ func TestStopResumeLes3(t *testing.T) { for expBuf >= testCost { req() expBuf -= testCost - if err := expectResponse(server.peer.app, BlockHeadersMsg, reqID, expBuf, []*types.Header{header}); err != nil { + if err := expectResponse(server.peer.app, protocol, BlockHeadersMsg, reqID, expBuf, testCost, []*types.Header{header}); err != nil { t.Errorf("expected response and failed: %v", err) } } @@ -646,7 +659,15 @@ func TestStopResumeLes3(t *testing.T) { // expect a ResumeMsg with the partially recharged buffer value expBuf += testBufRecharge * wait - if err := p2p.ExpectMsg(server.peer.app, ResumeMsg, expBuf); err != nil { + sf := stateFeedback{ + protocolVersion: protocol, + stateFeedbackV4: stateFeedbackV4{ + BV: expBuf, + RealCost: 0, + TokenBalance: 0, + }, + } + if err := p2p.ExpectMsg(server.peer.app, ResumeMsg, sf); err != nil { t.Errorf("expected ResumeMsg and failed: %v", err) } } diff --git a/les/metrics.go b/les/metrics.go index 9ef8c365180c..12780346b65c 100644 --- a/les/metrics.go +++ b/les/metrics.go @@ -40,6 +40,8 @@ var ( miscInTxsTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/txs", nil) miscInTxStatusPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/txStatus", nil) miscInTxStatusTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/txStatus", nil) + miscInLespayPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/lespay", nil) + miscInLespayTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/lespay", nil) miscOutPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/total", nil) miscOutTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/total", nil) @@ -59,6 +61,8 @@ var ( miscOutTxsTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/txs", nil) miscOutTxStatusPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/txStatus", nil) miscOutTxStatusTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/txStatus", nil) + miscOutLespayPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/lespay", nil) + miscOutLespayTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/lespay", nil) miscServingTimeHeaderTimer = metrics.NewRegisteredTimer("les/misc/serve/header", nil) miscServingTimeBodyTimer = metrics.NewRegisteredTimer("les/misc/serve/body", nil) @@ -68,6 +72,7 @@ var ( miscServingTimeHelperTrieTimer = metrics.NewRegisteredTimer("les/misc/serve/helperTrie", nil) miscServingTimeTxTimer = metrics.NewRegisteredTimer("les/misc/serve/txs", nil) miscServingTimeTxStatusTimer = metrics.NewRegisteredTimer("les/misc/serve/txStatus", nil) + miscServingTimeLespayTimer = metrics.NewRegisteredTimer("les/misc/serve/lespay", nil) connectionTimer = metrics.NewRegisteredTimer("les/connection/duration", nil) serverConnectionGauge = metrics.NewRegisteredGauge("les/connection/server", nil) diff --git a/les/odr_test.go b/les/odr_test.go index 7d10878226d0..b6836a057fe2 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -38,8 +38,8 @@ import ( type odrTestFn func(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte -func TestOdrGetBlockLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetBlock) } func TestOdrGetBlockLes3(t *testing.T) { testOdr(t, 3, 1, true, odrGetBlock) } +func TestOdrGetBlockLes4(t *testing.T) { testOdr(t, 4, 1, true, odrGetBlock) } func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var block *types.Block @@ -55,8 +55,8 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon return rlp } -func TestOdrGetReceiptsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetReceipts) } func TestOdrGetReceiptsLes3(t *testing.T) { testOdr(t, 3, 1, true, odrGetReceipts) } +func TestOdrGetReceiptsLes4(t *testing.T) { testOdr(t, 4, 1, true, odrGetReceipts) } func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var receipts types.Receipts @@ -76,8 +76,8 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.Chain return rlp } -func TestOdrAccountsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrAccounts) } func TestOdrAccountsLes3(t *testing.T) { testOdr(t, 3, 1, true, odrAccounts) } +func TestOdrAccountsLes4(t *testing.T) { testOdr(t, 4, 1, true, odrAccounts) } func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") @@ -105,8 +105,8 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon return res } -func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, true, odrContractCall) } func TestOdrContractCallLes3(t *testing.T) { testOdr(t, 3, 2, true, odrContractCall) } +func TestOdrContractCallLes4(t *testing.T) { testOdr(t, 4, 2, true, odrContractCall) } type callmsg struct { types.Message @@ -155,8 +155,8 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai return res } -func TestOdrTxStatusLes2(t *testing.T) { testOdr(t, 2, 1, false, odrTxStatus) } func TestOdrTxStatusLes3(t *testing.T) { testOdr(t, 3, 1, false, odrTxStatus) } +func TestOdrTxStatusLes4(t *testing.T) { testOdr(t, 4, 1, false, odrTxStatus) } func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var txs types.Transactions @@ -236,7 +236,7 @@ func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn od // still expect all retrievals to pass, now data should be cached locally if checkCached { - client.handler.backend.peers.Unregister(client.peer.peer.id) + client.handler.backend.peers.Disconnect(client.peer.peer.id) time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed test(5) } diff --git a/les/peer.go b/les/peer.go index 801147df6baf..602ede7c9399 100644 --- a/les/peer.go +++ b/les/peer.go @@ -101,6 +101,9 @@ type peer struct { responseCount uint64 invalidCount uint32 + active bool + activate, deactivate func() + poolEntry *poolEntry hasBlock func(common.Hash, uint64, bool) bool responseErrors int @@ -113,6 +116,8 @@ type peer struct { fcParams flowcontrol.ServerParams fcCosts requestCostTable + getBalance func() posBalance + trusted, server bool onlyAnnounce bool chainSince, chainRecent uint64 @@ -198,7 +203,19 @@ func (p *peer) freezeClient() { time.Sleep(freezeCheckPeriod) } else { atomic.StoreUint32(&p.frozen, 0) - p.SendResume(bufValue) + var balance uint64 + if p.getBalance != nil { + balance = p.getBalance().value + } + sf := stateFeedback{ + protocolVersion: p.version, + stateFeedbackV4: stateFeedbackV4{ + BV: bufValue, + RealCost: 0, + TokenBalance: balance, + }, + } + p.SendResume(sf) break } } @@ -283,12 +300,20 @@ func (p *peer) updateCapacity(cap uint64) { p.responseLock.Lock() defer p.responseLock.Unlock() - p.fcParams = flowcontrol.ServerParams{MinRecharge: cap, BufLimit: cap * bufLimitRatio} - p.fcClient.UpdateParams(p.fcParams) - var kvList keyValueList - kvList = kvList.add("flowControl/MRR", cap) - kvList = kvList.add("flowControl/BL", cap*bufLimitRatio) - p.queueSend(func() { p.SendAnnounce(announceData{Update: kvList}) }) + if !p.active && cap != 0 && p.activate != nil { + p.activate() + } + if cap != 0 || p.version >= lpv4 { + p.fcParams = flowcontrol.ServerParams{MinRecharge: cap, BufLimit: cap * bufLimitRatio} + p.fcClient.UpdateParams(p.fcParams) + var kvList keyValueList + kvList = kvList.add("flowControl/BL", cap*bufLimitRatio) + kvList = kvList.add("flowControl/MRR", cap) + p.queueSend(func() { p.SendAnnounce(announceData{Update: kvList}) }) + } + if p.active && cap == 0 && p.deactivate != nil { + p.deactivate() + } } func (p *peer) responseID() uint64 { @@ -314,12 +339,13 @@ type reply struct { } // send sends the reply with the calculated buffer value -func (r *reply) send(bv uint64) error { +func (r *reply) send(sf stateFeedback) error { type resp struct { - ReqID, BV uint64 - Data rlp.RawValue + ReqID uint64 + SF stateFeedback + Data rlp.RawValue } - return p2p.Send(r.w, r.msgcode, resp{r.reqID, bv, r.data}) + return p2p.Send(r.w, r.msgcode, resp{r.reqID, sf, r.data}) } // size returns the RLP encoded size of the message data @@ -394,8 +420,8 @@ func (p *peer) SendStop() error { } // SendResume notifies the client about getting out of frozen state -func (p *peer) SendResume(bv uint64) error { - return p2p.Send(p.rw, ResumeMsg, bv) +func (p *peer) SendResume(sf stateFeedback) error { + return p2p.Send(p.rw, ResumeMsg, sf) } // ReplyBlockHeaders creates a reply with a batch of block headers @@ -501,6 +527,18 @@ func (p *peer) SendTxs(reqID, cost uint64, txs rlp.RawValue) error { return sendRequest(p.rw, SendTxV2Msg, reqID, cost, txs) } +// SendLespay sends a set of commands to the service token sale module +func (p *peer) SendLespay(reqID uint64, cmd []byte) error { + p.Log().Debug("Sending batch of lespay commands", "size", len(cmd)) + return sendRequest(p.rw, LespayMsg, reqID, 0, cmd) +} + +// ReplyLespay sends a set of replies to lespay commands +func (p *peer) ReplyLespay(reqID uint64, reply []byte) error { + p.Log().Debug("Sending batch of lespay replies", "size", len(reply)) + return sendRequest(p.rw, LespayReplyMsg, reqID, 0, reply) +} + type keyValueEntry struct { Key string Value rlp.RawValue @@ -601,8 +639,15 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis send = send.add("serveRecentState", stateRecent) send = send.add("txRelay", nil) } - send = send.add("flowControl/BL", server.defParams.BufLimit) - send = send.add("flowControl/MRR", server.defParams.MinRecharge) + + p.active = p.version < lpv4 + if p.active { + p.fcParams = server.defParams + } else { + p.fcParams = flowcontrol.ServerParams{} + } + send = send.add("flowControl/BL", p.fcParams.BufLimit) + send = send.add("flowControl/MRR", p.fcParams.MinRecharge) var costList RequestCostList if server.costTracker.testCostList != nil { @@ -612,7 +657,6 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis } send = send.add("flowControl/MRC", costList) p.fcCosts = costList.decode(ProtocolLengths[uint(p.version)]) - p.fcParams = server.defParams // Add advertised checkpoint and register block height which // client can verify the checkpoint validity. @@ -683,7 +727,7 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis // set default announceType on server side p.announceType = announceTypeSimple } - p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams) + p.fcClient = flowcontrol.NewClientNode(server.fcManager, p.fcParams) } } else { if recv.get("serveChainSince", &p.chainSince) != nil { @@ -720,6 +764,7 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis p.fcParams = sParams p.fcServer = flowcontrol.NewServerNode(sParams, &mclock.System{}) p.fcCosts = MRC.decode(ProtocolLengths[uint(p.version)]) + p.active = p.paramsUseful() recv.get("checkpoint/value", &p.checkpoint) recv.get("checkpoint/registerHeight", &p.checkpointNumber) @@ -745,10 +790,12 @@ func (p *peer) updateFlowControl(update keyValueMap) { } // If any of the flow control params is nil, refuse to update. var params flowcontrol.ServerParams + updated := false if update.get("flowControl/BL", ¶ms.BufLimit) == nil && update.get("flowControl/MRR", ¶ms.MinRecharge) == nil { // todo can light client set a minimal acceptable flow control params? p.fcParams = params p.fcServer.UpdateParams(params) + updated = true } var MRC RequestCostList if update.get("flowControl/MRC", &MRC) == nil { @@ -756,9 +803,18 @@ func (p *peer) updateFlowControl(update keyValueMap) { for code, cost := range costUpdate { p.fcCosts[code] = cost } + updated = true + } + if updated { + p.active = p.paramsUseful() } } +func (p *peer) paramsUseful() bool { + reqRecharge, reqBufLimit := p.fcCosts.reqParams() + return p.fcParams.MinRecharge >= reqRecharge && p.fcParams.BufLimit >= reqBufLimit +} + // String implements fmt.Stringer. func (p *peer) String() string { return fmt.Sprintf("Peer %s [%s]", p.id, @@ -776,16 +832,17 @@ type peerSetNotify interface { // peerSet represents the collection of active peers currently participating in // the Light Ethereum sub-protocol. type peerSet struct { - peers map[string]*peer - lock sync.RWMutex - notifyList []peerSetNotify - closed bool + active, inactive map[string]*peer + lock sync.RWMutex + notifyList []peerSetNotify + closed bool } // newPeerSet creates a new peer set to track the active participants. func newPeerSet() *peerSet { return &peerSet{ - peers: make(map[string]*peer), + active: make(map[string]*peer), + inactive: make(map[string]*peer), } } @@ -793,8 +850,8 @@ func newPeerSet() *peerSet { func (ps *peerSet) notify(n peerSetNotify) { ps.lock.Lock() ps.notifyList = append(ps.notifyList, n) - peers := make([]*peer, 0, len(ps.peers)) - for _, p := range ps.peers { + peers := make([]*peer, 0, len(ps.active)) + for _, p := range ps.active { peers = append(peers, p) } ps.lock.Unlock() @@ -812,12 +869,17 @@ func (ps *peerSet) Register(p *peer) error { ps.lock.Unlock() return errClosed } - if _, ok := ps.peers[p.id]; ok { + if _, ok := ps.active[p.id]; ok { ps.lock.Unlock() return errAlreadyRegistered } - ps.peers[p.id] = p - p.sendQueue = newExecQueue(100) + if _, ok := ps.inactive[p.id]; ok { + delete(ps.inactive, p.id) + } else { + p.sendQueue = newExecQueue(100) + } + ps.active[p.id] = p + peers := make([]peerSetNotify, len(ps.notifyList)) copy(peers, ps.notifyList) ps.lock.Unlock() @@ -829,14 +891,15 @@ func (ps *peerSet) Register(p *peer) error { } // Unregister removes a remote peer from the active set, disabling any further -// actions to/from that particular entity. It also initiates disconnection at the networking layer. -func (ps *peerSet) Unregister(id string) error { +// actions to/from that particular entity. +func (ps *peerSet) Unregister(p *peer) error { ps.lock.Lock() - if p, ok := ps.peers[id]; !ok { + if _, ok := ps.active[p.id]; !ok { ps.lock.Unlock() return errNotRegistered } else { - delete(ps.peers, id) + delete(ps.active, p.id) + ps.inactive[p.id] = p peers := make([]peerSetNotify, len(ps.notifyList)) copy(peers, ps.notifyList) ps.lock.Unlock() @@ -844,22 +907,47 @@ func (ps *peerSet) Unregister(id string) error { for _, n := range peers { n.unregisterPeer(p) } + return nil + } +} - p.sendQueue.quit() - p.Peer.Disconnect(p2p.DiscUselessPeer) +// Disconnect removes a remote peer from either the active or inactive set and +// initiates disconnection at the networking layer. +func (ps *peerSet) Disconnect(id string) error { + ps.lock.Lock() - return nil + var ( + peers []peerSetNotify + p *peer + ok bool + ) + if p, ok = ps.active[id]; ok { + delete(ps.active, p.id) + peers = make([]peerSetNotify, len(ps.notifyList)) + copy(peers, ps.notifyList) + } else if p, ok = ps.inactive[id]; ok { + delete(ps.inactive, id) + } else { + ps.lock.Unlock() + return errNotRegistered + } + ps.lock.Unlock() + for _, n := range peers { + n.unregisterPeer(p) } + p.sendQueue.quit() + p.Peer.Disconnect(p2p.DiscUselessPeer) + return nil } -// AllPeerIDs returns a list of all registered peer IDs +// AllPeerIDs returns a list of all active peer IDs func (ps *peerSet) AllPeerIDs() []string { ps.lock.RLock() defer ps.lock.RUnlock() - res := make([]string, len(ps.peers)) + res := make([]string, len(ps.active)) idx := 0 - for id := range ps.peers { + for id := range ps.active { res[idx] = id idx++ } @@ -871,15 +959,18 @@ func (ps *peerSet) Peer(id string) *peer { ps.lock.RLock() defer ps.lock.RUnlock() - return ps.peers[id] + if p, ok := ps.active[id]; ok { + return p + } + return ps.inactive[id] } -// Len returns if the current number of peers in the set. +// Len returns if the current number of peers in the active set. func (ps *peerSet) Len() int { ps.lock.RLock() defer ps.lock.RUnlock() - return len(ps.peers) + return len(ps.active) } // BestPeer retrieves the known peer with the currently highest total difficulty. @@ -891,7 +982,7 @@ func (ps *peerSet) BestPeer() *peer { bestPeer *peer bestTd *big.Int ) - for _, p := range ps.peers { + for _, p := range ps.active { if td := p.Td(); bestPeer == nil || td.Cmp(bestTd) > 0 { bestPeer, bestTd = p, td } @@ -899,14 +990,14 @@ func (ps *peerSet) BestPeer() *peer { return bestPeer } -// AllPeers returns all peers in a list +// AllPeers returns all active peers in a list func (ps *peerSet) AllPeers() []*peer { ps.lock.RLock() defer ps.lock.RUnlock() - list := make([]*peer, len(ps.peers)) + list := make([]*peer, len(ps.active)) i := 0 - for _, peer := range ps.peers { + for _, peer := range ps.active { list[i] = peer i++ } @@ -919,7 +1010,10 @@ func (ps *peerSet) Close() { ps.lock.Lock() defer ps.lock.Unlock() - for _, p := range ps.peers { + for _, p := range ps.active { + p.Disconnect(p2p.DiscQuitting) + } + for _, p := range ps.inactive { p.Disconnect(p2p.DiscQuitting) } ps.closed = true diff --git a/les/protocol.go b/les/protocol.go index 36af88aea6d0..f1fd3520a6da 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -33,17 +33,18 @@ import ( const ( lpv2 = 2 lpv3 = 3 + lpv4 = 4 ) // Supported versions of the les protocol (first is primary) var ( - ClientProtocolVersions = []uint{lpv2, lpv3} - ServerProtocolVersions = []uint{lpv2, lpv3} + ClientProtocolVersions = []uint{lpv2, lpv3, lpv4} + ServerProtocolVersions = []uint{lpv2, lpv3, lpv4} AdvertiseProtocolVersions = []uint{lpv2} // clients are searching for the first advertised protocol in the list ) // Number of implemented message corresponding to different protocol versions. -var ProtocolLengths = map[uint]uint64{lpv2: 22, lpv3: 24} +var ProtocolLengths = map[uint]uint64{lpv2: 22, lpv3: 24, lpv4: 26} const ( NetworkId = 1 @@ -74,6 +75,9 @@ const ( // Protocol messages introduced in LPV3 StopMsg = 0x16 ResumeMsg = 0x17 + // Protocol messages introduced in LPV4 + LespayMsg = 0x18 + LespayReplyMsg = 0x19 ) type requestInfo struct { @@ -235,3 +239,28 @@ func (hn *hashOrNumber) DecodeRLP(s *rlp.Stream) error { type CodeData []struct { Value []byte } + +type stateFeedbackV4 struct { + BV, RealCost, TokenBalance uint64 +} + +type stateFeedback struct { + protocolVersion int + stateFeedbackV4 +} + +func (sf stateFeedback) EncodeRLP(w io.Writer) error { + if sf.protocolVersion >= lpv4 { + return rlp.Encode(w, sf.stateFeedbackV4) + } else { + return rlp.Encode(w, sf.BV) + } +} + +func (sf *stateFeedback) DecodeRLP(s *rlp.Stream) error { + if sf.protocolVersion >= lpv4 { + return s.Decode(&sf.stateFeedbackV4) + } else { + return s.Decode(&sf.BV) + } +} diff --git a/les/request_test.go b/les/request_test.go index 8d09703c57ef..9037fd5423ec 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -36,22 +36,22 @@ func secAddr(addr common.Address) []byte { type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest -func TestBlockAccessLes2(t *testing.T) { testAccess(t, 2, tfBlockAccess) } func TestBlockAccessLes3(t *testing.T) { testAccess(t, 3, tfBlockAccess) } +func TestBlockAccessLes4(t *testing.T) { testAccess(t, 4, tfBlockAccess) } func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.BlockRequest{Hash: bhash, Number: number} } -func TestReceiptsAccessLes2(t *testing.T) { testAccess(t, 2, tfReceiptsAccess) } func TestReceiptsAccessLes3(t *testing.T) { testAccess(t, 3, tfReceiptsAccess) } +func TestReceiptsAccessLes4(t *testing.T) { testAccess(t, 4, tfReceiptsAccess) } func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.ReceiptsRequest{Hash: bhash, Number: number} } -func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) } func TestTrieEntryAccessLes3(t *testing.T) { testAccess(t, 3, tfTrieEntryAccess) } +func TestTrieEntryAccessLes4(t *testing.T) { testAccess(t, 4, tfTrieEntryAccess) } func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { if number := rawdb.ReadHeaderNumber(db, bhash); number != nil { @@ -60,8 +60,8 @@ func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) ligh return nil } -func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) } func TestCodeAccessLes3(t *testing.T) { testAccess(t, 3, tfCodeAccess) } +func TestCodeAccessLes4(t *testing.T) { testAccess(t, 4, tfCodeAccess) } func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrRequest { number := rawdb.ReadHeaderNumber(db, bhash) diff --git a/les/retrieve.go b/les/retrieve.go index c806117902ad..a0d494b6c390 100644 --- a/les/retrieve.go +++ b/les/retrieve.go @@ -345,7 +345,7 @@ func (r *sentReq) tryRequest() { if hrto { pp.Log().Debug("Request timed out hard") if r.rm.peers != nil { - r.rm.peers.Unregister(pp.id) + r.rm.peers.Disconnect(pp.id) } } diff --git a/les/server.go b/les/server.go index e68903dd81da..9d9fba3876ff 100644 --- a/les/server.go +++ b/les/server.go @@ -42,6 +42,7 @@ type LesServer struct { handler *serverHandler lesTopics []discv5.Topic privateKey *ecdsa.PrivateKey + srvr *p2p.Server // Flow control and capacity management fcManager *flowcontrol.ClientManager @@ -49,6 +50,7 @@ type LesServer struct { defParams flowcontrol.ServerParams servingQueue *servingQueue clientPool *clientPool + tokenSale *tokenSale minCapacity, maxCapacity, freeCapacity uint64 threadsIdle int // Request serving threads count when system is idle. @@ -114,8 +116,9 @@ func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) { srv.maxCapacity = totalRecharge } srv.fcManager.SetCapacityLimits(srv.freeCapacity, srv.maxCapacity, srv.freeCapacity*2) - srv.clientPool = newClientPool(srv.chainDb, srv.freeCapacity, mclock.System{}, func(id enode.ID) { go srv.peers.Unregister(peerIdToString(id)) }) + srv.clientPool = newClientPool(srv.chainDb, srv.minCapacity, srv.freeCapacity, mclock.System{}, func(id enode.ID) { go srv.peers.Disconnect(peerIdToString(id)) }) srv.clientPool.setDefaultFactors(priceFactors{0, 1, 1}, priceFactors{0, 1, 1}) + srv.tokenSale = newTokenSale(srv.clientPool, 0.1) checkpoint := srv.latestLocalCheckpoint() if !checkpoint.Empty() { @@ -146,6 +149,12 @@ func (s *LesServer) APIs() []rpc.API { Service: NewPrivateDebugAPI(s), Public: false, }, + { + Namespace: "lespay", + Version: "1.0", + Service: NewPrivateLespayAPI(s.lesCommons.peers, nil, s.srvr.DiscV5, s.tokenSale), + Public: false, + }, } } @@ -165,6 +174,7 @@ func (s *LesServer) Protocols() []p2p.Protocol { // Start starts the LES server func (s *LesServer) Start(srvr *p2p.Server) { + s.srvr = srvr s.privateKey = srvr.PrivateKey s.handler.start() @@ -172,6 +182,7 @@ func (s *LesServer) Start(srvr *p2p.Server) { go s.capacityManagement() if srvr.DiscV5 != nil { + srvr.DiscV5.RegisterTalkHandler("lespay", s.handler.talkRequestHandler) for _, topic := range s.lesTopics { topic := topic go func() { @@ -189,6 +200,11 @@ func (s *LesServer) Start(srvr *p2p.Server) { func (s *LesServer) Stop() { close(s.closeCh) + if s.srvr.DiscV5 != nil { + s.srvr.DiscV5.RemoveTalkHandler("lespay") + } + s.tokenSale.stop() + // Disconnect existing sessions. // This also closes the gate for any new registrations on the peer set. // sessions which are already established but not added to pm.peers yet diff --git a/les/server_handler.go b/les/server_handler.go index 4b505c2bc490..f1a798eb676d 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -20,6 +20,9 @@ import ( "encoding/binary" "encoding/json" "errors" + "fmt" + "net" + "reflect" "sync" "sync/atomic" "time" @@ -35,6 +38,7 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -54,10 +58,7 @@ const ( MaxTxStatus = 256 // Amount of transactions to queried per request ) -var ( - errTooManyInvalidRequest = errors.New("too many invalid requests made") - errFullClientPool = errors.New("client pool is full") -) +var errTooManyInvalidRequest = errors.New("too many invalid requests made") // serverHandler is responsible for serving light client and process // all incoming light requests. @@ -102,6 +103,9 @@ func (h *serverHandler) stop() { // runPeer is the p2p protocol run function for the given version. func (h *serverHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error { peer := newPeer(int(version), h.server.config.NetworkId, false, p, newMeteredMsgWriter(rw, int(version))) + peer.getBalance = func() posBalance { + return h.server.clientPool.getPosBalance(p.ID()) + } h.wg.Add(1) defer h.wg.Done() return h.handle(peer) @@ -128,32 +132,62 @@ func (h *serverHandler) handle(p *peer) error { } // Reject light clients if server is not synced. if !h.synced() { - return p2p.DiscRequested + //return p2p.DiscRequested } defer p.fcClient.Disconnect() - // Disconnect the inbound peer if it's rejected by clientPool - if !h.server.clientPool.connect(p, 0) { - p.Log().Debug("Light Ethereum peer registration failed", "err", errFullClientPool) - return errFullClientPool + var ( + connectedAt mclock.AbsTime + wg *sync.WaitGroup // Wait group used to track all in-flight task routines. + ) + p.activate = func() { + // Register the peer locally + if err := h.server.peers.Register(p); err != nil { + h.server.clientPool.disconnect(p) + p.Log().Error("Light Ethereum peer registration failed", "err", err) + return + } + clientConnectionGauge.Update(int64(h.server.peers.Len())) + connectedAt = mclock.Now() + wg = new(sync.WaitGroup) + p.active = true } - // Register the peer locally - if err := h.server.peers.Register(p); err != nil { - h.server.clientPool.disconnect(p) - p.Log().Error("Light Ethereum peer registration failed", "err", err) - return err + p.deactivate = func() { + h.server.peers.Unregister(p) + if p.version < lpv4 { + h.server.peers.Disconnect(p.id) + } + clientConnectionGauge.Update(int64(h.server.peers.Len())) + connectionTimer.Update(time.Duration(mclock.Now() - connectedAt)) + p.active = false + } + if p.active { + p.activate() } - clientConnectionGauge.Update(int64(h.server.peers.Len())) - var wg sync.WaitGroup // Wait group used to track all in-flight task routines. + if capacity, err := h.server.clientPool.connect(p, 0); err != nil { + // Disconnect the inbound peer if it's rejected by clientPool + p.Log().Debug("Light Ethereum peer registration failed", "err", err) + return err + } else if capacity != p.fcParams.MinRecharge { + if p.version < lpv4 { + h.server.peers.Disconnect(p.id) + } else { + p.updateCapacity(capacity) + } + } - connectedAt := mclock.Now() defer func() { wg.Wait() // Ensure all background task routines have exited. - h.server.peers.Unregister(p.id) h.server.clientPool.disconnect(p) - clientConnectionGauge.Update(int64(h.server.peers.Len())) - connectionTimer.Update(time.Duration(mclock.Now() - connectedAt)) + p.responseLock.Lock() + if p.active { + p.deactivate() + } + p.activate = nil + p.deactivate = nil + p.responseLock.Unlock() + h.server.peers.Disconnect(p.id) }() // Spawn a main loop to handle all incoming messages. @@ -164,7 +198,7 @@ func (h *serverHandler) handle(p *peer) error { return err default: } - if err := h.handleMsg(p, &wg); err != nil { + if err := h.handleMsg(p, wg); err != nil { p.Log().Debug("Light Ethereum message handling failed", "err", err) return err } @@ -243,22 +277,33 @@ func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error { if reply != nil { replySize = reply.size() } - var realCost uint64 + var realCost, balance uint64 if h.server.costTracker.testing { realCost = maxCost // Assign a fake cost for testing purpose } else { realCost = h.server.costTracker.realCost(servingTime, msg.Size, replySize) + if realCost > maxCost { + realCost = maxCost + } } bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost) if amount != 0 { // Feed cost tracker request serving statistic. h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost) // Reduce priority "balance" for the specific peer. - h.server.clientPool.requestCost(p, realCost) + balance = h.server.clientPool.requestCost(p, realCost) + } + sf := stateFeedback{ + protocolVersion: p.version, + stateFeedbackV4: stateFeedbackV4{ + BV: bv, + RealCost: realCost, + TokenBalance: balance, + }, } if reply != nil { p.queueSend(func() { - if err := reply.send(bv); err != nil { + if err := reply.send(sf); err != nil { select { case p.errCh <- err: default: @@ -373,7 +418,7 @@ func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error { first = false } reply := p.ReplyBlockHeaders(req.ReqID, headers) - sendResponse(req.ReqID, query.Amount, p.ReplyBlockHeaders(req.ReqID, headers), task.done()) + sendResponse(req.ReqID, query.Amount, reply, task.done()) if metrics.EnabledExpensive { miscOutHeaderPacketsMeter.Mark(1) miscOutHeaderTrafficMeter.Mark(int64(reply.size())) @@ -822,6 +867,38 @@ func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error { } }() } + case LespayMsg: + p.Log().Trace("Received transaction status query request") + if metrics.EnabledExpensive { + miscInLespayPacketsMeter.Mark(1) + miscInLespayTrafficMeter.Mark(int64(msg.Size)) + defer func(start time.Time) { miscServingTimeLespayTimer.UpdateSince(start) }(time.Now()) + } + var req struct { + ReqID uint64 + Cmd []byte + } + if err := msg.Decode(&req); err != nil { + clientErrorMeter.Mark(1) + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + if !h.server.tokenSale.queueCommand(p.id, tokenCmd{ + cmd: req.Cmd, + id: p.ID(), + freeID: p.freeClientId(), + send: func(reply []byte) { + if metrics.EnabledExpensive { + miscOutLespayPacketsMeter.Mark(1) + miscOutLespayTrafficMeter.Mark(int64(len(reply))) + } + p.queueSend(func() { + p.ReplyLespay(req.ReqID, reply) + }) + }, + }, p.fcClient.Params().MinRecharge) { + clientErrorMeter.Mark(1) + return errResp(ErrRequestRejected, "") + } default: p.Log().Trace("Received invalid message", "code", msg.Code) @@ -953,3 +1030,41 @@ func (h *serverHandler) broadcastHeaders() { } } } + +func (h *serverHandler) talkRequestHandler(id enode.ID, addr *net.UDPAddr, payload interface{}) (interface{}, bool) { + fmt.Println("talkRequestHandler", id, addr, payload, reflect.TypeOf(payload)) + c, ok := payload.([]interface{}) + if !ok { + return nil, false + } + resultCh := make(chan []byte, len(c)) + results := make([][]byte, len(c)) + for _, c := range c { + cmd, ok := c.([]byte) + if !ok { + fmt.Println("type err", reflect.TypeOf(c)) + return nil, false + } + if !h.server.tokenSale.queueCommand(id.String(), tokenCmd{ + cmd: cmd, + id: id, + freeID: addr.IP.String(), + send: func(reply []byte) { + resultCh <- reply + }, + }, h.server.freeCapacity) { + fmt.Println("failed to queue") + return nil, false + } + } + + for i, _ := range results { + select { + case results[i] = <-resultCh: + case <-h.closeCh: + return nil, false + } + } + fmt.Println("results", results) + return results, true +} diff --git a/les/servingqueue.go b/les/servingqueue.go index 8842cf9e9d55..487e6046a501 100644 --- a/les/servingqueue.go +++ b/les/servingqueue.go @@ -70,7 +70,7 @@ type runToken chan struct{} // start blocks until the task can start and returns true if it is allowed to run. // Returning false means that the task should be cancelled. func (t *servingTask) start() bool { - if t.peer.isFrozen() { + if t.peer != nil && t.peer.isFrozen() { return false } t.tokenCh = make(chan runToken, 1) @@ -289,7 +289,7 @@ func (sq *servingQueue) addTask(task *servingTask) { sq.queuedTime += task.expTime sqServedGauge.Update(int64(sq.recentTime)) sqQueuedGauge.Update(int64(sq.queuedTime)) - if sq.recentTime+sq.queuedTime > sq.burstLimit { + if sq.burstLimit != 0 && sq.recentTime+sq.queuedTime > sq.burstLimit { sq.freezePeers() } } diff --git a/les/test_helper.go b/les/test_helper.go index ee3d7a32e136..3fd0844cbaa4 100644 --- a/les/test_helper.go +++ b/les/test_helper.go @@ -77,10 +77,10 @@ var ( processConfirms = big.NewInt(1) // The token bucket buffer limit for testing purpose. - testBufLimit = uint64(1000000) + testBufLimit = uint64(6000) // The buffer recharging speed for testing purpose. - testBufRecharge = uint64(1000) + testBufRecharge = uint64(1) ) /* @@ -280,7 +280,7 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da } server.costTracker, server.freeCapacity = newCostTracker(db, server.config) server.costTracker.testCostList = testCostList(0) // Disable flow control mechanism. - server.clientPool = newClientPool(db, 1, clock, nil) + server.clientPool = newClientPool(db, 1, 1, clock, nil) server.clientPool.setLimits(10000, 10000) // Assign enough capacity for clientpool server.handler = newServerHandler(server, simulation.Blockchain(), db, txpool, func() bool { return true }) if server.oracle != nil { @@ -393,8 +393,13 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu expList = expList.add("serveStateSince", uint64(0)) expList = expList.add("serveRecentState", uint64(core.TriesInMemory-4)) expList = expList.add("txRelay", nil) - expList = expList.add("flowControl/BL", testBufLimit) - expList = expList.add("flowControl/MRR", testBufRecharge) + if p.peer.version >= lpv4 { + expList = expList.add("flowControl/BL", uint64(0)) + expList = expList.add("flowControl/MRR", uint64(0)) + } else { + expList = expList.add("flowControl/BL", testBufLimit) + expList = expList.add("flowControl/MRR", testBufRecharge) + } expList = expList.add("flowControl/MRC", costList) if err := p2p.ExpectMsg(p.app, StatusMsg, expList); err != nil { @@ -403,9 +408,16 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu if err := p2p.Send(p.app, StatusMsg, sendList); err != nil { t.Fatalf("status send: %v", err) } - p.peer.fcParams = flowcontrol.ServerParams{ - BufLimit: testBufLimit, - MinRecharge: testBufRecharge, +} + +func (p *testPeer) expectCapUpdate(t *testing.T) { + if p.peer.version >= lpv4 { + var expList keyValueList + expList = expList.add("flowControl/BL", testBufLimit) + expList = expList.add("flowControl/MRR", testBufRecharge) + if err := p2p.ExpectMsg(p.app, AnnounceMsg, announceData{Update: expList}); err != nil { + t.Fatalf("status recv: %v", err) + } } } @@ -436,7 +448,7 @@ type testServer struct { bloomTrieIndexer *core.ChainIndexer } -func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, simClock bool, newPeer bool, testCost uint64) (*testServer, func()) { +func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, simClock bool, newPeer bool, testCost uint64, expectCapUpdate bool) (*testServer, func()) { db := rawdb.NewMemoryDatabase() indexers := testIndexers(db, nil, light.TestServerIndexerConfig) @@ -477,6 +489,9 @@ func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallba cIndexer.Close() bIndexer.Close() } + if expectCapUpdate { + server.peer.expectCapUpdate(t) + } return server, teardown } diff --git a/les/tokensale.go b/les/tokensale.go new file mode 100644 index 000000000000..310b23e18145 --- /dev/null +++ b/les/tokensale.go @@ -0,0 +1,570 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "fmt" + "io" + "math" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/rlp" +) + +const ( + basePriceTC = time.Hour * 10 + tokenQueueTC = time.Hour +) + +type paymentReceiver interface { + info() keyValueList + receivePayment(from enode.ID, proofOfPayment, oldMeta []byte) (value uint64, newMeta []byte, err error) + requestPayment(from enode.ID, value uint64, meta []byte) uint64 +} + +type tokenSale struct { + lock, qlock sync.Mutex + clientPool *clientPool + stopCh chan struct{} + receivers map[string]paymentReceiver + receiverNames []string + basePrice, minBasePrice float64 + + sq *servingQueue + sources map[string]*cmdSource +} + +func newTokenSale(clientPool *clientPool, minBasePrice float64) *tokenSale { + t := &tokenSale{ + clientPool: clientPool, + receivers: make(map[string]paymentReceiver), + basePrice: minBasePrice, + minBasePrice: minBasePrice, + stopCh: make(chan struct{}), + sq: newServingQueue(0, 0), + sources: make(map[string]*cmdSource), + } + t.sq.setThreads(1) + go func() { + cleanupCounter := 0 + for { + select { + case <-time.After(time.Second * 10): + t.lock.Lock() + cost, ok := t.tokenCost(1) + if cost > t.basePrice*10 || !ok { + cost = t.basePrice * 10 + } + t.basePrice += (cost - t.basePrice) * float64(time.Second*10) / float64(basePriceTC) + if t.basePrice < minBasePrice { + t.basePrice = minBasePrice + } + t.lock.Unlock() + + cleanupCounter++ + if cleanupCounter == 100 { + t.sourceMapCleanup() + cleanupCounter = 0 + } + case <-t.stopCh: + return + } + } + }() + return t +} + +type ( + cmdSource struct { + ch chan tokenCmd + recentTime float64 + lastUpdate mclock.AbsTime + } + + tokenCmd struct { + cmd []byte + id enode.ID + freeID string + send func([]byte) + } +) + +func (c *cmdSource) priority(capacity uint64) int64 { + dt := mclock.Now() - c.lastUpdate + rt := c.recentTime + if dt > 0 { + rt *= math.Exp(-float64(dt) / float64(tokenQueueTC)) + } + return -int64(rt / float64(capacity)) +} + +func (c *cmdSource) addTime(time uint64) { + now := mclock.Now() + dt := now - c.lastUpdate + if dt > 0 { + c.recentTime *= math.Exp(-float64(dt) / float64(tokenQueueTC)) + c.lastUpdate = now + } + c.recentTime += float64(time) +} + +func (t *tokenSale) sourceMapCleanup() { + t.qlock.Lock() + defer t.qlock.Unlock() + + for src, s := range t.sources { + s.addTime(0) + if s.recentTime < float64(time.Millisecond*100) { + delete(t.sources, src) + } + } +} + +func (t *tokenSale) queueCommand(src string, cmd tokenCmd, capacity uint64) bool { + t.qlock.Lock() + defer t.qlock.Unlock() + + s := t.sources[src] + if s == nil { + s = &cmdSource{lastUpdate: mclock.Now()} + t.sources[src] = s + } + if s.ch != nil { + select { + case s.ch <- cmd: + return true + default: + return false + } + } + s.ch = make(chan tokenCmd, 16) + s.ch <- cmd + + go func() { + loop: + for { + select { + case cmd := <-s.ch: + task := t.sq.newTask(nil, 0, s.priority(capacity)) + if !task.start() { + break loop + } + start := mclock.Now() + reply := t.runCommand(cmd.cmd, cmd.id, cmd.freeID) + runTime := mclock.Now() - start + cmd.send(reply) + time.Sleep(time.Duration(runTime) * 9) + task.done() + t.qlock.Lock() + s.addTime(uint64(runTime)) + t.qlock.Unlock() + default: + break loop + } + t.qlock.Lock() + s.ch = nil // TODO map cleanup + t.qlock.Unlock() + } + }() + return true +} + +func (t *tokenSale) stop() { + close(t.stopCh) + t.sq.stop() +} + +func (t *tokenSale) tokenCost(buyAmount uint64) (float64, bool) { + tokenLimit := t.clientPool.totalTokenLimit() + tokenAmount := t.clientPool.totalTokenAmount() + if tokenAmount+buyAmount >= tokenLimit { + return 0, false + } + r := float64(tokenAmount) / float64(tokenLimit) + b := float64(buyAmount) / float64(tokenLimit) + var relCost float64 + if r < 0.5 { + if r+b <= 0.5 { + relCost = b * (r + r + b) + b = 0 + } else { + relCost = (0.5 - r) * (r + 0.5) + b = r + b - 0.5 + r = 0.5 + } + } + if b > 0 { + l := 1 - r + if l < 1e-10 { + return 0, false + } + l = -b / l + if l < -1+1e-10 { + return 0, false + } + relCost += -math.Log1p(l) / 2 + + } + return t.basePrice * float64(tokenLimit) * relCost, true +} + +func (t *tokenSale) tokensFor(maxCost uint64) uint64 { + tokenLimit := t.clientPool.totalTokenLimit() + tokenAmount := t.clientPool.totalTokenAmount() + if tokenLimit <= tokenAmount { + return 0 + } + r := float64(tokenAmount) / float64(tokenLimit) + c := float64(maxCost) / (t.basePrice * float64(tokenLimit)) + var relTokens float64 + if r < 0.5 { + relTokens = math.Sqrt(r*r+c) - r + if r+relTokens <= 0.5 { + c = 0 + } else { + relTokens = 0.5 - r + c -= (0.5 - r) * (r + 0.5) + } + } + if c > 0 { + relTokens += -math.Expm1(-2*c) * (1 - r) + } + return uint64(relTokens * float64(tokenLimit)) +} + +func (t *tokenSale) connection(id enode.ID, freeID string, requestedCapacity uint64, stayConnected time.Duration, paymentModule []string, setCap bool) (availableCapacity, tokenBalance, tokensMissing, pcBalance, pcMissing uint64, paymentRequired []uint64, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + tokensMissing, availableCapacity, err = t.clientPool.setCapacityLocked(id, freeID, requestedCapacity, stayConnected, setCap) + pb := t.clientPool.getPosBalance(id) + tokenBalance = pb.value + var meta tokenSaleMeta + if err := rlp.DecodeBytes([]byte(pb.meta), &meta); err == nil { + pcBalance = meta.pcBalance + } + if tokensMissing == 0 { + return + } + tokenLimit := t.clientPool.totalTokenLimit() + tokenAmount := t.clientPool.totalTokenAmount() + if tokenLimit <= tokenAmount || tokenLimit-tokenAmount <= tokensMissing { + pcMissing = math.MaxUint64 + } else { + tokensAvailable := tokenLimit - tokenAmount + pcr := -math.Log(float64(tokensAvailable-tokensMissing)/float64(tokensAvailable)) * t.basePrice + if pcr > 0 { + if pcr > maxBalance { + pcMissing = math.MaxUint64 + } else { + pcMissing = uint64(pcr) + if pcMissing > maxBalance { + pcMissing = math.MaxUint64 + } else { + if pcMissing > pcBalance { + pcMissing -= pcBalance + } else { + pcMissing = 0 + } + } + } + } + } + if pcMissing == 0 { + return + } + paymentRequired = make([]uint64, len(paymentModule)) + for i, recID := range paymentModule { + if rec, ok := t.receivers[recID]; !ok || pcMissing == math.MaxUint64 { + paymentRequired[i] = math.MaxUint64 + } else { + paymentRequired[i] = rec.requestPayment(id, pcMissing, meta.receiverMeta[recID]) + } + } + return +} + +func (t *tokenSale) deposit(id enode.ID, paymentModule string, proofOfPayment []byte) (pcValue, pcBalance uint64, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + pb := t.clientPool.getPosBalance(id) + var meta tokenSaleMeta + if err := rlp.DecodeBytes([]byte(pb.meta), &meta); err == nil { + pcBalance = meta.pcBalance + } + + pm := t.receivers[paymentModule] + if pm == nil { + return 0, pcBalance, fmt.Errorf("Unknown payment receiver '%s'", paymentModule) + } + pcValue, meta.receiverMeta[paymentModule], err = pm.receivePayment(id, proofOfPayment, meta.receiverMeta[paymentModule]) + if err != nil { + return 0, pcBalance, err + } + pcBalance += pcValue + meta.pcBalance = pcBalance + metaEnc, _ := rlp.EncodeToBytes(&meta) + t.clientPool.addBalance(id, 0, string(metaEnc)) + return +} + +func (t *tokenSale) buyTokens(id enode.ID, maxSpend, minReceive uint64, relative, spendAll bool) (pcBalance, tokenBalance, spend, receive uint64, success bool) { + t.lock.Lock() + defer t.lock.Unlock() + + pb := t.clientPool.getPosBalance(id) + tokenBalance = pb.value + var meta tokenSaleMeta + if err := rlp.DecodeBytes([]byte(pb.meta), &meta); err == nil { + pcBalance = meta.pcBalance + } + if relative { + if pcBalance > maxSpend { + maxSpend = pcBalance - maxSpend + } else { + maxSpend = 0 + } + if minReceive > tokenBalance { + minReceive -= tokenBalance + } else { + minReceive = 0 + } + } + + if maxSpend > pcBalance { + maxSpend = pcBalance + } + if spendAll { + spend = maxSpend + receive = t.tokensFor(spend) + success = receive >= minReceive + } else { + receive = minReceive + if cost, ok := t.tokenCost(receive); ok { + spend = uint64(cost) + } else { + spend = math.MaxUint64 + } + success = spend <= maxSpend + } + if success { + pcBalance -= spend + tokenBalance += receive + meta.pcBalance = pcBalance + metaEnc, _ := rlp.EncodeToBytes(&meta) + t.clientPool.addBalance(id, int64(receive), string(metaEnc)) + } + return +} + +func (t *tokenSale) getBalance(id enode.ID) (pcBalance, tokenBalance uint64) { + t.lock.Lock() + defer t.lock.Unlock() + + pb := t.clientPool.getPosBalance(id) + tokenBalance = pb.value + var meta tokenSaleMeta + if err := rlp.DecodeBytes([]byte(pb.meta), &meta); err == nil { + pcBalance = meta.pcBalance + } + return +} + +func (t *tokenSale) info() (version, compatible uint, info keyValueList, receivers []string) { + t.lock.Lock() + defer t.lock.Unlock() + + return 1, 1, keyValueList{}, t.receiverNames +} + +func (t *tokenSale) receiverInfo(receiverIDs []string) []keyValueList { + t.lock.Lock() + defer t.lock.Unlock() + + res := make([]keyValueList, len(receiverIDs)) + for i, id := range receiverIDs { + if rec, ok := t.receivers[id]; ok { + res[i] = rec.info() + } + } + return res +} + +type tokenSaleMeta struct { + pcBalance uint64 + receiverMeta map[string][]byte +} + +type receiverMetaEnc struct { + Id string + Meta []byte +} + +type tokenSaleMetaEnc struct { + Id string + Version uint + PcBalance uint64 + Receivers []receiverMetaEnc +} + +// EncodeRLP implements rlp.Encoder +func (t *tokenSaleMeta) EncodeRLP(w io.Writer) error { + receivers := make([]receiverMetaEnc, len(t.receiverMeta)) + i := 0 + for id, meta := range t.receiverMeta { + receivers[i] = receiverMetaEnc{id, meta} + i++ + } + return rlp.Encode(w, tokenSaleMetaEnc{ + Id: "tokenSale", + Version: 1, + PcBalance: t.pcBalance, + Receivers: receivers, + }) +} + +// DecodeRLP implements rlp.Decoder +func (t *tokenSaleMeta) DecodeRLP(s *rlp.Stream) error { + var e tokenSaleMetaEnc + if err := s.Decode(&e); err != nil { + return err + } + if e.Id != "tokenSale" || e.Version != 1 { + return fmt.Errorf("Unknown balance meta format '%s' version %d", e.Id, e.Version) + } + t.receiverMeta = make(map[string][]byte) + t.pcBalance = e.PcBalance + for _, r := range e.Receivers { + t.receiverMeta[r.Id] = r.Meta + } + return nil +} + +const ( + tsInfo = iota + tsReceiverInfo + tsGetBalance + tsDeposit + tsBuyTokens + tsConnection +) + +type ( + tsInfoResults struct { + Version, Compatible uint + Info keyValueList + Receivers []string + } + tsReceiverInfoParams []string + tsReceiverInfoResults []keyValueList + tsGetBalanceResults struct { + PcBalance, TokenBalance uint64 + } + tsDepositParams struct { + PaymentModule string + ProofOfPayment []byte + } + tsDepositResults struct { + PcValue, PcBalance uint64 + Err string + } + tsBuyTokensParams struct { + MaxSpend, MinReceive uint64 + Relative, SpendAll bool + } + tsBuyTokensResults struct { + PcBalance, TokenBalance, Spend, Receive uint64 + Success bool + } + tsConnectionParams struct { + RequestedCapacity, StayConnected uint64 + PaymentModule []string + SetCap bool + } + tsConnectionResults struct { + AvailableCapacity, TokenBalance, TokensMissing, PcBalance, PcMissing uint64 + PaymentRequired []uint64 + Err string + } +) + +func (t *tokenSale) runCommand(cmd []byte, id enode.ID, freeID string) []byte { + var res []byte + switch cmd[0] { + case tsInfo: + var results tsInfoResults + if len(cmd) == 1 { + results.Version, results.Compatible, results.Info, results.Receivers = t.info() + res, _ = rlp.EncodeToBytes(&results) + } + case tsReceiverInfo: + var ( + params tsReceiverInfoParams + results tsReceiverInfoResults + ) + if err := rlp.DecodeBytes(cmd[1:], ¶ms); err == nil { + results = t.receiverInfo(params) + res, _ = rlp.EncodeToBytes(&results) + } + case tsGetBalance: + var results tsGetBalanceResults + if len(cmd) == 1 { + results.PcBalance, results.TokenBalance = t.getBalance(id) + res, _ = rlp.EncodeToBytes(&results) + } + case tsDeposit: + var ( + params tsDepositParams + results tsDepositResults + ) + if err := rlp.DecodeBytes(cmd[1:], ¶ms); err == nil { + results.PcValue, results.PcBalance, err = t.deposit(id, params.PaymentModule, params.ProofOfPayment) + if err != nil { + results.Err = err.Error() + } + res, _ = rlp.EncodeToBytes(&results) + } + case tsBuyTokens: + var ( + params tsBuyTokensParams + results tsBuyTokensResults + ) + if err := rlp.DecodeBytes(cmd[1:], ¶ms); err == nil { + results.PcBalance, results.TokenBalance, results.Spend, results.Receive, results.Success = + t.buyTokens(id, params.MaxSpend, params.MinReceive, params.Relative, params.SpendAll) + res, _ = rlp.EncodeToBytes(&results) + } + case tsConnection: + var ( + params tsConnectionParams + results tsConnectionResults + ) + if err := rlp.DecodeBytes(cmd[1:], ¶ms); err == nil { + results.AvailableCapacity, results.TokenBalance, results.TokensMissing, results.PcBalance, results.PcMissing, results.PaymentRequired, err = + t.connection(id, freeID, params.RequestedCapacity, time.Duration(params.StayConnected)*time.Second, params.PaymentModule, params.SetCap) + if err != nil { + results.Err = err.Error() + } + res, _ = rlp.EncodeToBytes(&results) + } + } + return res +} diff --git a/les/ulc_test.go b/les/ulc_test.go index 9112bf928c18..9353d115b94b 100644 --- a/les/ulc_test.go +++ b/les/ulc_test.go @@ -28,8 +28,8 @@ import ( "github.com/ethereum/go-ethereum/p2p/enode" ) -func TestULCAnnounceThresholdLes2(t *testing.T) { testULCAnnounceThreshold(t, 2) } func TestULCAnnounceThresholdLes3(t *testing.T) { testULCAnnounceThreshold(t, 3) } +func TestULCAnnounceThresholdLes4(t *testing.T) { testULCAnnounceThreshold(t, 4) } func testULCAnnounceThreshold(t *testing.T, protocol int) { // todo figure out why it takes fetcher so longer to fetcher the announced header. @@ -126,7 +126,7 @@ func connect(server *serverHandler, serverId enode.ID, client *clientHandler, pr // newServerPeer creates server peer. func newServerPeer(t *testing.T, blocks int, protocol int) (*testServer, *enode.Node, func()) { - s, teardown := newServerEnv(t, blocks, protocol, nil, false, false, 0) + s, teardown := newServerEnv(t, blocks, protocol, nil, false, false, 0, false) key, err := crypto.GenerateKey() if err != nil { t.Fatal("generate key err:", err) diff --git a/p2p/discv5/net.go b/p2p/discv5/net.go index dd2ec3e9298f..0a7af53b9e23 100644 --- a/p2p/discv5/net.go +++ b/p2p/discv5/net.go @@ -22,12 +22,14 @@ import ( "errors" "fmt" "net" + "sync" "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/rlp" "golang.org/x/crypto/sha3" @@ -64,12 +66,17 @@ type Network struct { refreshResp chan (<-chan struct{}) // ...and get the channel to block on from this one read chan ingressPacket // ingress packets arrive here timeout chan timeoutEvent - queryReq chan *findnodeQuery // lookups submit findnode queries on this channel + queryReq chan deferredQuery // lookups submit findnode queries on this channel tableOpReq chan func() tableOpResp chan struct{} topicRegisterReq chan topicRegisterReq topicSearchReq chan topicSearchReq + talkRequestSubLock sync.RWMutex + talkResponseSubLock sync.Mutex + talkRequestSubs map[string]TalkRequestHandler + talkResponseSubs map[string]TalkResponseHandler + // State of the main loop. tab *Table topictab *topicTable @@ -79,6 +86,11 @@ type Network struct { timeoutTimers map[timeoutEvent]*time.Timer } +type ( + TalkRequestHandler func(enode.ID, *net.UDPAddr, interface{}) (interface{}, bool) + TalkResponseHandler func(interface{}) bool +) + // transport is implemented by the UDP transport. // it is an interface so we can test without opening lots of UDP // sockets and without generating a private key. @@ -101,6 +113,14 @@ type findnodeQuery struct { reply chan<- []*Node } +type talkQuery struct { + remote *Node + talkID string + payload interface{} + key string + handler TalkResponseHandler +} + type topicRegisterReq struct { add bool topic Topic @@ -151,10 +171,12 @@ func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, dbPath string, netres timeoutTimers: make(map[timeoutEvent]*time.Timer), tableOpReq: make(chan func()), tableOpResp: make(chan struct{}), - queryReq: make(chan *findnodeQuery), + queryReq: make(chan deferredQuery), topicRegisterReq: make(chan topicRegisterReq), topicSearchReq: make(chan topicSearchReq), nodes: make(map[NodeID]*Node), + talkRequestSubs: make(map[string]TalkRequestHandler), + talkResponseSubs: make(map[string]TalkResponseHandler), } go net.loop() return net, nil @@ -410,9 +432,11 @@ loop: // Ingress packet handling. case pkt := <-net.read: - //fmt.Println("read", pkt.ev) log.Trace("<-net.read") n := net.internNode(&pkt) + if pkt.ev == talkRequestPacket { + fmt.Println("read trp", n.state, pkt) + } prestate := n.state status := "ok" if err := net.handle(n, pkt.ev, &pkt); err != nil { @@ -446,7 +470,7 @@ loop: case q := <-net.queryReq: log.Trace("<-net.queryReq") if !q.start(net) { - q.remote.deferQuery(q) + q.deferQuery() } // Interacting with the table. @@ -700,11 +724,17 @@ func (net *Network) refresh(done chan<- struct{}) { func (net *Network) internNode(pkt *ingressPacket) *Node { if n := net.nodes[pkt.remoteID]; n != nil { + if pkt.ev == talkRequestPacket { + fmt.Println("node exists") + } n.IP = pkt.remoteAddr.IP n.UDP = uint16(pkt.remoteAddr.Port) n.TCP = uint16(pkt.remoteAddr.Port) return n } + if pkt.ev == talkRequestPacket { + fmt.Println("node created") + } n := NewNode(pkt.remoteID, pkt.remoteAddr.IP, uint16(pkt.remoteAddr.Port), uint16(pkt.remoteAddr.Port)) n.state = unknown net.nodes[pkt.remoteID] = n @@ -767,14 +797,21 @@ type nodeNetGuts struct { // State machine fields. Access to these fields // is restricted to the Network.loop goroutine. state *nodeState - pingEcho []byte // hash of last ping sent by us - pingTopics []Topic // topic set sent by us in last ping - deferredQueries []*findnodeQuery // queries that can't be sent yet - pendingNeighbours *findnodeQuery // current query, waiting for reply + pingEcho []byte // hash of last ping sent by us + pingTopics []Topic // topic set sent by us in last ping + deferredQueries []deferredQuery // queries that can't be sent yet + pendingNeighbours *findnodeQuery // current query, waiting for reply queryTimeouts int + talkFailures int } -func (n *nodeNetGuts) deferQuery(q *findnodeQuery) { +type deferredQuery interface { + start(net *Network) bool + cancel() + deferQuery() +} + +func (n *nodeNetGuts) deferQuery(q deferredQuery) { n.deferredQueries = append(n.deferredQueries, q) } @@ -810,6 +847,14 @@ func (q *findnodeQuery) start(net *Network) bool { return false } +func (q *findnodeQuery) cancel() { + q.reply <- nil +} + +func (q *findnodeQuery) deferQuery() { + q.remote.deferQuery(q) +} + // Node Events (the input to the state machine). type nodeEvent uint @@ -828,6 +873,8 @@ const ( topicRegisterPacket topicQueryPacket topicNodesPacket + talkRequestPacket + talkResponsePacket // Non-packet events. // Event values in this category are allocated outside @@ -835,6 +882,7 @@ const ( pongTimeout nodeEvent = iota + 256 pingTimeout neighboursTimeout + talkTimeout ) // Node State Machine. @@ -868,7 +916,7 @@ func init() { n.pingEcho = nil // Abort active queries. for _, q := range n.deferredQueries { - q.reply <- nil + q.cancel() } n.deferredQueries = nil if n.pendingNeighbours != nil { @@ -1021,7 +1069,7 @@ func (net *Network) handle(n *Node, ev nodeEvent, pkt *ingressPacket) error { //fmt.Println("handle", n.addr().String(), n.state, ev) if pkt != nil { if err := net.checkPacket(n, ev, pkt); err != nil { - //fmt.Println("check err:", err) + //fmt.Println("check err:", err, pkt) return err } // Start the background expiration goroutine after the first @@ -1036,6 +1084,7 @@ func (net *Network) handle(n *Node, ev nodeEvent, pkt *ingressPacket) error { if n.state == nil { n.state = unknown //??? } + //fmt.Println("old state:", n.state) next, err := n.state.handle(net, n, ev, pkt) net.transition(n, next) //fmt.Println("new state:", n.state) @@ -1196,7 +1245,53 @@ func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket) } } return n.state, nil - + case talkRequestPacket: + p := pkt.data.(*talkRequest) + net.talkRequestSubLock.RLock() + subFn := net.talkRequestSubs[string(p.TalkID)] + net.talkRequestSubLock.RUnlock() + fmt.Println("trp", p) + if subFn != nil { + resp, ok := subFn(enode.ID(n.sha), n.addr(), p.Payload) + fmt.Println("subFn", ok) + if ok { + net.conn.send(n, talkResponsePacket, talkResponse{ReplyTok: pkt.hash, Payload: resp}) + } else { + n.talkFailures++ + } + } else { + n.talkFailures++ + } + if n.talkFailures > maxTalkFailures && n.state == known { + return contested, errors.New("too many talk failures") + } + return n.state, nil + case talkResponsePacket: + p := pkt.data.(*talkResponse) + net.talkResponseSubLock.Lock() + key := string(n.sha[:]) + string(p.ReplyTok) + subFn := net.talkResponseSubs[key] + if subFn != nil { + delete(net.talkResponseSubs, key) + } + net.talkResponseSubLock.Unlock() + if subFn == nil || !subFn(p.Payload) { + n.talkFailures++ + if n.talkFailures > maxTalkFailures && n.state == known { + return contested, errors.New("too many talk failures") + } + } + return n.state, nil + case talkTimeout: + if n.pendingNeighbours != nil { + n.pendingNeighbours.reply <- nil + n.pendingNeighbours = nil + } + n.queryTimeouts++ + if n.queryTimeouts > maxFindnodeFailures && n.state == known { + return contested, errors.New("too many timeouts") + } + return n.state, nil default: return n.state, errInvalidEvent } @@ -1260,3 +1355,78 @@ func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error { n.startNextQuery(net) return nil } + +func (net *Network) RegisterTalkHandler(talkID string, handler TalkRequestHandler) { + net.talkRequestSubLock.Lock() + net.talkRequestSubs[talkID] = handler + net.talkRequestSubLock.Unlock() +} + +func (net *Network) RemoveTalkHandler(talkID string) { + net.talkRequestSubLock.Lock() + delete(net.talkRequestSubs, talkID) + net.talkRequestSubLock.Unlock() +} + +func (q *talkQuery) start(net *Network) bool { + if q.remote == net.tab.self { + return false + } + if q.remote.state == known { + net.talkResponseSubLock.Lock() + hash := net.conn.send(q.remote, talkRequestPacket, talkRequest{TalkID: []byte(q.talkID), Payload: q.payload}) + q.key = string(q.remote.sha[:]) + string(hash[:]) + net.talkResponseSubs[q.key] = q.handler + net.talkResponseSubLock.Unlock() + fmt.Println("sent", q.remote.state) + return true + } + // If the node is not known yet, it won't accept queries. + // Initiate the transition to known. + // The request will be sent later when the node reaches known state. + if q.remote.state == unknown { + net.transition(q.remote, verifyinit) + } + return false +} + +func (q *talkQuery) cancel() { + q.handler(nil) +} + +func (q *talkQuery) deferQuery() { + q.remote.deferQuery(q) +} + +func (net *Network) SendTalkRequest(to *enode.Node, talkID string, payload interface{}, handler TalkResponseHandler) func() bool { + var nodeID NodeID + copy(nodeID[:], crypto.FromECDSAPub(to.Pubkey())[1:]) + node := net.nodes[nodeID] + if node == nil { + node = NewNode(nodeID, to.IP(), uint16(to.UDP()), uint16(to.TCP())) + node.state = unknown + net.nodes[nodeID] = node + } + q := &talkQuery{remote: node, talkID: talkID, payload: payload, handler: handler} + if node.state == nil || !node.state.canQuery { + net.ping(node, node.addr()) + } + select { + case net.queryReq <- q: + case <-net.closed: + return nil + } + + return func() bool { + net.talkResponseSubLock.Lock() + cancel := q.key != "" && net.talkResponseSubs[q.key] != nil + if cancel { + delete(net.talkResponseSubs, q.key) + } + net.talkResponseSubLock.Unlock() + if cancel { + handler(nil) + } + return cancel + } +} diff --git a/p2p/discv5/table.go b/p2p/discv5/table.go index 64c3ecd1c7b2..23b68e56928a 100644 --- a/p2p/discv5/table.go +++ b/p2p/discv5/table.go @@ -35,6 +35,7 @@ const ( nBuckets = hashBits + 1 // Number of buckets maxFindnodeFailures = 5 + maxTalkFailures = 5 ) type Table struct { diff --git a/p2p/discv5/udp.go b/p2p/discv5/udp.go index 088f95cac6a2..0526e6b566a2 100644 --- a/p2p/discv5/udp.go +++ b/p2p/discv5/udp.go @@ -119,6 +119,16 @@ type ( Nodes []rpcNode } + talkRequest struct { + TalkID []byte + Payload interface{} + } + + talkResponse struct { + ReplyTok []byte + Payload interface{} + } + rpcNode struct { IP net.IP // len 4 for IPv4 or 16 for IPv6 UDP uint16 // for discovery protocol @@ -420,6 +430,10 @@ func decodePacket(buffer []byte, pkt *ingressPacket) error { pkt.data = new(topicQuery) case topicNodesPacket: pkt.data = new(topicNodes) + case talkRequestPacket: + pkt.data = new(talkRequest) + case talkResponsePacket: + pkt.data = new(talkResponse) default: return fmt.Errorf("unknown packet type: %d", sigdata[0]) } diff --git a/p2p/enode/node.go b/p2p/enode/node.go index 9eb2544ffe14..3f6cda6d4a21 100644 --- a/p2p/enode/node.go +++ b/p2p/enode/node.go @@ -217,7 +217,7 @@ func (n ID) MarshalText() ([]byte, error) { // UnmarshalText implements the encoding.TextUnmarshaler interface. func (n *ID) UnmarshalText(text []byte) error { - id, err := parseID(string(text)) + id, err := ParseID(string(text)) if err != nil { return err } @@ -229,14 +229,14 @@ func (n *ID) UnmarshalText(text []byte) error { // The string may be prefixed with 0x. // It panics if the string is not a valid ID. func HexID(in string) ID { - id, err := parseID(in) + id, err := ParseID(in) if err != nil { panic(err) } return id } -func parseID(in string) (ID, error) { +func ParseID(in string) (ID, error) { var id ID b, err := hex.DecodeString(strings.TrimPrefix(in, "0x")) if err != nil {