diff --git a/cmd/geth/main.go b/cmd/geth/main.go index 99ef78238feb..70bcf7a96209 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -101,6 +101,7 @@ var ( utils.UltraLightServersFlag, utils.UltraLightFractionFlag, utils.UltraLightOnlyAnnounceFlag, + utils.LespayTestModuleFlag, utils.WhitelistFlag, utils.CacheFlag, utils.CacheDatabaseFlag, diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go index 6f3197b9c6d3..4d98a8e372af 100644 --- a/cmd/geth/usage.go +++ b/cmd/geth/usage.go @@ -94,6 +94,7 @@ var AppHelpFlagGroups = []flagGroup{ utils.UltraLightServersFlag, utils.UltraLightFractionFlag, utils.UltraLightOnlyAnnounceFlag, + utils.LespayTestModuleFlag, }, }, { diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index bdadebd852f4..100202479221 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -272,6 +272,10 @@ var ( Usage: "Maximum number of light clients to serve, or light servers to attach to", Value: eth.DefaultConfig.LightPeers, } + LespayTestModuleFlag = cli.BoolFlag{ + Name: "lespay.testmodule", + Usage: "Enable dummy payment module (for testing only)", + } UltraLightServersFlag = cli.StringFlag{ Name: "ulc.servers", Usage: "List of trusted ultra-light servers", @@ -1009,6 +1013,9 @@ func setLes(ctx *cli.Context, cfg *eth.Config) { if ctx.GlobalIsSet(UltraLightOnlyAnnounceFlag.Name) { cfg.UltraLightOnlyAnnounce = ctx.GlobalBool(UltraLightOnlyAnnounceFlag.Name) } + if ctx.GlobalIsSet(LespayTestModuleFlag.Name) { + cfg.LespayTestModule = true + } } // makeDatabaseHandles raises out the number of allowed file handles per process diff --git a/eth/config.go b/eth/config.go index 2eaf21fbc30c..4a2ea18ad1a8 100644 --- a/eth/config.go +++ b/eth/config.go @@ -116,6 +116,9 @@ type Config struct { UltraLightFraction int `toml:",omitempty"` // Percentage of trusted servers to accept an announcement UltraLightOnlyAnnounce bool `toml:",omitempty"` // Whether to only announce headers, or also serve them + // Light client payment options + LespayTestModule bool + // Database options SkipBcVersionCheck bool `toml:"-"` DatabaseHandles int `toml:"-"` diff --git a/eth/gen_config.go b/eth/gen_config.go index 1c659c393ca7..eb889e3aa8c3 100644 --- a/eth/gen_config.go +++ b/eth/gen_config.go @@ -32,6 +32,7 @@ func (c Config) MarshalTOML() (interface{}, error) { UltraLightServers []string `toml:",omitempty"` UltraLightFraction int `toml:",omitempty"` UltraLightOnlyAnnounce bool `toml:",omitempty"` + LespayTestModule bool `toml:"-"` SkipBcVersionCheck bool `toml:"-"` DatabaseHandles int `toml:"-"` DatabaseCache int @@ -68,6 +69,7 @@ func (c Config) MarshalTOML() (interface{}, error) { enc.UltraLightServers = c.UltraLightServers enc.UltraLightFraction = c.UltraLightFraction enc.UltraLightOnlyAnnounce = c.UltraLightOnlyAnnounce + enc.LespayTestModule = c.LespayTestModule enc.SkipBcVersionCheck = c.SkipBcVersionCheck enc.DatabaseHandles = c.DatabaseHandles enc.DatabaseCache = c.DatabaseCache @@ -108,6 +110,7 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { UltraLightServers []string `toml:",omitempty"` UltraLightFraction *int `toml:",omitempty"` UltraLightOnlyAnnounce *bool `toml:",omitempty"` + LespayTestModule *bool `toml:"-"` SkipBcVersionCheck *bool `toml:"-"` DatabaseHandles *int `toml:"-"` DatabaseCache *int @@ -175,6 +178,9 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { if dec.UltraLightOnlyAnnounce != nil { c.UltraLightOnlyAnnounce = *dec.UltraLightOnlyAnnounce } + if dec.LespayTestModule != nil { + c.LespayTestModule = *dec.LespayTestModule + } if dec.SkipBcVersionCheck != nil { c.SkipBcVersionCheck = *dec.SkipBcVersionCheck } diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go index bc105ef37c28..38cdf42d627e 100644 --- a/internal/web3ext/web3ext.go +++ b/internal/web3ext/web3ext.go @@ -33,6 +33,7 @@ var Modules = map[string]string{ "swarmfs": SwarmfsJs, "txpool": TxpoolJs, "les": LESJs, + "lespay": LESPAYJs, } const ChequebookJs = ` @@ -856,3 +857,50 @@ web3._extend({ ] }); ` + +const LESPAYJs = ` +web3._extend({ + property: 'lespay', + methods: + [ + new web3._extend.Method({ + name: 'connection', + call: 'lespay_connection', + params: 6 + }), + new web3._extend.Method({ + name: 'deposit', + call: 'lespay_deposit', + params: 4 + }), + new web3._extend.Method({ + name: 'buyTokens', + call: 'lespay_buyTokens', + params: 6 + }), + new web3._extend.Method({ + name: 'buyTokens', + call: 'lespay_sellTokens', + params: 6 + }), + new web3._extend.Method({ + name: 'getBalance', + call: 'lespay_getBalance', + params: 2 + }), + new web3._extend.Method({ + name: 'info', + call: 'lespay_info', + params: 2 + }), + new web3._extend.Method({ + name: 'receiverInfo', + call: 'lespay_receiverInfo', + params: 3 + }), + ], + properties: + [ + ] +}); +` diff --git a/les/api.go b/les/api.go index f9b8c34458b4..f60a4fa57db3 100644 --- a/les/api.go +++ b/les/api.go @@ -17,14 +17,16 @@ package les import ( + "context" "errors" "fmt" - "math" "time" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/rlp" ) var ( @@ -35,8 +37,6 @@ var ( errNoPriority = errors.New("priority too low to raise capacity") ) -const maxBalance = math.MaxInt64 - // PrivateLightServerAPI provides an API to access the LES light server. type PrivateLightServerAPI struct { server *LesServer @@ -104,13 +104,17 @@ func (api *PrivateLightServerAPI) clientInfo(c *clientInfo, id enode.ID) map[str info["capacity"] = c.capacity pb, nb := c.balanceTracker.getBalance(now) info["pricing/balance"], info["pricing/negBalance"] = pb, nb - info["pricing/balanceMeta"] = c.balanceMetaInfo - info["priority"] = pb != 0 + + cb := api.server.clientPool.ndb.getCurrencyBalance(id) + info["pricing/currency"] = cb.amount + info["priority"] = pb.base != 0 } else { info["isConnected"] = false - pb := api.server.clientPool.ndb.getOrNewPB(id) - info["pricing/balance"], info["pricing/balanceMeta"] = pb.value, pb.meta - info["priority"] = pb.value != 0 + pb := api.server.clientPool.ndb.getOrNewBalance(id.Bytes(), false) + + cb := api.server.clientPool.ndb.getCurrencyBalance(id) + info["pricing/balance"], info["pricing/currency"] = pb.value, cb.amount + info["priority"] = pb.value.base != 0 } return info } @@ -150,7 +154,7 @@ func (api *PrivateLightServerAPI) setParams(params map[string]interface{}, clien setFactor(&negFactors.requestFactor) case !defParams && name == "capacity": if capacity, ok := value.(float64); ok && uint64(capacity) >= api.server.minCapacity { - err = api.server.clientPool.setCapacity(client, uint64(capacity)) + _, _, err = api.server.clientPool.setCapacity(client.id, client.address, uint64(capacity), 0, true) // Don't have to call factor update explicitly. It's already done // in setCapacity function. } else { @@ -172,8 +176,8 @@ func (api *PrivateLightServerAPI) setParams(params map[string]interface{}, clien // AddBalance updates the balance of a client (either overwrites it or adds to it). // It also updates the balance meta info string. -func (api *PrivateLightServerAPI) AddBalance(id enode.ID, value int64, meta string) ([2]uint64, error) { - oldBalance, newBalance, err := api.server.clientPool.addBalance(id, value, meta) +func (api *PrivateLightServerAPI) AddBalance(id enode.ID, value int64) ([2]uint64, error) { + oldBalance, newBalance, err := api.server.clientPool.addBalance(id, value) return [2]uint64{oldBalance, newBalance}, err } @@ -184,7 +188,7 @@ func (api *PrivateLightServerAPI) SetClientParams(ids []enode.ID, params map[str if client != nil { update, err := api.setParams(params, client, nil, nil) if update { - client.updatePriceFactors() + updatePriceFactors(&client.balanceTracker, client.posFactors, client.negFactors) } return err } else { @@ -296,7 +300,7 @@ func (api *PrivateDebugAPI) FreezeClient(id enode.ID) error { if c == nil { return fmt.Errorf("client %064x is not connected", id[:]) } - c.peer.freezeClient() + c.peer.freeze() return nil }) } @@ -352,3 +356,222 @@ func (api *PrivateLightAPI) GetCheckpointContractAddress() (string, error) { } return api.backend.oracle.Contract().ContractAddr().Hex(), nil } + +// PrivateLespayAPI provides an API to use the LESpay commands of either the local or a remote server +type PrivateLespayAPI struct { + clientPeerSet *clientPeerSet + serverPeerSet *serverPeerSet + clientHandler *clientHandler + dht *discv5.Network + tokenSale *tokenSale +} + +// NewPrivateLespayAPI creates a new LESPAY API. +func NewPrivateLespayAPI(clientPeerSet *clientPeerSet, serverPeerSet *serverPeerSet, clientHandler *clientHandler, dht *discv5.Network, tokenSale *tokenSale) *PrivateLespayAPI { + return &PrivateLespayAPI{ + clientPeerSet: clientPeerSet, + serverPeerSet: serverPeerSet, + clientHandler: clientHandler, + dht: dht, + tokenSale: tokenSale, + } +} + +// makeCall sends an encoded command to either the local or a remote server and returns the encoded reply +// +// Note: nodeStr can represent either the node ID of a connected node or the full enode of any remote node. +// If remote is true then the command is sent to the specified node. It is sent through LES if it was specified +// with node ID, throush UDP talk otherwise. +// If remote is false then the command is executed locally, with the specified remote node assumed as sender. +func (api *PrivateLespayAPI) makeCall(ctx context.Context, remote bool, nodeStr string, cmd []byte) ([]byte, error) { + var ( + id enode.ID + freeID string + clientPeer *clientPeer + serverPeer *serverPeer + node *enode.Node + err error + ) + if nodeStr != "" { + if id, err = enode.ParseID(nodeStr); err == nil { + if api.clientPeerSet != nil { + if clientPeer = api.clientPeerSet.peer(peerIdToString(id)); clientPeer == nil { + return nil, errors.New("peer not connected") + } + freeID = clientPeer.freeClientId() + } else { + if serverPeer = api.serverPeerSet.peer(peerIdToString(id)); serverPeer == nil { + return nil, errors.New("peer not connected") + } + } + } else { + var err error + if node, err = enode.Parse(enode.ValidSchemes, nodeStr); err == nil { + id = node.ID() + freeID = node.IP().String() + } else { + return nil, err + } + } + } + + if remote { + var ( + reply []byte + cancelFn func() bool + ) + delivered := make(chan struct{}) + if serverPeer != nil { + // remote call to a connected peer through LES + if api.clientHandler == nil { + return nil, errors.New("client handler not available") + } + cancelFn = api.clientHandler.makeLespayCall(serverPeer, cmd, func(r []byte, delay uint) bool { + reply = r + close(delivered) + return reply != nil + }) + } else { + // remote call through UDP TALK + if api.dht == nil { + return nil, errors.New("UDP DHT not available") + } + cancelFn = api.dht.SendTalkRequest(node, "lespay", [][]byte{cmd}, func(payload interface{}, delay uint) bool { + if replies, ok := payload.([]interface{}); ok && len(replies) == 1 { + reply, _ = replies[0].([]byte) + } + close(delivered) + return reply != nil + }) + } + select { + case <-time.After(time.Second * 5): + cancelFn() + return nil, errors.New("timeout") + case <-ctx.Done(): + cancelFn() + return nil, ctx.Err() + case <-delivered: + if len(reply) == 0 { + return nil, errors.New("unknown command") + } + return reply, nil + } + } else { + if api.tokenSale == nil { + return nil, errors.New("token sale module not available") + } + // execute call locally + return api.tokenSale.runCommand(cmd, id, freeID), nil + } + +} + +// Connection checks whether it is possible with the current balance levels to establish +// requested connection or capacity change and then stay connected for the given amount +// of time. If it is possible and setCap is also true then the client is activated of the +// capacity change is performed. If not then returns how many tokens are missing and how +// much that would currently cost using the specified payment module(s). +func (api *PrivateLespayAPI) Connection(ctx context.Context, remote bool, node string, requestedCapacity, stayConnected uint64, paymentModule []string, setCap bool) (results tsConnectionResults, err error) { + params := tsConnectionParams{requestedCapacity, stayConnected, paymentModule, setCap} + enc, _ := rlp.EncodeToBytes(¶ms) + var resEnc []byte + resEnc, err = api.makeCall(ctx, remote, node, append([]byte{tsConnection}, enc...)) + if err != nil { + return + } + err = rlp.DecodeBytes(resEnc, &results) + return +} + +// Deposit credits a payment on the sender's account using the specified payment module +func (api *PrivateLespayAPI) Deposit(ctx context.Context, remote bool, node, paymentModule, proofOfPayment string) (results tsDepositResults, err error) { + var proof []byte + if proof, err = hexutil.Decode(proofOfPayment); err != nil { + return + } + params := tsDepositParams{paymentModule, proof} + enc, _ := rlp.EncodeToBytes(¶ms) + var resEnc []byte + resEnc, err = api.makeCall(ctx, remote, node, append([]byte{tsDeposit}, enc...)) + if err != nil { + return + } + err = rlp.DecodeBytes(resEnc, &results) + return +} + +// BuyTokens tries to convert the permanent balance (nominated in the server's preferred +// currency, PC) to service tokens. If spendAll is true then it sells the maxSpend amount +// of PC coins if the received service token amount is at least minReceive. If spendAll is +// false then is buys minReceive amount of tokens if it does not cost more than maxSpend +// amount of PC coins. +// if relative is true then maxSpend and minReceive are specified relative to their current +// balances. In this case maxSpend represents the amount under which the PC balance should +// not go and minReceive represents the amount the service token balance should reach. +// This mode is useful when actual conversion is intended to happen and the sender has to +// retry the command after not receiving a reply previously. In this case the sender cannot +// be sure whether the conversion has already happened or not. If relative is true then it +// is impossible to do a conversion twice. In exchange the sender needs to know its current +// balances (which it probably does if it has made a previous call to just ask the current price). +func (api *PrivateLespayAPI) BuyTokens(ctx context.Context, remote bool, node string, maxSpend, minReceive uint64, relative, spendAll bool) (results tsBuyTokensResults, err error) { + params := tsBuyTokensParams{maxSpend, minReceive, relative, spendAll} + enc, _ := rlp.EncodeToBytes(¶ms) + var resEnc []byte + resEnc, err = api.makeCall(ctx, remote, node, append([]byte{tsBuyTokens}, enc...)) + if err != nil { + return + } + err = rlp.DecodeBytes(resEnc, &results) + return +} + +// SellTokens tries to convert service tokens to permanent balance (nominated in the server's +// preferred currency, PC). Parameters work similarly to BuyTokens. +func (api *PrivateLespayAPI) SellTokens(ctx context.Context, remote bool, node string, maxSell, minRefund uint64, relative, sellAll bool) (results tsSellTokensResults, err error) { + params := tsSellTokensParams{maxSell, minRefund, relative, sellAll} + enc, _ := rlp.EncodeToBytes(¶ms) + var resEnc []byte + resEnc, err = api.makeCall(ctx, remote, node, append([]byte{tsSellTokens}, enc...)) + if err != nil { + return + } + err = rlp.DecodeBytes(resEnc, &results) + return +} + +// GetBalance returns the current PC balance and service token balance +func (api *PrivateLespayAPI) GetBalance(ctx context.Context, remote bool, node string) (results tsGetBalanceResults, err error) { + var resEnc []byte + resEnc, err = api.makeCall(ctx, remote, node, []byte{tsGetBalance}) + if err != nil { + return + } + err = rlp.DecodeBytes(resEnc, &results) + return +} + +// Info returns general information about the server, including version info of the +// lespay command set, supported payment modules and token expiration time constant +func (api *PrivateLespayAPI) Info(ctx context.Context, remote bool, node string) (results tsInfoApiResults, err error) { + var resEnc []byte + resEnc, err = api.makeCall(ctx, remote, node, []byte{tsInfo}) + if err != nil { + return + } + err = rlp.DecodeBytes(resEnc, &results) + return +} + +// ReceiverInfo returns information about the specified payment receiver(s) if supported +func (api *PrivateLespayAPI) ReceiverInfo(ctx context.Context, remote bool, node string, receiverIDs []string) (results tsReceiverInfoApiResults, err error) { + params := tsReceiverInfoParams(receiverIDs) + enc, _ := rlp.EncodeToBytes(¶ms) + var resEnc []byte + resEnc, err = api.makeCall(ctx, remote, node, append([]byte{tsReceiverInfo}, enc...)) + if err != nil { + return + } + err = rlp.DecodeBytes(resEnc, &results) + return +} diff --git a/les/balance.go b/les/balance.go index 51cef15c803d..97fc0d76eef7 100644 --- a/les/balance.go +++ b/les/balance.go @@ -17,29 +17,56 @@ package les import ( + "math" "sync" "time" "github.com/ethereum/go-ethereum/common/mclock" ) +const maxBalance = math.MaxInt64 + const ( balanceCallbackQueue = iota balanceCallbackZero balanceCallbackCount ) +// expirationController controls the exponential expiration of positive and negative +// balances +type expirationController interface { + posExpiration(mclock.AbsTime) fixed64 + negExpiration(mclock.AbsTime) fixed64 +} + +// priceFactors determine the pricing policy (may apply either to positive or +// negative balances which may have different factors). +// - timeFactor is cost unit per nanosecond of connection time +// - capacityFactor is cost unit per nanosecond of connection time per 1000000 capacity +// - requestFactor is cost unit per request "realCost" unit +type priceFactors struct { + timeFactor, capacityFactor, requestFactor float64 +} + +func (p priceFactors) timePrice(cap uint64) float64 { + return p.timeFactor + float64(cap)*p.capacityFactor/1000000 +} + +func (p priceFactors) reqPrice() float64 { + return p.requestFactor +} + // balanceTracker keeps track of the positive and negative balances of a connected // client and calculates actual and projected future priority values required by // prque.LazyQueue. type balanceTracker struct { lock sync.Mutex clock mclock.Clock + exp expirationController stopped bool capacity uint64 balance balance - timeFactor, requestFactor float64 - negTimeFactor, negRequestFactor float64 + posFactor, negFactor priceFactors sumReqCost uint64 lastUpdate, nextUpdate, initTime mclock.AbsTime updateEvent mclock.Timer @@ -53,7 +80,7 @@ type balanceTracker struct { // balance represents a pair of positive and negative balances type balance struct { - pos, neg uint64 + pos, neg expiredValue } // balanceCallback represents a single callback that is activated when client priority @@ -65,6 +92,7 @@ type balanceCallback struct { } // init initializes balanceTracker +// Note: capacity should never be zero func (bt *balanceTracker) init(clock mclock.Clock, capacity uint64) { bt.clock = clock bt.initTime, bt.lastUpdate = clock.Now(), clock.Now() // Init timestamps @@ -81,10 +109,8 @@ func (bt *balanceTracker) stop(now mclock.AbsTime) { bt.stopped = true bt.addBalance(now) - bt.negTimeFactor = 0 - bt.negRequestFactor = 0 - bt.timeFactor = 0 - bt.requestFactor = 0 + bt.posFactor = priceFactors{0, 0, 0} + bt.negFactor = priceFactors{0, 0, 0} if bt.updateEvent != nil { bt.updateEvent.Stop() bt.updateEvent = nil @@ -95,10 +121,43 @@ func (bt *balanceTracker) stop(now mclock.AbsTime) { // first to disconnect. Positive balance translates to negative priority. If positive // balance is zero then negative balance translates to a positive priority. func (bt *balanceTracker) balanceToPriority(b balance) int64 { - if b.pos > 0 { - return ^int64(b.pos / bt.capacity) + if b.pos.base > 0 { + return -int64(b.pos.value(bt.exp.posExpiration(bt.clock.Now())) / bt.capacity) } - return int64(b.neg) + return int64(b.neg.value(bt.exp.negExpiration(bt.clock.Now()))) +} + +// posBalanceMissing calculates the missing amount of positive balance in order to +// connect at targetCapacity, stay connected for the given amount of time and then +// still have a priority of targetPriority +func (bt *balanceTracker) posBalanceMissing(targetPriority int64, targetCapacity uint64, after time.Duration) uint64 { + now := bt.clock.Now() + if targetPriority > 0 { + timePrice := bt.negFactor.timePrice(targetCapacity) + timeCost := uint64(float64(after) * timePrice) + negBalance := bt.balance.neg.value(bt.exp.negExpiration(now)) + if timeCost+negBalance < uint64(targetPriority) { + return 0 + } + if uint64(targetPriority) > negBalance && timePrice > 1e-100 { + if negTime := time.Duration(float64(uint64(targetPriority)-negBalance) / timePrice); negTime < after { + after -= negTime + } else { + after = 0 + } + } + targetPriority = 0 + } + timePrice := bt.posFactor.timePrice(targetCapacity) + posRequired := uint64(float64(-targetPriority)*float64(targetCapacity)+float64(after)*timePrice) + 1 + if posRequired >= maxBalance { + return math.MaxUint64 // target not reachable + } + posBalance := bt.balance.pos.value(bt.exp.posExpiration(now)) + if posRequired > posBalance { + return posRequired - posBalance + } + return 0 } // reducedBalance estimates the reduced balance at a given time in the fututre based @@ -106,20 +165,19 @@ func (bt *balanceTracker) balanceToPriority(b balance) int64 { func (bt *balanceTracker) reducedBalance(at mclock.AbsTime, avgReqCost float64) balance { dt := float64(at - bt.lastUpdate) b := bt.balance - if b.pos != 0 { - factor := bt.timeFactor + bt.requestFactor*avgReqCost - diff := uint64(dt * factor) - if diff <= b.pos { - b.pos -= diff + if b.pos.base != 0 { + factor := bt.posFactor.timePrice(bt.capacity) + bt.posFactor.reqPrice()*avgReqCost + diff := -int64(dt * factor) + dd := b.pos.add(diff, bt.exp.posExpiration(at)) + if dd == diff { dt = 0 } else { - dt -= float64(b.pos) / factor - b.pos = 0 + dt += float64(dd) / factor } } - if dt != 0 { - factor := bt.negTimeFactor + bt.negRequestFactor*avgReqCost - b.neg += uint64(dt * factor) + if dt > 0 { + factor := bt.negFactor.timePrice(bt.capacity) + bt.negFactor.reqPrice()*avgReqCost + b.neg.add(int64(dt*factor), bt.exp.negExpiration(at)) } return b } @@ -130,20 +188,23 @@ func (bt *balanceTracker) reducedBalance(at mclock.AbsTime, avgReqCost float64) // Note: the function assumes that the balance has been recently updated and // calculates the time starting from the last update. func (bt *balanceTracker) timeUntil(priority int64) (time.Duration, bool) { + now := bt.clock.Now() var dt float64 - if bt.balance.pos != 0 { - if bt.timeFactor < 1e-100 { + if bt.balance.pos.base != 0 { + posBalance := bt.balance.pos.value(bt.exp.posExpiration(now)) + timePrice := bt.posFactor.timePrice(bt.capacity) + if timePrice < 1e-100 { return 0, false } if priority < 0 { - newBalance := uint64(^priority) * bt.capacity - if newBalance > bt.balance.pos { + newBalance := uint64(-priority) * bt.capacity + if newBalance > posBalance { return 0, false } - dt = float64(bt.balance.pos-newBalance) / bt.timeFactor + dt = float64(posBalance-newBalance) / timePrice return time.Duration(dt), true } else { - dt = float64(bt.balance.pos) / bt.timeFactor + dt = float64(posBalance) / timePrice } } else { if priority < 0 { @@ -151,16 +212,19 @@ func (bt *balanceTracker) timeUntil(priority int64) (time.Duration, bool) { } } // if we have a positive balance then dt equals the time needed to get it to zero - if uint64(priority) > bt.balance.neg { - if bt.negTimeFactor < 1e-100 { + negBalance := bt.balance.neg.value(bt.exp.negExpiration(now)) + timePrice := bt.negFactor.timePrice(bt.capacity) + if uint64(priority) > negBalance { + if timePrice < 1e-100 { return 0, false } - dt += float64(uint64(priority)-bt.balance.neg) / bt.negTimeFactor + dt += float64(uint64(priority)-negBalance) / timePrice } return time.Duration(dt), true } // setCapacity updates the capacity value used for priority calculation +// Note: capacity should never be zero func (bt *balanceTracker) setCapacity(capacity uint64) { bt.lock.Lock() defer bt.lock.Unlock() @@ -262,26 +326,26 @@ func (bt *balanceTracker) updateAfter(dt time.Duration) { } // requestCost should be called after serving a request for the given peer -func (bt *balanceTracker) requestCost(cost uint64) { +func (bt *balanceTracker) requestCost(cost uint64) uint64 { bt.lock.Lock() defer bt.lock.Unlock() if bt.stopped { - return + return 0 } now := bt.clock.Now() bt.addBalance(now) fcost := float64(cost) - if bt.balance.pos != 0 { - if bt.requestFactor != 0 { - c := uint64(fcost * bt.requestFactor) - if bt.balance.pos >= c { - bt.balance.pos -= c + posExp := bt.exp.posExpiration(now) + if bt.balance.pos.base != 0 { + if bt.posFactor.reqPrice() != 0 { + c := -int64(fcost * bt.posFactor.reqPrice()) + cc := bt.balance.pos.add(c, posExp) + if c == cc { fcost = 0 } else { - fcost *= 1 - float64(bt.balance.pos)/float64(c) - bt.balance.pos = 0 + fcost *= 1 - float64(cc)/float64(c) } bt.checkCallbacks(now) } else { @@ -289,16 +353,17 @@ func (bt *balanceTracker) requestCost(cost uint64) { } } if fcost > 0 { - if bt.negRequestFactor != 0 { - bt.balance.neg += uint64(fcost * bt.negRequestFactor) + if bt.negFactor.reqPrice() != 0 { + bt.balance.neg.add(int64(fcost*bt.negFactor.reqPrice()), bt.exp.negExpiration(now)) bt.checkCallbacks(now) } } bt.sumReqCost += cost + return bt.balance.pos.value(posExp) } // getBalance returns the current positive and negative balance -func (bt *balanceTracker) getBalance(now mclock.AbsTime) (uint64, uint64) { +func (bt *balanceTracker) getBalance(now mclock.AbsTime) (expiredValue, expiredValue) { bt.lock.Lock() defer bt.lock.Unlock() @@ -307,7 +372,7 @@ func (bt *balanceTracker) getBalance(now mclock.AbsTime) (uint64, uint64) { } // setBalance sets the positive and negative balance to the given values -func (bt *balanceTracker) setBalance(pos, neg uint64) error { +func (bt *balanceTracker) setBalance(pos, neg expiredValue) error { bt.lock.Lock() defer bt.lock.Unlock() @@ -321,7 +386,7 @@ func (bt *balanceTracker) setBalance(pos, neg uint64) error { // setFactors sets the price factors. timeFactor is the price of a nanosecond of // connection while requestFactor is the price of a "realCost" unit. -func (bt *balanceTracker) setFactors(neg bool, timeFactor, requestFactor float64) { +func (bt *balanceTracker) setFactors(posFactor, negFactor priceFactors) { bt.lock.Lock() defer bt.lock.Unlock() @@ -330,13 +395,7 @@ func (bt *balanceTracker) setFactors(neg bool, timeFactor, requestFactor float64 } now := bt.clock.Now() bt.addBalance(now) - if neg { - bt.negTimeFactor = timeFactor - bt.negRequestFactor = requestFactor - } else { - bt.timeFactor = timeFactor - bt.requestFactor = requestFactor - } + bt.posFactor, bt.negFactor = posFactor, negFactor bt.checkCallbacks(now) } diff --git a/les/balance_test.go b/les/balance_test.go index b571c2cc5c2d..69d8caf95b01 100644 --- a/les/balance_test.go +++ b/les/balance_test.go @@ -23,18 +23,31 @@ import ( "github.com/ethereum/go-ethereum/common/mclock" ) +type zeroExpCtrl struct{} + +func (z zeroExpCtrl) posExpiration(mclock.AbsTime) fixed64 { + return 0 +} + +func (z zeroExpCtrl) negExpiration(mclock.AbsTime) fixed64 { + return 0 +} + +func expval(v uint64) expiredValue { + return expiredValue{base: v} +} + func TestSetBalance(t *testing.T) { var clock = &mclock.Simulated{} var inputs = []struct { - pos uint64 - neg uint64 + pos, neg expiredValue }{ - {1000, 0}, - {0, 1000}, - {1000, 1000}, + {expval(1000), expval(0)}, + {expval(0), expval(1000)}, + {expval(1000), expval(1000)}, } - tracker := balanceTracker{} + tracker := balanceTracker{exp: zeroExpCtrl{}} tracker.init(clock, 1000) defer tracker.stop(clock.Now()) @@ -53,14 +66,13 @@ func TestSetBalance(t *testing.T) { func TestBalanceTimeCost(t *testing.T) { var ( clock = &mclock.Simulated{} - tracker = balanceTracker{} + tracker = balanceTracker{exp: zeroExpCtrl{}} ) tracker.init(clock, 1000) defer tracker.stop(clock.Now()) - tracker.setFactors(false, 1, 1) - tracker.setFactors(true, 1, 1) + tracker.setFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) - tracker.setBalance(uint64(time.Minute), 0) // 1 minute time allowance + tracker.setBalance(expval(uint64(time.Minute)), expval(0)) // 1 minute time allowance var inputs = []struct { runTime time.Duration @@ -74,21 +86,21 @@ func TestBalanceTimeCost(t *testing.T) { } for _, i := range inputs { clock.Run(i.runTime) - if pos, _ := tracker.getBalance(clock.Now()); pos != i.expPos { + if pos, _ := tracker.getBalance(clock.Now()); pos != expval(i.expPos) { t.Fatalf("Positive balance mismatch, want %v, got %v", i.expPos, pos) } - if _, neg := tracker.getBalance(clock.Now()); neg != i.expNeg { + if _, neg := tracker.getBalance(clock.Now()); neg != expval(i.expNeg) { t.Fatalf("Negative balance mismatch, want %v, got %v", i.expNeg, neg) } } - tracker.setBalance(uint64(time.Minute), 0) // Refill 1 minute time allowance + tracker.setBalance(expval(uint64(time.Minute)), expval(0)) // Refill 1 minute time allowance for _, i := range inputs { clock.Run(i.runTime) - if pos, _ := tracker.getBalance(clock.Now()); pos != i.expPos { + if pos, _ := tracker.getBalance(clock.Now()); pos != expval(i.expPos) { t.Fatalf("Positive balance mismatch, want %v, got %v", i.expPos, pos) } - if _, neg := tracker.getBalance(clock.Now()); neg != i.expNeg { + if _, neg := tracker.getBalance(clock.Now()); neg != expval(i.expNeg) { t.Fatalf("Negative balance mismatch, want %v, got %v", i.expNeg, neg) } } @@ -97,14 +109,13 @@ func TestBalanceTimeCost(t *testing.T) { func TestBalanceReqCost(t *testing.T) { var ( clock = &mclock.Simulated{} - tracker = balanceTracker{} + tracker = balanceTracker{exp: zeroExpCtrl{}} ) tracker.init(clock, 1000) defer tracker.stop(clock.Now()) - tracker.setFactors(false, 1, 1) - tracker.setFactors(true, 1, 1) + tracker.setFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) - tracker.setBalance(uint64(time.Minute), 0) // 1 minute time serving time allowance + tracker.setBalance(expval(uint64(time.Minute)), expval(0)) // 1 minute time serving time allowance var inputs = []struct { reqCost uint64 expPos uint64 @@ -117,10 +128,10 @@ func TestBalanceReqCost(t *testing.T) { } for _, i := range inputs { tracker.requestCost(i.reqCost) - if pos, _ := tracker.getBalance(clock.Now()); pos != i.expPos { + if pos, _ := tracker.getBalance(clock.Now()); pos != expval(i.expPos) { t.Fatalf("Positive balance mismatch, want %v, got %v", i.expPos, pos) } - if _, neg := tracker.getBalance(clock.Now()); neg != i.expNeg { + if _, neg := tracker.getBalance(clock.Now()); neg != expval(i.expNeg) { t.Fatalf("Negative balance mismatch, want %v, got %v", i.expNeg, neg) } } @@ -129,25 +140,24 @@ func TestBalanceReqCost(t *testing.T) { func TestBalanceToPriority(t *testing.T) { var ( clock = &mclock.Simulated{} - tracker = balanceTracker{} + tracker = balanceTracker{exp: zeroExpCtrl{}} ) tracker.init(clock, 1000) // cap = 1000 defer tracker.stop(clock.Now()) - tracker.setFactors(false, 1, 1) - tracker.setFactors(true, 1, 1) + tracker.setFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) var inputs = []struct { pos uint64 neg uint64 priority int64 }{ - {1000, 0, ^int64(1)}, - {2000, 0, ^int64(2)}, // Higher balance, lower priority value + {1000, 0, -1}, + {2000, 0, -2}, // Higher balance, lower priority value {0, 0, 0}, {0, 1000, 1000}, } for _, i := range inputs { - tracker.setBalance(i.pos, i.neg) + tracker.setBalance(expval(i.pos), expval(i.neg)) priority := tracker.getPriority(clock.Now()) if priority != i.priority { t.Fatalf("Priority mismatch, want %v, got %v", i.priority, priority) @@ -158,30 +168,29 @@ func TestBalanceToPriority(t *testing.T) { func TestEstimatedPriority(t *testing.T) { var ( clock = &mclock.Simulated{} - tracker = balanceTracker{} + tracker = balanceTracker{exp: zeroExpCtrl{}} ) tracker.init(clock, 1000000000) // cap = 1000,000,000 defer tracker.stop(clock.Now()) - tracker.setFactors(false, 1, 1) - tracker.setFactors(true, 1, 1) + tracker.setFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) - tracker.setBalance(uint64(time.Minute), 0) + tracker.setBalance(expval(uint64(time.Minute)), expval(0)) var inputs = []struct { runTime time.Duration // time cost futureTime time.Duration // diff of future time reqCost uint64 // single request cost priority int64 // expected estimated priority }{ - {time.Second, time.Second, 0, ^int64(58)}, - {0, time.Second, 0, ^int64(58)}, + {time.Second, time.Second, 0, -58}, + {0, time.Second, 0, -58}, // 2 seconds time cost, 1 second estimated time cost, 10^9 request cost, // 10^9 estimated request cost per second. - {time.Second, time.Second, 1000000000, ^int64(55)}, + {time.Second, time.Second, 1000000000, -55}, // 3 seconds time cost, 3 second estimated time cost, 10^9*2 request cost, // 4*10^9 estimated request cost. - {time.Second, 3 * time.Second, 1000000000, ^int64(48)}, + {time.Second, 3 * time.Second, 1000000000, -48}, // All positive balance is used up {time.Second * 55, 0, 0, 0}, @@ -202,22 +211,21 @@ func TestEstimatedPriority(t *testing.T) { func TestCallbackChecking(t *testing.T) { var ( clock = &mclock.Simulated{} - tracker = balanceTracker{} + tracker = balanceTracker{exp: zeroExpCtrl{}} ) tracker.init(clock, 1000000) // cap = 1000,000 defer tracker.stop(clock.Now()) - tracker.setFactors(false, 1, 1) - tracker.setFactors(true, 1, 1) + tracker.setFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) var inputs = []struct { priority int64 expDiff time.Duration }{ - {^int64(500), time.Millisecond * 500}, + {-500, time.Millisecond * 500}, {0, time.Second}, {int64(time.Second), 2 * time.Second}, } - tracker.setBalance(uint64(time.Second), 0) + tracker.setBalance(expval(uint64(time.Second)), expval(0)) for _, i := range inputs { diff, _ := tracker.timeUntil(i.priority) if diff != i.expDiff { @@ -229,15 +237,14 @@ func TestCallbackChecking(t *testing.T) { func TestCallback(t *testing.T) { var ( clock = &mclock.Simulated{} - tracker = balanceTracker{} + tracker = balanceTracker{exp: zeroExpCtrl{}} ) tracker.init(clock, 1000) // cap = 1000 defer tracker.stop(clock.Now()) - tracker.setFactors(false, 1, 1) - tracker.setFactors(true, 1, 1) + tracker.setFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) callCh := make(chan struct{}, 1) - tracker.setBalance(uint64(time.Minute), 0) + tracker.setBalance(expval(uint64(time.Minute)), expval(0)) tracker.addCallback(balanceCallbackZero, 0, func() { callCh <- struct{}{} }) clock.Run(time.Minute) @@ -247,7 +254,7 @@ func TestCallback(t *testing.T) { t.Fatalf("Callback hasn't been called yet") } - tracker.setBalance(uint64(time.Minute), 0) + tracker.setBalance(expval(uint64(time.Minute)), expval(0)) tracker.addCallback(balanceCallbackZero, 0, func() { callCh <- struct{}{} }) tracker.removeCallback(balanceCallbackZero) diff --git a/les/client.go b/les/client.go index dfd0909778c8..643aa88dc429 100644 --- a/les/client.go +++ b/les/client.go @@ -50,6 +50,7 @@ type LightEthereum struct { lesCommons peers *serverPeerSet + srvr *p2p.Server reqDist *requestDistributor retriever *retrieveManager odr *LesOdr @@ -208,6 +209,12 @@ func (s *LightEthereum) APIs() []rpc.API { Service: NewPrivateLightAPI(&s.lesCommons), Public: false, }, + { + Namespace: "lespay", + Version: "1.0", + Service: NewPrivateLespayAPI(nil, s.peers, s.handler, s.srvr.DiscV5, nil), + Public: false, + }, }...) } @@ -237,6 +244,7 @@ func (s *LightEthereum) Protocols() []p2p.Protocol { // light ethereum protocol implementation. func (s *LightEthereum) Start(srvr *p2p.Server) error { log.Warn("Light client mode is an experimental feature") + s.srvr = srvr // Start bloom request workers. s.wg.Add(bloomServiceThreads) diff --git a/les/client_handler.go b/les/client_handler.go index d04574c8c7fa..56392632bcf3 100644 --- a/les/client_handler.go +++ b/les/client_handler.go @@ -40,6 +40,9 @@ type clientHandler struct { downloader *downloader.Downloader backend *LightEthereum + lespayReplyHandlers map[uint64]func([]byte, uint) bool + lespayReplyLock sync.Mutex + closeCh chan struct{} wg sync.WaitGroup // WaitGroup used to track all connected peers. syncDone func() // Test hooks when syncing is done. @@ -47,9 +50,10 @@ type clientHandler struct { func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.TrustedCheckpoint, backend *LightEthereum) *clientHandler { handler := &clientHandler{ - checkpoint: checkpoint, - backend: backend, - closeCh: make(chan struct{}), + checkpoint: checkpoint, + backend: backend, + closeCh: make(chan struct{}), + lespayReplyHandlers: make(map[uint64]func([]byte, uint) bool), } if ulcServers != nil { ulc, err := newULC(ulcServers, ulcFraction) @@ -112,28 +116,48 @@ func (h *clientHandler) handle(p *serverPeer) error { p.Log().Debug("Light Ethereum handshake failed", "err", err) return err } - // Register the peer locally - if err := h.backend.peers.register(p); err != nil { - p.Log().Error("Light Ethereum peer registration failed", "err", err) - return err - } - serverConnectionGauge.Update(int64(h.backend.peers.len())) - connectedAt := mclock.Now() - defer func() { - h.backend.peers.unregister(p.id) + var ( + connectedAt mclock.AbsTime + lastActive bool + ) + activate := func() { + // Register the peer locally + if err := h.backend.peers.register(p); err != nil { + p.Log().Error("Light Ethereum peer registration failed", "err", err) + return + } + serverConnectionGauge.Update(int64(h.backend.peers.len())) + connectedAt = mclock.Now() + h.fetcher.announce(p, &announceData{Hash: p.headInfo.Hash, Number: p.headInfo.Number, Td: p.headInfo.Td}) + lastActive = true + } + deactivate := func() { + h.backend.peers.unregister(p) connectionTimer.Update(time.Duration(mclock.Now() - connectedAt)) serverConnectionGauge.Update(int64(h.backend.peers.len())) + lastActive = false + } + defer func() { + if lastActive { + deactivate() + } + h.backend.peers.disconnect(p.id) }() - h.fetcher.announce(p, &announceData{Hash: p.headInfo.Hash, Number: p.headInfo.Number, Td: p.headInfo.Td}) - // pool entry can be nil during the unit test. if p.poolEntry != nil { h.backend.serverPool.registered(p.poolEntry) } + // Spawn a main loop to handle all incoming messages. for { + if p.active && !lastActive { + activate() + } + if !p.active && lastActive { + deactivate() + } if err := h.handleMsg(p); err != nil { p.Log().Debug("Light Ethereum message handling failed", "err", err) p.fcServer.DumpLogs() @@ -157,7 +181,10 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { } defer msg.Discard() - var deliverMsg *Msg + var ( + deliverMsg *Msg + responseError bool + ) // Handle the message depending on its contents switch msg.Code { @@ -193,13 +220,15 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { case BlockHeadersMsg: p.Log().Trace("Received block header response message") var resp struct { - ReqID, BV uint64 - Headers []*types.Header + ReqID uint64 + SF stateFeedback + Headers []*types.Header } + resp.SF.protocolVersion = p.version if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV) if h.fetcher.requestedID(resp.ReqID) { h.fetcher.deliverHeaders(p, resp.ReqID, resp.Headers) } else { @@ -210,13 +239,15 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { case BlockBodiesMsg: p.Log().Trace("Received block bodies response") var resp struct { - ReqID, BV uint64 - Data []*types.Body + ReqID uint64 + SF stateFeedback + Data []*types.Body } + resp.SF.protocolVersion = p.version if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV) deliverMsg = &Msg{ MsgType: MsgBlockBodies, ReqID: resp.ReqID, @@ -225,13 +256,15 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { case CodeMsg: p.Log().Trace("Received code response") var resp struct { - ReqID, BV uint64 - Data [][]byte + ReqID uint64 + SF stateFeedback + Data [][]byte } + resp.SF.protocolVersion = p.version if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV) deliverMsg = &Msg{ MsgType: MsgCode, ReqID: resp.ReqID, @@ -240,13 +273,15 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { case ReceiptsMsg: p.Log().Trace("Received receipts response") var resp struct { - ReqID, BV uint64 - Receipts []types.Receipts + ReqID uint64 + SF stateFeedback + Receipts []types.Receipts } + resp.SF.protocolVersion = p.version if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV) deliverMsg = &Msg{ MsgType: MsgReceipts, ReqID: resp.ReqID, @@ -255,13 +290,15 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { case ProofsV2Msg: p.Log().Trace("Received les/2 proofs response") var resp struct { - ReqID, BV uint64 - Data light.NodeList + ReqID uint64 + SF stateFeedback + Data light.NodeList } + resp.SF.protocolVersion = p.version if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV) deliverMsg = &Msg{ MsgType: MsgProofsV2, ReqID: resp.ReqID, @@ -270,13 +307,15 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { case HelperTrieProofsMsg: p.Log().Trace("Received helper trie proof response") var resp struct { - ReqID, BV uint64 - Data HelperTrieResps + ReqID uint64 + SF stateFeedback + Data HelperTrieResps } + resp.SF.protocolVersion = p.version if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV) deliverMsg = &Msg{ MsgType: MsgHelperTrieProofs, ReqID: resp.ReqID, @@ -285,13 +324,15 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { case TxStatusMsg: p.Log().Trace("Received tx status response") var resp struct { - ReqID, BV uint64 - Status []light.TxStatus + ReqID uint64 + SF stateFeedback + Status []light.TxStatus } + resp.SF.protocolVersion = p.version if err := msg.Decode(&resp); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ReceivedReply(resp.ReqID, resp.BV) + p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV) deliverMsg = &Msg{ MsgType: MsgTxStatus, ReqID: resp.ReqID, @@ -302,13 +343,32 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { h.backend.retriever.frozen(p) p.Log().Debug("Service stopped") case ResumeMsg: - var bv uint64 - if err := msg.Decode(&bv); err != nil { + var sf stateFeedback + sf.protocolVersion = p.version + if err := msg.Decode(&sf); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } - p.fcServer.ResumeFreeze(bv) + p.fcServer.ResumeFreeze(sf.BV) p.unfreeze() p.Log().Debug("Service resumed") + case LespayReplyMsg: + p.Log().Trace("Received tx status response") + var resp struct { + ReqID uint64 + Reply lespayReply + } + if err := msg.Decode(&resp); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + h.lespayReplyLock.Lock() + if handler := h.lespayReplyHandlers[resp.ReqID]; handler != nil { + delete(h.lespayReplyHandlers, resp.ReqID) + responseError = !handler(resp.Reply.Reply, resp.Reply.Delay) + } else { + responseError = true + } + h.lespayReplyLock.Unlock() + default: p.Log().Trace("Received invalid message", "code", msg.Code) return errResp(ErrInvalidMsgCode, "%v", msg.Code) @@ -316,17 +376,48 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { // Deliver the received response to retriever. if deliverMsg != nil { if err := h.backend.retriever.deliver(p, deliverMsg); err != nil { - p.errCount++ - if p.errCount > maxResponseErrors { - return err - } + responseError = true + } + } + if responseError { + p.errCount++ + if p.errCount > maxResponseErrors { + return err } } return nil } +// makeLespayCall sends a lespay command through an LES connection and registers +// a response handler. It returns a cancel function that removes the response +// handler and calls it with a nil parameter if the response has not arrived yet. +func (h *clientHandler) makeLespayCall(p *serverPeer, cmd []byte, handler func([]byte, uint) bool) func() bool { + reqID := genReqID() + h.lespayReplyLock.Lock() + h.lespayReplyHandlers[reqID] = handler + h.lespayReplyLock.Unlock() + if p.sendLespay(reqID, cmd) != nil { + h.lespayReplyLock.Lock() + delete(h.lespayReplyHandlers, reqID) + h.lespayReplyLock.Unlock() + return nil + } + return func() bool { + h.lespayReplyLock.Lock() + cancel := h.lespayReplyHandlers[reqID] != nil + if cancel { + delete(h.lespayReplyHandlers, reqID) + } + h.lespayReplyLock.Unlock() + if cancel { + handler(nil, 0) + } + return cancel + } +} + func (h *clientHandler) removePeer(id string) { - h.backend.peers.unregister(id) + h.backend.peers.disconnect(id) } type peerConnection struct { diff --git a/les/clientdb.go b/les/clientdb.go new file mode 100644 index 000000000000..f24d499479b2 --- /dev/null +++ b/les/clientdb.go @@ -0,0 +1,339 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "bytes" + "encoding/binary" + "io" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/rlp" + lru "github.com/hashicorp/golang-lru" +) + +const balanceCacheLimit = 8192 // the maximum number of cached items in service token balance queue + +// tokenBalance is a wrapper of expiredValue which represents the service token +// balance of clients. The balance value will decay exponentially over time and +// can be deleted when the amount is small enough. +type tokenBalance struct { + value expiredValue +} + +// EncodeRLP implements rlp.Encoder +func (b *tokenBalance) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, []interface{}{b.value.base, b.value.exp}) +} + +// DecodeRLP implements rlp.Decoder +func (b *tokenBalance) DecodeRLP(s *rlp.Stream) error { + var entry struct { + Base, Exp uint64 + } + if err := s.Decode(&entry); err != nil { + return err + } + b.value = expiredValue{base: entry.Base, exp: entry.Exp} + return nil +} + +// currencyBalance represents the client's currency balance. +type currencyBalance struct { + amount uint64 + typ string +} + +// EncodeRLP implements rlp.Encoder +func (b *currencyBalance) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, []interface{}{b.amount, b.typ}) +} + +// DecodeRLP implements rlp.Decoder +func (b *currencyBalance) DecodeRLP(s *rlp.Stream) error { + var entry struct { + Amount uint64 + Type string + } + if err := s.Decode(&entry); err != nil { + return err + } + b.amount, b.typ = entry.Amount, entry.Type + return nil +} + +const ( + // nodeDBVersion is the version identifier of the node data in db + // + // Changelog: + // * Replace `lastTotal` with `meta` in positive balance: version 0=>1 + // * Rework balance, add currency balance: version 1=>2 + nodeDBVersion = 2 + + // dbCleanupCycle is the cycle of db for useless data cleanup + dbCleanupCycle = time.Hour +) + +var ( + posBalancePrefix = []byte("pb:") // dbVersion(uint16 big endian) + posBalancePrefix + id -> positive balance + negBalancePrefix = []byte("nb:") // dbVersion(uint16 big endian) + negBalancePrefix + ip -> negative balance + curBalancePrefix = []byte("cb:") // dbVersion(uint16 big endian) + curBalancePrefix + id -> currency balance + paymentReceiverPrefix = []byte("pr:") // dbVersion(uint16 big endian) + paymentReceiverPrefix + id + "receiverName:" -> receiver namespace + expirationKey = []byte("expiration:") // dbVersion(uint16 big endian) + expirationKey -> posExp, negExp +) + +type atomicWriteLock struct { + released chan struct{} + batch ethdb.Batch +} + +type nodeDB struct { + db ethdb.Database + cache *lru.Cache + clock mclock.Clock + closeCh chan struct{} + evictCallBack func(mclock.AbsTime, bool, tokenBalance) bool // Callback to determine whether the balance can be evicted. + idLockMutex sync.Mutex + idLocks map[string]atomicWriteLock + cleanupHook func() // Test hook used for testing +} + +func newNodeDB(db ethdb.Database, clock mclock.Clock) *nodeDB { + var buff [2]byte + binary.BigEndian.PutUint16(buff[:], uint16(nodeDBVersion)) + + cache, _ := lru.New(balanceCacheLimit) + ndb := &nodeDB{ + db: rawdb.NewTable(db, string(buff[:])), + cache: cache, + clock: clock, + closeCh: make(chan struct{}), + idLocks: make(map[string]atomicWriteLock), + } + go ndb.expirer() + return ndb +} + +func (db *nodeDB) close() { + close(db.closeCh) +} + +func (db *nodeDB) atomicWriteLock(id []byte) ethdb.KeyValueWriter { + db.idLockMutex.Lock() + for { + ch := db.idLocks[string(id)].released + if ch == nil { + break + } + db.idLockMutex.Unlock() + <-ch + db.idLockMutex.Lock() + } + batch := db.db.NewBatch() + db.idLocks[string(id)] = atomicWriteLock{ + released: make(chan struct{}), + batch: batch, + } + db.idLockMutex.Unlock() + return batch +} + +func (db *nodeDB) atomicWriteUnlock(id []byte) { + db.idLockMutex.Lock() + awl := db.idLocks[string(id)] + awl.batch.Write() + close(awl.released) + delete(db.idLocks, string(id)) + db.idLockMutex.Unlock() +} + +func (db *nodeDB) writer(id []byte) ethdb.KeyValueWriter { + db.idLockMutex.Lock() + batch := db.idLocks[string(id)].batch + db.idLockMutex.Unlock() + if batch == nil { + return db.db + } + return batch +} + +func idKey(id []byte, neg bool) []byte { + prefix := posBalancePrefix + if neg { + prefix = negBalancePrefix + } + return append(prefix, id...) +} + +func receiverPrefix(id enode.ID, receiver string) []byte { + return append(append(paymentReceiverPrefix, id.Bytes()...), []byte(receiver+":")...) +} + +func (db *nodeDB) getExpiration() (fixed64, fixed64) { + blob, err := db.db.Get(expirationKey) + if err != nil || len(blob) != 16 { + return 0, 0 + } + return fixed64(binary.BigEndian.Uint64(blob[:8])), fixed64(binary.BigEndian.Uint64(blob[8:16])) +} + +func (db *nodeDB) setExpiration(pos, neg fixed64) { + var buff [16]byte + binary.BigEndian.PutUint64(buff[:8], uint64(pos)) + binary.BigEndian.PutUint64(buff[8:16], uint64(neg)) + db.db.Put(expirationKey, buff[:16]) +} + +func (db *nodeDB) getCurrencyBalance(id enode.ID) currencyBalance { + var b currencyBalance + enc, err := db.db.Get(append(curBalancePrefix, id.Bytes()...)) + if err != nil || len(enc) == 0 { + return b + } + if err := rlp.DecodeBytes(enc, &b); err != nil { + log.Crit("Failed to decode positive balance", "err", err) + } + return b +} + +func (db *nodeDB) setCurrencyBalance(id enode.ID, b currencyBalance) { + enc, err := rlp.EncodeToBytes(&(b)) + if err != nil { + log.Crit("Failed to encode currency balance", "err", err) + } + db.writer(id.Bytes()).Put(append(curBalancePrefix, id.Bytes()...), enc) +} + +func (db *nodeDB) getOrNewBalance(id []byte, neg bool) tokenBalance { + key := idKey(id, neg) + item, exist := db.cache.Get(string(key)) + if exist { + return item.(tokenBalance) + } + var b tokenBalance + enc, err := db.db.Get(key) + if err != nil || len(enc) == 0 { + return b + } + if err := rlp.DecodeBytes(enc, &b); err != nil { + log.Crit("Failed to decode positive balance", "err", err) + } + db.cache.Add(string(key), b) + return b +} + +func (db *nodeDB) setBalance(id []byte, neg bool, b tokenBalance) { + key := idKey(id, neg) + enc, err := rlp.EncodeToBytes(&(b)) + if err != nil { + log.Crit("Failed to encode positive balance", "err", err) + } + if neg { + db.db.Put(key, enc) + } else { + db.writer(id).Put(key, enc) + } + db.cache.Add(string(key), b) +} + +func (db *nodeDB) delBalance(id []byte, neg bool) { + key := idKey(id, neg) + if neg { + db.db.Delete(key) + } else { + db.writer(id).Delete(key) + } + db.cache.Remove(string(key)) +} + +// getPosBalanceIDs returns a lexicographically ordered list of IDs of accounts +// with a positive balance +func (db *nodeDB) getPosBalanceIDs(start, stop enode.ID, maxCount int) (result []enode.ID) { + if maxCount <= 0 { + return + } + it := db.db.NewIteratorWithStart(idKey(start.Bytes(), false)) + defer it.Release() + for i := len(stop[:]) - 1; i >= 0; i-- { + stop[i]-- + if stop[i] != 255 { + break + } + } + stopKey := idKey(stop.Bytes(), false) + keyLen := len(stopKey) + + for it.Next() { + var id enode.ID + if len(it.Key()) != keyLen || bytes.Compare(it.Key(), stopKey) == 1 { + return + } + copy(id[:], it.Key()[keyLen-len(id):]) + result = append(result, id) + if len(result) == maxCount { + return + } + } + return +} + +func (db *nodeDB) expirer() { + for { + select { + case <-db.clock.After(dbCleanupCycle): + db.expireNodes() + case <-db.closeCh: + return + } + } +} + +// expireNodes iterates the whole node db and checks whether the +// token balances can deleted. +func (db *nodeDB) expireNodes() { + var ( + visited int + deleted int + start = time.Now() + ) + for index, prefix := range [][]byte{posBalancePrefix, negBalancePrefix} { + iter := db.db.NewIteratorWithPrefix(prefix) + for iter.Next() { + visited += 1 + var balance tokenBalance + if err := rlp.DecodeBytes(iter.Value(), &balance); err != nil { + log.Crit("Failed to decode negative balance", "err", err) + } + if db.evictCallBack != nil && db.evictCallBack(db.clock.Now(), index != 0, balance) { + deleted += 1 + db.db.Delete(iter.Key()) + } + } + } + // Invoke testing hook if it's not nil. + if db.cleanupHook != nil { + db.cleanupHook() + } + log.Debug("Expire nodes", "visited", visited, "deleted", deleted, "elapsed", common.PrettyDuration(time.Since(start))) +} diff --git a/les/clientdb_test.go b/les/clientdb_test.go new file mode 100644 index 000000000000..ac2181fb667b --- /dev/null +++ b/les/clientdb_test.go @@ -0,0 +1,139 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "reflect" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +func TestNodeDB(t *testing.T) { + ndb := newNodeDB(rawdb.NewMemoryDatabase(), mclock.System{}) + defer ndb.close() + + var cases = []struct { + id enode.ID + ip string + balance tokenBalance + positive bool + }{ + {enode.ID{0x00, 0x01, 0x02}, "", tokenBalance{value: expval(100)}, true}, + {enode.ID{0x00, 0x01, 0x02}, "", tokenBalance{value: expval(200)}, true}, + {enode.ID{}, "127.0.0.1", tokenBalance{value: expval(100)}, false}, + {enode.ID{}, "127.0.0.1", tokenBalance{value: expval(200)}, false}, + } + for _, c := range cases { + if c.positive { + ndb.setBalance(c.id.Bytes(), false, c.balance) + if pb := ndb.getOrNewBalance(c.id.Bytes(), false); !reflect.DeepEqual(pb, c.balance) { + t.Fatalf("Positive balance mismatch, want %v, got %v", c.balance, pb) + } + } else { + ndb.setBalance([]byte(c.ip), true, c.balance) + if nb := ndb.getOrNewBalance([]byte(c.ip), true); !reflect.DeepEqual(nb, c.balance) { + t.Fatalf("Negative balance mismatch, want %v, got %v", c.balance, nb) + } + } + } + for _, c := range cases { + if c.positive { + ndb.delBalance(c.id.Bytes(), false) + if pb := ndb.getOrNewBalance(c.id.Bytes(), false); !reflect.DeepEqual(pb, tokenBalance{}) { + t.Fatalf("Positive balance mismatch, want %v, got %v", tokenBalance{}, pb) + } + } else { + ndb.delBalance([]byte(c.ip), true) + if nb := ndb.getOrNewBalance([]byte(c.ip), true); !reflect.DeepEqual(nb, tokenBalance{}) { + t.Fatalf("Negative balance mismatch, want %v, got %v", tokenBalance{}, nb) + } + } + } + posExp, negExp := fixed64(1000), fixed64(2000) + ndb.setExpiration(posExp, negExp) + if pos, neg := ndb.getExpiration(); pos != posExp || neg != negExp { + t.Fatalf("Expiration mismatch, want %v / %v, got %v / %v", posExp, negExp, pos, neg) + } + curBalance := currencyBalance{typ: "ETH", amount: 10000} + ndb.setCurrencyBalance(enode.ID{0x01, 0x02}, curBalance) + if got := ndb.getCurrencyBalance(enode.ID{0x01, 0x02}); !reflect.DeepEqual(got, curBalance) { + t.Fatalf("Currency balance mismatch, want %v, got %v", curBalance, got) + } +} + +func TestNodeDBExpiration(t *testing.T) { + var ( + iterated int + done = make(chan struct{}, 1) + ) + callback := func(now mclock.AbsTime, neg bool, b tokenBalance) bool { + iterated += 1 + return true + } + clock := &mclock.Simulated{} + ndb := newNodeDB(rawdb.NewMemoryDatabase(), clock) + defer ndb.close() + ndb.evictCallBack = callback + ndb.cleanupHook = func() { done <- struct{}{} } + + var cases = []struct { + id []byte + neg bool + balance tokenBalance + }{ + {[]byte{0x01, 0x02}, false, tokenBalance{value: expval(1)}}, + {[]byte{0x03, 0x04}, false, tokenBalance{value: expval(1)}}, + {[]byte{0x05, 0x06}, false, tokenBalance{value: expval(1)}}, + {[]byte{0x07, 0x08}, false, tokenBalance{value: expval(1)}}, + + {[]byte("127.0.0.1"), true, tokenBalance{value: expval(1)}}, + {[]byte("127.0.0.2"), true, tokenBalance{value: expval(1)}}, + {[]byte("127.0.0.3"), true, tokenBalance{value: expval(1)}}, + {[]byte("127.0.0.4"), true, tokenBalance{value: expval(1)}}, + } + for _, c := range cases { + ndb.setBalance(c.id, c.neg, c.balance) + } + clock.WaitForTimers(1) + clock.Run(time.Hour + time.Minute) + select { + case <-done: + case <-time.NewTimer(time.Second).C: + t.Fatalf("timeout") + } + if iterated != 8 { + t.Fatalf("Failed to evict useless balances, want %v, got %d", 8, iterated) + } + + for _, c := range cases { + ndb.setBalance(c.id, c.neg, c.balance) + } + clock.WaitForTimers(1) + clock.Run(time.Hour + time.Minute) + select { + case <-done: + case <-time.NewTimer(time.Second).C: + t.Fatalf("timeout") + } + if iterated != 16 { + t.Fatalf("Failed to evict useless balances, want %v, got %d", 16, iterated) + } +} diff --git a/les/clientpool.go b/les/clientpool.go index b01c825a7a96..035f8f2c2708 100644 --- a/les/clientpool.go +++ b/les/clientpool.go @@ -17,40 +17,35 @@ package les import ( - "bytes" - "encoding/binary" "fmt" - "io" "math" "sync" "time" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/prque" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/rlp" - lru "github.com/hashicorp/golang-lru" ) const ( - negBalanceExpTC = time.Hour // time constant for exponentially reducing negative balance - fixedPointMultiplier = 0x1000000 // constant to convert logarithms to fixed point format - lazyQueueRefresh = time.Second * 10 // refresh period of the connected queue - persistCumulativeTimeRefresh = time.Minute * 5 // refresh period of the cumulative running time persistence - posBalanceCacheLimit = 8192 // the maximum number of cached items in positive balance queue - negBalanceCacheLimit = 8192 // the maximum number of cached items in negative balance queue - - // connectedBias is applied to already connected clients So that + defaultPosExpTC = 36000 // default time constant (in seconds) for exponentially reducing positive balance + defaultNegExpTC = 3600 // default time constant (in seconds) for exponentially reducing negative balance + lazyQueueRefresh = time.Second * 10 // refresh period of the connected queue + tryActivatePeriod = time.Second * 5 // periodically check whether inactive clients can be activated + dropInactiveCycles = 2 // number of activation check periods after non-priority inactive peers are dropped + persistExpirationRefresh = time.Minute * 5 // refresh period of the token expiration persistence + freeRatioTC = time.Hour // time constant of token supply control based on free service availability + + // activeBias is applied to already connected clients So that // already connected client won't be kicked out very soon and we // can ensure all connected clients can have enough time to request // or sync some data. // // todo(rjl493456442) make it configurable. It can be the option of // free trial time! - connectedBias = time.Minute * 3 + activeBias = time.Minute * 3 ) // clientPool implements a client database that assigns a priority to each client @@ -60,7 +55,7 @@ const ( // then negative balance is accumulated. // // Balance tracking and priority calculation for connected clients is done by -// balanceTracker. connectedQueue ensures that clients with the lowest positive or +// balanceTracker. activeQueue ensures that clients with the lowest positive or // highest negative balance get evicted when the total capacity allowance is full // and new clients with a better balance want to connect. // @@ -69,11 +64,8 @@ const ( // each client can have several minutes of connection time. // // Balances of disconnected clients are stored in nodeDB including positive balance -// and negative banalce. Negative balance is transformed into a logarithmic form -// with a constantly shifting linear offset in order to implement an exponential -// decrease. Besides nodeDB will have a background thread to check the negative -// balance of disconnected client. If the balance is low enough, then the record -// will be dropped. +// and negative banalce. Boeth positive balance and negative balance will decrease +// exponentially. If the balance is low enough, then the record will be dropped. type clientPool struct { ndb *nodeDB lock sync.Mutex @@ -82,19 +74,32 @@ type clientPool struct { closed bool removePeer func(enode.ID) - connectedMap map[enode.ID]*clientInfo - connectedQueue *prque.LazyQueue + connectedMap map[enode.ID]*clientInfo + activeQueue *prque.LazyQueue + inactiveQueue *prque.Prque + dropInactivePeers map[uint64][]*clientInfo + dropInactiveCounter uint64 + + activeBalances, inactiveBalances expiredValue + lastConnectedBalanceUpdate mclock.AbsTime + freeRatio, averageFreeRatio float64 defaultPosFactors, defaultNegFactors priceFactors - connLimit int // The maximum number of connections that clientpool can support - capLimit uint64 // The maximum cumulative capacity that clientpool can support - connectedCap uint64 // The sum of the capacity of the current clientpool connected - priorityConnected uint64 // The sum of the capacity of currently connected priority clients - freeClientCap uint64 // The capacity value of each free client - startTime mclock.AbsTime // The timestamp at which the clientpool started running - cumulativeTime int64 // The cumulative running time of clientpool at the start point. - disableBias bool // Disable connection bias(used in testing) + activeLimit int // The maximum number of connections that clientpool can support + capLimit uint64 // The maximum cumulative capacity that clientpool can support + activeCap uint64 // The sum of the capacity of the current clientpool connected + priorityActive uint64 // The sum of the capacity of currently connected priority clients + minCap uint64 // The minimal capacity value allowed for any client + freeClientCap uint64 // The capacity value of each free client + disableBias bool // Disable connection bias(used in testing) + + // fields in this group are protected by expLock + expLock sync.RWMutex + posExpTC, negExpTC uint64 + posExp, negExp fixed64 + posExpTCi, negExpTCi float64 // already inverted (logMultiplier/time) + freeRatioLastUpdate mclock.AbsTime } // clientPoolPeer represents a client peer in the pool. @@ -106,87 +111,146 @@ type clientPoolPeer interface { ID() enode.ID freeClientId() string updateCapacity(uint64) - freezeClient() + freeze() } -// clientInfo represents a connected client +// clientInfo defines all information required by clientpool. type clientInfo struct { - address string - id enode.ID - connectedAt mclock.AbsTime - capacity uint64 - priority bool - pool *clientPool - peer clientPoolPeer - queueIndex int // position in connectedQueue - balanceTracker balanceTracker - posFactors, negFactors priceFactors - balanceMetaInfo string -} - -// connSetIndex callback updates clientInfo item index in connectedQueue + id enode.ID + address string + active bool + capacity uint64 + priority bool + pool *clientPool + peer clientPoolPeer + connectedAt mclock.AbsTime + queueIndex int + balanceTracker balanceTracker + posFactors priceFactors + negFactors priceFactors +} + +// connSetIndex callback updates clientInfo item index in activeQueue func connSetIndex(a interface{}, index int) { a.(*clientInfo).queueIndex = index } -// connPriority callback returns actual priority of clientInfo item in connectedQueue +// connPriority callback returns actual priority of clientInfo item in activeQueue func connPriority(a interface{}, now mclock.AbsTime) int64 { c := a.(*clientInfo) return c.balanceTracker.getPriority(now) } -// connMaxPriority callback returns estimated maximum priority of clientInfo item in connectedQueue +// connMaxPriority callback returns estimated maximum priority of clientInfo item in activeQueue func connMaxPriority(a interface{}, until mclock.AbsTime) int64 { c := a.(*clientInfo) pri := c.balanceTracker.estimatedPriority(until, true) c.balanceTracker.addCallback(balanceCallbackQueue, pri+1, func() { c.pool.lock.Lock() - if c.queueIndex != -1 { - c.pool.connectedQueue.Update(c.queueIndex) + if c.active && c.queueIndex != -1 { + c.pool.activeQueue.Update(c.queueIndex) } c.pool.lock.Unlock() }) return pri } -// priceFactors determine the pricing policy (may apply either to positive or -// negative balances which may have different factors). -// - timeFactor is cost unit per nanosecond of connection time -// - capacityFactor is cost unit per nanosecond of connection time per 1000000 capacity -// - requestFactor is cost unit per request "realCost" unit -type priceFactors struct { - timeFactor, capacityFactor, requestFactor float64 -} - // newClientPool creates a new client pool -func newClientPool(db ethdb.Database, freeClientCap uint64, clock mclock.Clock, removePeer func(enode.ID)) *clientPool { +func newClientPool(db ethdb.Database, minCap, freeClientCap uint64, clock mclock.Clock, removePeer func(enode.ID)) *clientPool { ndb := newNodeDB(db, clock) + posExp, negExp := ndb.getExpiration() pool := &clientPool{ - ndb: ndb, - clock: clock, - connectedMap: make(map[enode.ID]*clientInfo), - connectedQueue: prque.NewLazyQueue(connSetIndex, connPriority, connMaxPriority, clock, lazyQueueRefresh), - freeClientCap: freeClientCap, - removePeer: removePeer, - startTime: clock.Now(), - cumulativeTime: ndb.getCumulativeTime(), - stopCh: make(chan struct{}), - } - // If the negative balance of free client is even lower than 1, - // delete this entry. - ndb.nbEvictCallBack = func(now mclock.AbsTime, b negBalance) bool { - balance := math.Exp(float64(b.logValue-pool.logOffset(now)) / fixedPointMultiplier) - return balance <= 1 + ndb: ndb, + clock: clock, + connectedMap: make(map[enode.ID]*clientInfo), + activeQueue: prque.NewLazyQueue(connSetIndex, connPriority, connMaxPriority, clock, lazyQueueRefresh), + inactiveQueue: prque.New(connSetIndex), + dropInactivePeers: make(map[uint64][]*clientInfo), + minCap: minCap, + freeClientCap: freeClientCap, + removePeer: removePeer, + freeRatioLastUpdate: clock.Now(), + posExp: posExp, + negExp: negExp, + freeRatio: 1, + averageFreeRatio: 1, + stopCh: make(chan struct{}), + } + // set default expiration constants used by tests + // Note: server overwrites this if token sale is active + pool.setExpirationTCs(0, defaultNegExpTC) + // calculate total token balance amount + var start enode.ID + for { + ids := pool.ndb.getPosBalanceIDs(start, enode.ID{}, 1000) + var stop bool + l := len(ids) + if l == 1000 { + l-- + start = ids[l] + } else { + stop = true + } + for i := 0; i < l; i++ { + pool.inactiveBalances.addExp(pool.ndb.getOrNewBalance(ids[i].Bytes(), false).value) + } + if stop { + break + } + } + // The positive and negative balances of clients are stored in database + // and both of these decay exponentially over time. Delete them if the + // value is small enough. + ndb.evictCallBack = func(now mclock.AbsTime, neg bool, b tokenBalance) bool { + var expiration fixed64 + if neg { + expiration = pool.negExpiration(now) + } else { + expiration = pool.posExpiration(now) + } + return b.value.value(expiration) <= uint64(time.Second) } go func() { for { select { case <-clock.After(lazyQueueRefresh): pool.lock.Lock() - pool.connectedQueue.Refresh() + pool.activeQueue.Refresh() + pool.lock.Unlock() + case <-pool.stopCh: + return + } + } + }() + go func() { + for { + select { + case <-clock.After(persistExpirationRefresh): + pool.lock.Lock() + now := pool.clock.Now() + posExp := pool.posExpiration(now) + negExp := pool.negExpiration(now) + pool.lock.Unlock() + pool.ndb.setExpiration(posExp, negExp) + case <-pool.stopCh: + return + } + } + }() + go func() { + for { + select { + case <-clock.After(tryActivatePeriod): + pool.lock.Lock() + pool.tryActivateClients() + for _, c := range pool.dropInactivePeers[pool.dropInactiveCounter] { + if _, ok := pool.connectedMap[c.id]; ok && !c.active && !c.priority { + pool.drop(c.peer, true) + } + } + delete(pool.dropInactivePeers, pool.dropInactiveCounter) + pool.dropInactiveCounter++ pool.lock.Unlock() - case <-clock.After(persistCumulativeTimeRefresh): - pool.ndb.setCumulativeTime(pool.logOffset(clock.Now())) case <-pool.stopCh: return } @@ -201,122 +265,212 @@ func (f *clientPool) stop() { f.lock.Lock() f.closed = true f.lock.Unlock() - f.ndb.setCumulativeTime(f.logOffset(f.clock.Now())) + now := f.clock.Now() + f.ndb.setExpiration(f.posExpiration(now), f.negExpiration(now)) f.ndb.close() } +// updateFreeRatio updates freeRatio, averageFreeRatio, posExp and negExp based +// on free service availability. Should be called after capLimit or priorityActive +// is changed. +func (f *clientPool) updateFreeRatio() { + f.freeRatio = 0 + if f.priorityActive < f.capLimit { + freeCap := f.capLimit - f.priorityActive + if freeCap > f.freeClientCap { + freeCapThreshold := f.capLimit / 4 + if freeCap > freeCapThreshold { + f.freeRatio = 1 + } else { + f.freeRatio = float64(freeCap-f.freeClientCap) / float64(freeCapThreshold-f.freeClientCap) + } + } + } + f.expLock.Lock() + now := f.clock.Now() + dt := now - f.freeRatioLastUpdate + if dt < 0 { + dt = 0 + } + f.averageFreeRatio -= (f.freeRatio - f.averageFreeRatio) * math.Expm1(-float64(dt)/float64(freeRatioTC)) + f.freeRatioLastUpdate = now + + f.posExp += fixed64(float64(dt) * f.posExpTCi * f.freeRatio) + f.negExp += fixed64(float64(dt) * f.negExpTCi * f.freeRatio) + f.expLock.Unlock() +} + +// setExpirationTCs sets positive and negative token expiration time constants. +// Specified in seconds, 0 means infinite (no expiration). +func (f *clientPool) setExpirationTCs(pos, neg uint64) { + f.lock.Lock() + f.updateFreeRatio() + f.lock.Unlock() + + f.expLock.Lock() + f.posExpTC, f.negExpTC = pos, neg + if pos > 0 { + f.posExpTCi = fixedFactor / float64(pos*uint64(time.Second)) + } else { + f.posExpTCi = 0 + } + if neg > 0 { + f.negExpTCi = fixedFactor / float64(neg*uint64(time.Second)) + } else { + f.negExpTCi = 0 + } + f.expLock.Unlock() +} + +// getExpirationTCs returns the current positive and negative token expiration +// time constants +func (f *clientPool) getExpirationTCs() (pos, neg uint64) { + f.expLock.Lock() + defer f.expLock.Unlock() + + return f.posExpTC, f.negExpTC +} + +// posExpiration implements expirationController. Expiration happens only when +// free service is available. +func (f *clientPool) posExpiration(now mclock.AbsTime) fixed64 { + f.expLock.RLock() + defer f.expLock.RUnlock() + + if f.posExpTC == 0 { + return f.posExp + } + dt := now - f.freeRatioLastUpdate + if dt < 0 { + dt = 0 + } + return f.posExp + fixed64(float64(dt)*f.posExpTCi*f.freeRatio) +} + +// negExpiration implements expirationController. Expiration happens only when +// free service is available. +func (f *clientPool) negExpiration(now mclock.AbsTime) fixed64 { + f.expLock.RLock() + defer f.expLock.RUnlock() + + if f.negExpTC == 0 { + return f.negExp + } + dt := now - f.freeRatioLastUpdate + if dt < 0 { + dt = 0 + } + return f.negExp + fixed64(float64(dt)*f.negExpTCi*f.freeRatio) +} + +// totalTokenLimit returns the current token supply limit. Token prices are based +// on the ratio of total token amount and supply limit while the limit depends on +// averageFreeRatio, ensuring the availability of free service most of the time. +func (f *clientPool) totalTokenLimit() uint64 { + f.lock.Lock() + defer f.lock.Unlock() + + f.updateFreeRatio() + d := f.averageFreeRatio + if d > 0.5 { + d = -math.Log(0.5/d) * float64(freeRatioTC) + } else { + d = 0 + } + return uint64(d * float64(f.capLimit) * f.defaultPosFactors.capacityFactor) +} + +// totalTokenAmount returns the total amount of currently existing service tokens +func (f *clientPool) totalTokenAmount() uint64 { + f.lock.Lock() + defer f.lock.Unlock() + + now := f.clock.Now() + if now > f.lastConnectedBalanceUpdate+mclock.AbsTime(time.Second) { + f.activeBalances = expiredValue{} + for _, c := range f.connectedMap { + pos, _ := c.balanceTracker.getBalance(now) + f.activeBalances.addExp(pos) + } + f.lastConnectedBalanceUpdate = now + } + sum := f.activeBalances + sum.addExp(f.inactiveBalances) + return sum.value(f.posExpiration(now)) +} + // connect should be called after a successful handshake. If the connection was // rejected, there is no need to call disconnect. -func (f *clientPool) connect(peer clientPoolPeer, capacity uint64) bool { +func (f *clientPool) connect(peer clientPoolPeer, reqCapacity uint64) (uint64, error) { f.lock.Lock() defer f.lock.Unlock() // Short circuit if clientPool is already closed. if f.closed { - return false + return 0, fmt.Errorf("Client pool is already closed") } // Dedup connected peers. id, freeID := peer.ID(), peer.freeClientId() if _, ok := f.connectedMap[id]; ok { clientRejectedMeter.Mark(1) log.Debug("Client already connected", "address", freeID, "id", peerIdToString(id)) - return false - } - // Create a clientInfo but do not add it yet - var ( - posBalance uint64 - negBalance uint64 - now = f.clock.Now() - ) - pb := f.ndb.getOrNewPB(id) - posBalance = pb.value - - nb := f.ndb.getOrNewNB(freeID) - if nb.logValue != 0 { - negBalance = uint64(math.Exp(float64(nb.logValue-f.logOffset(now))/fixedPointMultiplier) * float64(time.Second)) + return 0, fmt.Errorf("Client already connected address=%s id=%s", freeID, peerIdToString(id)) } + pb := f.ndb.getOrNewBalance(id.Bytes(), false) + nb := f.ndb.getOrNewBalance([]byte(freeID), true) e := &clientInfo{ - pool: f, - peer: peer, - address: freeID, - queueIndex: -1, - id: id, - connectedAt: now, - priority: posBalance != 0, - posFactors: f.defaultPosFactors, - negFactors: f.defaultNegFactors, - balanceMetaInfo: pb.meta, - } - // If the client is a free client, assign with a low free capacity, - // Otherwise assign with the given value(priority client) - if !e.priority || capacity == 0 { - capacity = f.freeClientCap - } + id: id, + address: freeID, + capacity: reqCapacity, + pool: f, + peer: peer, + queueIndex: -1, + connectedAt: f.clock.Now(), + priority: pb.value.base != 0, + posFactors: f.defaultPosFactors, + negFactors: f.defaultNegFactors, + } + missing, capacity := f.capAvailable(id, freeID, reqCapacity, 0, true) + f.connectedMap[id] = e + if missing != 0 { + // capacity is not available, add client to inactive queue + f.initBalanceTracker(&e.balanceTracker, pb, nb, capacity, false) + f.inactiveQueue.Push(e, -connPriority(e, f.clock.Now())) + return 0, nil + } + // capacity is available, add client + e.active = true e.capacity = capacity - - // Starts a balance tracker - e.balanceTracker.init(f.clock, capacity) - e.balanceTracker.setBalance(posBalance, negBalance) - e.updatePriceFactors() - - // If the number of clients already connected in the clientpool exceeds its - // capacity, evict some clients with lowest priority. - // - // If the priority of the newly added client is lower than the priority of - // all connected clients, the client is rejected. - newCapacity := f.connectedCap + capacity - newCount := f.connectedQueue.Size() + 1 - if newCapacity > f.capLimit || newCount > f.connLimit { - var ( - kickList []*clientInfo - kickPriority int64 - ) - f.connectedQueue.MultiPop(func(data interface{}, priority int64) bool { - c := data.(*clientInfo) - kickList = append(kickList, c) - kickPriority = priority - newCapacity -= c.capacity - newCount-- - return newCapacity > f.capLimit || newCount > f.connLimit - }) - bias := connectedBias - if f.disableBias { - bias = 0 - } - if newCapacity > f.capLimit || newCount > f.connLimit || (e.balanceTracker.estimatedPriority(now+mclock.AbsTime(bias), false)-kickPriority) > 0 { - for _, c := range kickList { - f.connectedQueue.Push(c) - } - clientRejectedMeter.Mark(1) - log.Debug("Client rejected", "address", freeID, "id", peerIdToString(id)) - return false - } - // accept new client, drop old ones - for _, c := range kickList { - f.dropClient(c, now, true) - } - } - + f.initBalanceTracker(&e.balanceTracker, pb, nb, capacity, true) // Register new client to connection queue. - f.connectedMap[id] = e - f.connectedQueue.Push(e) - f.connectedCap += e.capacity + f.inactiveBalances.subExp(pb.value) + f.activeBalances.addExp(pb.value) + f.activeQueue.Push(e) + f.activeCap += e.capacity // If the current client is a paid client, monitor the status of client, // downgrade it to normal client if positive balance is used up. if e.priority { - f.priorityConnected += capacity + f.priorityActive += capacity + f.updateFreeRatio() e.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) }) } - // If the capacity of client is not the default value(free capacity), notify - // it to update capacity. - if e.capacity != f.freeClientCap { - e.peer.updateCapacity(e.capacity) - } - totalConnectedGauge.Update(int64(f.connectedCap)) + totalConnectedGauge.Update(int64(f.activeCap)) clientConnectedMeter.Mark(1) log.Debug("Client accepted", "address", freeID) - return true + return e.capacity, nil +} + +// initBalanceTracker initializes the positive and negative balances and price factors +func (f *clientPool) initBalanceTracker(bt *balanceTracker, pb tokenBalance, nb tokenBalance, capacity uint64, active bool) { + bt.exp = f + bt.init(f.clock, capacity) + bt.setBalance(pb.value, nb.value) + if active { + updatePriceFactors(bt, f.defaultPosFactors, f.defaultNegFactors) + } else { + zeroPriceFactors(bt) + } } // disconnect should be called when a connection is terminated. If the disconnection @@ -326,17 +480,106 @@ func (f *clientPool) disconnect(p clientPoolPeer) { f.lock.Lock() defer f.lock.Unlock() + f.drop(p, false) +} + +// drop deactivates the peer if necessary and drops it from the inactive queue +func (f *clientPool) drop(p clientPoolPeer, kicked bool) { // Short circuit if client pool is already closed. if f.closed { return } - // Short circuit if the peer hasn't been registered. - e := f.connectedMap[p.ID()] - if e == nil { + e, ok := f.connectedMap[p.ID()] + if !ok { log.Debug("Client not connected", "address", p.freeClientId(), "id", peerIdToString(p.ID())) return } - f.dropClient(e, f.clock.Now(), false) + tryActivate := e.active + if e.active { + f.deactivateClient(e, false) + } + f.finalizeBalance(e, f.clock.Now()) + f.inactiveQueue.Remove(e.queueIndex) + delete(f.connectedMap, e.id) + if kicked { + clientKickedMeter.Mark(1) + log.Debug("Client kicked out", "address", e.address) + } else { + clientDisconnectedMeter.Mark(1) + log.Debug("Client disconnected", "address", e.address) + } + if tryActivate { + f.tryActivateClients() + } +} + +// capAvailable checks whether the current priority level of the given client is enough to +// connect or change capacity to the requested level and then stay connected for at least +// the specified duration. If not then the additional required amount of positive balance is returned. +func (f *clientPool) capAvailable(id enode.ID, freeID string, capacity uint64, minConnTime time.Duration, kick bool) (uint64, uint64) { + var missing uint64 + if capacity == 0 { + capacity = f.freeClientCap + } + if capacity < f.minCap { + capacity = f.minCap + } + newCapacity := f.activeCap + capacity + newCount := f.activeQueue.Size() + 1 + client := f.connectedMap[id] + if client != nil && client.active { + newCapacity -= client.capacity + newCount-- + } + if newCapacity > f.capLimit || newCount > f.activeLimit { + var ( + popList []*clientInfo + targetPriority int64 + ) + f.activeQueue.MultiPop(func(data interface{}, priority int64) bool { + c := data.(*clientInfo) + popList = append(popList, c) + if c != client { + targetPriority = priority + newCapacity -= c.capacity + newCount-- + } + return newCapacity > f.capLimit || newCount > f.activeLimit + }) + if newCapacity > f.capLimit || newCount > f.activeLimit { + missing = math.MaxUint64 + } else { + var bt *balanceTracker + if client != nil { + bt = &client.balanceTracker + } else { + bt = &balanceTracker{} + f.initBalanceTracker(bt, f.ndb.getOrNewBalance(id.Bytes(), false), f.ndb.getOrNewBalance([]byte(freeID), true), capacity, true) + } + if capacity != f.freeClientCap && targetPriority >= 0 { + targetPriority = -1 + } + bias := activeBias + if f.disableBias { + bias = 0 + } + if bias < minConnTime { + bias = minConnTime + } + missing = bt.posBalanceMissing(targetPriority, capacity, bias) + } + if missing != 0 { + kick = false + } + for _, c := range popList { + if kick && c != client { + f.deactivateClient(c, true) + } else { + f.activeQueue.Push(c) + } + } + } + return missing, capacity } // forClients iterates through a list of clients, calling the callback for each one. @@ -371,27 +614,61 @@ func (f *clientPool) setDefaultFactors(posFactors, negFactors priceFactors) { f.defaultNegFactors = negFactors } -// dropClient removes a client from the connected queue and finalizes its balance. -// If kick is true then it also initiates the disconnection. -func (f *clientPool) dropClient(e *clientInfo, now mclock.AbsTime, kick bool) { - if _, ok := f.connectedMap[e.id]; !ok { +// deactivateClient puts a client in inactive state +func (f *clientPool) deactivateClient(e *clientInfo, scheduleDrop bool) { + if _, ok := f.connectedMap[e.id]; !ok || !e.active { return } - f.finalizeBalance(e, now) - f.connectedQueue.Remove(e.queueIndex) - delete(f.connectedMap, e.id) - f.connectedCap -= e.capacity + f.activeQueue.Remove(e.queueIndex) + f.activeCap -= e.capacity if e.priority { - f.priorityConnected -= e.capacity - } - totalConnectedGauge.Update(int64(f.connectedCap)) - if kick { - clientKickedMeter.Mark(1) - log.Debug("Client kicked out", "address", e.address) - f.removePeer(e.id) - } else { - clientDisconnectedMeter.Mark(1) - log.Debug("Client disconnected", "address", e.address) + f.priorityActive -= e.capacity + f.updateFreeRatio() + } + e.active = false + e.peer.updateCapacity(0) + totalConnectedGauge.Update(int64(f.activeCap)) + f.inactiveQueue.Push(e, -connPriority(e, f.clock.Now())) + if scheduleDrop { + f.dropInactivePeers[f.dropInactiveCounter+dropInactiveCycles] = append(f.dropInactivePeers[f.dropInactiveCounter+dropInactiveCycles], e) + } +} + +// tryActivateClients checks whether some inactive clients have enough priority now +// and activates them if possible +func (f *clientPool) tryActivateClients() { + now := f.clock.Now() + for f.inactiveQueue.Size() != 0 { + e := f.inactiveQueue.PopItem().(*clientInfo) + missing, capacity := f.capAvailable(e.id, e.address, e.capacity, 0, true) + if missing != 0 { + f.inactiveQueue.Push(e, -connPriority(e, now)) + return + } + // capacity is available, activate client + e.active = true + e.capacity = capacity + e.peer.updateCapacity(capacity) + balance, _ := e.balanceTracker.getBalance(now) + e.balanceTracker.setCapacity(capacity) + updatePriceFactors(&e.balanceTracker, f.defaultPosFactors, f.defaultNegFactors) + // Register activated client to connection queue. + f.inactiveBalances.subExp(balance) + f.activeBalances.addExp(balance) + f.activeQueue.Push(e) + f.activeCap += e.capacity + + // If the current client is a paid client, monitor the status of client, + // downgrade it to normal client if positive balance is used up. + if e.priority { + f.priorityActive += capacity + f.updateFreeRatio() + e.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(e.id) }) + } + e.peer.updateCapacity(e.capacity) + totalConnectedGauge.Update(int64(f.activeCap)) + clientConnectedMeter.Mark(1) + log.Debug("Client activated", "address", e.address) } } @@ -401,7 +678,7 @@ func (f *clientPool) capacityInfo() (uint64, uint64, uint64) { f.lock.Lock() defer f.lock.Unlock() - return f.capLimit, f.connectedCap, f.priorityConnected + return f.capLimit, f.activeCap, f.priorityActive } // finalizeBalance stops the balance tracker, retrieves the final balances and @@ -409,17 +686,27 @@ func (f *clientPool) capacityInfo() (uint64, uint64, uint64) { func (f *clientPool) finalizeBalance(c *clientInfo, now mclock.AbsTime) { c.balanceTracker.stop(now) pos, neg := c.balanceTracker.getBalance(now) + f.inactiveBalances.addExp(pos) + f.activeBalances.subExp(pos) - pb, nb := f.ndb.getOrNewPB(c.id), f.ndb.getOrNewNB(c.address) - pb.value = pos - f.ndb.setPB(c.id, pb) - - neg /= uint64(time.Second) // Convert the expanse to second level. - if neg > 1 { - nb.logValue = int64(math.Log(float64(neg))*fixedPointMultiplier) + f.logOffset(now) - f.ndb.setNB(c.address, nb) - } else { - f.ndb.delNB(c.address) // Negative balance is small enough, drop it directly. + for index, value := range []expiredValue{pos, neg} { + var ( + id []byte + expiration fixed64 + ) + neg := index == 1 + if !neg { + id = c.id.Bytes() + expiration = f.posExpiration(f.clock.Now()) + } else { + id = []byte(c.address) + expiration = f.negExpiration(f.clock.Now()) + } + if value.value(expiration) > uint64(time.Second) { + f.ndb.setBalance(id, neg, tokenBalance{value: value}) + } else { + f.ndb.delBalance(id, neg) // balance is small enough, drop it directly. + } } } @@ -434,429 +721,165 @@ func (f *clientPool) balanceExhausted(id enode.ID) { return } if c.priority { - f.priorityConnected -= c.capacity + f.priorityActive -= c.capacity + f.updateFreeRatio() } c.priority = false if c.capacity != f.freeClientCap { - f.connectedCap += f.freeClientCap - c.capacity - totalConnectedGauge.Update(int64(f.connectedCap)) + f.activeCap += f.freeClientCap - c.capacity + totalConnectedGauge.Update(int64(f.activeCap)) c.capacity = f.freeClientCap c.balanceTracker.setCapacity(c.capacity) c.peer.updateCapacity(c.capacity) } - pb := f.ndb.getOrNewPB(id) - pb.value = 0 - f.ndb.setPB(id, pb) + f.ndb.delBalance(id.Bytes(), false) } -// setConnLimit sets the maximum number and total capacity of connected clients, +// setactiveLimit sets the maximum number and total capacity of connected clients, // dropping some of them if necessary. func (f *clientPool) setLimits(totalConn int, totalCap uint64) { f.lock.Lock() defer f.lock.Unlock() - f.connLimit = totalConn + f.activeLimit = totalConn f.capLimit = totalCap - if f.connectedCap > f.capLimit || f.connectedQueue.Size() > f.connLimit { - f.connectedQueue.MultiPop(func(data interface{}, priority int64) bool { - f.dropClient(data.(*clientInfo), mclock.Now(), true) - return f.connectedCap > f.capLimit || f.connectedQueue.Size() > f.connLimit + if f.activeCap > f.capLimit || f.activeQueue.Size() > f.activeLimit { + f.activeQueue.MultiPop(func(data interface{}, priority int64) bool { + f.deactivateClient(data.(*clientInfo), true) + return f.activeCap > f.capLimit || f.activeQueue.Size() > f.activeLimit }) + } else { + f.tryActivateClients() } + f.updateFreeRatio() } // setCapacity sets the assigned capacity of a connected client -func (f *clientPool) setCapacity(c *clientInfo, capacity uint64) error { - if f.connectedMap[c.id] != c { - return fmt.Errorf("client %064x is not connected", c.id[:]) - } - if c.capacity == capacity { - return nil - } - if !c.priority { - return errNoPriority - } - oldCapacity := c.capacity - c.capacity = capacity - f.connectedCap += capacity - oldCapacity - c.balanceTracker.setCapacity(capacity) - f.connectedQueue.Update(c.queueIndex) - if f.connectedCap > f.capLimit { - var kickList []*clientInfo - kick := true - f.connectedQueue.MultiPop(func(data interface{}, priority int64) bool { - client := data.(*clientInfo) - kickList = append(kickList, client) - f.connectedCap -= client.capacity - if client == c { - kick = false - } - return kick && (f.connectedCap > f.capLimit) - }) - if kick { - now := mclock.Now() - for _, c := range kickList { - f.dropClient(c, now, true) - } - } else { - c.capacity = oldCapacity - c.balanceTracker.setCapacity(oldCapacity) - for _, c := range kickList { - f.connectedCap += c.capacity - f.connectedQueue.Push(c) - } - return errNoPriority +func (f *clientPool) setCapacity(id enode.ID, freeID string, capacity uint64, minConnTime time.Duration, setCap bool) (uint64, uint64, error) { + c := f.connectedMap[id] + if c != nil { + if c.capacity == capacity { + return 0, capacity, nil } } - totalConnectedGauge.Update(int64(f.connectedCap)) - f.priorityConnected += capacity - oldCapacity - c.updatePriceFactors() - c.peer.updateCapacity(c.capacity) - return nil + var missing uint64 + missing, capacity = f.capAvailable(id, freeID, capacity, 0, setCap && c != nil) + if missing != 0 { + return missing, capacity, errNoPriority + } + // capacity update is possible + if setCap { + if c == nil { + return 0, capacity, fmt.Errorf("client %064x is not connected", c.id[:]) + } + f.activeCap += capacity - c.capacity + f.priorityActive += capacity - c.capacity + f.updateFreeRatio() + c.capacity = capacity + c.balanceTracker.setCapacity(capacity) + f.activeQueue.Update(c.queueIndex) + totalConnectedGauge.Update(int64(f.activeCap)) + updatePriceFactors(&c.balanceTracker, c.posFactors, c.negFactors) + c.peer.updateCapacity(c.capacity) + f.tryActivateClients() + } + return 0, capacity, nil } -// requestCost feeds request cost after serving a request from the given peer. -func (f *clientPool) requestCost(p *clientPeer, cost uint64) { +// setCapacityLocked is the equivalent of setCapacity used when f.lock is already locked +func (f *clientPool) setCapacityLocked(id enode.ID, freeID string, capacity uint64, minConnTime time.Duration, setCap bool) (uint64, uint64, error) { f.lock.Lock() defer f.lock.Unlock() - info, exist := f.connectedMap[p.ID()] - if !exist || f.closed { - return + return f.setCapacity(id, freeID, capacity, minConnTime, setCap) +} + +// requestCost feeds request cost after serving a request from the given peer and +// returns the remaining token balance +func (f *clientPool) requestCost(p *clientPeer, cost uint64) uint64 { + f.lock.Lock() + defer f.lock.Unlock() + + c := f.connectedMap[p.ID()] + if c == nil || f.closed { + return 0 } - info.balanceTracker.requestCost(cost) + return c.balanceTracker.requestCost(cost) } -// logOffset calculates the time-dependent offset for the logarithmic -// representation of negative balance -// -// From another point of view, the result returned by the function represents -// the total time that the clientpool is cumulatively running(total_hours/multiplier). -func (f *clientPool) logOffset(now mclock.AbsTime) int64 { - // Note: fixedPointMultiplier acts as a multiplier here; the reason for dividing the divisor - // is to avoid int64 overflow. We assume that int64(negBalanceExpTC) >> fixedPointMultiplier. - cumulativeTime := int64((time.Duration(now - f.startTime)) / (negBalanceExpTC / fixedPointMultiplier)) - return f.cumulativeTime + cumulativeTime +// updatePriceFactors sets the pricing factors for an individual connected client +func updatePriceFactors(bt *balanceTracker, posFactors, negFactors priceFactors) { + bt.setFactors(posFactors, negFactors) } -// setClientPriceFactors sets the pricing factors for an individual connected client -func (c *clientInfo) updatePriceFactors() { - c.balanceTracker.setFactors(true, c.negFactors.timeFactor+float64(c.capacity)*c.negFactors.capacityFactor/1000000, c.negFactors.requestFactor) - c.balanceTracker.setFactors(false, c.posFactors.timeFactor+float64(c.capacity)*c.posFactors.capacityFactor/1000000, c.posFactors.requestFactor) +// zeroPriceFactors sets the pricing factors to zero +func zeroPriceFactors(bt *balanceTracker) { + bt.setFactors(priceFactors{0, 0, 0}, priceFactors{0, 0, 0}) } // getPosBalance retrieves a single positive balance entry from cache or the database -func (f *clientPool) getPosBalance(id enode.ID) posBalance { +func (f *clientPool) getPosBalance(id enode.ID) tokenBalance { f.lock.Lock() defer f.lock.Unlock() - return f.ndb.getOrNewPB(id) + if c := f.connectedMap[id]; c != nil { + value, _ := c.balanceTracker.getBalance(f.clock.Now()) + return tokenBalance{value: value} + } else { + return f.ndb.getOrNewBalance(id.Bytes(), false) + } } // addBalance updates the balance of a client (either overwrites it or adds to it). // It also updates the balance meta info string. -func (f *clientPool) addBalance(id enode.ID, amount int64, meta string) (uint64, uint64, error) { +func (f *clientPool) addBalance(id enode.ID, amount int64) (uint64, uint64, error) { f.lock.Lock() defer f.lock.Unlock() - pb := f.ndb.getOrNewPB(id) - var negBalance uint64 + now := f.clock.Now() + pb := f.ndb.getOrNewBalance(id.Bytes(), false) + var negBalance expiredValue c := f.connectedMap[id] if c != nil { - pb.value, negBalance = c.balanceTracker.getBalance(f.clock.Now()) + pb.value, negBalance = c.balanceTracker.getBalance(now) } oldBalance := pb.value - if amount > 0 { - if amount > maxBalance || pb.value > maxBalance-uint64(amount) { - return oldBalance, oldBalance, errBalanceOverflow - } - pb.value += uint64(amount) - } else { - if uint64(-amount) > pb.value { - pb.value = 0 - } else { - pb.value -= uint64(-amount) - } + posExp := f.posExpiration(now) + oldValue := oldBalance.value(posExp) + if amount > 0 && (amount > maxBalance || oldValue > maxBalance-uint64(amount)) { + return oldValue, oldValue, errBalanceOverflow } - pb.meta = meta - f.ndb.setPB(id, pb) + pb.value.add(amount, posExp) + f.ndb.setBalance(id.Bytes(), false, pb) if c != nil { c.balanceTracker.setBalance(pb.value, negBalance) - if !c.priority && pb.value > 0 { - // The capacity should be adjusted based on the requirement, - // but we have no idea about the new capacity, need a second - // call to udpate it. - c.priority = true - f.priorityConnected += c.capacity - c.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) }) - } - // if balance is set to zero then reverting to non-priority status - // is handled by the balanceExhausted callback - c.balanceMetaInfo = meta - } - return oldBalance, pb.value, nil -} - -// posBalance represents a recently accessed positive balance entry -type posBalance struct { - value uint64 - meta string -} - -// EncodeRLP implements rlp.Encoder -func (e *posBalance) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, []interface{}{e.value, e.meta}) -} - -// DecodeRLP implements rlp.Decoder -func (e *posBalance) DecodeRLP(s *rlp.Stream) error { - var entry struct { - Value uint64 - Meta string - } - if err := s.Decode(&entry); err != nil { - return err - } - e.value = entry.Value - e.meta = entry.Meta - return nil -} - -// negBalance represents a negative balance entry of a disconnected client -type negBalance struct{ logValue int64 } - -// EncodeRLP implements rlp.Encoder -func (e *negBalance) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, []interface{}{uint64(e.logValue)}) -} - -// DecodeRLP implements rlp.Decoder -func (e *negBalance) DecodeRLP(s *rlp.Stream) error { - var entry struct { - LogValue uint64 - } - if err := s.Decode(&entry); err != nil { - return err - } - e.logValue = int64(entry.LogValue) - return nil -} - -const ( - // nodeDBVersion is the version identifier of the node data in db - // - // Changelog: - // * Replace `lastTotal` with `meta` in positive balance: version 0=>1 - nodeDBVersion = 1 - - // dbCleanupCycle is the cycle of db for useless data cleanup - dbCleanupCycle = time.Hour -) - -var ( - positiveBalancePrefix = []byte("pb:") // dbVersion(uint16 big endian) + positiveBalancePrefix + id -> balance - negativeBalancePrefix = []byte("nb:") // dbVersion(uint16 big endian) + negativeBalancePrefix + ip -> balance - cumulativeRunningTimeKey = []byte("cumulativeTime:") // dbVersion(uint16 big endian) + cumulativeRunningTimeKey -> cumulativeTime -) - -type nodeDB struct { - db ethdb.Database - pcache *lru.Cache - ncache *lru.Cache - auxbuf []byte // 37-byte auxiliary buffer for key encoding - verbuf [2]byte // 2-byte auxiliary buffer for db version - nbEvictCallBack func(mclock.AbsTime, negBalance) bool // Callback to determine whether the negative balance can be evicted. - clock mclock.Clock - closeCh chan struct{} - cleanupHook func() // Test hook used for testing -} - -func newNodeDB(db ethdb.Database, clock mclock.Clock) *nodeDB { - pcache, _ := lru.New(posBalanceCacheLimit) - ncache, _ := lru.New(negBalanceCacheLimit) - ndb := &nodeDB{ - db: db, - pcache: pcache, - ncache: ncache, - auxbuf: make([]byte, 37), - clock: clock, - closeCh: make(chan struct{}), - } - binary.BigEndian.PutUint16(ndb.verbuf[:], uint16(nodeDBVersion)) - go ndb.expirer() - return ndb -} - -func (db *nodeDB) close() { - close(db.closeCh) -} - -func (db *nodeDB) key(id []byte, neg bool) []byte { - prefix := positiveBalancePrefix - if neg { - prefix = negativeBalancePrefix - } - if len(prefix)+len(db.verbuf)+len(id) > len(db.auxbuf) { - db.auxbuf = append(db.auxbuf, make([]byte, len(prefix)+len(db.verbuf)+len(id)-len(db.auxbuf))...) - } - copy(db.auxbuf[:len(db.verbuf)], db.verbuf[:]) - copy(db.auxbuf[len(db.verbuf):len(db.verbuf)+len(prefix)], prefix) - copy(db.auxbuf[len(prefix)+len(db.verbuf):len(prefix)+len(db.verbuf)+len(id)], id) - return db.auxbuf[:len(prefix)+len(db.verbuf)+len(id)] -} - -func (db *nodeDB) getCumulativeTime() int64 { - blob, err := db.db.Get(append(cumulativeRunningTimeKey, db.verbuf[:]...)) - if err != nil || len(blob) == 0 { - return 0 - } - return int64(binary.BigEndian.Uint64(blob)) -} - -func (db *nodeDB) setCumulativeTime(v int64) { - binary.BigEndian.PutUint64(db.auxbuf[:8], uint64(v)) - db.db.Put(append(cumulativeRunningTimeKey, db.verbuf[:]...), db.auxbuf[:8]) -} - -func (db *nodeDB) getOrNewPB(id enode.ID) posBalance { - key := db.key(id.Bytes(), false) - item, exist := db.pcache.Get(string(key)) - if exist { - return item.(posBalance) - } - var balance posBalance - if enc, err := db.db.Get(key); err == nil { - if err := rlp.DecodeBytes(enc, &balance); err != nil { - log.Error("Failed to decode positive balance", "err", err) - } - } - db.pcache.Add(string(key), balance) - return balance -} - -func (db *nodeDB) setPB(id enode.ID, b posBalance) { - if b.value == 0 && len(b.meta) == 0 { - db.delPB(id) - return - } - key := db.key(id.Bytes(), false) - enc, err := rlp.EncodeToBytes(&(b)) - if err != nil { - log.Error("Failed to encode positive balance", "err", err) - return - } - db.db.Put(key, enc) - db.pcache.Add(string(key), b) -} - -func (db *nodeDB) delPB(id enode.ID) { - key := db.key(id.Bytes(), false) - db.db.Delete(key) - db.pcache.Remove(string(key)) -} - -// getPosBalanceIDs returns a lexicographically ordered list of IDs of accounts -// with a positive balance -func (db *nodeDB) getPosBalanceIDs(start, stop enode.ID, maxCount int) (result []enode.ID) { - if maxCount <= 0 { - return - } - it := db.db.NewIteratorWithStart(db.key(start.Bytes(), false)) - defer it.Release() - for i := len(stop[:]) - 1; i >= 0; i-- { - stop[i]-- - if stop[i] != 255 { - break - } - } - stopKey := db.key(stop.Bytes(), false) - keyLen := len(stopKey) - - for it.Next() { - var id enode.ID - if len(it.Key()) != keyLen || bytes.Compare(it.Key(), stopKey) == 1 { - return - } - copy(id[:], it.Key()[keyLen-len(id):]) - result = append(result, id) - if len(result) == maxCount { - return - } - } - return -} - -func (db *nodeDB) getOrNewNB(id string) negBalance { - key := db.key([]byte(id), true) - item, exist := db.ncache.Get(string(key)) - if exist { - return item.(negBalance) - } - var balance negBalance - if enc, err := db.db.Get(key); err == nil { - if err := rlp.DecodeBytes(enc, &balance); err != nil { - log.Error("Failed to decode negative balance", "err", err) - } - } - db.ncache.Add(string(key), balance) - return balance -} - -func (db *nodeDB) setNB(id string, b negBalance) { - key := db.key([]byte(id), true) - enc, err := rlp.EncodeToBytes(&(b)) - if err != nil { - log.Error("Failed to encode negative balance", "err", err) - return - } - db.db.Put(key, enc) - db.ncache.Add(string(key), b) -} - -func (db *nodeDB) delNB(id string) { - key := db.key([]byte(id), true) - db.db.Delete(key) - db.ncache.Remove(string(key)) -} - -func (db *nodeDB) expirer() { - for { - select { - case <-db.clock.After(dbCleanupCycle): - db.expireNodes() - case <-db.closeCh: - return - } - } -} - -// expireNodes iterates the whole node db and checks whether the negative balance -// entry can deleted. -// -// The rationale behind this is: server doesn't need to keep the negative balance -// records if they are low enough. -func (db *nodeDB) expireNodes() { - var ( - visited int - deleted int - start = time.Now() - ) - iter := db.db.NewIteratorWithPrefix(append(db.verbuf[:], negativeBalancePrefix...)) - for iter.Next() { - visited += 1 - var balance negBalance - if err := rlp.DecodeBytes(iter.Value(), &balance); err != nil { - log.Error("Failed to decode negative balance", "err", err) - continue + if c.active { + f.activeQueue.Update(c.queueIndex) + if !c.priority && pb.value.base > 0 { + // The capacity should be adjusted based on the requirement, + // but we have no idea about the new capacity, need a second + // call to udpate it. + f.priorityActive += c.capacity + f.updateFreeRatio() + c.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) }) + } + f.activeBalances.subExp(oldBalance) + f.activeBalances.addExp(pb.value) + } else { + f.inactiveQueue.Remove(c.queueIndex) + f.inactiveQueue.Push(c, -connPriority(c, f.clock.Now())) + f.inactiveBalances.subExp(oldBalance) + f.inactiveBalances.addExp(pb.value) } - if db.nbEvictCallBack != nil && db.nbEvictCallBack(db.clock.Now(), balance) { - deleted += 1 - db.db.Delete(iter.Key()) + if pb.value.base > 0 { + c.priority = true + // if balance is set to zero then reverting to non-priority status + // is handled by the balanceExhausted callback } + } else { + f.inactiveBalances.subExp(oldBalance) + f.inactiveBalances.addExp(pb.value) } - // Invoke testing hook if it's not nil. - if db.cleanupHook != nil { - db.cleanupHook() - } - log.Debug("Expire nodes", "visited", visited, "deleted", deleted, "elapsed", common.PrettyDuration(time.Since(start))) + f.tryActivateClients() + return oldValue, pb.value.value(posExp), nil } diff --git a/les/clientpool_test.go b/les/clientpool_test.go index 6308113fe74c..a34deff1b88d 100644 --- a/les/clientpool_test.go +++ b/les/clientpool_test.go @@ -17,11 +17,8 @@ package les import ( - "bytes" "fmt" - "math" "math/rand" - "reflect" "testing" "time" @@ -56,29 +53,34 @@ func TestClientPoolL100C300P20(t *testing.T) { const testClientPoolTicks = 100000 -type poolTestPeer int - -func (i poolTestPeer) ID() enode.ID { - return enode.ID{byte(i % 256), byte(i >> 8)} +type poolTestPeer struct { + index int + disconnCh chan int + cap uint64 } -func (i poolTestPeer) freeClientId() string { - return fmt.Sprintf("addr #%d", i) +func newPoolTestPeer(i int, disconnCh chan int) *poolTestPeer { + return &poolTestPeer{index: i, disconnCh: disconnCh} } -func (i poolTestPeer) updateCapacity(uint64) {} - -type poolTestPeerWithCap struct { - poolTestPeer +func (i *poolTestPeer) ID() enode.ID { + return enode.ID{byte(i.index % 256), byte(i.index >> 8)} +} - cap uint64 +func (i *poolTestPeer) freeClientId() string { + return fmt.Sprintf("addr #%d", i) } -func (i *poolTestPeerWithCap) updateCapacity(cap uint64) { i.cap = cap } +func (i *poolTestPeer) updateCapacity(cap uint64) { + i.cap = cap + if cap == 0 && i.disconnCh != nil { + i.disconnCh <- i.index + } +} -func (i poolTestPeer) freezeClient() {} +func (i *poolTestPeer) freeze() {} -func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomDisconnect bool) { +func testClientPool(t *testing.T, activeLimit, clientCount, paidCount int, randomDisconnect bool) { rand.Seed(time.Now().UnixNano()) var ( clock mclock.Simulated @@ -89,15 +91,16 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD disconnFn = func(id enode.ID) { disconnCh <- int(id[0]) + int(id[1])<<8 } - pool = newClientPool(db, 1, &clock, disconnFn) + pool = newClientPool(db, 1, 1, &clock, disconnFn) ) + pool.disableBias = true - pool.setLimits(connLimit, uint64(connLimit)) + pool.setLimits(activeLimit, uint64(activeLimit)) pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) // pool should accept new peers up to its connected limit - for i := 0; i < connLimit; i++ { - if pool.connect(poolTestPeer(i), 0) { + for i := 0; i < activeLimit; i++ { + if cap, _ := pool.connect(newPoolTestPeer(i, disconnCh), 0); cap != 0 { connected[i] = true } else { t.Fatalf("Test peer #%d rejected", i) @@ -111,28 +114,30 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD // give a positive balance to some of the peers amount := testClientPoolTicks / 2 * int64(time.Second) // enough for half of the simulation period for i := 0; i < paidCount; i++ { - pool.addBalance(poolTestPeer(i).ID(), amount, "") + pool.addBalance(newPoolTestPeer(i, disconnCh).ID(), amount) } } i := rand.Intn(clientCount) if connected[i] { if randomDisconnect { - pool.disconnect(poolTestPeer(i)) + pool.disconnect(newPoolTestPeer(i, disconnCh)) connected[i] = false connTicks[i] += tickCounter } } else { - if pool.connect(poolTestPeer(i), 0) { + if cap, _ := pool.connect(newPoolTestPeer(i, disconnCh), 0); cap != 0 { connected[i] = true connTicks[i] -= tickCounter + } else { + pool.disconnect(newPoolTestPeer(i, disconnCh)) } } pollDisconnects: for { select { case i := <-disconnCh: - pool.disconnect(poolTestPeer(i)) + pool.disconnect(newPoolTestPeer(i, disconnCh)) if connected[i] { connTicks[i] += tickCounter connected[i] = false @@ -143,10 +148,10 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD } } - expTicks := testClientPoolTicks/2*connLimit/clientCount + testClientPoolTicks/2*(connLimit-paidCount)/(clientCount-paidCount) + expTicks := testClientPoolTicks/2*activeLimit/clientCount + testClientPoolTicks/2*(activeLimit-paidCount)/(clientCount-paidCount) expMin := expTicks - expTicks/5 expMax := expTicks + expTicks/5 - paidTicks := testClientPoolTicks/2*connLimit/clientCount + testClientPoolTicks/2 + paidTicks := testClientPoolTicks/2*activeLimit/clientCount + testClientPoolTicks/2 paidMin := paidTicks - paidTicks/5 paidMax := paidTicks + paidTicks/5 @@ -172,15 +177,15 @@ func TestConnectPaidClient(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, &clock, nil) + pool := newClientPool(db, 1, 1, &clock, nil) defer pool.stop() pool.setLimits(10, uint64(10)) pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) // Add balance for an external client and mark it as paid client - pool.addBalance(poolTestPeer(0).ID(), 1000, "") + pool.addBalance(newPoolTestPeer(0, nil).ID(), 1000) - if !pool.connect(poolTestPeer(0), 10) { + if cap, _ := pool.connect(newPoolTestPeer(0, nil), 10); cap == 0 { t.Fatalf("Failed to connect paid client") } } @@ -190,16 +195,16 @@ func TestConnectPaidClientToSmallPool(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, &clock, nil) + pool := newClientPool(db, 1, 1, &clock, nil) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) // Add balance for an external client and mark it as paid client - pool.addBalance(poolTestPeer(0).ID(), 1000, "") + pool.addBalance(newPoolTestPeer(0, nil).ID(), 1000) // Connect a fat paid client to pool, should reject it. - if pool.connect(poolTestPeer(0), 100) { + if cap, _ := pool.connect(newPoolTestPeer(0, nil), 100); cap != 0 { t.Fatalf("Connected fat paid client, should reject it") } } @@ -210,23 +215,23 @@ func TestConnectPaidClientToFullPool(t *testing.T) { db = rawdb.NewMemoryDatabase() ) removeFn := func(enode.ID) {} // Noop - pool := newClientPool(db, 1, &clock, removeFn) + pool := newClientPool(db, 1, 1, &clock, removeFn) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) for i := 0; i < 10; i++ { - pool.addBalance(poolTestPeer(i).ID(), 1000000000, "") - pool.connect(poolTestPeer(i), 1) + pool.addBalance(newPoolTestPeer(i, nil).ID(), int64(time.Second)) + pool.connect(newPoolTestPeer(i, nil), 1) } - pool.addBalance(poolTestPeer(11).ID(), 1000, "") // Add low balance to new paid client - if pool.connect(poolTestPeer(11), 1) { + pool.addBalance(newPoolTestPeer(11, nil).ID(), 1000) // Add low balance to new paid client + if cap, _ := pool.connect(newPoolTestPeer(11, nil), 1); cap != 0 { t.Fatalf("Low balance paid client should be rejected") } clock.Run(time.Second) - pool.addBalance(poolTestPeer(12).ID(), 1000000000*60*3, "") // Add high balance to new paid client - if !pool.connect(poolTestPeer(12), 1) { - t.Fatalf("High balance paid client should be accpected") + pool.addBalance(newPoolTestPeer(12, nil).ID(), int64(time.Minute*5)) // Add high balance to new paid client + if cap, _ := pool.connect(newPoolTestPeer(12, nil), 1); cap == 0 { + t.Fatalf("High balance paid client should be accepted") } } @@ -237,19 +242,19 @@ func TestPaidClientKickedOut(t *testing.T) { kickedCh = make(chan int, 1) ) removeFn := func(id enode.ID) { kickedCh <- int(id[0]) } - pool := newClientPool(db, 1, &clock, removeFn) + pool := newClientPool(db, 1, 1, &clock, removeFn) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) for i := 0; i < 10; i++ { - pool.addBalance(poolTestPeer(i).ID(), 1000000000, "") // 1 second allowance - pool.connect(poolTestPeer(i), 1) + pool.addBalance(newPoolTestPeer(i, kickedCh).ID(), 1000000000) // 1 second allowance + pool.connect(newPoolTestPeer(i, kickedCh), 1) clock.Run(time.Millisecond) } clock.Run(time.Second) - clock.Run(connectedBias) - if !pool.connect(poolTestPeer(11), 0) { + clock.Run(activeBias) + if cap, _ := pool.connect(newPoolTestPeer(11, kickedCh), 0); cap == 0 { t.Fatalf("Free client should be accectped") } select { @@ -267,11 +272,11 @@ func TestConnectFreeClient(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, &clock, nil) + pool := newClientPool(db, 1, 1, &clock, nil) defer pool.stop() pool.setLimits(10, uint64(10)) pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) - if !pool.connect(poolTestPeer(0), 10) { + if cap, _ := pool.connect(newPoolTestPeer(0, nil), 10); cap == 0 { t.Fatalf("Failed to connect free client") } } @@ -282,24 +287,24 @@ func TestConnectFreeClientToFullPool(t *testing.T) { db = rawdb.NewMemoryDatabase() ) removeFn := func(enode.ID) {} // Noop - pool := newClientPool(db, 1, &clock, removeFn) + pool := newClientPool(db, 1, 1, &clock, removeFn) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) for i := 0; i < 10; i++ { - pool.connect(poolTestPeer(i), 1) + pool.connect(newPoolTestPeer(i, nil), 1) } - if pool.connect(poolTestPeer(11), 1) { + if cap, _ := pool.connect(newPoolTestPeer(11, nil), 1); cap != 0 { t.Fatalf("New free client should be rejected") } clock.Run(time.Minute) - if pool.connect(poolTestPeer(12), 1) { + if cap, _ := pool.connect(newPoolTestPeer(12, nil), 1); cap != 0 { t.Fatalf("New free client should be rejected") } clock.Run(time.Millisecond) clock.Run(4 * time.Minute) - if !pool.connect(poolTestPeer(13), 1) { + if cap, _ := pool.connect(newPoolTestPeer(13, nil), 1); cap == 0 { t.Fatalf("Old client connects more than 5min should be kicked") } } @@ -311,21 +316,22 @@ func TestFreeClientKickedOut(t *testing.T) { kicked = make(chan int, 10) ) removeFn := func(id enode.ID) { kicked <- int(id[0]) } - pool := newClientPool(db, 1, &clock, removeFn) + pool := newClientPool(db, 1, 1, &clock, removeFn) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) for i := 0; i < 10; i++ { - pool.connect(poolTestPeer(i), 1) + pool.connect(newPoolTestPeer(i, kicked), 1) clock.Run(time.Millisecond) } - if pool.connect(poolTestPeer(10), 1) { + if cap, _ := pool.connect(newPoolTestPeer(10, kicked), 1); cap != 0 { t.Fatalf("New free client should be rejected") } + pool.disconnect(newPoolTestPeer(10, kicked)) clock.Run(5 * time.Minute) for i := 0; i < 10; i++ { - pool.connect(poolTestPeer(i+10), 1) + pool.connect(newPoolTestPeer(i+10, kicked), 1) } for i := 0; i < 10; i++ { select { @@ -346,18 +352,18 @@ func TestPositiveBalanceCalculation(t *testing.T) { kicked = make(chan int, 10) ) removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop - pool := newClientPool(db, 1, &clock, removeFn) + pool := newClientPool(db, 1, 1, &clock, removeFn) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) - pool.addBalance(poolTestPeer(0).ID(), int64(time.Minute*3), "") - pool.connect(poolTestPeer(0), 10) + pool.addBalance(newPoolTestPeer(0, kicked).ID(), int64(time.Minute*3)) + pool.connect(newPoolTestPeer(0, kicked), 10) clock.Run(time.Minute) - pool.disconnect(poolTestPeer(0)) - pb := pool.ndb.getOrNewPB(poolTestPeer(0).ID()) - if pb.value != uint64(time.Minute*2) { + pool.disconnect(newPoolTestPeer(0, kicked)) + pb := pool.ndb.getOrNewBalance(newPoolTestPeer(0, kicked).ID().Bytes(), false) + if pb.value != expval(uint64(time.Minute*2)) { t.Fatalf("Positive balance mismatch, want %v, got %v", uint64(time.Minute*2), pb.value) } } @@ -369,16 +375,14 @@ func TestDowngradePriorityClient(t *testing.T) { kicked = make(chan int, 10) ) removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop - pool := newClientPool(db, 1, &clock, removeFn) + pool := newClientPool(db, 1, 1, &clock, removeFn) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) - p := &poolTestPeerWithCap{ - poolTestPeer: poolTestPeer(0), - } - pool.addBalance(p.ID(), int64(time.Minute), "") - pool.connect(p, 10) + p := newPoolTestPeer(0, kicked) + pool.addBalance(p.ID(), int64(time.Minute)) + p.cap, _ = pool.connect(p, 10) if p.cap != 10 { t.Fatalf("The capcacity of priority peer hasn't been updated, got: %d", p.cap) } @@ -388,156 +392,130 @@ func TestDowngradePriorityClient(t *testing.T) { if p.cap != 1 { t.Fatalf("The capcacity of peer should be downgraded, got: %d", p.cap) } - pb := pool.ndb.getOrNewPB(poolTestPeer(0).ID()) - if pb.value != 0 { + pb := pool.ndb.getOrNewBalance(newPoolTestPeer(0, kicked).ID().Bytes(), false) + if pb.value.base != 0 { t.Fatalf("Positive balance mismatch, want %v, got %v", 0, pb.value) } - pool.addBalance(poolTestPeer(0).ID(), int64(time.Minute), "") - pb = pool.ndb.getOrNewPB(poolTestPeer(0).ID()) - if pb.value != uint64(time.Minute) { + pool.addBalance(newPoolTestPeer(0, kicked).ID(), int64(time.Minute)) + pb = pool.ndb.getOrNewBalance(newPoolTestPeer(0, kicked).ID().Bytes(), false) + if pb.value != expval(uint64(time.Minute)) { t.Fatalf("Positive balance mismatch, want %v, got %v", uint64(time.Minute), pb.value) } } func TestNegativeBalanceCalculation(t *testing.T) { var ( - clock mclock.Simulated - db = rawdb.NewMemoryDatabase() - kicked = make(chan int, 10) + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() ) - removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop - pool := newClientPool(db, 1, &clock, removeFn) + pool := newClientPool(db, 1, 1, &clock, nil) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) for i := 0; i < 10; i++ { - pool.connect(poolTestPeer(i), 1) + pool.connect(newPoolTestPeer(i, nil), 1) } clock.Run(time.Second) for i := 0; i < 10; i++ { - pool.disconnect(poolTestPeer(i)) - nb := pool.ndb.getOrNewNB(poolTestPeer(i).freeClientId()) - if nb.logValue != 0 { + pool.disconnect(newPoolTestPeer(i, nil)) + nb := pool.ndb.getOrNewBalance([]byte(newPoolTestPeer(i, nil).freeClientId()), true) + if nb.value.base != 0 { t.Fatalf("Short connection shouldn't be recorded") } } - for i := 0; i < 10; i++ { - pool.connect(poolTestPeer(i), 1) + pool.connect(newPoolTestPeer(i, nil), 1) } clock.Run(time.Minute) for i := 0; i < 10; i++ { - pool.disconnect(poolTestPeer(i)) - nb := pool.ndb.getOrNewNB(poolTestPeer(i).freeClientId()) - nb.logValue -= pool.logOffset(clock.Now()) - nb.logValue /= fixedPointMultiplier - if nb.logValue != int64(math.Log(float64(time.Minute/time.Second))) { - t.Fatalf("Negative balance mismatch, want %v, got %v", int64(math.Log(float64(time.Minute/time.Second))), nb.logValue) - } - } -} - -func TestNodeDB(t *testing.T) { - ndb := newNodeDB(rawdb.NewMemoryDatabase(), mclock.System{}) - defer ndb.close() - - if !bytes.Equal(ndb.verbuf[:], []byte{0x00, nodeDBVersion}) { - t.Fatalf("version buffer mismatch, want %v, got %v", []byte{0x00, nodeDBVersion}, ndb.verbuf) - } - var cases = []struct { - id enode.ID - ip string - balance interface{} - positive bool - }{ - {enode.ID{0x00, 0x01, 0x02}, "", posBalance{value: 100}, true}, - {enode.ID{0x00, 0x01, 0x02}, "", posBalance{value: 200}, true}, - {enode.ID{}, "127.0.0.1", negBalance{logValue: 10}, false}, - {enode.ID{}, "127.0.0.1", negBalance{logValue: 20}, false}, - } - for _, c := range cases { - if c.positive { - ndb.setPB(c.id, c.balance.(posBalance)) - if pb := ndb.getOrNewPB(c.id); !reflect.DeepEqual(pb, c.balance.(posBalance)) { - t.Fatalf("Positive balance mismatch, want %v, got %v", c.balance.(posBalance), pb) - } - } else { - ndb.setNB(c.ip, c.balance.(negBalance)) - if nb := ndb.getOrNewNB(c.ip); !reflect.DeepEqual(nb, c.balance.(negBalance)) { - t.Fatalf("Negative balance mismatch, want %v, got %v", c.balance.(negBalance), nb) - } + pool.disconnect(newPoolTestPeer(i, nil)) + nb := pool.ndb.getOrNewBalance([]byte(newPoolTestPeer(i, nil).freeClientId()), true) + value := nb.value.value(pool.negExpiration(clock.Now())) + if value != uint64(time.Minute) { + t.Fatalf("Negative balance mismatch, want %v, got %v", time.Minute, value) } } - for _, c := range cases { - if c.positive { - ndb.delPB(c.id) - if pb := ndb.getOrNewPB(c.id); !reflect.DeepEqual(pb, posBalance{}) { - t.Fatalf("Positive balance mismatch, want %v, got %v", posBalance{}, pb) - } - } else { - ndb.delNB(c.ip) - if nb := ndb.getOrNewNB(c.ip); !reflect.DeepEqual(nb, negBalance{}) { - t.Fatalf("Negative balance mismatch, want %v, got %v", negBalance{}, nb) - } - } - } - ndb.setCumulativeTime(100) - if ndb.getCumulativeTime() != 100 { - t.Fatalf("Cumulative time mismatch, want %v, got %v", 100, ndb.getCumulativeTime()) - } } -func TestNodeDBExpiration(t *testing.T) { +func TestInactiveClient(t *testing.T) { var ( - iterated int - done = make(chan struct{}, 1) + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() ) - callback := func(now mclock.AbsTime, b negBalance) bool { - iterated += 1 - return true - } - clock := &mclock.Simulated{} - ndb := newNodeDB(rawdb.NewMemoryDatabase(), clock) - defer ndb.close() - ndb.nbEvictCallBack = callback - ndb.cleanupHook = func() { done <- struct{}{} } - - var cases = []struct { - ip string - balance negBalance - }{ - {"127.0.0.1", negBalance{logValue: 1}}, - {"127.0.0.2", negBalance{logValue: 1}}, - {"127.0.0.3", negBalance{logValue: 1}}, - {"127.0.0.4", negBalance{logValue: 1}}, - } - for _, c := range cases { - ndb.setNB(c.ip, c.balance) - } - clock.WaitForTimers(1) - clock.Run(time.Hour + time.Minute) - select { - case <-done: - case <-time.NewTimer(time.Second).C: - t.Fatalf("timeout") - } - if iterated != 4 { - t.Fatalf("Failed to evict useless negative balances, want %v, got %d", 4, iterated) - } - clock.WaitForTimers(1) - for _, c := range cases { - ndb.setNB(c.ip, c.balance) - } - clock.Run(time.Hour + time.Minute) - select { - case <-done: - case <-time.NewTimer(time.Second).C: - t.Fatalf("timeout") - } - if iterated != 8 { - t.Fatalf("Failed to evict useless negative balances, want %v, got %d", 4, iterated) + pool := newClientPool(db, 1, 1, &clock, nil) + defer pool.stop() + pool.setLimits(2, uint64(2)) // Total capacity limit is 10 + + p1 := newPoolTestPeer(1, nil) + p2 := newPoolTestPeer(2, nil) + p3 := newPoolTestPeer(3, nil) + pool.addBalance(p1.ID(), 1000) + pool.addBalance(p3.ID(), 2000) + // p1: 1000 p2: 0 p3: 2000 + p1.cap, _ = pool.connect(p1, 1) + if p1.cap != 1 { + t.Fatalf("Failed to connect peer #1") + } + p2.cap, _ = pool.connect(p2, 1) + if p2.cap != 1 { + t.Fatalf("Failed to connect peer #2") + } + p3.cap, _ = pool.connect(p3, 1) + if p3.cap != 1 { + t.Fatalf("Failed to connect peer #3") + } + if p2.cap != 0 { + t.Fatalf("Failed to deactivate peer #2") + } + pool.addBalance(p2.ID(), 3000) + // p1: 1000 p2: 3000 p3: 2000 + if p2.cap != 1 { + t.Fatalf("Failed to activate peer #2") + } + if p1.cap != 0 { + t.Fatalf("Failed to deactivate peer #1") + } + pool.addBalance(p2.ID(), -2500) + // p1: 1000 p2: 500 p3: 2000 + if p1.cap != 1 { + t.Fatalf("Failed to activate peer #1") + } + if p2.cap != 0 { + t.Fatalf("Failed to deactivate peer #2") + } + pool.setDefaultFactors(priceFactors{1e-9, 0, 0}, priceFactors{1e-9, 0, 0}) + p4 := newPoolTestPeer(4, nil) + pool.addBalance(p4.ID(), 1500) + // p1: 1000 p2: 500 p3: 2000 p4: 1500 + p4.cap, _ = pool.connect(p4, 1) + if p4.cap != 1 { + t.Fatalf("Failed to activate peer #4") + } + if p1.cap != 0 { + t.Fatalf("Failed to deactivate peer #1") + } + clock.Run(time.Second * 600) + // manually trigger a check to avoid a long real-time wait + pool.lock.Lock() + pool.tryActivateClients() + pool.lock.Unlock() + // p1: 1000 p2: 500 p3: 2000 p4: 900 + if p1.cap != 1 { + t.Fatalf("Failed to activate peer #1") + } + if p4.cap != 0 { + t.Fatalf("Failed to deactivate peer #4") + } + pool.disconnect(p2) + pool.disconnect(p4) + pool.addBalance(p1.ID(), -1000) + if p1.cap != 1 { + t.Fatalf("Should not deactivate peer #1") + } + if p2.cap != 0 { + t.Fatalf("Should not activate peer #2") } } diff --git a/les/costtracker.go b/les/costtracker.go index 81da04566007..abf8618ffde9 100644 --- a/les/costtracker.go +++ b/les/costtracker.go @@ -159,15 +159,9 @@ func newCostTracker(db ethdb.Database, config *eth.Config) (*costTracker, uint64 } ct.gfLoop() costList := ct.makeCostList(ct.globalFactor() * 1.25) - for _, c := range costList { - amount := minBufferReqAmount[c.MsgCode] - cost := c.BaseCost + amount*c.ReqCost - if cost > ct.minBufLimit { - ct.minBufLimit = cost - } - } - ct.minBufLimit *= uint64(minBufferMultiplier) - return ct, (ct.minBufLimit-1)/bufLimitRatio + 1 + var minRecharge uint64 + ct.minBufLimit, minRecharge = costList.decode(ProtocolLengths[ServerProtocolVersions[len(ServerProtocolVersions)-1]]).reqParams() + return ct, minRecharge } // stop stops the cost tracker and saves the cost factor statistics to the database @@ -480,6 +474,22 @@ func (table requestCostTable) getMaxCost(code, amount uint64) uint64 { return costs.baseCost + amount*costs.reqCost } +func (table requestCostTable) reqParams() (minRecharge, minBufLimit uint64) { + for code, c := range table { + amount := minBufferReqAmount[code] + cost := c.baseCost + amount*c.reqCost + if cost > minBufLimit { + minBufLimit = cost + } + } + minBufLimit *= uint64(minBufferMultiplier) + if minBufLimit < 1 { + minBufLimit = 1 + } + minRecharge = (minBufLimit-1)/bufLimitRatio + 1 + return +} + // decode converts a cost list to a cost table func (list RequestCostList) decode(protocolLength uint64) requestCostTable { table := make(requestCostTable) diff --git a/les/expiredvalue.go b/les/expiredvalue.go new file mode 100644 index 000000000000..0bb251a1ac87 --- /dev/null +++ b/les/expiredvalue.go @@ -0,0 +1,129 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import "math" + +// expiredValue is a scalar value that is continuously expired (decreased +// exponentially) based on the provided logarithmic expiration offset value. +// +// The formula for value calculation is: base*2^(exp-logOffset). In order to +// simplify the calculation of expiredValue, its value is expressed in the form +// of an exponent with a base of 2. +// +// Also here is a trick to reduce a lot of calculations. In theory, when a value X +// decays over time and then a new value Y is added, the final result should be +// X*2^(exp-logOffset)+Y. However it's very hard to represent in memory. +// So the trick is using the idea of inflation instead of exponential decay. At this +// moment the temporary value becomes: X*2^exp+Y*2^logOffset_1, apply the exponential +// decay when we actually want to calculate the value. +// +// e.g. +// t0: V = 100 +// t1: add 30, inflationary value is: 100 + 30/0.3, 0.3 is the decay coefficient +// t2: get value, decay coefficient is 0.2 now, final result is: 200*0.2 = 40 +type expiredValue struct { + base, exp uint64 +} + +// value calculates the value at the given moment. +func (e expiredValue) value(logOffset fixed64) uint64 { + offset := uint64ToFixed64(e.exp) - logOffset + return uint64(float64(e.base) * offset.pow2()) +} + +// add adds a signed value at the given moment +func (e *expiredValue) add(amount int64, logOffset fixed64) int64 { + integer, frac := logOffset.toUint64(), logOffset.fraction() + factor := frac.pow2() + base := factor * float64(amount) + if integer < e.exp { + base /= math.Pow(2, float64(e.exp-integer)) + } + if integer > e.exp { + e.base >>= (integer - e.exp) + e.exp = integer + } + if base >= 0 || uint64(-base) <= e.base { + e.base += uint64(base) + return amount + } + net := int64(-float64(e.base) / factor) + e.base = 0 + return net +} + +// addExp adds another expiredValue +func (e *expiredValue) addExp(a expiredValue) { + if e.exp > a.exp { + a.base >>= (e.exp - a.exp) + } + if e.exp < a.exp { + e.base >>= (a.exp - e.exp) + e.exp = a.exp + } + e.base += a.base +} + +// subExp subtracts another expiredValue +func (e *expiredValue) subExp(a expiredValue) { + if e.exp > a.exp { + a.base >>= (e.exp - a.exp) + } + if e.exp < a.exp { + e.base >>= (a.exp - e.exp) + e.exp = a.exp + } + if e.base > a.base { + e.base -= a.base + } else { + e.base = 0 + } +} + +// fixedFactor is the fixed point multiplier factor used by fixed64. +const fixedFactor = 0x1000000 + +// fixed64 implements 64-bit fixed point arithmetic functions. +type fixed64 int64 + +// uint64ToFixed64 converts uint64 integer to fixed64 format. +func uint64ToFixed64(f uint64) fixed64 { + return fixed64(f * fixedFactor) +} + +// float64ToFixed64 converts float64 to fixed64 format. +func float64ToFixed64(f float64) fixed64 { + return fixed64(f * fixedFactor) +} + +// toUint64 converts fixed64 format to uint64. +func (f64 fixed64) toUint64() uint64 { + return uint64(f64) / fixedFactor +} + +// fraction returns the fractional part of a fixed64 value. +func (f64 fixed64) fraction() fixed64 { + return f64 % fixedFactor +} + +var fixedLogFactor = math.Log(2) / float64(fixedFactor) + +// pow2Fixed returns the base 2 power of the fixed point value. +func (f64 fixed64) pow2() float64 { + return math.Exp(float64(f64) * fixedLogFactor) +} diff --git a/les/expiredvalue_test.go b/les/expiredvalue_test.go new file mode 100644 index 000000000000..a59510a70291 --- /dev/null +++ b/les/expiredvalue_test.go @@ -0,0 +1,116 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import "testing" + +func TestValueExpiration(t *testing.T) { + var cases = []struct { + input expiredValue + timeOffset fixed64 + expect uint64 + }{ + {expiredValue{base: 128, exp: 0}, uint64ToFixed64(0), 128}, + {expiredValue{base: 128, exp: 0}, uint64ToFixed64(1), 64}, + {expiredValue{base: 128, exp: 0}, uint64ToFixed64(2), 32}, + {expiredValue{base: 128, exp: 2}, uint64ToFixed64(2), 128}, + {expiredValue{base: 128, exp: 2}, uint64ToFixed64(3), 64}, + } + for _, c := range cases { + if got := c.input.value(c.timeOffset); got != c.expect { + t.Fatalf("Value mismatch, want=%d, got=%d", c.expect, got) + } + } +} + +func TestValueAddition(t *testing.T) { + var cases = []struct { + input expiredValue + addend int64 + timeOffset fixed64 + expect uint64 + expectNet int64 + }{ + // Addition + {expiredValue{base: 128, exp: 0}, 128, uint64ToFixed64(0), 256, 128}, + {expiredValue{base: 128, exp: 2}, 128, uint64ToFixed64(0), 640, 128}, + + // Addition with offset + {expiredValue{base: 128, exp: 0}, 128, uint64ToFixed64(1), 192, 128}, + {expiredValue{base: 128, exp: 2}, 128, uint64ToFixed64(1), 384, 128}, + {expiredValue{base: 128, exp: 2}, 128, uint64ToFixed64(3), 192, 128}, + + // Subtraction + {expiredValue{base: 128, exp: 0}, -64, uint64ToFixed64(0), 64, -64}, + {expiredValue{base: 128, exp: 0}, -128, uint64ToFixed64(0), 0, -128}, + {expiredValue{base: 128, exp: 0}, -192, uint64ToFixed64(0), 0, -128}, + + // Subtraction with offset + {expiredValue{base: 128, exp: 0}, -64, uint64ToFixed64(1), 0, -64}, + {expiredValue{base: 128, exp: 0}, -128, uint64ToFixed64(1), 0, -64}, + {expiredValue{base: 128, exp: 2}, -128, uint64ToFixed64(1), 128, -128}, + {expiredValue{base: 128, exp: 2}, -128, uint64ToFixed64(2), 0, -128}, + } + for _, c := range cases { + if net := c.input.add(c.addend, c.timeOffset); net != c.expectNet { + t.Fatalf("Net amount mismatch, want=%d, got=%d", c.expectNet, net) + } + if got := c.input.value(c.timeOffset); got != c.expect { + t.Fatalf("Value mismatch, want=%d, got=%d", c.expect, got) + } + } +} + +func TestExpiredValueAddition(t *testing.T) { + var cases = []struct { + input expiredValue + another expiredValue + timeOffset fixed64 + expect uint64 + }{ + {expiredValue{base: 128, exp: 0}, expiredValue{base: 128, exp: 0}, uint64ToFixed64(0), 256}, + {expiredValue{base: 128, exp: 1}, expiredValue{base: 128, exp: 0}, uint64ToFixed64(0), 384}, + {expiredValue{base: 128, exp: 0}, expiredValue{base: 128, exp: 1}, uint64ToFixed64(0), 384}, + {expiredValue{base: 128, exp: 0}, expiredValue{base: 128, exp: 0}, uint64ToFixed64(1), 128}, + } + for _, c := range cases { + c.input.addExp(c.another) + if got := c.input.value(c.timeOffset); got != c.expect { + t.Fatalf("Value mismatch, want=%d, got=%d", c.expect, got) + } + } +} + +func TestExpiredValueSubtraction(t *testing.T) { + var cases = []struct { + input expiredValue + another expiredValue + timeOffset fixed64 + expect uint64 + }{ + {expiredValue{base: 128, exp: 0}, expiredValue{base: 128, exp: 0}, uint64ToFixed64(0), 0}, + {expiredValue{base: 128, exp: 0}, expiredValue{base: 128, exp: 1}, uint64ToFixed64(0), 0}, + {expiredValue{base: 128, exp: 1}, expiredValue{base: 128, exp: 0}, uint64ToFixed64(0), 128}, + {expiredValue{base: 128, exp: 1}, expiredValue{base: 128, exp: 0}, uint64ToFixed64(1), 64}, + } + for _, c := range cases { + c.input.subExp(c.another) + if got := c.input.value(c.timeOffset); got != c.expect { + t.Fatalf("Value mismatch, want=%d, got=%d", c.expect, got) + } + } +} diff --git a/les/flowcontrol/control.go b/les/flowcontrol/control.go index 490013677c63..1c40882902fc 100644 --- a/les/flowcontrol/control.go +++ b/les/flowcontrol/control.go @@ -185,6 +185,14 @@ func (node *ClientNode) UpdateParams(params ServerParams) { } } +// Params returns the current server parameters +func (node *ClientNode) Params() ServerParams { + node.lock.Lock() + defer node.lock.Unlock() + + return node.params +} + // updateParams updates the flow control parameters of the node func (node *ClientNode) updateParams(params ServerParams, now mclock.AbsTime) { diff := int64(params.BufLimit - node.params.BufLimit) diff --git a/les/handler_test.go b/les/handler_test.go index 1612caf42769..d3e4df1b77bd 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -38,20 +38,29 @@ import ( "github.com/ethereum/go-ethereum/trie" ) -func expectResponse(r p2p.MsgReader, msgcode, reqID, bv uint64, data interface{}) error { +func expectResponse(r p2p.MsgReader, protocol int, msgcode, reqID, bv, cost uint64, data interface{}) error { type resp struct { - ReqID, BV uint64 - Data interface{} + ReqID uint64 + SF stateFeedback + Data interface{} } - return p2p.ExpectMsg(r, msgcode, resp{reqID, bv, data}) + sf := stateFeedback{ + protocolVersion: protocol, + stateFeedbackV4: stateFeedbackV4{ + BV: bv, + RealCost: cost, + TokenBalance: 0, + }, + } + return p2p.ExpectMsg(r, msgcode, resp{reqID, sf, data}) } // Tests that block headers can be retrieved from a remote chain based on user queries. -func TestGetBlockHeadersLes2(t *testing.T) { testGetBlockHeaders(t, 2) } func TestGetBlockHeadersLes3(t *testing.T) { testGetBlockHeaders(t, 3) } +func TestGetBlockHeadersLes4(t *testing.T) { testGetBlockHeaders(t, 4) } func testGetBlockHeaders(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -168,19 +177,20 @@ func testGetBlockHeaders(t *testing.T, protocol int) { // Send the hash request and verify the response reqID++ + cost := server.peer.speer.getRequestCost(GetBlockHeadersMsg, int(tt.query.Amount)) sendRequest(server.peer.app, GetBlockHeadersMsg, reqID, tt.query) - if err := expectResponse(server.peer.app, BlockHeadersMsg, reqID, testBufLimit, headers); err != nil { + if err := expectResponse(server.peer.app, protocol, BlockHeadersMsg, reqID, testBufLimit, cost, headers); err != nil { t.Errorf("test %d: headers mismatch: %v", i, err) } } } // Tests that block contents can be retrieved from a remote chain based on their hashes. -func TestGetBlockBodiesLes2(t *testing.T) { testGetBlockBodies(t, 2) } func TestGetBlockBodiesLes3(t *testing.T) { testGetBlockBodies(t, 3) } +func TestGetBlockBodiesLes4(t *testing.T) { testGetBlockBodies(t, 4) } func testGetBlockBodies(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, downloader.MaxBlockFetch+15, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, downloader.MaxBlockFetch+15, protocol, nil, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -245,20 +255,21 @@ func testGetBlockBodies(t *testing.T, protocol int) { reqID++ // Send the hash request and verify the response + cost := server.peer.speer.getRequestCost(GetBlockBodiesMsg, len(hashes)) sendRequest(server.peer.app, GetBlockBodiesMsg, reqID, hashes) - if err := expectResponse(server.peer.app, BlockBodiesMsg, reqID, testBufLimit, bodies); err != nil { + if err := expectResponse(server.peer.app, protocol, BlockBodiesMsg, reqID, testBufLimit, cost, bodies); err != nil { t.Errorf("test %d: bodies mismatch: %v", i, err) } } } // Tests that the contract codes can be retrieved based on account addresses. -func TestGetCodeLes2(t *testing.T) { testGetCode(t, 2) } func TestGetCodeLes3(t *testing.T) { testGetCode(t, 3) } +func TestGetCodeLes4(t *testing.T) { testGetCode(t, 4) } func testGetCode(t *testing.T, protocol int) { // Assemble the test environment - server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -276,18 +287,19 @@ func testGetCode(t *testing.T, protocol int) { } } + cost := server.peer.speer.getRequestCost(GetCodeMsg, len(codereqs)) sendRequest(server.peer.app, GetCodeMsg, 42, codereqs) - if err := expectResponse(server.peer.app, CodeMsg, 42, testBufLimit, codes); err != nil { + if err := expectResponse(server.peer.app, protocol, CodeMsg, 42, testBufLimit, cost, codes); err != nil { t.Errorf("codes mismatch: %v", err) } } // Tests that the stale contract codes can't be retrieved based on account addresses. -func TestGetStaleCodeLes2(t *testing.T) { testGetStaleCode(t, 2) } func TestGetStaleCodeLes3(t *testing.T) { testGetStaleCode(t, 3) } +func TestGetStaleCodeLes4(t *testing.T) { testGetStaleCode(t, 4) } func testGetStaleCode(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -296,8 +308,9 @@ func testGetStaleCode(t *testing.T, protocol int) { BHash: bc.GetHeaderByNumber(number).Hash(), AccKey: crypto.Keccak256(testContractAddr[:]), } + cost := server.peer.speer.getRequestCost(GetCodeMsg, 1) sendRequest(server.peer.app, GetCodeMsg, 42, []*CodeReq{req}) - if err := expectResponse(server.peer.app, CodeMsg, 42, testBufLimit, expected); err != nil { + if err := expectResponse(server.peer.app, protocol, CodeMsg, 42, testBufLimit, cost, expected); err != nil { t.Errorf("codes mismatch: %v", err) } } @@ -307,12 +320,12 @@ func testGetStaleCode(t *testing.T, protocol int) { } // Tests that the transaction receipts can be retrieved based on hashes. -func TestGetReceiptLes2(t *testing.T) { testGetReceipt(t, 2) } func TestGetReceiptLes3(t *testing.T) { testGetReceipt(t, 3) } +func TestGetReceiptLes4(t *testing.T) { testGetReceipt(t, 4) } func testGetReceipt(t *testing.T, protocol int) { // Assemble the test environment - server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -327,19 +340,20 @@ func testGetReceipt(t *testing.T, protocol int) { receipts = append(receipts, rawdb.ReadRawReceipts(server.db, block.Hash(), block.NumberU64())) } // Send the hash request and verify the response + cost := server.peer.speer.getRequestCost(GetReceiptsMsg, len(hashes)) sendRequest(server.peer.app, GetReceiptsMsg, 42, hashes) - if err := expectResponse(server.peer.app, ReceiptsMsg, 42, testBufLimit, receipts); err != nil { + if err := expectResponse(server.peer.app, protocol, ReceiptsMsg, 42, testBufLimit, cost, receipts); err != nil { t.Errorf("receipts mismatch: %v", err) } } // Tests that trie merkle proofs can be retrieved -func TestGetProofsLes2(t *testing.T) { testGetProofs(t, 2) } func TestGetProofsLes3(t *testing.T) { testGetProofs(t, 3) } +func TestGetProofsLes4(t *testing.T) { testGetProofs(t, 4) } func testGetProofs(t *testing.T, protocol int) { // Assemble the test environment - server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -362,18 +376,19 @@ func testGetProofs(t *testing.T, protocol int) { } } // Send the proof request and verify the response + cost := server.peer.speer.getRequestCost(GetProofsV2Msg, len(proofreqs)) sendRequest(server.peer.app, GetProofsV2Msg, 42, proofreqs) - if err := expectResponse(server.peer.app, ProofsV2Msg, 42, testBufLimit, proofsV2.NodeList()); err != nil { + if err := expectResponse(server.peer.app, protocol, ProofsV2Msg, 42, testBufLimit, cost, proofsV2.NodeList()); err != nil { t.Errorf("proofs mismatch: %v", err) } } // Tests that the stale contract codes can't be retrieved based on account addresses. -func TestGetStaleProofLes2(t *testing.T) { testGetStaleProof(t, 2) } func TestGetStaleProofLes3(t *testing.T) { testGetStaleProof(t, 3) } +func TestGetStaleProofLes4(t *testing.T) { testGetStaleProof(t, 4) } func testGetStaleProof(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -395,7 +410,8 @@ func testGetStaleProof(t *testing.T, protocol int) { t.Prove(account, 0, proofsV2) expected = proofsV2.NodeList() } - if err := expectResponse(server.peer.app, ProofsV2Msg, 42, testBufLimit, expected); err != nil { + cost := server.peer.speer.getRequestCost(GetProofsV2Msg, 1) + if err := expectResponse(server.peer.app, protocol, ProofsV2Msg, 42, testBufLimit, cost, expected); err != nil { t.Errorf("codes mismatch: %v", err) } } @@ -405,8 +421,8 @@ func testGetStaleProof(t *testing.T, protocol int) { } // Tests that CHT proofs can be correctly retrieved. -func TestGetCHTProofsLes2(t *testing.T) { testGetCHTProofs(t, 2) } func TestGetCHTProofsLes3(t *testing.T) { testGetCHTProofs(t, 3) } +func TestGetCHTProofsLes4(t *testing.T) { testGetCHTProofs(t, 4) } func testGetCHTProofs(t *testing.T, protocol int) { config := light.TestServerIndexerConfig @@ -420,7 +436,7 @@ func testGetCHTProofs(t *testing.T, protocol int) { time.Sleep(10 * time.Millisecond) } } - server, tearDown := newServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, false, true, 0) + server, tearDown := newServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -446,14 +462,15 @@ func testGetCHTProofs(t *testing.T, protocol int) { AuxReq: auxHeader, }} // Send the proof request and verify the response + cost := server.peer.speer.getRequestCost(GetHelperTrieProofsMsg, len(requestsV2)) sendRequest(server.peer.app, GetHelperTrieProofsMsg, 42, requestsV2) - if err := expectResponse(server.peer.app, HelperTrieProofsMsg, 42, testBufLimit, proofsV2); err != nil { + if err := expectResponse(server.peer.app, protocol, HelperTrieProofsMsg, 42, testBufLimit, cost, proofsV2); err != nil { t.Errorf("proofs mismatch: %v", err) } } -func TestGetBloombitsProofsLes2(t *testing.T) { testGetBloombitsProofs(t, 2) } func TestGetBloombitsProofsLes3(t *testing.T) { testGetBloombitsProofs(t, 3) } +func TestGetBloombitsProofsLes4(t *testing.T) { testGetBloombitsProofs(t, 4) } // Tests that bloombits proofs can be correctly retrieved. func testGetBloombitsProofs(t *testing.T, protocol int) { @@ -468,7 +485,7 @@ func testGetBloombitsProofs(t *testing.T, protocol int) { time.Sleep(10 * time.Millisecond) } } - server, tearDown := newServerEnv(t, int(config.BloomTrieSize+config.BloomTrieConfirms), protocol, waitIndexers, false, true, 0) + server, tearDown := newServerEnv(t, int(config.BloomTrieSize+config.BloomTrieConfirms), protocol, waitIndexers, false, true, 0, true) defer tearDown() bc := server.handler.blockchain @@ -494,18 +511,19 @@ func testGetBloombitsProofs(t *testing.T, protocol int) { trie.Prove(key, 0, &proofs.Proofs) // Send the proof request and verify the response + cost := server.peer.speer.getRequestCost(GetHelperTrieProofsMsg, len(requests)) sendRequest(server.peer.app, GetHelperTrieProofsMsg, 42, requests) - if err := expectResponse(server.peer.app, HelperTrieProofsMsg, 42, testBufLimit, proofs); err != nil { + if err := expectResponse(server.peer.app, protocol, HelperTrieProofsMsg, 42, testBufLimit, cost, proofs); err != nil { t.Errorf("bit %d: proofs mismatch: %v", bit, err) } } } -func TestTransactionStatusLes2(t *testing.T) { testTransactionStatus(t, 2) } func TestTransactionStatusLes3(t *testing.T) { testTransactionStatus(t, 3) } +func TestTransactionStatusLes4(t *testing.T) { testTransactionStatus(t, 4) } func testTransactionStatus(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, 0, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, 0, protocol, nil, false, true, 0, true) defer tearDown() server.handler.addTxsSync = true @@ -515,12 +533,15 @@ func testTransactionStatus(t *testing.T, protocol int) { test := func(tx *types.Transaction, send bool, expStatus light.TxStatus) { reqID++ + var cost uint64 if send { + cost = server.peer.speer.getRequestCost(SendTxV2Msg, 1) sendRequest(server.peer.app, SendTxV2Msg, reqID, types.Transactions{tx}) } else { + cost = server.peer.speer.getRequestCost(GetTxStatusMsg, 1) sendRequest(server.peer.app, GetTxStatusMsg, reqID, []common.Hash{tx.Hash()}) } - if err := expectResponse(server.peer.app, TxStatusMsg, reqID, testBufLimit, []light.TxStatus{expStatus}); err != nil { + if err := expectResponse(server.peer.app, protocol, TxStatusMsg, reqID, testBufLimit, cost, []light.TxStatus{expStatus}); err != nil { t.Errorf("transaction status mismatch") } } @@ -595,8 +616,11 @@ func testTransactionStatus(t *testing.T, protocol int) { test(tx2, false, light.TxStatus{Status: core.TxStatusPending}) } -func TestStopResumeLes3(t *testing.T) { - server, tearDown := newServerEnv(t, 0, 3, nil, true, true, testBufLimit/10) +func TestStopResumeLes3(t *testing.T) { testStopResume(t, 3) } +func TestStopResumeLes4(t *testing.T) { testStopResume(t, 4) } + +func testStopResume(t *testing.T, protocol int) { + server, tearDown := newServerEnv(t, 0, protocol, nil, true, true, testBufLimit/10, true) defer tearDown() server.handler.server.costTracker.testing = true @@ -616,7 +640,7 @@ func TestStopResumeLes3(t *testing.T) { for expBuf >= testCost { req() expBuf -= testCost - if err := expectResponse(server.peer.app, BlockHeadersMsg, reqID, expBuf, []*types.Header{header}); err != nil { + if err := expectResponse(server.peer.app, protocol, BlockHeadersMsg, reqID, expBuf, testCost, []*types.Header{header}); err != nil { t.Errorf("expected response and failed: %v", err) } } @@ -635,7 +659,15 @@ func TestStopResumeLes3(t *testing.T) { // expect a ResumeMsg with the partially recharged buffer value expBuf += testBufRecharge * wait - if err := p2p.ExpectMsg(server.peer.app, ResumeMsg, expBuf); err != nil { + sf := stateFeedback{ + protocolVersion: protocol, + stateFeedbackV4: stateFeedbackV4{ + BV: expBuf, + RealCost: 0, + TokenBalance: 0, + }, + } + if err := p2p.ExpectMsg(server.peer.app, ResumeMsg, sf); err != nil { t.Errorf("expected ResumeMsg and failed: %v", err) } } diff --git a/les/metrics.go b/les/metrics.go index 9ef8c365180c..12780346b65c 100644 --- a/les/metrics.go +++ b/les/metrics.go @@ -40,6 +40,8 @@ var ( miscInTxsTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/txs", nil) miscInTxStatusPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/txStatus", nil) miscInTxStatusTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/txStatus", nil) + miscInLespayPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/lespay", nil) + miscInLespayTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/lespay", nil) miscOutPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/total", nil) miscOutTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/total", nil) @@ -59,6 +61,8 @@ var ( miscOutTxsTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/txs", nil) miscOutTxStatusPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/txStatus", nil) miscOutTxStatusTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/txStatus", nil) + miscOutLespayPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/lespay", nil) + miscOutLespayTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/lespay", nil) miscServingTimeHeaderTimer = metrics.NewRegisteredTimer("les/misc/serve/header", nil) miscServingTimeBodyTimer = metrics.NewRegisteredTimer("les/misc/serve/body", nil) @@ -68,6 +72,7 @@ var ( miscServingTimeHelperTrieTimer = metrics.NewRegisteredTimer("les/misc/serve/helperTrie", nil) miscServingTimeTxTimer = metrics.NewRegisteredTimer("les/misc/serve/txs", nil) miscServingTimeTxStatusTimer = metrics.NewRegisteredTimer("les/misc/serve/txStatus", nil) + miscServingTimeLespayTimer = metrics.NewRegisteredTimer("les/misc/serve/lespay", nil) connectionTimer = metrics.NewRegisteredTimer("les/connection/duration", nil) serverConnectionGauge = metrics.NewRegisteredGauge("les/connection/server", nil) diff --git a/les/odr_test.go b/les/odr_test.go index bbe439dfec82..a56de3176961 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -38,8 +38,8 @@ import ( type odrTestFn func(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte -func TestOdrGetBlockLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetBlock) } func TestOdrGetBlockLes3(t *testing.T) { testOdr(t, 3, 1, true, odrGetBlock) } +func TestOdrGetBlockLes4(t *testing.T) { testOdr(t, 4, 1, true, odrGetBlock) } func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var block *types.Block @@ -55,8 +55,8 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon return rlp } -func TestOdrGetReceiptsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetReceipts) } func TestOdrGetReceiptsLes3(t *testing.T) { testOdr(t, 3, 1, true, odrGetReceipts) } +func TestOdrGetReceiptsLes4(t *testing.T) { testOdr(t, 4, 1, true, odrGetReceipts) } func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var receipts types.Receipts @@ -76,8 +76,8 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.Chain return rlp } -func TestOdrAccountsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrAccounts) } func TestOdrAccountsLes3(t *testing.T) { testOdr(t, 3, 1, true, odrAccounts) } +func TestOdrAccountsLes4(t *testing.T) { testOdr(t, 4, 1, true, odrAccounts) } func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") @@ -105,8 +105,8 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon return res } -func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, true, odrContractCall) } func TestOdrContractCallLes3(t *testing.T) { testOdr(t, 3, 2, true, odrContractCall) } +func TestOdrContractCallLes4(t *testing.T) { testOdr(t, 4, 2, true, odrContractCall) } type callmsg struct { types.Message @@ -155,8 +155,8 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai return res } -func TestOdrTxStatusLes2(t *testing.T) { testOdr(t, 2, 1, false, odrTxStatus) } func TestOdrTxStatusLes3(t *testing.T) { testOdr(t, 3, 1, false, odrTxStatus) } +func TestOdrTxStatusLes4(t *testing.T) { testOdr(t, 4, 1, false, odrTxStatus) } func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var txs types.Transactions @@ -236,7 +236,7 @@ func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn od // still expect all retrievals to pass, now data should be cached locally if checkCached { - client.handler.backend.peers.unregister(client.peer.speer.id) + client.handler.backend.peers.disconnect(client.peer.speer.id) time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed test(5) } diff --git a/les/peer.go b/les/peer.go index 28ec201bc9a7..40e9d25ea473 100644 --- a/les/peer.go +++ b/les/peer.go @@ -132,6 +132,7 @@ type peerCommons struct { frozen uint32 // Flag whether the peer is frozen. announceType uint64 // New block announcement type. headInfo blockInfo // Latest block information. + active bool // Background task queue for caching peer tasks and executing in order. sendQueue *execQueue @@ -478,12 +479,18 @@ func (p *serverPeer) requestTxStatus(reqID uint64, txHashes []common.Hash) error return sendRequest(p.rw, GetTxStatusMsg, reqID, txHashes) } -// SendTxStatus creates a reply with a batch of transactions to be added to the remote transaction pool. +// sendTxs creates a reply with a batch of transactions to be added to the remote transaction pool. func (p *serverPeer) sendTxs(reqID uint64, txs rlp.RawValue) error { p.Log().Debug("Sending batch of transactions", "size", len(txs)) return sendRequest(p.rw, SendTxV2Msg, reqID, txs) } +// sendLespay sends a set of commands to the service token sale module +func (p *serverPeer) sendLespay(reqID uint64, cmd []byte) error { + p.Log().Debug("Sending batch of lespay commands", "size", len(cmd)) + return sendRequest(p.rw, LespayMsg, reqID, cmd) +} + // waitBefore implements distPeer interface func (p *serverPeer) waitBefore(maxCost uint64) (time.Duration, float64) { return p.fcServer.CanSend(maxCost) @@ -554,10 +561,12 @@ func (p *serverPeer) updateFlowControl(update keyValueMap) { // If any of the flow control params is nil, refuse to update. var params flowcontrol.ServerParams + updated := false if update.get("flowControl/BL", ¶ms.BufLimit) == nil && update.get("flowControl/MRR", ¶ms.MinRecharge) == nil { // todo can light client set a minimal acceptable flow control params? p.fcParams = params p.fcServer.UpdateParams(params) + updated = true } var MRC RequestCostList if update.get("flowControl/MRC", &MRC) == nil { @@ -565,9 +574,20 @@ func (p *serverPeer) updateFlowControl(update keyValueMap) { for code, cost := range costUpdate { p.fcCosts[code] = cost } + updated = true + } + if updated { + p.active = p.paramsUseful() } } +// paramsUseful returns true if the server parameters ensure the minimum required +// buffer limit and recharge +func (p *serverPeer) paramsUseful() bool { + reqRecharge, reqBufLimit := p.fcCosts.reqParams() + return p.fcParams.MinRecharge >= reqRecharge && p.fcParams.BufLimit >= reqBufLimit +} + // Handshake executes the les protocol handshake, negotiating version number, // network IDs, difficulties, head and genesis blocks. func (p *serverPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, server *LesServer) error { @@ -614,6 +634,7 @@ func (p *serverPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge p.fcParams = sParams p.fcServer = flowcontrol.NewServerNode(sParams, &mclock.System{}) p.fcCosts = MRC.decode(ProtocolLengths[uint(p.version)]) + p.active = p.paramsUseful() recv.get("checkpoint/value", &p.checkpoint) recv.get("checkpoint/registerHeight", &p.checkpointNumber) @@ -634,6 +655,9 @@ func (p *serverPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge type clientPeer struct { peerCommons + activate, deactivate func() + getBalance func() uint64 + // responseLock ensures that responses are queued in the same order as // RequestProcessed is called responseLock sync.Mutex @@ -681,8 +705,8 @@ func (p *clientPeer) sendStop() error { } // sendResume notifies the client about getting out of frozen state -func (p *clientPeer) sendResume(bv uint64) error { - return p2p.Send(p.rw, ResumeMsg, bv) +func (p *clientPeer) sendResume(sf stateFeedback) error { + return p2p.Send(p.rw, ResumeMsg, sf) } // freeze temporarily puts the client in a frozen state which means all unprocessed @@ -711,7 +735,19 @@ func (p *clientPeer) freeze() { continue } atomic.StoreUint32(&p.frozen, 0) - p.sendResume(bufValue) + var balance uint64 + if p.getBalance != nil { + balance = p.getBalance() + } + sf := stateFeedback{ + protocolVersion: p.version, + stateFeedbackV4: stateFeedbackV4{ + BV: bufValue, + RealCost: 0, + TokenBalance: balance, + }, + } + p.sendResume(sf) return } }() @@ -728,12 +764,13 @@ type reply struct { } // send sends the reply with the calculated buffer value -func (r *reply) send(bv uint64) error { +func (r *reply) send(sf stateFeedback) error { type resp struct { - ReqID, BV uint64 - Data rlp.RawValue + ReqID uint64 + SF stateFeedback + Data rlp.RawValue } - return p2p.Send(r.w, r.msgcode, resp{r.reqID, bv, r.data}) + return p2p.Send(r.w, r.msgcode, resp{r.reqID, sf, r.data}) } // size returns the RLP encoded size of the message data @@ -786,6 +823,12 @@ func (p *clientPeer) replyTxStatus(reqID uint64, stats []light.TxStatus) *reply return &reply{p.rw, TxStatusMsg, reqID, data} } +// replyLespay sends a set of replies to lespay commands +func (p *clientPeer) replyLespay(reqID uint64, reply []byte, delay uint) error { + p.Log().Debug("Sending batch of lespay replies", "size", len(reply)) + return sendRequest(p.rw, LespayReplyMsg, reqID, lespayReply{reply, delay}) +} + // sendAnnounce announces the availability of a number of blocks through // a hash notification. func (p *clientPeer) sendAnnounce(request announceData) error { @@ -798,46 +841,21 @@ func (p *clientPeer) updateCapacity(cap uint64) { p.lock.Lock() defer p.lock.Unlock() - p.fcParams = flowcontrol.ServerParams{MinRecharge: cap, BufLimit: cap * bufLimitRatio} - p.fcClient.UpdateParams(p.fcParams) - var kvList keyValueList - kvList = kvList.add("flowControl/MRR", cap) - kvList = kvList.add("flowControl/BL", cap*bufLimitRatio) - p.mustQueueSend(func() { p.sendAnnounce(announceData{Update: kvList}) }) -} - -// freezeClient temporarily puts the client in a frozen state which means all -// unprocessed and subsequent requests are dropped. Unfreezing happens automatically -// after a short time if the client's buffer value is at least in the slightly positive -// region. The client is also notified about being frozen/unfrozen with a Stop/Resume -// message. -func (p *clientPeer) freezeClient() { - if p.version < lpv3 { - // if Stop/Resume is not supported then just drop the peer after setting - // its frozen status permanently - atomic.StoreUint32(&p.frozen, 1) - p.Peer.Disconnect(p2p.DiscUselessPeer) - return + if !p.active && cap != 0 && p.activate != nil { + p.activate() } - if atomic.SwapUint32(&p.frozen, 1) == 0 { - go func() { - p.sendStop() - time.Sleep(freezeTimeBase + time.Duration(rand.Int63n(int64(freezeTimeRandom)))) - for { - bufValue, bufLimit := p.fcClient.BufferStatus() - if bufLimit == 0 { - return - } - if bufValue <= bufLimit/8 { - time.Sleep(freezeCheckPeriod) - } else { - atomic.StoreUint32(&p.frozen, 0) - p.sendResume(bufValue) - break - } - } - }() + if cap != 0 || p.version >= lpv4 { + p.fcParams = flowcontrol.ServerParams{MinRecharge: cap, BufLimit: cap * bufLimitRatio} + p.fcClient.UpdateParams(p.fcParams) + var kvList keyValueList + kvList = kvList.add("flowControl/BL", cap*bufLimitRatio) + kvList = kvList.add("flowControl/MRR", cap) + p.mustQueueSend(func() { p.sendAnnounce(announceData{Update: kvList}) }) } + if p.active && cap == 0 && p.deactivate != nil { + p.deactivate() + } + } // Handshake executes the les protocol handshake, negotiating version number, @@ -859,8 +877,14 @@ func (p *clientPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge *lists = (*lists).add("serveRecentState", stateRecent) *lists = (*lists).add("txRelay", nil) } - *lists = (*lists).add("flowControl/BL", server.defParams.BufLimit) - *lists = (*lists).add("flowControl/MRR", server.defParams.MinRecharge) + p.active = p.version < lpv4 + if p.active { + p.fcParams = server.defParams + } else { + p.fcParams = flowcontrol.ServerParams{} + } + *lists = (*lists).add("flowControl/BL", p.fcParams.BufLimit) + *lists = (*lists).add("flowControl/MRR", p.fcParams.MinRecharge) var costList RequestCostList if server.costTracker.testCostList != nil { @@ -870,7 +894,6 @@ func (p *clientPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge } *lists = (*lists).add("flowControl/MRC", costList) p.fcCosts = costList.decode(ProtocolLengths[uint(p.version)]) - p.fcParams = server.defParams // Add advertised checkpoint and register block height which // client can verify the checkpoint validity. @@ -890,7 +913,7 @@ func (p *clientPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge // set default announceType on server side p.announceType = announceTypeSimple } - p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams) + p.fcClient = flowcontrol.NewClientNode(server.fcManager, p.fcParams) } return nil }) @@ -913,7 +936,7 @@ type clientPeerSubscriber interface { // clientPeerSet represents the set of active client peers currently // participating in the Light Ethereum sub-protocol. type clientPeerSet struct { - peers map[string]*clientPeer + active, inactive map[string]*clientPeer // subscribers is a batch of subscribers and peerset will notify // these subscribers when the peerset changes(new client peer is // added or removed) @@ -924,17 +947,23 @@ type clientPeerSet struct { // newClientPeerSet creates a new peer set to track the client peers. func newClientPeerSet() *clientPeerSet { - return &clientPeerSet{peers: make(map[string]*clientPeer)} + return &clientPeerSet{ + active: make(map[string]*clientPeer), + inactive: make(map[string]*clientPeer), + } } // subscribe adds a service to be notified about added or removed // peers and also register all active peers into the given service. func (ps *clientPeerSet) subscribe(sub clientPeerSubscriber) { ps.lock.Lock() - defer ps.lock.Unlock() - ps.subscribers = append(ps.subscribers, sub) - for _, p := range ps.peers { + notify := make([]*clientPeer, 0, len(ps.active)) + for _, p := range ps.active { + notify = append(notify, p) + } + ps.lock.Unlock() + for _, p := range notify { sub.registerPeer(p) } } @@ -954,19 +983,25 @@ func (ps *clientPeerSet) unSubscribe(sub clientPeerSubscriber) { // register adds a new peer into the peer set, or returns an error if the // peer is already known. -func (ps *clientPeerSet) register(peer *clientPeer) error { +func (ps *clientPeerSet) register(p *clientPeer) error { ps.lock.Lock() - defer ps.lock.Unlock() - if ps.closed { + ps.lock.Unlock() return errClosed } - if _, exist := ps.peers[peer.id]; exist { + if _, ok := ps.active[p.id]; ok { + ps.lock.Unlock() return errAlreadyRegistered } - ps.peers[peer.id] = peer - for _, sub := range ps.subscribers { - sub.registerPeer(peer) + delete(ps.inactive, p.id) + ps.active[p.id] = p + + peers := make([]clientPeerSubscriber, len(ps.subscribers)) + copy(peers, ps.subscribers) + ps.lock.Unlock() + + for _, n := range peers { + n.registerPeer(p) } return nil } @@ -974,29 +1009,60 @@ func (ps *clientPeerSet) register(peer *clientPeer) error { // unregister removes a remote peer from the peer set, disabling any further // actions to/from that particular entity. It also initiates disconnection // at the networking layer. -func (ps *clientPeerSet) unregister(id string) error { +func (ps *clientPeerSet) unregister(p *clientPeer) error { ps.lock.Lock() - defer ps.lock.Unlock() + if _, ok := ps.active[p.id]; !ok { + ps.lock.Unlock() + return errNotRegistered + } else { + delete(ps.active, p.id) + ps.inactive[p.id] = p + peers := make([]clientPeerSubscriber, len(ps.subscribers)) + copy(peers, ps.subscribers) + ps.lock.Unlock() + + for _, n := range peers { + n.unregisterPeer(p) + } + return nil + } +} - p, ok := ps.peers[id] - if !ok { +// disconnect removes a remote peer from either the active or inactive set and +// initiates disconnection at the networking layer. +func (ps *clientPeerSet) disconnect(id string) error { + ps.lock.Lock() + + var ( + peers []clientPeerSubscriber + p *clientPeer + ok bool + ) + if p, ok = ps.active[id]; ok { + delete(ps.active, p.id) + peers = make([]clientPeerSubscriber, len(ps.subscribers)) + copy(peers, ps.subscribers) + } else if p, ok = ps.inactive[id]; ok { + delete(ps.inactive, id) + } else { + ps.lock.Unlock() return errNotRegistered } - delete(ps.peers, id) - for _, sub := range ps.subscribers { - sub.unregisterPeer(p) + ps.lock.Unlock() + for _, n := range peers { + n.unregisterPeer(p) } - p.Peer.Disconnect(p2p.DiscRequested) + p.Peer.Disconnect(p2p.DiscUselessPeer) return nil } -// ids returns a list of all registered peer IDs +// ids returns a list of all active peer IDs func (ps *clientPeerSet) ids() []string { ps.lock.RLock() defer ps.lock.RUnlock() var ids []string - for id := range ps.peers { + for id := range ps.active { ids = append(ids, id) } return ids @@ -1007,24 +1073,27 @@ func (ps *clientPeerSet) peer(id string) *clientPeer { ps.lock.RLock() defer ps.lock.RUnlock() - return ps.peers[id] + if p, ok := ps.active[id]; ok { + return p + } + return ps.inactive[id] } -// len returns if the current number of peers in the set. +// len returns if the current number of peers in the active set. func (ps *clientPeerSet) len() int { ps.lock.RLock() defer ps.lock.RUnlock() - return len(ps.peers) + return len(ps.active) } -// allClientPeers returns all client peers in a list. +// allClientPeers returns all active client peers in a list. func (ps *clientPeerSet) allPeers() []*clientPeer { ps.lock.RLock() defer ps.lock.RUnlock() - list := make([]*clientPeer, 0, len(ps.peers)) - for _, p := range ps.peers { + list := make([]*clientPeer, 0, len(ps.active)) + for _, p := range ps.active { list = append(list, p) } return list @@ -1036,7 +1105,10 @@ func (ps *clientPeerSet) close() { ps.lock.Lock() defer ps.lock.Unlock() - for _, p := range ps.peers { + for _, p := range ps.active { + p.Disconnect(p2p.DiscQuitting) + } + for _, p := range ps.inactive { p.Disconnect(p2p.DiscQuitting) } ps.closed = true @@ -1045,7 +1117,7 @@ func (ps *clientPeerSet) close() { // serverPeerSet represents the set of active server peers currently // participating in the Light Ethereum sub-protocol. type serverPeerSet struct { - peers map[string]*serverPeer + active, inactive map[string]*serverPeer // subscribers is a batch of subscribers and peerset will notify // these subscribers when the peerset changes(new server peer is // added or removed) @@ -1056,17 +1128,23 @@ type serverPeerSet struct { // newServerPeerSet creates a new peer set to track the active server peers. func newServerPeerSet() *serverPeerSet { - return &serverPeerSet{peers: make(map[string]*serverPeer)} + return &serverPeerSet{ + active: make(map[string]*serverPeer), + inactive: make(map[string]*serverPeer), + } } // subscribe adds a service to be notified about added or removed // peers and also register all active peers into the given service. func (ps *serverPeerSet) subscribe(sub serverPeerSubscriber) { ps.lock.Lock() - defer ps.lock.Unlock() - ps.subscribers = append(ps.subscribers, sub) - for _, p := range ps.peers { + notify := make([]*serverPeer, 0, len(ps.active)) + for _, p := range ps.active { + notify = append(notify, p) + } + ps.lock.Unlock() + for _, p := range notify { sub.registerPeer(p) } } @@ -1086,19 +1164,25 @@ func (ps *serverPeerSet) unSubscribe(sub serverPeerSubscriber) { // register adds a new server peer into the set, or returns an error if the // peer is already known. -func (ps *serverPeerSet) register(peer *serverPeer) error { +func (ps *serverPeerSet) register(p *serverPeer) error { ps.lock.Lock() - defer ps.lock.Unlock() - if ps.closed { + ps.lock.Unlock() return errClosed } - if _, exist := ps.peers[peer.id]; exist { + if _, ok := ps.active[p.id]; ok { + ps.lock.Unlock() return errAlreadyRegistered } - ps.peers[peer.id] = peer - for _, sub := range ps.subscribers { - sub.registerPeer(peer) + delete(ps.inactive, p.id) + ps.active[p.id] = p + + peers := make([]serverPeerSubscriber, len(ps.subscribers)) + copy(peers, ps.subscribers) + ps.lock.Unlock() + + for _, n := range peers { + n.registerPeer(p) } return nil } @@ -1106,29 +1190,60 @@ func (ps *serverPeerSet) register(peer *serverPeer) error { // unregister removes a remote peer from the active set, disabling any further // actions to/from that particular entity. It also initiates disconnection at // the networking layer. -func (ps *serverPeerSet) unregister(id string) error { +func (ps *serverPeerSet) unregister(p *serverPeer) error { ps.lock.Lock() - defer ps.lock.Unlock() + if _, ok := ps.active[p.id]; !ok { + ps.lock.Unlock() + return errNotRegistered + } else { + delete(ps.active, p.id) + ps.inactive[p.id] = p + peers := make([]serverPeerSubscriber, len(ps.subscribers)) + copy(peers, ps.subscribers) + ps.lock.Unlock() + + for _, n := range peers { + n.unregisterPeer(p) + } + return nil + } +} - p, ok := ps.peers[id] - if !ok { +// disconnect removes a remote peer from either the active or inactive set and +// initiates disconnection at the networking layer. +func (ps *serverPeerSet) disconnect(id string) error { + ps.lock.Lock() + + var ( + peers []serverPeerSubscriber + p *serverPeer + ok bool + ) + if p, ok = ps.active[id]; ok { + delete(ps.active, p.id) + peers = make([]serverPeerSubscriber, len(ps.subscribers)) + copy(peers, ps.subscribers) + } else if p, ok = ps.inactive[id]; ok { + delete(ps.inactive, id) + } else { + ps.lock.Unlock() return errNotRegistered } - delete(ps.peers, id) - for _, sub := range ps.subscribers { - sub.unregisterPeer(p) + ps.lock.Unlock() + for _, n := range peers { + n.unregisterPeer(p) } - p.Peer.Disconnect(p2p.DiscRequested) + p.Peer.Disconnect(p2p.DiscUselessPeer) return nil } -// ids returns a list of all registered peer IDs +// ids returns a list of all active peer IDs func (ps *serverPeerSet) ids() []string { ps.lock.RLock() defer ps.lock.RUnlock() var ids []string - for id := range ps.peers { + for id := range ps.active { ids = append(ids, id) } return ids @@ -1139,15 +1254,18 @@ func (ps *serverPeerSet) peer(id string) *serverPeer { ps.lock.RLock() defer ps.lock.RUnlock() - return ps.peers[id] + if p, ok := ps.active[id]; ok { + return p + } + return ps.inactive[id] } -// len returns if the current number of peers in the set. +// len returns if the current number of peers in the active set. func (ps *serverPeerSet) len() int { ps.lock.RLock() defer ps.lock.RUnlock() - return len(ps.peers) + return len(ps.active) } // bestPeer retrieves the known peer with the currently highest total difficulty. @@ -1161,7 +1279,7 @@ func (ps *serverPeerSet) bestPeer() *serverPeer { bestPeer *serverPeer bestTd *big.Int ) - for _, p := range ps.peers { + for _, p := range ps.active { if td := p.Td(); bestTd == nil || td.Cmp(bestTd) > 0 { bestPeer, bestTd = p, td } @@ -1169,13 +1287,13 @@ func (ps *serverPeerSet) bestPeer() *serverPeer { return bestPeer } -// allServerPeers returns all server peers in a list. +// allPeers returns all active server peers in a list. func (ps *serverPeerSet) allPeers() []*serverPeer { ps.lock.RLock() defer ps.lock.RUnlock() - list := make([]*serverPeer, 0, len(ps.peers)) - for _, p := range ps.peers { + list := make([]*serverPeer, 0, len(ps.active)) + for _, p := range ps.active { list = append(list, p) } return list @@ -1187,7 +1305,10 @@ func (ps *serverPeerSet) close() { ps.lock.Lock() defer ps.lock.Unlock() - for _, p := range ps.peers { + for _, p := range ps.active { + p.Disconnect(p2p.DiscQuitting) + } + for _, p := range ps.inactive { p.Disconnect(p2p.DiscQuitting) } ps.closed = true diff --git a/les/peer_test.go b/les/peer_test.go index 59a2ad700954..a5da9a7a9682 100644 --- a/les/peer_test.go +++ b/les/peer_test.go @@ -85,7 +85,7 @@ func TestPeerSubscription(t *testing.T) { checkIds([]string{peer.id}) checkPeers(sub.regCh) - peers.unregister(peer.id) + peers.unregister(peer) checkIds([]string{}) checkPeers(sub.unregCh) } diff --git a/les/protocol.go b/les/protocol.go index 36af88aea6d0..5140795ecca3 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -33,17 +33,18 @@ import ( const ( lpv2 = 2 lpv3 = 3 + lpv4 = 4 ) // Supported versions of the les protocol (first is primary) var ( - ClientProtocolVersions = []uint{lpv2, lpv3} - ServerProtocolVersions = []uint{lpv2, lpv3} + ClientProtocolVersions = []uint{lpv2, lpv3, lpv4} + ServerProtocolVersions = []uint{lpv2, lpv3, lpv4} AdvertiseProtocolVersions = []uint{lpv2} // clients are searching for the first advertised protocol in the list ) // Number of implemented message corresponding to different protocol versions. -var ProtocolLengths = map[uint]uint64{lpv2: 22, lpv3: 24} +var ProtocolLengths = map[uint]uint64{lpv2: 22, lpv3: 24, lpv4: 26} const ( NetworkId = 1 @@ -74,6 +75,9 @@ const ( // Protocol messages introduced in LPV3 StopMsg = 0x16 ResumeMsg = 0x17 + // Protocol messages introduced in LPV4 + LespayMsg = 0x18 + LespayReplyMsg = 0x19 ) type requestInfo struct { @@ -201,6 +205,11 @@ type hashOrNumber struct { Number uint64 // Block hash from which to retrieve headers (excludes Hash) } +type lespayReply struct { + Reply []byte + Delay uint +} + // EncodeRLP is a specialized encoder for hashOrNumber to encode only one of the // two contained union fields. func (hn *hashOrNumber) EncodeRLP(w io.Writer) error { @@ -235,3 +244,28 @@ func (hn *hashOrNumber) DecodeRLP(s *rlp.Stream) error { type CodeData []struct { Value []byte } + +type stateFeedbackV4 struct { + BV, RealCost, TokenBalance uint64 +} + +type stateFeedback struct { + protocolVersion int + stateFeedbackV4 +} + +func (sf stateFeedback) EncodeRLP(w io.Writer) error { + if sf.protocolVersion >= lpv4 { + return rlp.Encode(w, sf.stateFeedbackV4) + } else { + return rlp.Encode(w, sf.BV) + } +} + +func (sf *stateFeedback) DecodeRLP(s *rlp.Stream) error { + if sf.protocolVersion >= lpv4 { + return s.Decode(&sf.stateFeedbackV4) + } else { + return s.Decode(&sf.BV) + } +} diff --git a/les/request_test.go b/les/request_test.go index f58ebca9c1d9..f5e370a6ba6d 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -36,22 +36,22 @@ func secAddr(addr common.Address) []byte { type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest -func TestBlockAccessLes2(t *testing.T) { testAccess(t, 2, tfBlockAccess) } func TestBlockAccessLes3(t *testing.T) { testAccess(t, 3, tfBlockAccess) } +func TestBlockAccessLes4(t *testing.T) { testAccess(t, 4, tfBlockAccess) } func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.BlockRequest{Hash: bhash, Number: number} } -func TestReceiptsAccessLes2(t *testing.T) { testAccess(t, 2, tfReceiptsAccess) } func TestReceiptsAccessLes3(t *testing.T) { testAccess(t, 3, tfReceiptsAccess) } +func TestReceiptsAccessLes4(t *testing.T) { testAccess(t, 4, tfReceiptsAccess) } func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { return &light.ReceiptsRequest{Hash: bhash, Number: number} } -func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) } func TestTrieEntryAccessLes3(t *testing.T) { testAccess(t, 3, tfTrieEntryAccess) } +func TestTrieEntryAccessLes4(t *testing.T) { testAccess(t, 4, tfTrieEntryAccess) } func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { if number := rawdb.ReadHeaderNumber(db, bhash); number != nil { @@ -60,8 +60,8 @@ func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) ligh return nil } -func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) } func TestCodeAccessLes3(t *testing.T) { testAccess(t, 3, tfCodeAccess) } +func TestCodeAccessLes4(t *testing.T) { testAccess(t, 4, tfCodeAccess) } func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrRequest { number := rawdb.ReadHeaderNumber(db, bhash) diff --git a/les/retrieve.go b/les/retrieve.go index 5fa68b745682..45976bc9eca0 100644 --- a/les/retrieve.go +++ b/les/retrieve.go @@ -345,7 +345,7 @@ func (r *sentReq) tryRequest() { if hrto { pp.Log().Debug("Request timed out hard") if r.rm.peers != nil { - r.rm.peers.unregister(pp.id) + r.rm.peers.disconnect(pp.id) } } diff --git a/les/server.go b/les/server.go index f72f31321abf..2430d1914a5f 100644 --- a/les/server.go +++ b/les/server.go @@ -44,6 +44,7 @@ type LesServer struct { handler *serverHandler lesTopics []discv5.Topic privateKey *ecdsa.PrivateKey + srvr *p2p.Server // Flow control and capacity management fcManager *flowcontrol.ClientManager @@ -51,6 +52,7 @@ type LesServer struct { defParams flowcontrol.ServerParams servingQueue *servingQueue clientPool *clientPool + tokenSale *tokenSale minCapacity, maxCapacity, freeCapacity uint64 threadsIdle int // Request serving threads count when system is idle. @@ -116,8 +118,13 @@ func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) { srv.maxCapacity = totalRecharge } srv.fcManager.SetCapacityLimits(srv.freeCapacity, srv.maxCapacity, srv.freeCapacity*2) - srv.clientPool = newClientPool(srv.chainDb, srv.freeCapacity, mclock.System{}, func(id enode.ID) { go srv.peers.unregister(peerIdToString(id)) }) + srv.clientPool = newClientPool(srv.chainDb, srv.minCapacity, srv.freeCapacity, mclock.System{}, func(id enode.ID) { go srv.peers.disconnect(peerIdToString(id)) }) srv.clientPool.setDefaultFactors(priceFactors{0, 1, 1}, priceFactors{0, 1, 1}) + srv.tokenSale = newTokenSale(srv.clientPool, 0.1, 100) + if config.LespayTestModule { + srv.tokenSale.addReceiver("test", testReceiver{}) + srv.clientPool.setExpirationTCs(defaultPosExpTC, defaultNegExpTC) + } checkpoint := srv.latestLocalCheckpoint() if !checkpoint.Empty() { @@ -148,6 +155,12 @@ func (s *LesServer) APIs() []rpc.API { Service: NewPrivateDebugAPI(s), Public: false, }, + { + Namespace: "lespay", + Version: "1.0", + Service: NewPrivateLespayAPI(s.peers, nil, nil, s.srvr.DiscV5, s.tokenSale), + Public: false, + }, } } @@ -167,6 +180,7 @@ func (s *LesServer) Protocols() []p2p.Protocol { // Start starts the LES server func (s *LesServer) Start(srvr *p2p.Server) { + s.srvr = srvr s.privateKey = srvr.PrivateKey s.handler.start() @@ -174,6 +188,7 @@ func (s *LesServer) Start(srvr *p2p.Server) { go s.capacityManagement() if srvr.DiscV5 != nil { + srvr.DiscV5.RegisterTalkHandler("lespay", s.handler.talkRequestHandler) for _, topic := range s.lesTopics { topic := topic go func() { @@ -191,6 +206,11 @@ func (s *LesServer) Start(srvr *p2p.Server) { func (s *LesServer) Stop() { close(s.closeCh) + if s.srvr.DiscV5 != nil { + s.srvr.DiscV5.RemoveTalkHandler("lespay") + } + s.tokenSale.stop() + // Disconnect existing sessions. // This also closes the gate for any new registrations on the peer set. // sessions which are already established but not added to pm.peers yet diff --git a/les/server_handler.go b/les/server_handler.go index 186bdcbb03f4..b928c3fb0b24 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -20,6 +20,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "net" "sync" "sync/atomic" "time" @@ -35,6 +36,7 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -54,10 +56,7 @@ const ( MaxTxStatus = 256 // Amount of transactions to queried per request ) -var ( - errTooManyInvalidRequest = errors.New("too many invalid requests made") - errFullClientPool = errors.New("client pool is full") -) +var errTooManyInvalidRequest = errors.New("too many invalid requests made") // serverHandler is responsible for serving light client and process // all incoming light requests. @@ -103,6 +102,9 @@ func (h *serverHandler) stop() { func (h *serverHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error { peer := newClientPeer(int(version), h.server.config.NetworkId, p, newMeteredMsgWriter(rw, int(version))) defer peer.close() + peer.getBalance = func() uint64 { + return h.server.clientPool.getPosBalance(p.ID()).value.value(h.server.clientPool.posExpiration(mclock.Now())) + } h.wg.Add(1) defer h.wg.Done() return h.handle(peer) @@ -134,28 +136,58 @@ func (h *serverHandler) handle(p *clientPeer) error { } defer p.fcClient.Disconnect() - // Disconnect the inbound peer if it's rejected by clientPool - if !h.server.clientPool.connect(p, 0) { - p.Log().Debug("Light Ethereum peer registration failed", "err", errFullClientPool) - return errFullClientPool + var ( + connectedAt mclock.AbsTime + wg *sync.WaitGroup // Wait group used to track all in-flight task routines. + ) + p.activate = func() { + // Register the peer locally + if err := h.server.peers.register(p); err != nil { + h.server.clientPool.disconnect(p) + p.Log().Error("Light Ethereum peer registration failed", "err", err) + return + } + clientConnectionGauge.Update(int64(h.server.peers.len())) + connectedAt = mclock.Now() + wg = new(sync.WaitGroup) + p.active = true } - // Register the peer locally - if err := h.server.peers.register(p); err != nil { - h.server.clientPool.disconnect(p) - p.Log().Error("Light Ethereum peer registration failed", "err", err) - return err + p.deactivate = func() { + h.server.peers.unregister(p) + if p.version < lpv4 { + h.server.peers.disconnect(p.id) + } + clientConnectionGauge.Update(int64(h.server.peers.len())) + connectionTimer.Update(time.Duration(mclock.Now() - connectedAt)) + p.active = false + } + if p.active { + p.activate() } - clientConnectionGauge.Update(int64(h.server.peers.len())) - var wg sync.WaitGroup // Wait group used to track all in-flight task routines. + if capacity, err := h.server.clientPool.connect(p, 0); err != nil { + // Disconnect the inbound peer if it's rejected by clientPool + p.Log().Debug("Light Ethereum peer registration failed", "err", err) + return err + } else if capacity != p.fcParams.MinRecharge { + if p.version < lpv4 { + h.server.peers.disconnect(p.id) + } else { + p.updateCapacity(capacity) + } + } - connectedAt := mclock.Now() defer func() { wg.Wait() // Ensure all background task routines have exited. - h.server.peers.unregister(p.id) h.server.clientPool.disconnect(p) - clientConnectionGauge.Update(int64(h.server.peers.len())) - connectionTimer.Update(time.Duration(mclock.Now() - connectedAt)) + p.responseLock.Lock() + if p.active { + p.deactivate() + } + p.activate = nil + p.deactivate = nil + p.responseLock.Unlock() + h.server.peers.disconnect(p.id) }() // Spawn a main loop to handle all incoming messages. @@ -166,7 +198,7 @@ func (h *serverHandler) handle(p *clientPeer) error { return err default: } - if err := h.handleMsg(p, &wg); err != nil { + if err := h.handleMsg(p, wg); err != nil { p.Log().Debug("Light Ethereum message handling failed", "err", err) return err } @@ -245,22 +277,33 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { if reply != nil { replySize = reply.size() } - var realCost uint64 + var realCost, balance uint64 if h.server.costTracker.testing { realCost = maxCost // Assign a fake cost for testing purpose } else { realCost = h.server.costTracker.realCost(servingTime, msg.Size, replySize) + if realCost > maxCost { + realCost = maxCost + } } bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost) if amount != 0 { // Feed cost tracker request serving statistic. h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost) // Reduce priority "balance" for the specific peer. - h.server.clientPool.requestCost(p, realCost) + balance = h.server.clientPool.requestCost(p, realCost) + } + sf := stateFeedback{ + protocolVersion: p.version, + stateFeedbackV4: stateFeedbackV4{ + BV: bv, + RealCost: realCost, + TokenBalance: balance, + }, } if reply != nil { p.mustQueueSend(func() { - if err := reply.send(bv); err != nil { + if err := reply.send(sf); err != nil { select { case p.errCh <- err: default: @@ -374,7 +417,7 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { first = false } reply := p.replyBlockHeaders(req.ReqID, headers) - sendResponse(req.ReqID, query.Amount, p.replyBlockHeaders(req.ReqID, headers), task.done()) + sendResponse(req.ReqID, query.Amount, reply, task.done()) if metrics.EnabledExpensive { miscOutHeaderPacketsMeter.Mark(1) miscOutHeaderTrafficMeter.Mark(int64(reply.size())) @@ -824,6 +867,38 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { } }() } + case LespayMsg: + p.Log().Trace("Received transaction status query request") + if metrics.EnabledExpensive { + miscInLespayPacketsMeter.Mark(1) + miscInLespayTrafficMeter.Mark(int64(msg.Size)) + defer func(start time.Time) { miscServingTimeLespayTimer.UpdateSince(start) }(time.Now()) + } + var req struct { + ReqID uint64 + Cmd []byte + } + if err := msg.Decode(&req); err != nil { + clientErrorMeter.Mark(1) + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + if !h.server.tokenSale.queueCommand(p.id, lespayCmd{ + cmd: req.Cmd, + id: p.ID(), + freeID: p.freeClientId(), + send: func(reply []byte, delay uint) { + if metrics.EnabledExpensive { + miscOutLespayPacketsMeter.Mark(1) + miscOutLespayTrafficMeter.Mark(int64(len(reply))) + } + p.queueSend(func() { + p.replyLespay(req.ReqID, reply, delay) + }) + }, + }) { + clientErrorMeter.Mark(1) + return errResp(ErrRequestRejected, "") + } default: p.Log().Trace("Received invalid message", "code", msg.Code) @@ -959,3 +1034,49 @@ func (h *serverHandler) broadcastHeaders() { } } } + +// talkRequestHandler implements discv5.TalkRequestHandler. It processes a list of +// lespay token sale commands and returns the results and the recommended delay. +// +// Note: the UDP talk format for lespay commands allows multiple commands in a single +// packet because UDP does not guarantee the correct order of messages which might be +// important in some cases (like deposit followed by buyTokens). +func (h *serverHandler) talkRequestHandler(id enode.ID, addr *net.UDPAddr, payload interface{}) (interface{}, uint, bool) { + c, ok := payload.([]interface{}) + if !ok { + return nil, 0, false + } + type result struct { + data []byte + delay uint + } + resultCh := make(chan result, len(c)) + results := make([][]byte, len(c)) + for _, c := range c { + cmd, ok := c.([]byte) + if !ok { + return nil, 0, false + } + if !h.server.tokenSale.queueCommand(id.String(), lespayCmd{ + cmd: cmd, + id: id, + freeID: addr.IP.String(), + send: func(reply []byte, delay uint) { + resultCh <- result{reply, delay} + }, + }) { + return nil, 0, false + } + } + + var lastDelay uint + for i := range results { + select { + case r := <-resultCh: + results[i], lastDelay = r.data, r.delay + case <-h.closeCh: + return nil, 0, false + } + } + return results, lastDelay, true +} diff --git a/les/servingqueue.go b/les/servingqueue.go index 9db84e6159cf..b1a8dda80b3d 100644 --- a/les/servingqueue.go +++ b/les/servingqueue.go @@ -70,7 +70,7 @@ type runToken chan struct{} // start blocks until the task can start and returns true if it is allowed to run. // Returning false means that the task should be cancelled. func (t *servingTask) start() bool { - if t.peer.isFrozen() { + if t.peer != nil && t.peer.isFrozen() { return false } t.tokenCh = make(chan runToken, 1) @@ -289,7 +289,7 @@ func (sq *servingQueue) addTask(task *servingTask) { sq.queuedTime += task.expTime sqServedGauge.Update(int64(sq.recentTime)) sqQueuedGauge.Update(int64(sq.queuedTime)) - if sq.recentTime+sq.queuedTime > sq.burstLimit { + if sq.burstLimit != 0 && sq.recentTime+sq.queuedTime > sq.burstLimit { sq.freezePeers() } } diff --git a/les/test_helper.go b/les/test_helper.go index d9ffe32db205..da8222add4f9 100644 --- a/les/test_helper.go +++ b/les/test_helper.go @@ -78,10 +78,10 @@ var ( processConfirms = big.NewInt(1) // The token bucket buffer limit for testing purpose. - testBufLimit = uint64(1000000) + testBufLimit = uint64(6000) // The buffer recharging speed for testing purpose. - testBufRecharge = uint64(1000) + testBufRecharge = uint64(1) ) /* @@ -281,7 +281,7 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da } server.costTracker, server.freeCapacity = newCostTracker(db, server.config) server.costTracker.testCostList = testCostList(0) // Disable flow control mechanism. - server.clientPool = newClientPool(db, 1, clock, nil) + server.clientPool = newClientPool(db, 1, 1, clock, nil) server.clientPool.setLimits(10000, 10000) // Assign enough capacity for clientpool server.handler = newServerHandler(server, simulation.Blockchain(), db, txpool, func() bool { return true }) if server.oracle != nil { @@ -309,7 +309,8 @@ func newTestPeer(t *testing.T, name string, version int, handler *serverHandler, // Generate a random id and create the peer var id enode.ID rand.Read(id[:]) - peer := newClientPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net) + cpeer := newClientPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net) + speer := newServerPeer(version, NetworkId, false, p2p.NewPeer(id, name, nil), app) // Start the peer on a new thread errCh := make(chan error, 1) @@ -317,13 +318,14 @@ func newTestPeer(t *testing.T, name string, version int, handler *serverHandler, select { case <-handler.closeCh: errCh <- p2p.DiscQuitting - case errCh <- handler.handle(peer): + case errCh <- handler.handle(cpeer): } }() tp := &testPeer{ app: app, net: net, - cpeer: peer, + cpeer: cpeer, + speer: speer, } // Execute any implicitly requested handshakes and return if shake { @@ -395,8 +397,13 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu expList = expList.add("serveStateSince", uint64(0)) expList = expList.add("serveRecentState", uint64(core.TriesInMemory-4)) expList = expList.add("txRelay", nil) - expList = expList.add("flowControl/BL", testBufLimit) - expList = expList.add("flowControl/MRR", testBufRecharge) + if p.cpeer.version >= lpv4 { + expList = expList.add("flowControl/BL", uint64(0)) + expList = expList.add("flowControl/MRR", uint64(0)) + } else { + expList = expList.add("flowControl/BL", testBufLimit) + expList = expList.add("flowControl/MRR", testBufRecharge) + } expList = expList.add("flowControl/MRC", costList) if err := p2p.ExpectMsg(p.app, StatusMsg, expList); err != nil { @@ -405,9 +412,16 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu if err := p2p.Send(p.app, StatusMsg, sendList); err != nil { t.Fatalf("status send: %v", err) } - p.cpeer.fcParams = flowcontrol.ServerParams{ - BufLimit: testBufLimit, - MinRecharge: testBufRecharge, +} + +func (p *testPeer) expectCapUpdate(t *testing.T) { + if p.cpeer.version >= lpv4 { + var expList keyValueList + expList = expList.add("flowControl/BL", testBufLimit) + expList = expList.add("flowControl/MRR", testBufRecharge) + if err := p2p.ExpectMsg(p.app, AnnounceMsg, announceData{Update: expList}); err != nil { + t.Fatalf("status recv: %v", err) + } } } @@ -438,7 +452,7 @@ type testServer struct { bloomTrieIndexer *core.ChainIndexer } -func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, simClock bool, newPeer bool, testCost uint64) (*testServer, func()) { +func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, simClock bool, newPeer bool, testCost uint64, expectCapUpdate bool) (*testServer, func()) { db := rawdb.NewMemoryDatabase() indexers := testIndexers(db, nil, light.TestServerIndexerConfig) @@ -480,6 +494,9 @@ func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallba cIndexer.Close() bIndexer.Close() } + if expectCapUpdate { + server.peer.expectCapUpdate(t) + } return server, teardown } diff --git a/les/tokensale.go b/les/tokensale.go new file mode 100644 index 000000000000..099058fa7138 --- /dev/null +++ b/les/tokensale.go @@ -0,0 +1,860 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "encoding/binary" + "fmt" + "math" + "strconv" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/rlp" +) + +const ( + basePriceTC = time.Hour * 10 // time constant for controlling the base price + tokenSellMaxRatio = 0.9 // total amount/supply limit ratio over which selling price does not increase further + tsMinDelay = time.Second * 5 // minimum recommended delay for sending the next command + tsMaxBurst = 16 // maximum commands processed in a row before the recommended delay has elapsed +) + +// paymentReceiver processes incoming payments and can be implemented using different +// payment technologies +type paymentReceiver interface { + info() keyValueList + receivePayment(from enode.ID, proofOfPayment []byte, reader ethdb.KeyValueReader, writer ethdb.KeyValueWriter) (value uint64, err error) + requestPayment(from enode.ID, value uint64, reader ethdb.KeyValueReader) uint64 +} + +// tokenSale handles client balance deposits, conversion to and from service tokens +// and granting connections and capacity changes through a set of commands called "lespay". +type tokenSale struct { + lock sync.Mutex + clientPool *clientPool + stopCh chan struct{} + receivers map[string]paymentReceiver + receiverNames []string + basePrice, minBasePrice float64 + totalTokenLimit, totalTokenAmount func() uint64 + + qlock sync.Mutex + sq *servingQueue + sources map[string]*cmdSource + delayFactorZero, delayFactorLast mclock.AbsTime + tsProcessDelay, tsTargetPeriod time.Duration +} + +// newTokenSale creates a new token sale module instance +func newTokenSale(clientPool *clientPool, minBasePrice float64, talkSpeed int) *tokenSale { + t := &tokenSale{ + clientPool: clientPool, + receivers: make(map[string]paymentReceiver), + basePrice: minBasePrice, + minBasePrice: minBasePrice, + totalTokenLimit: clientPool.totalTokenLimit, + totalTokenAmount: clientPool.totalTokenAmount, + stopCh: make(chan struct{}), + sq: newServingQueue(0, 0), + sources: make(map[string]*cmdSource), + delayFactorZero: mclock.Now(), + delayFactorLast: mclock.Now(), + tsProcessDelay: time.Second / time.Duration(talkSpeed), + tsTargetPeriod: 5 * time.Second / time.Duration(talkSpeed), + } + t.sq.setThreads(1) + go func() { + cleanupCounter := 0 + for { + select { + case <-time.After(time.Second * 10): + t.lock.Lock() + cost, ok := t.tokenPrice(1, true) + if cost > t.basePrice*10 || !ok { + cost = t.basePrice * 10 + } + t.basePrice += (cost - t.basePrice) * float64(time.Second*10) / float64(basePriceTC) + if t.basePrice < minBasePrice { + t.basePrice = minBasePrice + } + t.lock.Unlock() + + cleanupCounter++ + if cleanupCounter == 100 { + t.sourceMapCleanup() + cleanupCounter = 0 + } + case <-t.stopCh: + return + } + } + }() + return t +} + +type ( + // cmdSource represents a source where lespay commands can come from. + // It can be either an LES connected peer or a UDP address. + cmdSource struct { + ch chan lespayCmd + delayUntil mclock.AbsTime + burstCounter int + } + // lespayCmd represents a single lespay command, including the source it came + // from and the callback that is going to process the results. + lespayCmd struct { + cmd []byte + id enode.ID + freeID string + send func([]byte, uint) + } +) + +// priority returns the processing priority for the next command coming from the given +// source. Commands sent before the previously recommended delay has elapsed have a +// lower priority. It also checks whether the number of commands consecutively sent +// before the delay has elapsed exceeds maxBurst and rejects the command instantly if +// necessary. +func (c *cmdSource) priority() (int64, bool) { + dt := c.delayUntil - mclock.Now() + if dt <= 0 { + c.burstCounter = 0 + return 0, true + } + if c.burstCounter >= tsMaxBurst { + return 0, false + } + c.burstCounter++ + return -int64(dt), true +} + +// addDelay adds the given amount to the recommended delay +func (c *cmdSource) addDelay(now mclock.AbsTime, delay time.Duration) uint { + dt := time.Duration(c.delayUntil - now) + if dt <= 0 { + dt = 0 + } + dt += delay + if dt < tsMinDelay { + dt = tsMinDelay + } + c.delayUntil = now + mclock.AbsTime(dt) + return uint((dt + time.Second - 1) / time.Second) +} + +// delayFactor calculates the amount added to the recommended delay after processing +// a single command +func (t *tokenSale) delayFactor(now mclock.AbsTime) time.Duration { + if now > t.delayFactorZero { + t.delayFactorZero = now + } + t.delayFactorZero += mclock.AbsTime(t.tsTargetPeriod) + t.delayFactorLast - now + t.delayFactorLast = now + if now >= t.delayFactorZero { + return 0 + } else { + return time.Duration(t.delayFactorZero-now) / 4 + } +} + +// sourceMapCleanup removes unnecessary entries from the command source map +func (t *tokenSale) sourceMapCleanup() { + t.qlock.Lock() + defer t.qlock.Unlock() + + now := mclock.Now() + for src, s := range t.sources { + if s.delayUntil < now { + delete(t.sources, src) + } + } +} + +// queueCommand schedules a lespay command (encapsulated in a lespayCmd) for execution +func (t *tokenSale) queueCommand(src string, cmd lespayCmd) bool { + t.qlock.Lock() + defer t.qlock.Unlock() + + s := t.sources[src] + if s == nil { + s = &cmdSource{} + t.sources[src] = s + } + if s.ch != nil { + select { + case s.ch <- cmd: + return true + default: + return false + } + } + s.ch = make(chan lespayCmd, 16) + s.ch <- cmd + + go func() { + loop: + for { + select { + case cmd := <-s.ch: + t.qlock.Lock() + pri, ok := s.priority() + t.qlock.Unlock() + if ok { + task := t.sq.newTask(nil, 0, pri) + if !task.start() { + break loop + } + reply := t.runCommand(cmd.cmd, cmd.id, cmd.freeID) + t.qlock.Lock() + now := mclock.Now() + delay := s.addDelay(now, t.delayFactor(now)) + t.qlock.Unlock() + cmd.send(reply, delay) + time.Sleep(t.tsProcessDelay) + task.done() + } else { + cmd.send(nil, 0) + } + default: + break loop + } + t.qlock.Lock() + s.ch = nil + t.qlock.Unlock() + } + }() + return true +} + +// stop stops the token sale module +func (t *tokenSale) stop() { + close(t.stopCh) + t.sq.stop() +} + +// addReceiver adds a new payment receiver module +func (t *tokenSale) addReceiver(id string, r paymentReceiver) { + t.lock.Lock() + defer t.lock.Unlock() + + t.receivers[id] = r + t.receiverNames = append(t.receiverNames, id) +} + +// tokenPrice returns the PC units required to buy the specified amount of service +// tokens or the PC units received when selling the given amount of tokens. +// Returns false if not possible. +// +// Note: the price of each token unit depends on the current amount of existing tokens +// and the total token limit, first raising from 0 to basePrice linearly, then tends to +// infinity as tokenAmount approaches tokenLimit. +// +// if 0 <= tokenAmount <= tokenLimit/2: +// tokenPrice = basePrice*tokenAmount/(tokenLimit/2) +// if tokenLimit/2 <= tokenAmount < tokenLimit: +// tokenPrice = basePrice*tokenLimit/2/(tokenLimit-tokenAmount) +// +// The price of multiple tokens is calculated as an integral based on the above formula. +func (t *tokenSale) tokenPrice(buySellAmount uint64, buy bool) (float64, bool) { + tokenLimit := t.totalTokenLimit() + tokenAmount := t.totalTokenAmount() + if buy { + if tokenAmount+buySellAmount >= tokenLimit { + return 0, false + } + } else { + maxAmount := uint64(float64(tokenLimit) * tokenSellMaxRatio) + if tokenAmount > maxAmount { + tokenAmount = maxAmount + } + if tokenAmount < buySellAmount { + buySellAmount = tokenAmount + } + tokenAmount -= buySellAmount + } + r := float64(tokenAmount) / float64(tokenLimit) + b := float64(buySellAmount) / float64(tokenLimit) + var relPrice float64 + if r < 0.5 { + // first purchased token is in the linear range + if r+b <= 0.5 { + // all purchased tokens are in the linear range + relPrice = b * (r + r + b) + b = 0 + } else { + // some purchased tokens are in the 1/x range, calculate linear price + // update starting point and amount left to buy in the 1/x range + relPrice = (0.5 - r) * (r + 0.5) + b = r + b - 0.5 + r = 0.5 + } + } + if b > 0 { + // some purchased tokens are in the 1/x range + l := 1 - r + if l < 1e-10 { + return 0, false + } + l = -b / l + if l < -1+1e-10 { + return 0, false + } + relPrice += -math.Log1p(l) / 2 + } + return t.basePrice * float64(tokenLimit) * relPrice, true +} + +// tokenBuyAmount returns the service token amount currently available for the given +// sum of PC units +func (t *tokenSale) tokenBuyAmount(price float64) uint64 { + tokenLimit := t.totalTokenLimit() + tokenAmount := t.totalTokenAmount() + if tokenLimit <= tokenAmount { + return 0 + } + r := float64(tokenAmount) / float64(tokenLimit) + c := price / (t.basePrice * float64(tokenLimit)) + var relTokens float64 + if r < 0.5 { + // first purchased token is in the linear range + relTokens = math.Sqrt(r*r+c) - r + if r+relTokens <= 0.5 { + // all purchased tokens are in the linear range, no more to spend + c = 0 + } else { + // some purchased tokens are in the 1/x range, calculate linear amount + // update starting point and available funds left to buy in the 1/x range + relTokens = 0.5 - r + c -= (0.5 - r) * (r + 0.5) + r = 0.5 + } + } + if c > 0 { + relTokens -= math.Expm1(-2*c) * (1 - r) + } + return uint64(relTokens * float64(tokenLimit)) +} + +// tokenSellAmount returns the service token amount that needs to be sold in order +// to receive the given sum of PC units. Returns false if not possible. +func (t *tokenSale) tokenSellAmount(price float64) (uint64, bool) { + tokenLimit := t.totalTokenLimit() + tokenAmount := t.totalTokenAmount() + r := float64(tokenAmount) / float64(tokenLimit) + if r > tokenSellMaxRatio { + r = tokenSellMaxRatio + } + c := price / (t.basePrice * float64(tokenLimit)) + var relTokens float64 + if r > 0.5 { + // first sold token is in the 1/x range + relTokens = math.Expm1(2*c) * (1 - r) + if r-relTokens >= 0.5 || 1-r < 1e-10 { + // all sold tokens are in the 1/x range, no more to sell + c = 0 + } else { + // some sold tokens are in the linear range, calculate price in 1/x range + // update starting point and remaining price to sell for in the linear range + relTokens = r - 0.5 + c -= math.Log1p(relTokens/(1-r)) / 2 + r = 0.5 + } + } + if c > 0 { + // some sold tokens are in the linear range + if x := r*r - c; x >= 0 { + relTokens += r - math.Sqrt(x) + } else { + return 0, false + } + } + return uint64(relTokens * float64(tokenLimit)), true +} + +// connection checks whether it is possible with the current balance levels to establish +// requested connection or capacity change and then stay connected for the given amount +// of time. If it is possible and setCap is also true then the client is activated of the +// capacity change is performed. If not then returns how many tokens are missing and how +// much that would currently cost using the specified payment module(s). +func (t *tokenSale) connection(id enode.ID, freeID string, requestedCapacity uint64, stayConnected time.Duration, paymentModule []string, setCap bool) (availableCapacity, tokenBalance, tokensMissing, pcBalance, pcMissing uint64, paymentRequired []uint64, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + tokensMissing, availableCapacity, err = t.clientPool.setCapacityLocked(id, freeID, requestedCapacity, stayConnected, setCap) + pb := t.clientPool.getPosBalance(id) + tokenBalance = pb.value.value(t.clientPool.posExpiration(mclock.Now())) + cb := t.clientPool.ndb.getCurrencyBalance(id) + pcBalance = cb.amount + if tokensMissing == 0 { + return + } + tokenLimit := t.clientPool.totalTokenLimit() + tokenAmount := t.clientPool.totalTokenAmount() + if tokenLimit <= tokenAmount || tokenLimit-tokenAmount <= tokensMissing { + pcMissing = math.MaxUint64 + } else { + tokensAvailable := tokenLimit - tokenAmount + pcr := -math.Log(float64(tokensAvailable-tokensMissing)/float64(tokensAvailable)) * t.basePrice + if pcr > 0 { + if pcr > maxBalance { + pcMissing = math.MaxUint64 + } else { + pcMissing = uint64(pcr) + if pcMissing > maxBalance { + pcMissing = math.MaxUint64 + } else { + if pcMissing > pcBalance { + pcMissing -= pcBalance + } else { + pcMissing = 0 + } + } + } + } + } + if pcMissing == 0 { + return + } + paymentRequired = make([]uint64, len(paymentModule)) + for i, recID := range paymentModule { + if rec, ok := t.receivers[recID]; !ok || pcMissing == math.MaxUint64 { + paymentRequired[i] = math.MaxUint64 + } else { + paymentRequired[i] = rec.requestPayment(id, pcMissing, newReaderTable(t.clientPool.ndb.db, receiverPrefix(id, recID))) + } + } + return +} + +// deposit credits a payment on the sender's account using the specified payment module +func (t *tokenSale) deposit(id enode.ID, paymentModule string, proofOfPayment []byte) (pcValue, pcBalance uint64, err error) { + writer := t.clientPool.ndb.atomicWriteLock(id.Bytes()) + t.lock.Lock() + defer func() { + t.lock.Unlock() + t.clientPool.ndb.atomicWriteUnlock(id.Bytes()) + }() + + cb := t.clientPool.ndb.getCurrencyBalance(id) + pcBalance = cb.amount + pm := t.receivers[paymentModule] + if pm == nil { + return 0, pcBalance, fmt.Errorf("Unknown payment receiver '%s'", paymentModule) + } + prefix := receiverPrefix(id, paymentModule) + pcValue, err = pm.receivePayment(id, proofOfPayment, newReaderTable(t.clientPool.ndb.db, prefix), newWriterTable(writer, prefix)) + if err != nil { + return 0, pcBalance, err + } + pcBalance += pcValue + cb.amount = pcBalance + t.clientPool.ndb.setCurrencyBalance(id, cb) + return +} + +// buyTokens tries to convert the permanent balance (nominated in the server's preferred +// currency, PC) to service tokens. If spendAll is true then it sells the maxSpend amount +// of PC coins if the received service token amount is at least minReceive. If spendAll is +// false then is buys minReceive amount of tokens if it does not cost more than maxSpend +// amount of PC coins. +// if relative is true then maxSpend and minReceive are specified relative to their current +// balances. In this case maxSpend represents the amount under which the PC balance should +// not go and minReceive represents the amount the service token balance should reach. +// This mode is useful when actual conversion is intended to happen and the sender has to +// retry the command after not receiving a reply previously. In this case the sender cannot +// be sure whether the conversion has already happened or not. If relative is true then it +// is impossible to do a conversion twice. In exchange the sender needs to know its current +// balances (which it probably does if it has made a previous call to just ask the current price). +func (t *tokenSale) buyTokens(id enode.ID, maxSpend, minReceive uint64, relative, spendAll bool) (pcBalance, tokenBalance, spend, receive uint64, success bool) { + t.clientPool.ndb.atomicWriteLock(id.Bytes()) + t.lock.Lock() + defer func() { + t.lock.Unlock() + t.clientPool.ndb.atomicWriteUnlock(id.Bytes()) + }() + + pb := t.clientPool.getPosBalance(id) + tokenBalance = pb.value.value(t.clientPool.posExpiration(mclock.Now())) + cb := t.clientPool.ndb.getCurrencyBalance(id) + pcBalance = cb.amount + if relative { + if pcBalance > maxSpend { + maxSpend = pcBalance - maxSpend + } else { + maxSpend = 0 + } + if minReceive > tokenBalance { + minReceive -= tokenBalance + } else { + minReceive = 0 + } + } + + if maxSpend > pcBalance { + maxSpend = pcBalance + } + if spendAll { + spend = maxSpend + receive = t.tokenBuyAmount(float64(spend)) + success = receive >= minReceive + } else { + receive = minReceive + if cost, ok := t.tokenPrice(receive, true); ok { + spend = uint64(cost) + 1 // ensure that we don't sell small amounts for free + } else { + spend = math.MaxUint64 + } + success = spend <= maxSpend + } + if success { + pcBalance -= spend + cb.amount = pcBalance + tokenBalance += receive + t.clientPool.ndb.setCurrencyBalance(id, cb) + t.clientPool.addBalance(id, int64(receive)) + } + return +} + +// sellTokens tries to convert service tokens to permanent balance (nominated in the server's +// preferred currency, PC). Parameters work similarly to buyTokens. +func (t *tokenSale) sellTokens(id enode.ID, maxSell, minRefund uint64, relative, sellAll bool) (pcBalance, tokenBalance, sell, refund uint64, success bool) { + t.clientPool.ndb.atomicWriteLock(id.Bytes()) + t.lock.Lock() + defer func() { + t.lock.Unlock() + t.clientPool.ndb.atomicWriteUnlock(id.Bytes()) + }() + + pb := t.clientPool.getPosBalance(id) + tokenBalance = pb.value.value(t.clientPool.posExpiration(mclock.Now())) + cb := t.clientPool.ndb.getCurrencyBalance(id) + pcBalance = cb.amount + if relative { + if pcBalance < minRefund { + minRefund -= pcBalance + } else { + minRefund = 0 + } + if maxSell < tokenBalance { + maxSell = tokenBalance - maxSell + } else { + maxSell = 0 + } + } + + if maxSell > tokenBalance { + maxSell = tokenBalance + } + if sellAll { + sell = maxSell + if r, ok := t.tokenPrice(sell, false); ok { + refund = uint64(r) + success = refund >= minRefund + } + } else { + refund = minRefund + if s, ok := t.tokenSellAmount(float64(refund)); ok { + sell = s + 1 // ensure that we don't sell small amounts for free + } else { + sell = math.MaxUint64 + } + success = sell <= maxSell + } + if success { + pcBalance += refund + tokenBalance -= sell + cb.amount = pcBalance + t.clientPool.ndb.setCurrencyBalance(id, cb) + t.clientPool.addBalance(id, -int64(sell)) + } + return +} + +// getBalance returns the current PC balance and service token balance +func (t *tokenSale) getBalance(id enode.ID) (pcBalance, tokenBalance uint64) { + t.lock.Lock() + defer t.lock.Unlock() + + pb := t.clientPool.getPosBalance(id) + cb := t.clientPool.ndb.getCurrencyBalance(id) + return pb.value.value(t.clientPool.posExpiration(mclock.Now())), cb.amount +} + +// info returns general information about the server, including version info of the +// lespay command set, supported payment modules and token expiration time constant +func (t *tokenSale) info() (version, compatible uint, info keyValueList, receivers []string) { + t.lock.Lock() + defer t.lock.Unlock() + + exp, _ := t.clientPool.getExpirationTCs() + info = info.add("tokenExpiration", strconv.FormatUint(exp, 10)) + return 1, 1, info, t.receiverNames +} + +// receiverInfo returns information about the specified payment receiver(s) if supported +func (t *tokenSale) receiverInfo(receiverIDs []string) []keyValueList { + t.lock.Lock() + defer t.lock.Unlock() + + res := make([]keyValueList, len(receiverIDs)) + for i, id := range receiverIDs { + if rec, ok := t.receivers[id]; ok { + res[i] = rec.info() + } + } + return res +} + +const ( + tsInfo = iota + tsReceiverInfo + tsGetBalance + tsDeposit + tsBuyTokens + tsSellTokens + tsConnection +) + +type ( + tsInfoResults struct { + Version, Compatible uint + Info keyValueList + Receivers []string + } + tsInfoApiResults struct { + Version, Compatible uint + Info keyValueMapDecoded + Receivers []string + } + tsReceiverInfoParams []string + tsReceiverInfoResults []keyValueList + tsReceiverInfoApiResults []keyValueMapDecoded + tsGetBalanceResults struct { + PcBalance, TokenBalance uint64 + } + tsDepositParams struct { + PaymentModule string + ProofOfPayment []byte + } + tsDepositResults struct { + PcValue, PcBalance uint64 + Err string + } + tsBuyTokensParams struct { + MaxSpend, MinReceive uint64 + Relative, SpendAll bool + } + tsBuyTokensResults struct { + PcBalance, TokenBalance, Spend, Receive uint64 + Success bool + } + tsSellTokensParams struct { + MaxSell, MinRefund uint64 + Relative, SellAll bool + } + tsSellTokensResults struct { + PcBalance, TokenBalance, Sell, Refund uint64 + Success bool + } + tsConnectionParams struct { + RequestedCapacity, StayConnected uint64 + PaymentModule []string + SetCap bool + } + tsConnectionResults struct { + AvailableCapacity, TokenBalance, TokensMissing, PcBalance, PcMissing uint64 + PaymentRequired []uint64 + Err string + } +) + +// runCommand runs an encoded lespay command and returns the encoded results +func (t *tokenSale) runCommand(cmd []byte, id enode.ID, freeID string) []byte { + var res []byte + switch cmd[0] { + case tsInfo: + var results tsInfoResults + if len(cmd) == 1 { + results.Version, results.Compatible, results.Info, results.Receivers = t.info() + res, _ = rlp.EncodeToBytes(&results) + } + case tsReceiverInfo: + var ( + params tsReceiverInfoParams + results tsReceiverInfoResults + ) + if err := rlp.DecodeBytes(cmd[1:], ¶ms); err == nil { + results = t.receiverInfo(params) + res, _ = rlp.EncodeToBytes(&results) + } + case tsGetBalance: + var results tsGetBalanceResults + if len(cmd) == 1 { + results.PcBalance, results.TokenBalance = t.getBalance(id) + res, _ = rlp.EncodeToBytes(&results) + } + case tsDeposit: + var ( + params tsDepositParams + results tsDepositResults + ) + if err := rlp.DecodeBytes(cmd[1:], ¶ms); err == nil { + results.PcValue, results.PcBalance, err = t.deposit(id, params.PaymentModule, params.ProofOfPayment) + if err != nil { + results.Err = err.Error() + } + res, _ = rlp.EncodeToBytes(&results) + } + case tsBuyTokens: + var ( + params tsBuyTokensParams + results tsBuyTokensResults + ) + if err := rlp.DecodeBytes(cmd[1:], ¶ms); err == nil { + results.PcBalance, results.TokenBalance, results.Spend, results.Receive, results.Success = + t.buyTokens(id, params.MaxSpend, params.MinReceive, params.Relative, params.SpendAll) + res, _ = rlp.EncodeToBytes(&results) + } + case tsSellTokens: + var ( + params tsSellTokensParams + results tsSellTokensResults + ) + if err := rlp.DecodeBytes(cmd[1:], ¶ms); err == nil { + results.PcBalance, results.TokenBalance, results.Sell, results.Refund, results.Success = + t.sellTokens(id, params.MaxSell, params.MinRefund, params.Relative, params.SellAll) + res, _ = rlp.EncodeToBytes(&results) + } + case tsConnection: + var ( + params tsConnectionParams + results tsConnectionResults + ) + if err := rlp.DecodeBytes(cmd[1:], ¶ms); err == nil { + results.AvailableCapacity, results.TokenBalance, results.TokensMissing, results.PcBalance, results.PcMissing, results.PaymentRequired, err = + t.connection(id, freeID, params.RequestedCapacity, time.Duration(params.StayConnected)*time.Second, params.PaymentModule, params.SetCap) + if err != nil { + results.Err = err.Error() + } + res, _ = rlp.EncodeToBytes(&results) + } + } + return res +} + +type keyValueMapDecoded map[string]interface{} + +// DecodeRLP implements rlp.Decoder +func (k *keyValueMapDecoded) DecodeRLP(s *rlp.Stream) error { + var list keyValueList + if err := s.Decode(&list); err != nil { + return err + } + *k = make(keyValueMapDecoded) + for _, item := range list { + var s string + if err := rlp.DecodeBytes(item.Value, &s); err != nil { + return err + } + (*k)[item.Key] = s + } + return nil +} + +// testReceiver implements paymentReceiver. It should only be used for testing. +type testReceiver struct{} + +func (t testReceiver) info() keyValueList { + var info keyValueList + info = info.add("description", "Test payment receiver") + info = info.add("version", "1.0.0") + return info +} + +// receivePayment implements paymentReceiver. proofOfPayment is a base 10 ascii number +// which is credited to the sender's account without any further conditions. +func (t testReceiver) receivePayment(from enode.ID, proofOfPayment []byte, reader ethdb.KeyValueReader, writer ethdb.KeyValueWriter) (value uint64, err error) { + if len(proofOfPayment) > 8 { + err = fmt.Errorf("proof of payment is too long; max 8 bytes long big endian integer expected") + return + } + var b [8]byte + copy(b[8-len(proofOfPayment):], proofOfPayment) + value = binary.BigEndian.Uint64(b[:]) + return +} + +// requestPayment implements paymentReceiver +func (t testReceiver) requestPayment(from enode.ID, value uint64, reader ethdb.KeyValueReader) uint64 { + return value +} + +// readerTable is a wrapper around a database that prefixes each key access with a pre- +// configured string. +type readerTable struct { + db ethdb.KeyValueReader + prefix []byte +} + +// newReaderTable returns a database object that prefixes all keys with a given string. +func newReaderTable(db ethdb.KeyValueReader, prefix []byte) ethdb.KeyValueReader { + return &readerTable{ + db: db, + prefix: prefix, + } +} + +// Has retrieves if a prefixed version of a key is present in the database. +func (t *readerTable) Has(key []byte) (bool, error) { + return t.db.Has(append(t.prefix, key...)) +} + +// Get retrieves the given prefixed key if it's present in the database. +func (t *readerTable) Get(key []byte) ([]byte, error) { + return t.db.Get(append(t.prefix, key...)) +} + +// writerTable is a wrapper around a database that prefixes each key access with a pre- +// configured string. +type writerTable struct { + db ethdb.KeyValueWriter + prefix []byte +} + +// newReaderTable returns a database object that prefixes all keys with a given string. +func newWriterTable(db ethdb.KeyValueWriter, prefix []byte) ethdb.KeyValueWriter { + return &writerTable{ + db: db, + prefix: prefix, + } +} + +// Put inserts the given value into the database at a prefixed version of the +// provided key. +func (t *writerTable) Put(key []byte, value []byte) error { + return t.db.Put(append(t.prefix, key...), value) +} + +// Delete removes the given prefixed key from the database. +func (t writerTable) Delete(key []byte) error { + return t.db.Delete(append(t.prefix, key...)) +} diff --git a/les/tokensale_test.go b/les/tokensale_test.go new file mode 100644 index 000000000000..5c5db9d9b3d6 --- /dev/null +++ b/les/tokensale_test.go @@ -0,0 +1,157 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "math/rand" + "testing" +) + +func TestTokenPriceCalculation(t *testing.T) { + var totalLimit, totalAmount uint64 + ts := &tokenSale{ + basePrice: 1, + totalTokenLimit: func() uint64 { return totalLimit }, + totalTokenAmount: func() uint64 { return totalAmount }, + } + totalLimit = 1000000000000 + maxDiff := int64(totalLimit / 1000000) + // inaccuracy increases around both ends of the allowed token range + min := totalLimit / 100 + max := uint64(float64(totalLimit) * tokenSellMaxRatio) + for count := 0; count < 100000; count++ { + start := min + uint64(rand.Int63n(int64(max-min))) + stop := min + uint64(rand.Int63n(int64(max-min))) + if start > stop { + start, stop = stop, start + } + // buy (start-stop) tokens in two steps + mid := start + uint64(rand.Int63n(int64(stop-start+1))) + totalAmount = start + cost, ok := ts.tokenPrice(mid-start, true) + if !ok { + t.Fatalf("Failed to buy tokens") + } + totalAmount = mid + cost2, ok := ts.tokenPrice(stop-mid, true) + if !ok { + t.Fatalf("Failed to buy tokens") + } + cost += cost2 + + // sell the same amount of tokens in two steps + mid = start + uint64(rand.Int63n(int64(stop-start+1))) + totalAmount = stop + refund, ok := ts.tokenPrice(stop-mid, false) + if !ok { + t.Fatalf("Failed to sell tokens") + } + totalAmount = mid + refund2, ok := ts.tokenPrice(mid-start, false) + if !ok { + t.Fatalf("Failed to sell tokens") + } + refund += refund2 + ratio := (refund + 1) / (cost + 1) + if ratio < 0.999999 || ratio > 1.000001 { + t.Fatalf("Token selling price does not match buy cost") + } + + // buy tokens for the previously calculated price in two steps + pcost := cost * rand.Float64() + totalAmount = start + totalAmount += ts.tokenBuyAmount(pcost) + totalAmount += ts.tokenBuyAmount(cost - pcost) + + diff := int64(totalAmount - stop) + if diff > maxDiff || diff < -maxDiff { + t.Fatalf("Bought token amount mismatch") + } + + // sell tokens for the previously calculated price in two steps + pcost = cost * rand.Float64() + totalAmount = stop + soldAmount, ok := ts.tokenSellAmount(pcost) + if !ok { + t.Fatalf("Failed to sell tokens") + } + totalAmount -= soldAmount + soldAmount, ok = ts.tokenSellAmount(cost - pcost) + if !ok { + t.Fatalf("Failed to sell tokens") + } + totalAmount -= soldAmount + + diff = int64(totalAmount - start) + if diff > maxDiff || diff < -maxDiff { + t.Fatalf("Sold token amount mismatch") + } + } +} + +func TestSingleTokenPrice(t *testing.T) { + var totalLimit, totalAmount uint64 + ts := &tokenSale{ + basePrice: 1, + totalTokenLimit: func() uint64 { return totalLimit }, + totalTokenAmount: func() uint64 { return totalAmount }, + } + totalLimit = 1000000000000 + buyLimit := uint64(float64(totalLimit) * tokenSellMaxRatio) + for count := 0; count < 10000; count++ { + totalAmount = uint64(rand.Int63n(int64(buyLimit))) + relAmount := float64(totalAmount) / float64(totalLimit) + var expPrice, maxDiff float64 + if relAmount < 0.5 { + expPrice = relAmount * 2 + maxDiff = 0.001 + } else { + expPrice = 0.5 / (1 - relAmount) + maxDiff = 0.001 * expPrice + } + price, ok := ts.tokenPrice(1, true) + if !ok { + t.Fatalf("Failed to buy tokens") + } + if price < expPrice-maxDiff || price > expPrice+maxDiff { + t.Fatalf("Token price mismatch") + } + + price, ok = ts.tokenPrice(1, false) + if !ok { + t.Fatalf("Failed to sell tokens") + } + if price < expPrice-maxDiff || price > expPrice+maxDiff { + t.Fatalf("Token price mismatch") + } + + if relAmount > 0.01 { + amount := ts.tokenBuyAmount(expPrice * 100) + if amount < 99 || amount > 101 { + t.Fatalf("Bought token amount mismatch") + } + + amount, ok = ts.tokenSellAmount(expPrice * 100) + if !ok { + t.Fatalf("Failed to sell tokens") + } + if amount < 99 || amount > 101 { + t.Fatalf("Sold token amount mismatch") + } + } + } +} diff --git a/les/ulc_test.go b/les/ulc_test.go index 273c63e4bd5f..d26a782e2312 100644 --- a/les/ulc_test.go +++ b/les/ulc_test.go @@ -28,8 +28,8 @@ import ( "github.com/ethereum/go-ethereum/p2p/enode" ) -func TestULCAnnounceThresholdLes2(t *testing.T) { testULCAnnounceThreshold(t, 2) } func TestULCAnnounceThresholdLes3(t *testing.T) { testULCAnnounceThreshold(t, 3) } +func TestULCAnnounceThresholdLes4(t *testing.T) { testULCAnnounceThreshold(t, 4) } func testULCAnnounceThreshold(t *testing.T, protocol int) { // todo figure out why it takes fetcher so longer to fetcher the announced header. @@ -124,9 +124,9 @@ func connect(server *serverHandler, serverId enode.ID, client *clientHandler, pr return peer1, peer2, nil } -// newTestServerPeer creates server peer. +// newServerPeer creates server peer. func newTestServerPeer(t *testing.T, blocks int, protocol int) (*testServer, *enode.Node, func()) { - s, teardown := newServerEnv(t, blocks, protocol, nil, false, false, 0) + s, teardown := newServerEnv(t, blocks, protocol, nil, false, false, 0, false) key, err := crypto.GenerateKey() if err != nil { t.Fatal("generate key err:", err) diff --git a/p2p/discv5/net.go b/p2p/discv5/net.go index dd2ec3e9298f..2c27f61b14a7 100644 --- a/p2p/discv5/net.go +++ b/p2p/discv5/net.go @@ -22,12 +22,14 @@ import ( "errors" "fmt" "net" + "sync" "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/rlp" "golang.org/x/crypto/sha3" @@ -64,12 +66,17 @@ type Network struct { refreshResp chan (<-chan struct{}) // ...and get the channel to block on from this one read chan ingressPacket // ingress packets arrive here timeout chan timeoutEvent - queryReq chan *findnodeQuery // lookups submit findnode queries on this channel + queryReq chan deferredQuery // lookups submit findnode queries on this channel tableOpReq chan func() tableOpResp chan struct{} topicRegisterReq chan topicRegisterReq topicSearchReq chan topicSearchReq + talkRequestSubLock sync.RWMutex + talkResponseSubLock sync.Mutex + talkRequestSubs map[string]TalkRequestHandler + talkResponseSubs map[string]TalkResponseHandler + // State of the main loop. tab *Table topictab *topicTable @@ -79,6 +86,21 @@ type Network struct { timeoutTimers map[timeoutEvent]*time.Timer } +type ( + // TalkRequestHandler processes an incoming talkRequest. If ok is true then response is + // sent back, along with the recommended delay feedback. If there is no urgent need to + // communicate (for example, in case of regular polling) then the sender should wait the + // given amount of seconds before sending the next request. If delay recommendation is + // disobeyed too many times by a sender then the handler can stop responding until the + // last recommended delay has elapsed. + TalkRequestHandler func(id enode.ID, addr *net.UDPAddr, request interface{}) (response interface{}, delay uint, ok bool) + // TalkResponseHandler is registered by the sender of each talkRequest. If the request + // is canceled the handler is still called with a nil parameter. If too many responses + // from a peer are considered invalid by their handlers then the peer goes into + // contested state. + TalkResponseHandler func(response interface{}, delay uint) (valid bool) +) + // transport is implemented by the UDP transport. // it is an interface so we can test without opening lots of UDP // sockets and without generating a private key. @@ -101,6 +123,14 @@ type findnodeQuery struct { reply chan<- []*Node } +type talkQuery struct { + remote *Node + talkID string + payload interface{} + key string + handler TalkResponseHandler +} + type topicRegisterReq struct { add bool topic Topic @@ -151,10 +181,12 @@ func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, dbPath string, netres timeoutTimers: make(map[timeoutEvent]*time.Timer), tableOpReq: make(chan func()), tableOpResp: make(chan struct{}), - queryReq: make(chan *findnodeQuery), + queryReq: make(chan deferredQuery), topicRegisterReq: make(chan topicRegisterReq), topicSearchReq: make(chan topicSearchReq), nodes: make(map[NodeID]*Node), + talkRequestSubs: make(map[string]TalkRequestHandler), + talkResponseSubs: make(map[string]TalkResponseHandler), } go net.loop() return net, nil @@ -410,7 +442,6 @@ loop: // Ingress packet handling. case pkt := <-net.read: - //fmt.Println("read", pkt.ev) log.Trace("<-net.read") n := net.internNode(&pkt) prestate := n.state @@ -446,7 +477,7 @@ loop: case q := <-net.queryReq: log.Trace("<-net.queryReq") if !q.start(net) { - q.remote.deferQuery(q) + q.deferQuery() } // Interacting with the table. @@ -767,14 +798,21 @@ type nodeNetGuts struct { // State machine fields. Access to these fields // is restricted to the Network.loop goroutine. state *nodeState - pingEcho []byte // hash of last ping sent by us - pingTopics []Topic // topic set sent by us in last ping - deferredQueries []*findnodeQuery // queries that can't be sent yet - pendingNeighbours *findnodeQuery // current query, waiting for reply + pingEcho []byte // hash of last ping sent by us + pingTopics []Topic // topic set sent by us in last ping + deferredQueries []deferredQuery // queries that can't be sent yet + pendingNeighbours *findnodeQuery // current query, waiting for reply queryTimeouts int + talkFailures int } -func (n *nodeNetGuts) deferQuery(q *findnodeQuery) { +type deferredQuery interface { + start(net *Network) bool + cancel() + deferQuery() +} + +func (n *nodeNetGuts) deferQuery(q deferredQuery) { n.deferredQueries = append(n.deferredQueries, q) } @@ -810,6 +848,14 @@ func (q *findnodeQuery) start(net *Network) bool { return false } +func (q *findnodeQuery) cancel() { + q.reply <- nil +} + +func (q *findnodeQuery) deferQuery() { + q.remote.deferQuery(q) +} + // Node Events (the input to the state machine). type nodeEvent uint @@ -828,6 +874,8 @@ const ( topicRegisterPacket topicQueryPacket topicNodesPacket + talkRequestPacket + talkResponsePacket // Non-packet events. // Event values in this category are allocated outside @@ -835,6 +883,7 @@ const ( pongTimeout nodeEvent = iota + 256 pingTimeout neighboursTimeout + talkTimeout ) // Node State Machine. @@ -868,7 +917,7 @@ func init() { n.pingEcho = nil // Abort active queries. for _, q := range n.deferredQueries { - q.reply <- nil + q.cancel() } n.deferredQueries = nil if n.pendingNeighbours != nil { @@ -1021,7 +1070,7 @@ func (net *Network) handle(n *Node, ev nodeEvent, pkt *ingressPacket) error { //fmt.Println("handle", n.addr().String(), n.state, ev) if pkt != nil { if err := net.checkPacket(n, ev, pkt); err != nil { - //fmt.Println("check err:", err) + //fmt.Println("check err:", err, pkt) return err } // Start the background expiration goroutine after the first @@ -1036,6 +1085,7 @@ func (net *Network) handle(n *Node, ev nodeEvent, pkt *ingressPacket) error { if n.state == nil { n.state = unknown //??? } + //fmt.Println("old state:", n.state) next, err := n.state.handle(net, n, ev, pkt) net.transition(n, next) //fmt.Println("new state:", n.state) @@ -1196,7 +1246,51 @@ func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket) } } return n.state, nil - + case talkRequestPacket: + p := pkt.data.(*talkRequest) + net.talkRequestSubLock.RLock() + subFn := net.talkRequestSubs[string(p.TalkID)] + net.talkRequestSubLock.RUnlock() + if subFn != nil { + resp, delay, ok := subFn(enode.ID(n.sha), n.addr(), p.Payload) + if ok { + net.conn.send(n, talkResponsePacket, talkResponse{ReplyTok: pkt.hash, Delay: delay, Payload: resp}) + } else { + n.talkFailures++ + } + } else { + n.talkFailures++ + } + if n.talkFailures > maxTalkFailures && n.state == known { + return contested, errors.New("too many talk failures") + } + return n.state, nil + case talkResponsePacket: + p := pkt.data.(*talkResponse) + net.talkResponseSubLock.Lock() + key := string(n.sha[:]) + string(p.ReplyTok) + subFn := net.talkResponseSubs[key] + if subFn != nil { + delete(net.talkResponseSubs, key) + } + net.talkResponseSubLock.Unlock() + if subFn == nil || !subFn(p.Payload, p.Delay) { + n.talkFailures++ + if n.talkFailures > maxTalkFailures && n.state == known { + return contested, errors.New("too many talk failures") + } + } + return n.state, nil + case talkTimeout: + if n.pendingNeighbours != nil { + n.pendingNeighbours.reply <- nil + n.pendingNeighbours = nil + } + n.queryTimeouts++ + if n.queryTimeouts > maxFindnodeFailures && n.state == known { + return contested, errors.New("too many timeouts") + } + return n.state, nil default: return n.state, errInvalidEvent } @@ -1260,3 +1354,82 @@ func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error { n.startNextQuery(net) return nil } + +// RegisterTalkHandler assigns a handler callback to the given talk ID +func (net *Network) RegisterTalkHandler(talkID string, handler TalkRequestHandler) { + net.talkRequestSubLock.Lock() + net.talkRequestSubs[talkID] = handler + net.talkRequestSubLock.Unlock() +} + +// RemoveTalkHandler removes the handler assigned to the given talk ID +func (net *Network) RemoveTalkHandler(talkID string) { + net.talkRequestSubLock.Lock() + delete(net.talkRequestSubs, talkID) + net.talkRequestSubLock.Unlock() +} + +func (q *talkQuery) start(net *Network) bool { + if q.remote == net.tab.self { + return false + } + if q.remote.state == known { + net.talkResponseSubLock.Lock() + hash := net.conn.send(q.remote, talkRequestPacket, talkRequest{TalkID: []byte(q.talkID), Payload: q.payload}) + q.key = string(q.remote.sha[:]) + string(hash[:]) + net.talkResponseSubs[q.key] = q.handler + net.talkResponseSubLock.Unlock() + return true + } + // If the node is not known yet, it won't accept queries. + // Initiate the transition to known. + // The request will be sent later when the node reaches known state. + if q.remote.state == unknown { + net.transition(q.remote, verifyinit) + } + return false +} + +func (q *talkQuery) cancel() { + q.handler(nil, 0) +} + +func (q *talkQuery) deferQuery() { + q.remote.deferQuery(q) +} + +// SendTalkRequest sends a talRequest and registers a response handler. It returns +// a cancel function that removes the response handler and calls it with a nil parameter +// if the response has not arrived yet. +func (net *Network) SendTalkRequest(to *enode.Node, talkID string, payload interface{}, handler TalkResponseHandler) (cancel func() bool) { + var nodeID NodeID + copy(nodeID[:], crypto.FromECDSAPub(to.Pubkey())[1:]) + node := net.nodes[nodeID] + if node == nil { + node = NewNode(nodeID, to.IP(), uint16(to.UDP()), uint16(to.TCP())) + node.state = unknown + net.nodes[nodeID] = node + } + q := &talkQuery{remote: node, talkID: talkID, payload: payload, handler: handler} + if node.state == nil || !node.state.canQuery { + net.ping(node, node.addr()) + } + select { + case net.queryReq <- q: + case <-net.closed: + return nil + } + + return func() bool { + net.talkResponseSubLock.Lock() + cancel := q.key != "" && net.talkResponseSubs[q.key] != nil + if cancel { + delete(net.talkResponseSubs, q.key) + } + net.talkResponseSubLock.Unlock() + if cancel { + handler(nil, 0) + } + return cancel + } +} diff --git a/p2p/discv5/table.go b/p2p/discv5/table.go index 64c3ecd1c7b2..23b68e56928a 100644 --- a/p2p/discv5/table.go +++ b/p2p/discv5/table.go @@ -35,6 +35,7 @@ const ( nBuckets = hashBits + 1 // Number of buckets maxFindnodeFailures = 5 + maxTalkFailures = 5 ) type Table struct { diff --git a/p2p/discv5/udp.go b/p2p/discv5/udp.go index 088f95cac6a2..a90c70b312c9 100644 --- a/p2p/discv5/udp.go +++ b/p2p/discv5/udp.go @@ -119,6 +119,17 @@ type ( Nodes []rpcNode } + talkRequest struct { + TalkID []byte + Payload interface{} + } + + talkResponse struct { + ReplyTok []byte + Delay uint + Payload interface{} + } + rpcNode struct { IP net.IP // len 4 for IPv4 or 16 for IPv6 UDP uint16 // for discovery protocol @@ -420,6 +431,10 @@ func decodePacket(buffer []byte, pkt *ingressPacket) error { pkt.data = new(topicQuery) case topicNodesPacket: pkt.data = new(topicNodes) + case talkRequestPacket: + pkt.data = new(talkRequest) + case talkResponsePacket: + pkt.data = new(talkResponse) default: return fmt.Errorf("unknown packet type: %d", sigdata[0]) } diff --git a/p2p/enode/node.go b/p2p/enode/node.go index 9eb2544ffe14..3f6cda6d4a21 100644 --- a/p2p/enode/node.go +++ b/p2p/enode/node.go @@ -217,7 +217,7 @@ func (n ID) MarshalText() ([]byte, error) { // UnmarshalText implements the encoding.TextUnmarshaler interface. func (n *ID) UnmarshalText(text []byte) error { - id, err := parseID(string(text)) + id, err := ParseID(string(text)) if err != nil { return err } @@ -229,14 +229,14 @@ func (n *ID) UnmarshalText(text []byte) error { // The string may be prefixed with 0x. // It panics if the string is not a valid ID. func HexID(in string) ID { - id, err := parseID(in) + id, err := ParseID(in) if err != nil { panic(err) } return id } -func parseID(in string) (ID, error) { +func ParseID(in string) (ID, error) { var id ID b, err := hex.DecodeString(strings.TrimPrefix(in, "0x")) if err != nil {