diff --git a/cmd/geth/main.go b/cmd/geth/main.go
index 99ef78238feb..70bcf7a96209 100644
--- a/cmd/geth/main.go
+++ b/cmd/geth/main.go
@@ -101,6 +101,7 @@ var (
utils.UltraLightServersFlag,
utils.UltraLightFractionFlag,
utils.UltraLightOnlyAnnounceFlag,
+ utils.LespayTestModuleFlag,
utils.WhitelistFlag,
utils.CacheFlag,
utils.CacheDatabaseFlag,
diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go
index 6f3197b9c6d3..4d98a8e372af 100644
--- a/cmd/geth/usage.go
+++ b/cmd/geth/usage.go
@@ -94,6 +94,7 @@ var AppHelpFlagGroups = []flagGroup{
utils.UltraLightServersFlag,
utils.UltraLightFractionFlag,
utils.UltraLightOnlyAnnounceFlag,
+ utils.LespayTestModuleFlag,
},
},
{
diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go
index bdadebd852f4..100202479221 100644
--- a/cmd/utils/flags.go
+++ b/cmd/utils/flags.go
@@ -272,6 +272,10 @@ var (
Usage: "Maximum number of light clients to serve, or light servers to attach to",
Value: eth.DefaultConfig.LightPeers,
}
+ LespayTestModuleFlag = cli.BoolFlag{
+ Name: "lespay.testmodule",
+ Usage: "Enable dummy payment module (for testing only)",
+ }
UltraLightServersFlag = cli.StringFlag{
Name: "ulc.servers",
Usage: "List of trusted ultra-light servers",
@@ -1009,6 +1013,9 @@ func setLes(ctx *cli.Context, cfg *eth.Config) {
if ctx.GlobalIsSet(UltraLightOnlyAnnounceFlag.Name) {
cfg.UltraLightOnlyAnnounce = ctx.GlobalBool(UltraLightOnlyAnnounceFlag.Name)
}
+ if ctx.GlobalIsSet(LespayTestModuleFlag.Name) {
+ cfg.LespayTestModule = true
+ }
}
// makeDatabaseHandles raises out the number of allowed file handles per process
diff --git a/eth/config.go b/eth/config.go
index 2eaf21fbc30c..4a2ea18ad1a8 100644
--- a/eth/config.go
+++ b/eth/config.go
@@ -116,6 +116,9 @@ type Config struct {
UltraLightFraction int `toml:",omitempty"` // Percentage of trusted servers to accept an announcement
UltraLightOnlyAnnounce bool `toml:",omitempty"` // Whether to only announce headers, or also serve them
+ // Light client payment options
+ LespayTestModule bool
+
// Database options
SkipBcVersionCheck bool `toml:"-"`
DatabaseHandles int `toml:"-"`
diff --git a/eth/gen_config.go b/eth/gen_config.go
index 1c659c393ca7..eb889e3aa8c3 100644
--- a/eth/gen_config.go
+++ b/eth/gen_config.go
@@ -32,6 +32,7 @@ func (c Config) MarshalTOML() (interface{}, error) {
UltraLightServers []string `toml:",omitempty"`
UltraLightFraction int `toml:",omitempty"`
UltraLightOnlyAnnounce bool `toml:",omitempty"`
+ LespayTestModule bool `toml:"-"`
SkipBcVersionCheck bool `toml:"-"`
DatabaseHandles int `toml:"-"`
DatabaseCache int
@@ -68,6 +69,7 @@ func (c Config) MarshalTOML() (interface{}, error) {
enc.UltraLightServers = c.UltraLightServers
enc.UltraLightFraction = c.UltraLightFraction
enc.UltraLightOnlyAnnounce = c.UltraLightOnlyAnnounce
+ enc.LespayTestModule = c.LespayTestModule
enc.SkipBcVersionCheck = c.SkipBcVersionCheck
enc.DatabaseHandles = c.DatabaseHandles
enc.DatabaseCache = c.DatabaseCache
@@ -108,6 +110,7 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error {
UltraLightServers []string `toml:",omitempty"`
UltraLightFraction *int `toml:",omitempty"`
UltraLightOnlyAnnounce *bool `toml:",omitempty"`
+ LespayTestModule *bool `toml:"-"`
SkipBcVersionCheck *bool `toml:"-"`
DatabaseHandles *int `toml:"-"`
DatabaseCache *int
@@ -175,6 +178,9 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error {
if dec.UltraLightOnlyAnnounce != nil {
c.UltraLightOnlyAnnounce = *dec.UltraLightOnlyAnnounce
}
+ if dec.LespayTestModule != nil {
+ c.LespayTestModule = *dec.LespayTestModule
+ }
if dec.SkipBcVersionCheck != nil {
c.SkipBcVersionCheck = *dec.SkipBcVersionCheck
}
diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go
index bc105ef37c28..38cdf42d627e 100644
--- a/internal/web3ext/web3ext.go
+++ b/internal/web3ext/web3ext.go
@@ -33,6 +33,7 @@ var Modules = map[string]string{
"swarmfs": SwarmfsJs,
"txpool": TxpoolJs,
"les": LESJs,
+ "lespay": LESPAYJs,
}
const ChequebookJs = `
@@ -856,3 +857,50 @@ web3._extend({
]
});
`
+
+const LESPAYJs = `
+web3._extend({
+ property: 'lespay',
+ methods:
+ [
+ new web3._extend.Method({
+ name: 'connection',
+ call: 'lespay_connection',
+ params: 6
+ }),
+ new web3._extend.Method({
+ name: 'deposit',
+ call: 'lespay_deposit',
+ params: 4
+ }),
+ new web3._extend.Method({
+ name: 'buyTokens',
+ call: 'lespay_buyTokens',
+ params: 6
+ }),
+ new web3._extend.Method({
+ name: 'buyTokens',
+ call: 'lespay_sellTokens',
+ params: 6
+ }),
+ new web3._extend.Method({
+ name: 'getBalance',
+ call: 'lespay_getBalance',
+ params: 2
+ }),
+ new web3._extend.Method({
+ name: 'info',
+ call: 'lespay_info',
+ params: 2
+ }),
+ new web3._extend.Method({
+ name: 'receiverInfo',
+ call: 'lespay_receiverInfo',
+ params: 3
+ }),
+ ],
+ properties:
+ [
+ ]
+});
+`
diff --git a/les/api.go b/les/api.go
index f9b8c34458b4..f60a4fa57db3 100644
--- a/les/api.go
+++ b/les/api.go
@@ -17,14 +17,16 @@
package les
import (
+ "context"
"errors"
"fmt"
- "math"
"time"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/common/mclock"
+ "github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/rlp"
)
var (
@@ -35,8 +37,6 @@ var (
errNoPriority = errors.New("priority too low to raise capacity")
)
-const maxBalance = math.MaxInt64
-
// PrivateLightServerAPI provides an API to access the LES light server.
type PrivateLightServerAPI struct {
server *LesServer
@@ -104,13 +104,17 @@ func (api *PrivateLightServerAPI) clientInfo(c *clientInfo, id enode.ID) map[str
info["capacity"] = c.capacity
pb, nb := c.balanceTracker.getBalance(now)
info["pricing/balance"], info["pricing/negBalance"] = pb, nb
- info["pricing/balanceMeta"] = c.balanceMetaInfo
- info["priority"] = pb != 0
+
+ cb := api.server.clientPool.ndb.getCurrencyBalance(id)
+ info["pricing/currency"] = cb.amount
+ info["priority"] = pb.base != 0
} else {
info["isConnected"] = false
- pb := api.server.clientPool.ndb.getOrNewPB(id)
- info["pricing/balance"], info["pricing/balanceMeta"] = pb.value, pb.meta
- info["priority"] = pb.value != 0
+ pb := api.server.clientPool.ndb.getOrNewBalance(id.Bytes(), false)
+
+ cb := api.server.clientPool.ndb.getCurrencyBalance(id)
+ info["pricing/balance"], info["pricing/currency"] = pb.value, cb.amount
+ info["priority"] = pb.value.base != 0
}
return info
}
@@ -150,7 +154,7 @@ func (api *PrivateLightServerAPI) setParams(params map[string]interface{}, clien
setFactor(&negFactors.requestFactor)
case !defParams && name == "capacity":
if capacity, ok := value.(float64); ok && uint64(capacity) >= api.server.minCapacity {
- err = api.server.clientPool.setCapacity(client, uint64(capacity))
+ _, _, err = api.server.clientPool.setCapacity(client.id, client.address, uint64(capacity), 0, true)
// Don't have to call factor update explicitly. It's already done
// in setCapacity function.
} else {
@@ -172,8 +176,8 @@ func (api *PrivateLightServerAPI) setParams(params map[string]interface{}, clien
// AddBalance updates the balance of a client (either overwrites it or adds to it).
// It also updates the balance meta info string.
-func (api *PrivateLightServerAPI) AddBalance(id enode.ID, value int64, meta string) ([2]uint64, error) {
- oldBalance, newBalance, err := api.server.clientPool.addBalance(id, value, meta)
+func (api *PrivateLightServerAPI) AddBalance(id enode.ID, value int64) ([2]uint64, error) {
+ oldBalance, newBalance, err := api.server.clientPool.addBalance(id, value)
return [2]uint64{oldBalance, newBalance}, err
}
@@ -184,7 +188,7 @@ func (api *PrivateLightServerAPI) SetClientParams(ids []enode.ID, params map[str
if client != nil {
update, err := api.setParams(params, client, nil, nil)
if update {
- client.updatePriceFactors()
+ updatePriceFactors(&client.balanceTracker, client.posFactors, client.negFactors)
}
return err
} else {
@@ -296,7 +300,7 @@ func (api *PrivateDebugAPI) FreezeClient(id enode.ID) error {
if c == nil {
return fmt.Errorf("client %064x is not connected", id[:])
}
- c.peer.freezeClient()
+ c.peer.freeze()
return nil
})
}
@@ -352,3 +356,222 @@ func (api *PrivateLightAPI) GetCheckpointContractAddress() (string, error) {
}
return api.backend.oracle.Contract().ContractAddr().Hex(), nil
}
+
+// PrivateLespayAPI provides an API to use the LESpay commands of either the local or a remote server
+type PrivateLespayAPI struct {
+ clientPeerSet *clientPeerSet
+ serverPeerSet *serverPeerSet
+ clientHandler *clientHandler
+ dht *discv5.Network
+ tokenSale *tokenSale
+}
+
+// NewPrivateLespayAPI creates a new LESPAY API.
+func NewPrivateLespayAPI(clientPeerSet *clientPeerSet, serverPeerSet *serverPeerSet, clientHandler *clientHandler, dht *discv5.Network, tokenSale *tokenSale) *PrivateLespayAPI {
+ return &PrivateLespayAPI{
+ clientPeerSet: clientPeerSet,
+ serverPeerSet: serverPeerSet,
+ clientHandler: clientHandler,
+ dht: dht,
+ tokenSale: tokenSale,
+ }
+}
+
+// makeCall sends an encoded command to either the local or a remote server and returns the encoded reply
+//
+// Note: nodeStr can represent either the node ID of a connected node or the full enode of any remote node.
+// If remote is true then the command is sent to the specified node. It is sent through LES if it was specified
+// with node ID, throush UDP talk otherwise.
+// If remote is false then the command is executed locally, with the specified remote node assumed as sender.
+func (api *PrivateLespayAPI) makeCall(ctx context.Context, remote bool, nodeStr string, cmd []byte) ([]byte, error) {
+ var (
+ id enode.ID
+ freeID string
+ clientPeer *clientPeer
+ serverPeer *serverPeer
+ node *enode.Node
+ err error
+ )
+ if nodeStr != "" {
+ if id, err = enode.ParseID(nodeStr); err == nil {
+ if api.clientPeerSet != nil {
+ if clientPeer = api.clientPeerSet.peer(peerIdToString(id)); clientPeer == nil {
+ return nil, errors.New("peer not connected")
+ }
+ freeID = clientPeer.freeClientId()
+ } else {
+ if serverPeer = api.serverPeerSet.peer(peerIdToString(id)); serverPeer == nil {
+ return nil, errors.New("peer not connected")
+ }
+ }
+ } else {
+ var err error
+ if node, err = enode.Parse(enode.ValidSchemes, nodeStr); err == nil {
+ id = node.ID()
+ freeID = node.IP().String()
+ } else {
+ return nil, err
+ }
+ }
+ }
+
+ if remote {
+ var (
+ reply []byte
+ cancelFn func() bool
+ )
+ delivered := make(chan struct{})
+ if serverPeer != nil {
+ // remote call to a connected peer through LES
+ if api.clientHandler == nil {
+ return nil, errors.New("client handler not available")
+ }
+ cancelFn = api.clientHandler.makeLespayCall(serverPeer, cmd, func(r []byte, delay uint) bool {
+ reply = r
+ close(delivered)
+ return reply != nil
+ })
+ } else {
+ // remote call through UDP TALK
+ if api.dht == nil {
+ return nil, errors.New("UDP DHT not available")
+ }
+ cancelFn = api.dht.SendTalkRequest(node, "lespay", [][]byte{cmd}, func(payload interface{}, delay uint) bool {
+ if replies, ok := payload.([]interface{}); ok && len(replies) == 1 {
+ reply, _ = replies[0].([]byte)
+ }
+ close(delivered)
+ return reply != nil
+ })
+ }
+ select {
+ case <-time.After(time.Second * 5):
+ cancelFn()
+ return nil, errors.New("timeout")
+ case <-ctx.Done():
+ cancelFn()
+ return nil, ctx.Err()
+ case <-delivered:
+ if len(reply) == 0 {
+ return nil, errors.New("unknown command")
+ }
+ return reply, nil
+ }
+ } else {
+ if api.tokenSale == nil {
+ return nil, errors.New("token sale module not available")
+ }
+ // execute call locally
+ return api.tokenSale.runCommand(cmd, id, freeID), nil
+ }
+
+}
+
+// Connection checks whether it is possible with the current balance levels to establish
+// requested connection or capacity change and then stay connected for the given amount
+// of time. If it is possible and setCap is also true then the client is activated of the
+// capacity change is performed. If not then returns how many tokens are missing and how
+// much that would currently cost using the specified payment module(s).
+func (api *PrivateLespayAPI) Connection(ctx context.Context, remote bool, node string, requestedCapacity, stayConnected uint64, paymentModule []string, setCap bool) (results tsConnectionResults, err error) {
+ params := tsConnectionParams{requestedCapacity, stayConnected, paymentModule, setCap}
+ enc, _ := rlp.EncodeToBytes(¶ms)
+ var resEnc []byte
+ resEnc, err = api.makeCall(ctx, remote, node, append([]byte{tsConnection}, enc...))
+ if err != nil {
+ return
+ }
+ err = rlp.DecodeBytes(resEnc, &results)
+ return
+}
+
+// Deposit credits a payment on the sender's account using the specified payment module
+func (api *PrivateLespayAPI) Deposit(ctx context.Context, remote bool, node, paymentModule, proofOfPayment string) (results tsDepositResults, err error) {
+ var proof []byte
+ if proof, err = hexutil.Decode(proofOfPayment); err != nil {
+ return
+ }
+ params := tsDepositParams{paymentModule, proof}
+ enc, _ := rlp.EncodeToBytes(¶ms)
+ var resEnc []byte
+ resEnc, err = api.makeCall(ctx, remote, node, append([]byte{tsDeposit}, enc...))
+ if err != nil {
+ return
+ }
+ err = rlp.DecodeBytes(resEnc, &results)
+ return
+}
+
+// BuyTokens tries to convert the permanent balance (nominated in the server's preferred
+// currency, PC) to service tokens. If spendAll is true then it sells the maxSpend amount
+// of PC coins if the received service token amount is at least minReceive. If spendAll is
+// false then is buys minReceive amount of tokens if it does not cost more than maxSpend
+// amount of PC coins.
+// if relative is true then maxSpend and minReceive are specified relative to their current
+// balances. In this case maxSpend represents the amount under which the PC balance should
+// not go and minReceive represents the amount the service token balance should reach.
+// This mode is useful when actual conversion is intended to happen and the sender has to
+// retry the command after not receiving a reply previously. In this case the sender cannot
+// be sure whether the conversion has already happened or not. If relative is true then it
+// is impossible to do a conversion twice. In exchange the sender needs to know its current
+// balances (which it probably does if it has made a previous call to just ask the current price).
+func (api *PrivateLespayAPI) BuyTokens(ctx context.Context, remote bool, node string, maxSpend, minReceive uint64, relative, spendAll bool) (results tsBuyTokensResults, err error) {
+ params := tsBuyTokensParams{maxSpend, minReceive, relative, spendAll}
+ enc, _ := rlp.EncodeToBytes(¶ms)
+ var resEnc []byte
+ resEnc, err = api.makeCall(ctx, remote, node, append([]byte{tsBuyTokens}, enc...))
+ if err != nil {
+ return
+ }
+ err = rlp.DecodeBytes(resEnc, &results)
+ return
+}
+
+// SellTokens tries to convert service tokens to permanent balance (nominated in the server's
+// preferred currency, PC). Parameters work similarly to BuyTokens.
+func (api *PrivateLespayAPI) SellTokens(ctx context.Context, remote bool, node string, maxSell, minRefund uint64, relative, sellAll bool) (results tsSellTokensResults, err error) {
+ params := tsSellTokensParams{maxSell, minRefund, relative, sellAll}
+ enc, _ := rlp.EncodeToBytes(¶ms)
+ var resEnc []byte
+ resEnc, err = api.makeCall(ctx, remote, node, append([]byte{tsSellTokens}, enc...))
+ if err != nil {
+ return
+ }
+ err = rlp.DecodeBytes(resEnc, &results)
+ return
+}
+
+// GetBalance returns the current PC balance and service token balance
+func (api *PrivateLespayAPI) GetBalance(ctx context.Context, remote bool, node string) (results tsGetBalanceResults, err error) {
+ var resEnc []byte
+ resEnc, err = api.makeCall(ctx, remote, node, []byte{tsGetBalance})
+ if err != nil {
+ return
+ }
+ err = rlp.DecodeBytes(resEnc, &results)
+ return
+}
+
+// Info returns general information about the server, including version info of the
+// lespay command set, supported payment modules and token expiration time constant
+func (api *PrivateLespayAPI) Info(ctx context.Context, remote bool, node string) (results tsInfoApiResults, err error) {
+ var resEnc []byte
+ resEnc, err = api.makeCall(ctx, remote, node, []byte{tsInfo})
+ if err != nil {
+ return
+ }
+ err = rlp.DecodeBytes(resEnc, &results)
+ return
+}
+
+// ReceiverInfo returns information about the specified payment receiver(s) if supported
+func (api *PrivateLespayAPI) ReceiverInfo(ctx context.Context, remote bool, node string, receiverIDs []string) (results tsReceiverInfoApiResults, err error) {
+ params := tsReceiverInfoParams(receiverIDs)
+ enc, _ := rlp.EncodeToBytes(¶ms)
+ var resEnc []byte
+ resEnc, err = api.makeCall(ctx, remote, node, append([]byte{tsReceiverInfo}, enc...))
+ if err != nil {
+ return
+ }
+ err = rlp.DecodeBytes(resEnc, &results)
+ return
+}
diff --git a/les/balance.go b/les/balance.go
index 51cef15c803d..97fc0d76eef7 100644
--- a/les/balance.go
+++ b/les/balance.go
@@ -17,29 +17,56 @@
package les
import (
+ "math"
"sync"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
)
+const maxBalance = math.MaxInt64
+
const (
balanceCallbackQueue = iota
balanceCallbackZero
balanceCallbackCount
)
+// expirationController controls the exponential expiration of positive and negative
+// balances
+type expirationController interface {
+ posExpiration(mclock.AbsTime) fixed64
+ negExpiration(mclock.AbsTime) fixed64
+}
+
+// priceFactors determine the pricing policy (may apply either to positive or
+// negative balances which may have different factors).
+// - timeFactor is cost unit per nanosecond of connection time
+// - capacityFactor is cost unit per nanosecond of connection time per 1000000 capacity
+// - requestFactor is cost unit per request "realCost" unit
+type priceFactors struct {
+ timeFactor, capacityFactor, requestFactor float64
+}
+
+func (p priceFactors) timePrice(cap uint64) float64 {
+ return p.timeFactor + float64(cap)*p.capacityFactor/1000000
+}
+
+func (p priceFactors) reqPrice() float64 {
+ return p.requestFactor
+}
+
// balanceTracker keeps track of the positive and negative balances of a connected
// client and calculates actual and projected future priority values required by
// prque.LazyQueue.
type balanceTracker struct {
lock sync.Mutex
clock mclock.Clock
+ exp expirationController
stopped bool
capacity uint64
balance balance
- timeFactor, requestFactor float64
- negTimeFactor, negRequestFactor float64
+ posFactor, negFactor priceFactors
sumReqCost uint64
lastUpdate, nextUpdate, initTime mclock.AbsTime
updateEvent mclock.Timer
@@ -53,7 +80,7 @@ type balanceTracker struct {
// balance represents a pair of positive and negative balances
type balance struct {
- pos, neg uint64
+ pos, neg expiredValue
}
// balanceCallback represents a single callback that is activated when client priority
@@ -65,6 +92,7 @@ type balanceCallback struct {
}
// init initializes balanceTracker
+// Note: capacity should never be zero
func (bt *balanceTracker) init(clock mclock.Clock, capacity uint64) {
bt.clock = clock
bt.initTime, bt.lastUpdate = clock.Now(), clock.Now() // Init timestamps
@@ -81,10 +109,8 @@ func (bt *balanceTracker) stop(now mclock.AbsTime) {
bt.stopped = true
bt.addBalance(now)
- bt.negTimeFactor = 0
- bt.negRequestFactor = 0
- bt.timeFactor = 0
- bt.requestFactor = 0
+ bt.posFactor = priceFactors{0, 0, 0}
+ bt.negFactor = priceFactors{0, 0, 0}
if bt.updateEvent != nil {
bt.updateEvent.Stop()
bt.updateEvent = nil
@@ -95,10 +121,43 @@ func (bt *balanceTracker) stop(now mclock.AbsTime) {
// first to disconnect. Positive balance translates to negative priority. If positive
// balance is zero then negative balance translates to a positive priority.
func (bt *balanceTracker) balanceToPriority(b balance) int64 {
- if b.pos > 0 {
- return ^int64(b.pos / bt.capacity)
+ if b.pos.base > 0 {
+ return -int64(b.pos.value(bt.exp.posExpiration(bt.clock.Now())) / bt.capacity)
}
- return int64(b.neg)
+ return int64(b.neg.value(bt.exp.negExpiration(bt.clock.Now())))
+}
+
+// posBalanceMissing calculates the missing amount of positive balance in order to
+// connect at targetCapacity, stay connected for the given amount of time and then
+// still have a priority of targetPriority
+func (bt *balanceTracker) posBalanceMissing(targetPriority int64, targetCapacity uint64, after time.Duration) uint64 {
+ now := bt.clock.Now()
+ if targetPriority > 0 {
+ timePrice := bt.negFactor.timePrice(targetCapacity)
+ timeCost := uint64(float64(after) * timePrice)
+ negBalance := bt.balance.neg.value(bt.exp.negExpiration(now))
+ if timeCost+negBalance < uint64(targetPriority) {
+ return 0
+ }
+ if uint64(targetPriority) > negBalance && timePrice > 1e-100 {
+ if negTime := time.Duration(float64(uint64(targetPriority)-negBalance) / timePrice); negTime < after {
+ after -= negTime
+ } else {
+ after = 0
+ }
+ }
+ targetPriority = 0
+ }
+ timePrice := bt.posFactor.timePrice(targetCapacity)
+ posRequired := uint64(float64(-targetPriority)*float64(targetCapacity)+float64(after)*timePrice) + 1
+ if posRequired >= maxBalance {
+ return math.MaxUint64 // target not reachable
+ }
+ posBalance := bt.balance.pos.value(bt.exp.posExpiration(now))
+ if posRequired > posBalance {
+ return posRequired - posBalance
+ }
+ return 0
}
// reducedBalance estimates the reduced balance at a given time in the fututre based
@@ -106,20 +165,19 @@ func (bt *balanceTracker) balanceToPriority(b balance) int64 {
func (bt *balanceTracker) reducedBalance(at mclock.AbsTime, avgReqCost float64) balance {
dt := float64(at - bt.lastUpdate)
b := bt.balance
- if b.pos != 0 {
- factor := bt.timeFactor + bt.requestFactor*avgReqCost
- diff := uint64(dt * factor)
- if diff <= b.pos {
- b.pos -= diff
+ if b.pos.base != 0 {
+ factor := bt.posFactor.timePrice(bt.capacity) + bt.posFactor.reqPrice()*avgReqCost
+ diff := -int64(dt * factor)
+ dd := b.pos.add(diff, bt.exp.posExpiration(at))
+ if dd == diff {
dt = 0
} else {
- dt -= float64(b.pos) / factor
- b.pos = 0
+ dt += float64(dd) / factor
}
}
- if dt != 0 {
- factor := bt.negTimeFactor + bt.negRequestFactor*avgReqCost
- b.neg += uint64(dt * factor)
+ if dt > 0 {
+ factor := bt.negFactor.timePrice(bt.capacity) + bt.negFactor.reqPrice()*avgReqCost
+ b.neg.add(int64(dt*factor), bt.exp.negExpiration(at))
}
return b
}
@@ -130,20 +188,23 @@ func (bt *balanceTracker) reducedBalance(at mclock.AbsTime, avgReqCost float64)
// Note: the function assumes that the balance has been recently updated and
// calculates the time starting from the last update.
func (bt *balanceTracker) timeUntil(priority int64) (time.Duration, bool) {
+ now := bt.clock.Now()
var dt float64
- if bt.balance.pos != 0 {
- if bt.timeFactor < 1e-100 {
+ if bt.balance.pos.base != 0 {
+ posBalance := bt.balance.pos.value(bt.exp.posExpiration(now))
+ timePrice := bt.posFactor.timePrice(bt.capacity)
+ if timePrice < 1e-100 {
return 0, false
}
if priority < 0 {
- newBalance := uint64(^priority) * bt.capacity
- if newBalance > bt.balance.pos {
+ newBalance := uint64(-priority) * bt.capacity
+ if newBalance > posBalance {
return 0, false
}
- dt = float64(bt.balance.pos-newBalance) / bt.timeFactor
+ dt = float64(posBalance-newBalance) / timePrice
return time.Duration(dt), true
} else {
- dt = float64(bt.balance.pos) / bt.timeFactor
+ dt = float64(posBalance) / timePrice
}
} else {
if priority < 0 {
@@ -151,16 +212,19 @@ func (bt *balanceTracker) timeUntil(priority int64) (time.Duration, bool) {
}
}
// if we have a positive balance then dt equals the time needed to get it to zero
- if uint64(priority) > bt.balance.neg {
- if bt.negTimeFactor < 1e-100 {
+ negBalance := bt.balance.neg.value(bt.exp.negExpiration(now))
+ timePrice := bt.negFactor.timePrice(bt.capacity)
+ if uint64(priority) > negBalance {
+ if timePrice < 1e-100 {
return 0, false
}
- dt += float64(uint64(priority)-bt.balance.neg) / bt.negTimeFactor
+ dt += float64(uint64(priority)-negBalance) / timePrice
}
return time.Duration(dt), true
}
// setCapacity updates the capacity value used for priority calculation
+// Note: capacity should never be zero
func (bt *balanceTracker) setCapacity(capacity uint64) {
bt.lock.Lock()
defer bt.lock.Unlock()
@@ -262,26 +326,26 @@ func (bt *balanceTracker) updateAfter(dt time.Duration) {
}
// requestCost should be called after serving a request for the given peer
-func (bt *balanceTracker) requestCost(cost uint64) {
+func (bt *balanceTracker) requestCost(cost uint64) uint64 {
bt.lock.Lock()
defer bt.lock.Unlock()
if bt.stopped {
- return
+ return 0
}
now := bt.clock.Now()
bt.addBalance(now)
fcost := float64(cost)
- if bt.balance.pos != 0 {
- if bt.requestFactor != 0 {
- c := uint64(fcost * bt.requestFactor)
- if bt.balance.pos >= c {
- bt.balance.pos -= c
+ posExp := bt.exp.posExpiration(now)
+ if bt.balance.pos.base != 0 {
+ if bt.posFactor.reqPrice() != 0 {
+ c := -int64(fcost * bt.posFactor.reqPrice())
+ cc := bt.balance.pos.add(c, posExp)
+ if c == cc {
fcost = 0
} else {
- fcost *= 1 - float64(bt.balance.pos)/float64(c)
- bt.balance.pos = 0
+ fcost *= 1 - float64(cc)/float64(c)
}
bt.checkCallbacks(now)
} else {
@@ -289,16 +353,17 @@ func (bt *balanceTracker) requestCost(cost uint64) {
}
}
if fcost > 0 {
- if bt.negRequestFactor != 0 {
- bt.balance.neg += uint64(fcost * bt.negRequestFactor)
+ if bt.negFactor.reqPrice() != 0 {
+ bt.balance.neg.add(int64(fcost*bt.negFactor.reqPrice()), bt.exp.negExpiration(now))
bt.checkCallbacks(now)
}
}
bt.sumReqCost += cost
+ return bt.balance.pos.value(posExp)
}
// getBalance returns the current positive and negative balance
-func (bt *balanceTracker) getBalance(now mclock.AbsTime) (uint64, uint64) {
+func (bt *balanceTracker) getBalance(now mclock.AbsTime) (expiredValue, expiredValue) {
bt.lock.Lock()
defer bt.lock.Unlock()
@@ -307,7 +372,7 @@ func (bt *balanceTracker) getBalance(now mclock.AbsTime) (uint64, uint64) {
}
// setBalance sets the positive and negative balance to the given values
-func (bt *balanceTracker) setBalance(pos, neg uint64) error {
+func (bt *balanceTracker) setBalance(pos, neg expiredValue) error {
bt.lock.Lock()
defer bt.lock.Unlock()
@@ -321,7 +386,7 @@ func (bt *balanceTracker) setBalance(pos, neg uint64) error {
// setFactors sets the price factors. timeFactor is the price of a nanosecond of
// connection while requestFactor is the price of a "realCost" unit.
-func (bt *balanceTracker) setFactors(neg bool, timeFactor, requestFactor float64) {
+func (bt *balanceTracker) setFactors(posFactor, negFactor priceFactors) {
bt.lock.Lock()
defer bt.lock.Unlock()
@@ -330,13 +395,7 @@ func (bt *balanceTracker) setFactors(neg bool, timeFactor, requestFactor float64
}
now := bt.clock.Now()
bt.addBalance(now)
- if neg {
- bt.negTimeFactor = timeFactor
- bt.negRequestFactor = requestFactor
- } else {
- bt.timeFactor = timeFactor
- bt.requestFactor = requestFactor
- }
+ bt.posFactor, bt.negFactor = posFactor, negFactor
bt.checkCallbacks(now)
}
diff --git a/les/balance_test.go b/les/balance_test.go
index b571c2cc5c2d..69d8caf95b01 100644
--- a/les/balance_test.go
+++ b/les/balance_test.go
@@ -23,18 +23,31 @@ import (
"github.com/ethereum/go-ethereum/common/mclock"
)
+type zeroExpCtrl struct{}
+
+func (z zeroExpCtrl) posExpiration(mclock.AbsTime) fixed64 {
+ return 0
+}
+
+func (z zeroExpCtrl) negExpiration(mclock.AbsTime) fixed64 {
+ return 0
+}
+
+func expval(v uint64) expiredValue {
+ return expiredValue{base: v}
+}
+
func TestSetBalance(t *testing.T) {
var clock = &mclock.Simulated{}
var inputs = []struct {
- pos uint64
- neg uint64
+ pos, neg expiredValue
}{
- {1000, 0},
- {0, 1000},
- {1000, 1000},
+ {expval(1000), expval(0)},
+ {expval(0), expval(1000)},
+ {expval(1000), expval(1000)},
}
- tracker := balanceTracker{}
+ tracker := balanceTracker{exp: zeroExpCtrl{}}
tracker.init(clock, 1000)
defer tracker.stop(clock.Now())
@@ -53,14 +66,13 @@ func TestSetBalance(t *testing.T) {
func TestBalanceTimeCost(t *testing.T) {
var (
clock = &mclock.Simulated{}
- tracker = balanceTracker{}
+ tracker = balanceTracker{exp: zeroExpCtrl{}}
)
tracker.init(clock, 1000)
defer tracker.stop(clock.Now())
- tracker.setFactors(false, 1, 1)
- tracker.setFactors(true, 1, 1)
+ tracker.setFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
- tracker.setBalance(uint64(time.Minute), 0) // 1 minute time allowance
+ tracker.setBalance(expval(uint64(time.Minute)), expval(0)) // 1 minute time allowance
var inputs = []struct {
runTime time.Duration
@@ -74,21 +86,21 @@ func TestBalanceTimeCost(t *testing.T) {
}
for _, i := range inputs {
clock.Run(i.runTime)
- if pos, _ := tracker.getBalance(clock.Now()); pos != i.expPos {
+ if pos, _ := tracker.getBalance(clock.Now()); pos != expval(i.expPos) {
t.Fatalf("Positive balance mismatch, want %v, got %v", i.expPos, pos)
}
- if _, neg := tracker.getBalance(clock.Now()); neg != i.expNeg {
+ if _, neg := tracker.getBalance(clock.Now()); neg != expval(i.expNeg) {
t.Fatalf("Negative balance mismatch, want %v, got %v", i.expNeg, neg)
}
}
- tracker.setBalance(uint64(time.Minute), 0) // Refill 1 minute time allowance
+ tracker.setBalance(expval(uint64(time.Minute)), expval(0)) // Refill 1 minute time allowance
for _, i := range inputs {
clock.Run(i.runTime)
- if pos, _ := tracker.getBalance(clock.Now()); pos != i.expPos {
+ if pos, _ := tracker.getBalance(clock.Now()); pos != expval(i.expPos) {
t.Fatalf("Positive balance mismatch, want %v, got %v", i.expPos, pos)
}
- if _, neg := tracker.getBalance(clock.Now()); neg != i.expNeg {
+ if _, neg := tracker.getBalance(clock.Now()); neg != expval(i.expNeg) {
t.Fatalf("Negative balance mismatch, want %v, got %v", i.expNeg, neg)
}
}
@@ -97,14 +109,13 @@ func TestBalanceTimeCost(t *testing.T) {
func TestBalanceReqCost(t *testing.T) {
var (
clock = &mclock.Simulated{}
- tracker = balanceTracker{}
+ tracker = balanceTracker{exp: zeroExpCtrl{}}
)
tracker.init(clock, 1000)
defer tracker.stop(clock.Now())
- tracker.setFactors(false, 1, 1)
- tracker.setFactors(true, 1, 1)
+ tracker.setFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
- tracker.setBalance(uint64(time.Minute), 0) // 1 minute time serving time allowance
+ tracker.setBalance(expval(uint64(time.Minute)), expval(0)) // 1 minute time serving time allowance
var inputs = []struct {
reqCost uint64
expPos uint64
@@ -117,10 +128,10 @@ func TestBalanceReqCost(t *testing.T) {
}
for _, i := range inputs {
tracker.requestCost(i.reqCost)
- if pos, _ := tracker.getBalance(clock.Now()); pos != i.expPos {
+ if pos, _ := tracker.getBalance(clock.Now()); pos != expval(i.expPos) {
t.Fatalf("Positive balance mismatch, want %v, got %v", i.expPos, pos)
}
- if _, neg := tracker.getBalance(clock.Now()); neg != i.expNeg {
+ if _, neg := tracker.getBalance(clock.Now()); neg != expval(i.expNeg) {
t.Fatalf("Negative balance mismatch, want %v, got %v", i.expNeg, neg)
}
}
@@ -129,25 +140,24 @@ func TestBalanceReqCost(t *testing.T) {
func TestBalanceToPriority(t *testing.T) {
var (
clock = &mclock.Simulated{}
- tracker = balanceTracker{}
+ tracker = balanceTracker{exp: zeroExpCtrl{}}
)
tracker.init(clock, 1000) // cap = 1000
defer tracker.stop(clock.Now())
- tracker.setFactors(false, 1, 1)
- tracker.setFactors(true, 1, 1)
+ tracker.setFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
var inputs = []struct {
pos uint64
neg uint64
priority int64
}{
- {1000, 0, ^int64(1)},
- {2000, 0, ^int64(2)}, // Higher balance, lower priority value
+ {1000, 0, -1},
+ {2000, 0, -2}, // Higher balance, lower priority value
{0, 0, 0},
{0, 1000, 1000},
}
for _, i := range inputs {
- tracker.setBalance(i.pos, i.neg)
+ tracker.setBalance(expval(i.pos), expval(i.neg))
priority := tracker.getPriority(clock.Now())
if priority != i.priority {
t.Fatalf("Priority mismatch, want %v, got %v", i.priority, priority)
@@ -158,30 +168,29 @@ func TestBalanceToPriority(t *testing.T) {
func TestEstimatedPriority(t *testing.T) {
var (
clock = &mclock.Simulated{}
- tracker = balanceTracker{}
+ tracker = balanceTracker{exp: zeroExpCtrl{}}
)
tracker.init(clock, 1000000000) // cap = 1000,000,000
defer tracker.stop(clock.Now())
- tracker.setFactors(false, 1, 1)
- tracker.setFactors(true, 1, 1)
+ tracker.setFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
- tracker.setBalance(uint64(time.Minute), 0)
+ tracker.setBalance(expval(uint64(time.Minute)), expval(0))
var inputs = []struct {
runTime time.Duration // time cost
futureTime time.Duration // diff of future time
reqCost uint64 // single request cost
priority int64 // expected estimated priority
}{
- {time.Second, time.Second, 0, ^int64(58)},
- {0, time.Second, 0, ^int64(58)},
+ {time.Second, time.Second, 0, -58},
+ {0, time.Second, 0, -58},
// 2 seconds time cost, 1 second estimated time cost, 10^9 request cost,
// 10^9 estimated request cost per second.
- {time.Second, time.Second, 1000000000, ^int64(55)},
+ {time.Second, time.Second, 1000000000, -55},
// 3 seconds time cost, 3 second estimated time cost, 10^9*2 request cost,
// 4*10^9 estimated request cost.
- {time.Second, 3 * time.Second, 1000000000, ^int64(48)},
+ {time.Second, 3 * time.Second, 1000000000, -48},
// All positive balance is used up
{time.Second * 55, 0, 0, 0},
@@ -202,22 +211,21 @@ func TestEstimatedPriority(t *testing.T) {
func TestCallbackChecking(t *testing.T) {
var (
clock = &mclock.Simulated{}
- tracker = balanceTracker{}
+ tracker = balanceTracker{exp: zeroExpCtrl{}}
)
tracker.init(clock, 1000000) // cap = 1000,000
defer tracker.stop(clock.Now())
- tracker.setFactors(false, 1, 1)
- tracker.setFactors(true, 1, 1)
+ tracker.setFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
var inputs = []struct {
priority int64
expDiff time.Duration
}{
- {^int64(500), time.Millisecond * 500},
+ {-500, time.Millisecond * 500},
{0, time.Second},
{int64(time.Second), 2 * time.Second},
}
- tracker.setBalance(uint64(time.Second), 0)
+ tracker.setBalance(expval(uint64(time.Second)), expval(0))
for _, i := range inputs {
diff, _ := tracker.timeUntil(i.priority)
if diff != i.expDiff {
@@ -229,15 +237,14 @@ func TestCallbackChecking(t *testing.T) {
func TestCallback(t *testing.T) {
var (
clock = &mclock.Simulated{}
- tracker = balanceTracker{}
+ tracker = balanceTracker{exp: zeroExpCtrl{}}
)
tracker.init(clock, 1000) // cap = 1000
defer tracker.stop(clock.Now())
- tracker.setFactors(false, 1, 1)
- tracker.setFactors(true, 1, 1)
+ tracker.setFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
callCh := make(chan struct{}, 1)
- tracker.setBalance(uint64(time.Minute), 0)
+ tracker.setBalance(expval(uint64(time.Minute)), expval(0))
tracker.addCallback(balanceCallbackZero, 0, func() { callCh <- struct{}{} })
clock.Run(time.Minute)
@@ -247,7 +254,7 @@ func TestCallback(t *testing.T) {
t.Fatalf("Callback hasn't been called yet")
}
- tracker.setBalance(uint64(time.Minute), 0)
+ tracker.setBalance(expval(uint64(time.Minute)), expval(0))
tracker.addCallback(balanceCallbackZero, 0, func() { callCh <- struct{}{} })
tracker.removeCallback(balanceCallbackZero)
diff --git a/les/client.go b/les/client.go
index dfd0909778c8..643aa88dc429 100644
--- a/les/client.go
+++ b/les/client.go
@@ -50,6 +50,7 @@ type LightEthereum struct {
lesCommons
peers *serverPeerSet
+ srvr *p2p.Server
reqDist *requestDistributor
retriever *retrieveManager
odr *LesOdr
@@ -208,6 +209,12 @@ func (s *LightEthereum) APIs() []rpc.API {
Service: NewPrivateLightAPI(&s.lesCommons),
Public: false,
},
+ {
+ Namespace: "lespay",
+ Version: "1.0",
+ Service: NewPrivateLespayAPI(nil, s.peers, s.handler, s.srvr.DiscV5, nil),
+ Public: false,
+ },
}...)
}
@@ -237,6 +244,7 @@ func (s *LightEthereum) Protocols() []p2p.Protocol {
// light ethereum protocol implementation.
func (s *LightEthereum) Start(srvr *p2p.Server) error {
log.Warn("Light client mode is an experimental feature")
+ s.srvr = srvr
// Start bloom request workers.
s.wg.Add(bloomServiceThreads)
diff --git a/les/client_handler.go b/les/client_handler.go
index d04574c8c7fa..56392632bcf3 100644
--- a/les/client_handler.go
+++ b/les/client_handler.go
@@ -40,6 +40,9 @@ type clientHandler struct {
downloader *downloader.Downloader
backend *LightEthereum
+ lespayReplyHandlers map[uint64]func([]byte, uint) bool
+ lespayReplyLock sync.Mutex
+
closeCh chan struct{}
wg sync.WaitGroup // WaitGroup used to track all connected peers.
syncDone func() // Test hooks when syncing is done.
@@ -47,9 +50,10 @@ type clientHandler struct {
func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.TrustedCheckpoint, backend *LightEthereum) *clientHandler {
handler := &clientHandler{
- checkpoint: checkpoint,
- backend: backend,
- closeCh: make(chan struct{}),
+ checkpoint: checkpoint,
+ backend: backend,
+ closeCh: make(chan struct{}),
+ lespayReplyHandlers: make(map[uint64]func([]byte, uint) bool),
}
if ulcServers != nil {
ulc, err := newULC(ulcServers, ulcFraction)
@@ -112,28 +116,48 @@ func (h *clientHandler) handle(p *serverPeer) error {
p.Log().Debug("Light Ethereum handshake failed", "err", err)
return err
}
- // Register the peer locally
- if err := h.backend.peers.register(p); err != nil {
- p.Log().Error("Light Ethereum peer registration failed", "err", err)
- return err
- }
- serverConnectionGauge.Update(int64(h.backend.peers.len()))
- connectedAt := mclock.Now()
- defer func() {
- h.backend.peers.unregister(p.id)
+ var (
+ connectedAt mclock.AbsTime
+ lastActive bool
+ )
+ activate := func() {
+ // Register the peer locally
+ if err := h.backend.peers.register(p); err != nil {
+ p.Log().Error("Light Ethereum peer registration failed", "err", err)
+ return
+ }
+ serverConnectionGauge.Update(int64(h.backend.peers.len()))
+ connectedAt = mclock.Now()
+ h.fetcher.announce(p, &announceData{Hash: p.headInfo.Hash, Number: p.headInfo.Number, Td: p.headInfo.Td})
+ lastActive = true
+ }
+ deactivate := func() {
+ h.backend.peers.unregister(p)
connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
serverConnectionGauge.Update(int64(h.backend.peers.len()))
+ lastActive = false
+ }
+ defer func() {
+ if lastActive {
+ deactivate()
+ }
+ h.backend.peers.disconnect(p.id)
}()
- h.fetcher.announce(p, &announceData{Hash: p.headInfo.Hash, Number: p.headInfo.Number, Td: p.headInfo.Td})
-
// pool entry can be nil during the unit test.
if p.poolEntry != nil {
h.backend.serverPool.registered(p.poolEntry)
}
+
// Spawn a main loop to handle all incoming messages.
for {
+ if p.active && !lastActive {
+ activate()
+ }
+ if !p.active && lastActive {
+ deactivate()
+ }
if err := h.handleMsg(p); err != nil {
p.Log().Debug("Light Ethereum message handling failed", "err", err)
p.fcServer.DumpLogs()
@@ -157,7 +181,10 @@ func (h *clientHandler) handleMsg(p *serverPeer) error {
}
defer msg.Discard()
- var deliverMsg *Msg
+ var (
+ deliverMsg *Msg
+ responseError bool
+ )
// Handle the message depending on its contents
switch msg.Code {
@@ -193,13 +220,15 @@ func (h *clientHandler) handleMsg(p *serverPeer) error {
case BlockHeadersMsg:
p.Log().Trace("Received block header response message")
var resp struct {
- ReqID, BV uint64
- Headers []*types.Header
+ ReqID uint64
+ SF stateFeedback
+ Headers []*types.Header
}
+ resp.SF.protocolVersion = p.version
if err := msg.Decode(&resp); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
- p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
+ p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV)
if h.fetcher.requestedID(resp.ReqID) {
h.fetcher.deliverHeaders(p, resp.ReqID, resp.Headers)
} else {
@@ -210,13 +239,15 @@ func (h *clientHandler) handleMsg(p *serverPeer) error {
case BlockBodiesMsg:
p.Log().Trace("Received block bodies response")
var resp struct {
- ReqID, BV uint64
- Data []*types.Body
+ ReqID uint64
+ SF stateFeedback
+ Data []*types.Body
}
+ resp.SF.protocolVersion = p.version
if err := msg.Decode(&resp); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
- p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
+ p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV)
deliverMsg = &Msg{
MsgType: MsgBlockBodies,
ReqID: resp.ReqID,
@@ -225,13 +256,15 @@ func (h *clientHandler) handleMsg(p *serverPeer) error {
case CodeMsg:
p.Log().Trace("Received code response")
var resp struct {
- ReqID, BV uint64
- Data [][]byte
+ ReqID uint64
+ SF stateFeedback
+ Data [][]byte
}
+ resp.SF.protocolVersion = p.version
if err := msg.Decode(&resp); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
- p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
+ p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV)
deliverMsg = &Msg{
MsgType: MsgCode,
ReqID: resp.ReqID,
@@ -240,13 +273,15 @@ func (h *clientHandler) handleMsg(p *serverPeer) error {
case ReceiptsMsg:
p.Log().Trace("Received receipts response")
var resp struct {
- ReqID, BV uint64
- Receipts []types.Receipts
+ ReqID uint64
+ SF stateFeedback
+ Receipts []types.Receipts
}
+ resp.SF.protocolVersion = p.version
if err := msg.Decode(&resp); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
- p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
+ p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV)
deliverMsg = &Msg{
MsgType: MsgReceipts,
ReqID: resp.ReqID,
@@ -255,13 +290,15 @@ func (h *clientHandler) handleMsg(p *serverPeer) error {
case ProofsV2Msg:
p.Log().Trace("Received les/2 proofs response")
var resp struct {
- ReqID, BV uint64
- Data light.NodeList
+ ReqID uint64
+ SF stateFeedback
+ Data light.NodeList
}
+ resp.SF.protocolVersion = p.version
if err := msg.Decode(&resp); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
- p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
+ p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV)
deliverMsg = &Msg{
MsgType: MsgProofsV2,
ReqID: resp.ReqID,
@@ -270,13 +307,15 @@ func (h *clientHandler) handleMsg(p *serverPeer) error {
case HelperTrieProofsMsg:
p.Log().Trace("Received helper trie proof response")
var resp struct {
- ReqID, BV uint64
- Data HelperTrieResps
+ ReqID uint64
+ SF stateFeedback
+ Data HelperTrieResps
}
+ resp.SF.protocolVersion = p.version
if err := msg.Decode(&resp); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
- p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
+ p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV)
deliverMsg = &Msg{
MsgType: MsgHelperTrieProofs,
ReqID: resp.ReqID,
@@ -285,13 +324,15 @@ func (h *clientHandler) handleMsg(p *serverPeer) error {
case TxStatusMsg:
p.Log().Trace("Received tx status response")
var resp struct {
- ReqID, BV uint64
- Status []light.TxStatus
+ ReqID uint64
+ SF stateFeedback
+ Status []light.TxStatus
}
+ resp.SF.protocolVersion = p.version
if err := msg.Decode(&resp); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
- p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
+ p.fcServer.ReceivedReply(resp.ReqID, resp.SF.BV)
deliverMsg = &Msg{
MsgType: MsgTxStatus,
ReqID: resp.ReqID,
@@ -302,13 +343,32 @@ func (h *clientHandler) handleMsg(p *serverPeer) error {
h.backend.retriever.frozen(p)
p.Log().Debug("Service stopped")
case ResumeMsg:
- var bv uint64
- if err := msg.Decode(&bv); err != nil {
+ var sf stateFeedback
+ sf.protocolVersion = p.version
+ if err := msg.Decode(&sf); err != nil {
return errResp(ErrDecode, "msg %v: %v", msg, err)
}
- p.fcServer.ResumeFreeze(bv)
+ p.fcServer.ResumeFreeze(sf.BV)
p.unfreeze()
p.Log().Debug("Service resumed")
+ case LespayReplyMsg:
+ p.Log().Trace("Received tx status response")
+ var resp struct {
+ ReqID uint64
+ Reply lespayReply
+ }
+ if err := msg.Decode(&resp); err != nil {
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ h.lespayReplyLock.Lock()
+ if handler := h.lespayReplyHandlers[resp.ReqID]; handler != nil {
+ delete(h.lespayReplyHandlers, resp.ReqID)
+ responseError = !handler(resp.Reply.Reply, resp.Reply.Delay)
+ } else {
+ responseError = true
+ }
+ h.lespayReplyLock.Unlock()
+
default:
p.Log().Trace("Received invalid message", "code", msg.Code)
return errResp(ErrInvalidMsgCode, "%v", msg.Code)
@@ -316,17 +376,48 @@ func (h *clientHandler) handleMsg(p *serverPeer) error {
// Deliver the received response to retriever.
if deliverMsg != nil {
if err := h.backend.retriever.deliver(p, deliverMsg); err != nil {
- p.errCount++
- if p.errCount > maxResponseErrors {
- return err
- }
+ responseError = true
+ }
+ }
+ if responseError {
+ p.errCount++
+ if p.errCount > maxResponseErrors {
+ return err
}
}
return nil
}
+// makeLespayCall sends a lespay command through an LES connection and registers
+// a response handler. It returns a cancel function that removes the response
+// handler and calls it with a nil parameter if the response has not arrived yet.
+func (h *clientHandler) makeLespayCall(p *serverPeer, cmd []byte, handler func([]byte, uint) bool) func() bool {
+ reqID := genReqID()
+ h.lespayReplyLock.Lock()
+ h.lespayReplyHandlers[reqID] = handler
+ h.lespayReplyLock.Unlock()
+ if p.sendLespay(reqID, cmd) != nil {
+ h.lespayReplyLock.Lock()
+ delete(h.lespayReplyHandlers, reqID)
+ h.lespayReplyLock.Unlock()
+ return nil
+ }
+ return func() bool {
+ h.lespayReplyLock.Lock()
+ cancel := h.lespayReplyHandlers[reqID] != nil
+ if cancel {
+ delete(h.lespayReplyHandlers, reqID)
+ }
+ h.lespayReplyLock.Unlock()
+ if cancel {
+ handler(nil, 0)
+ }
+ return cancel
+ }
+}
+
func (h *clientHandler) removePeer(id string) {
- h.backend.peers.unregister(id)
+ h.backend.peers.disconnect(id)
}
type peerConnection struct {
diff --git a/les/clientdb.go b/les/clientdb.go
new file mode 100644
index 000000000000..f24d499479b2
--- /dev/null
+++ b/les/clientdb.go
@@ -0,0 +1,339 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package les
+
+import (
+ "bytes"
+ "encoding/binary"
+ "io"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/common/mclock"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/rlp"
+ lru "github.com/hashicorp/golang-lru"
+)
+
+const balanceCacheLimit = 8192 // the maximum number of cached items in service token balance queue
+
+// tokenBalance is a wrapper of expiredValue which represents the service token
+// balance of clients. The balance value will decay exponentially over time and
+// can be deleted when the amount is small enough.
+type tokenBalance struct {
+ value expiredValue
+}
+
+// EncodeRLP implements rlp.Encoder
+func (b *tokenBalance) EncodeRLP(w io.Writer) error {
+ return rlp.Encode(w, []interface{}{b.value.base, b.value.exp})
+}
+
+// DecodeRLP implements rlp.Decoder
+func (b *tokenBalance) DecodeRLP(s *rlp.Stream) error {
+ var entry struct {
+ Base, Exp uint64
+ }
+ if err := s.Decode(&entry); err != nil {
+ return err
+ }
+ b.value = expiredValue{base: entry.Base, exp: entry.Exp}
+ return nil
+}
+
+// currencyBalance represents the client's currency balance.
+type currencyBalance struct {
+ amount uint64
+ typ string
+}
+
+// EncodeRLP implements rlp.Encoder
+func (b *currencyBalance) EncodeRLP(w io.Writer) error {
+ return rlp.Encode(w, []interface{}{b.amount, b.typ})
+}
+
+// DecodeRLP implements rlp.Decoder
+func (b *currencyBalance) DecodeRLP(s *rlp.Stream) error {
+ var entry struct {
+ Amount uint64
+ Type string
+ }
+ if err := s.Decode(&entry); err != nil {
+ return err
+ }
+ b.amount, b.typ = entry.Amount, entry.Type
+ return nil
+}
+
+const (
+ // nodeDBVersion is the version identifier of the node data in db
+ //
+ // Changelog:
+ // * Replace `lastTotal` with `meta` in positive balance: version 0=>1
+ // * Rework balance, add currency balance: version 1=>2
+ nodeDBVersion = 2
+
+ // dbCleanupCycle is the cycle of db for useless data cleanup
+ dbCleanupCycle = time.Hour
+)
+
+var (
+ posBalancePrefix = []byte("pb:") // dbVersion(uint16 big endian) + posBalancePrefix + id -> positive balance
+ negBalancePrefix = []byte("nb:") // dbVersion(uint16 big endian) + negBalancePrefix + ip -> negative balance
+ curBalancePrefix = []byte("cb:") // dbVersion(uint16 big endian) + curBalancePrefix + id -> currency balance
+ paymentReceiverPrefix = []byte("pr:") // dbVersion(uint16 big endian) + paymentReceiverPrefix + id + "receiverName:" -> receiver namespace
+ expirationKey = []byte("expiration:") // dbVersion(uint16 big endian) + expirationKey -> posExp, negExp
+)
+
+type atomicWriteLock struct {
+ released chan struct{}
+ batch ethdb.Batch
+}
+
+type nodeDB struct {
+ db ethdb.Database
+ cache *lru.Cache
+ clock mclock.Clock
+ closeCh chan struct{}
+ evictCallBack func(mclock.AbsTime, bool, tokenBalance) bool // Callback to determine whether the balance can be evicted.
+ idLockMutex sync.Mutex
+ idLocks map[string]atomicWriteLock
+ cleanupHook func() // Test hook used for testing
+}
+
+func newNodeDB(db ethdb.Database, clock mclock.Clock) *nodeDB {
+ var buff [2]byte
+ binary.BigEndian.PutUint16(buff[:], uint16(nodeDBVersion))
+
+ cache, _ := lru.New(balanceCacheLimit)
+ ndb := &nodeDB{
+ db: rawdb.NewTable(db, string(buff[:])),
+ cache: cache,
+ clock: clock,
+ closeCh: make(chan struct{}),
+ idLocks: make(map[string]atomicWriteLock),
+ }
+ go ndb.expirer()
+ return ndb
+}
+
+func (db *nodeDB) close() {
+ close(db.closeCh)
+}
+
+func (db *nodeDB) atomicWriteLock(id []byte) ethdb.KeyValueWriter {
+ db.idLockMutex.Lock()
+ for {
+ ch := db.idLocks[string(id)].released
+ if ch == nil {
+ break
+ }
+ db.idLockMutex.Unlock()
+ <-ch
+ db.idLockMutex.Lock()
+ }
+ batch := db.db.NewBatch()
+ db.idLocks[string(id)] = atomicWriteLock{
+ released: make(chan struct{}),
+ batch: batch,
+ }
+ db.idLockMutex.Unlock()
+ return batch
+}
+
+func (db *nodeDB) atomicWriteUnlock(id []byte) {
+ db.idLockMutex.Lock()
+ awl := db.idLocks[string(id)]
+ awl.batch.Write()
+ close(awl.released)
+ delete(db.idLocks, string(id))
+ db.idLockMutex.Unlock()
+}
+
+func (db *nodeDB) writer(id []byte) ethdb.KeyValueWriter {
+ db.idLockMutex.Lock()
+ batch := db.idLocks[string(id)].batch
+ db.idLockMutex.Unlock()
+ if batch == nil {
+ return db.db
+ }
+ return batch
+}
+
+func idKey(id []byte, neg bool) []byte {
+ prefix := posBalancePrefix
+ if neg {
+ prefix = negBalancePrefix
+ }
+ return append(prefix, id...)
+}
+
+func receiverPrefix(id enode.ID, receiver string) []byte {
+ return append(append(paymentReceiverPrefix, id.Bytes()...), []byte(receiver+":")...)
+}
+
+func (db *nodeDB) getExpiration() (fixed64, fixed64) {
+ blob, err := db.db.Get(expirationKey)
+ if err != nil || len(blob) != 16 {
+ return 0, 0
+ }
+ return fixed64(binary.BigEndian.Uint64(blob[:8])), fixed64(binary.BigEndian.Uint64(blob[8:16]))
+}
+
+func (db *nodeDB) setExpiration(pos, neg fixed64) {
+ var buff [16]byte
+ binary.BigEndian.PutUint64(buff[:8], uint64(pos))
+ binary.BigEndian.PutUint64(buff[8:16], uint64(neg))
+ db.db.Put(expirationKey, buff[:16])
+}
+
+func (db *nodeDB) getCurrencyBalance(id enode.ID) currencyBalance {
+ var b currencyBalance
+ enc, err := db.db.Get(append(curBalancePrefix, id.Bytes()...))
+ if err != nil || len(enc) == 0 {
+ return b
+ }
+ if err := rlp.DecodeBytes(enc, &b); err != nil {
+ log.Crit("Failed to decode positive balance", "err", err)
+ }
+ return b
+}
+
+func (db *nodeDB) setCurrencyBalance(id enode.ID, b currencyBalance) {
+ enc, err := rlp.EncodeToBytes(&(b))
+ if err != nil {
+ log.Crit("Failed to encode currency balance", "err", err)
+ }
+ db.writer(id.Bytes()).Put(append(curBalancePrefix, id.Bytes()...), enc)
+}
+
+func (db *nodeDB) getOrNewBalance(id []byte, neg bool) tokenBalance {
+ key := idKey(id, neg)
+ item, exist := db.cache.Get(string(key))
+ if exist {
+ return item.(tokenBalance)
+ }
+ var b tokenBalance
+ enc, err := db.db.Get(key)
+ if err != nil || len(enc) == 0 {
+ return b
+ }
+ if err := rlp.DecodeBytes(enc, &b); err != nil {
+ log.Crit("Failed to decode positive balance", "err", err)
+ }
+ db.cache.Add(string(key), b)
+ return b
+}
+
+func (db *nodeDB) setBalance(id []byte, neg bool, b tokenBalance) {
+ key := idKey(id, neg)
+ enc, err := rlp.EncodeToBytes(&(b))
+ if err != nil {
+ log.Crit("Failed to encode positive balance", "err", err)
+ }
+ if neg {
+ db.db.Put(key, enc)
+ } else {
+ db.writer(id).Put(key, enc)
+ }
+ db.cache.Add(string(key), b)
+}
+
+func (db *nodeDB) delBalance(id []byte, neg bool) {
+ key := idKey(id, neg)
+ if neg {
+ db.db.Delete(key)
+ } else {
+ db.writer(id).Delete(key)
+ }
+ db.cache.Remove(string(key))
+}
+
+// getPosBalanceIDs returns a lexicographically ordered list of IDs of accounts
+// with a positive balance
+func (db *nodeDB) getPosBalanceIDs(start, stop enode.ID, maxCount int) (result []enode.ID) {
+ if maxCount <= 0 {
+ return
+ }
+ it := db.db.NewIteratorWithStart(idKey(start.Bytes(), false))
+ defer it.Release()
+ for i := len(stop[:]) - 1; i >= 0; i-- {
+ stop[i]--
+ if stop[i] != 255 {
+ break
+ }
+ }
+ stopKey := idKey(stop.Bytes(), false)
+ keyLen := len(stopKey)
+
+ for it.Next() {
+ var id enode.ID
+ if len(it.Key()) != keyLen || bytes.Compare(it.Key(), stopKey) == 1 {
+ return
+ }
+ copy(id[:], it.Key()[keyLen-len(id):])
+ result = append(result, id)
+ if len(result) == maxCount {
+ return
+ }
+ }
+ return
+}
+
+func (db *nodeDB) expirer() {
+ for {
+ select {
+ case <-db.clock.After(dbCleanupCycle):
+ db.expireNodes()
+ case <-db.closeCh:
+ return
+ }
+ }
+}
+
+// expireNodes iterates the whole node db and checks whether the
+// token balances can deleted.
+func (db *nodeDB) expireNodes() {
+ var (
+ visited int
+ deleted int
+ start = time.Now()
+ )
+ for index, prefix := range [][]byte{posBalancePrefix, negBalancePrefix} {
+ iter := db.db.NewIteratorWithPrefix(prefix)
+ for iter.Next() {
+ visited += 1
+ var balance tokenBalance
+ if err := rlp.DecodeBytes(iter.Value(), &balance); err != nil {
+ log.Crit("Failed to decode negative balance", "err", err)
+ }
+ if db.evictCallBack != nil && db.evictCallBack(db.clock.Now(), index != 0, balance) {
+ deleted += 1
+ db.db.Delete(iter.Key())
+ }
+ }
+ }
+ // Invoke testing hook if it's not nil.
+ if db.cleanupHook != nil {
+ db.cleanupHook()
+ }
+ log.Debug("Expire nodes", "visited", visited, "deleted", deleted, "elapsed", common.PrettyDuration(time.Since(start)))
+}
diff --git a/les/clientdb_test.go b/les/clientdb_test.go
new file mode 100644
index 000000000000..ac2181fb667b
--- /dev/null
+++ b/les/clientdb_test.go
@@ -0,0 +1,139 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package les
+
+import (
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common/mclock"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+)
+
+func TestNodeDB(t *testing.T) {
+ ndb := newNodeDB(rawdb.NewMemoryDatabase(), mclock.System{})
+ defer ndb.close()
+
+ var cases = []struct {
+ id enode.ID
+ ip string
+ balance tokenBalance
+ positive bool
+ }{
+ {enode.ID{0x00, 0x01, 0x02}, "", tokenBalance{value: expval(100)}, true},
+ {enode.ID{0x00, 0x01, 0x02}, "", tokenBalance{value: expval(200)}, true},
+ {enode.ID{}, "127.0.0.1", tokenBalance{value: expval(100)}, false},
+ {enode.ID{}, "127.0.0.1", tokenBalance{value: expval(200)}, false},
+ }
+ for _, c := range cases {
+ if c.positive {
+ ndb.setBalance(c.id.Bytes(), false, c.balance)
+ if pb := ndb.getOrNewBalance(c.id.Bytes(), false); !reflect.DeepEqual(pb, c.balance) {
+ t.Fatalf("Positive balance mismatch, want %v, got %v", c.balance, pb)
+ }
+ } else {
+ ndb.setBalance([]byte(c.ip), true, c.balance)
+ if nb := ndb.getOrNewBalance([]byte(c.ip), true); !reflect.DeepEqual(nb, c.balance) {
+ t.Fatalf("Negative balance mismatch, want %v, got %v", c.balance, nb)
+ }
+ }
+ }
+ for _, c := range cases {
+ if c.positive {
+ ndb.delBalance(c.id.Bytes(), false)
+ if pb := ndb.getOrNewBalance(c.id.Bytes(), false); !reflect.DeepEqual(pb, tokenBalance{}) {
+ t.Fatalf("Positive balance mismatch, want %v, got %v", tokenBalance{}, pb)
+ }
+ } else {
+ ndb.delBalance([]byte(c.ip), true)
+ if nb := ndb.getOrNewBalance([]byte(c.ip), true); !reflect.DeepEqual(nb, tokenBalance{}) {
+ t.Fatalf("Negative balance mismatch, want %v, got %v", tokenBalance{}, nb)
+ }
+ }
+ }
+ posExp, negExp := fixed64(1000), fixed64(2000)
+ ndb.setExpiration(posExp, negExp)
+ if pos, neg := ndb.getExpiration(); pos != posExp || neg != negExp {
+ t.Fatalf("Expiration mismatch, want %v / %v, got %v / %v", posExp, negExp, pos, neg)
+ }
+ curBalance := currencyBalance{typ: "ETH", amount: 10000}
+ ndb.setCurrencyBalance(enode.ID{0x01, 0x02}, curBalance)
+ if got := ndb.getCurrencyBalance(enode.ID{0x01, 0x02}); !reflect.DeepEqual(got, curBalance) {
+ t.Fatalf("Currency balance mismatch, want %v, got %v", curBalance, got)
+ }
+}
+
+func TestNodeDBExpiration(t *testing.T) {
+ var (
+ iterated int
+ done = make(chan struct{}, 1)
+ )
+ callback := func(now mclock.AbsTime, neg bool, b tokenBalance) bool {
+ iterated += 1
+ return true
+ }
+ clock := &mclock.Simulated{}
+ ndb := newNodeDB(rawdb.NewMemoryDatabase(), clock)
+ defer ndb.close()
+ ndb.evictCallBack = callback
+ ndb.cleanupHook = func() { done <- struct{}{} }
+
+ var cases = []struct {
+ id []byte
+ neg bool
+ balance tokenBalance
+ }{
+ {[]byte{0x01, 0x02}, false, tokenBalance{value: expval(1)}},
+ {[]byte{0x03, 0x04}, false, tokenBalance{value: expval(1)}},
+ {[]byte{0x05, 0x06}, false, tokenBalance{value: expval(1)}},
+ {[]byte{0x07, 0x08}, false, tokenBalance{value: expval(1)}},
+
+ {[]byte("127.0.0.1"), true, tokenBalance{value: expval(1)}},
+ {[]byte("127.0.0.2"), true, tokenBalance{value: expval(1)}},
+ {[]byte("127.0.0.3"), true, tokenBalance{value: expval(1)}},
+ {[]byte("127.0.0.4"), true, tokenBalance{value: expval(1)}},
+ }
+ for _, c := range cases {
+ ndb.setBalance(c.id, c.neg, c.balance)
+ }
+ clock.WaitForTimers(1)
+ clock.Run(time.Hour + time.Minute)
+ select {
+ case <-done:
+ case <-time.NewTimer(time.Second).C:
+ t.Fatalf("timeout")
+ }
+ if iterated != 8 {
+ t.Fatalf("Failed to evict useless balances, want %v, got %d", 8, iterated)
+ }
+
+ for _, c := range cases {
+ ndb.setBalance(c.id, c.neg, c.balance)
+ }
+ clock.WaitForTimers(1)
+ clock.Run(time.Hour + time.Minute)
+ select {
+ case <-done:
+ case <-time.NewTimer(time.Second).C:
+ t.Fatalf("timeout")
+ }
+ if iterated != 16 {
+ t.Fatalf("Failed to evict useless balances, want %v, got %d", 16, iterated)
+ }
+}
diff --git a/les/clientpool.go b/les/clientpool.go
index b01c825a7a96..035f8f2c2708 100644
--- a/les/clientpool.go
+++ b/les/clientpool.go
@@ -17,40 +17,35 @@
package les
import (
- "bytes"
- "encoding/binary"
"fmt"
- "io"
"math"
"sync"
"time"
- "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/common/prque"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode"
- "github.com/ethereum/go-ethereum/rlp"
- lru "github.com/hashicorp/golang-lru"
)
const (
- negBalanceExpTC = time.Hour // time constant for exponentially reducing negative balance
- fixedPointMultiplier = 0x1000000 // constant to convert logarithms to fixed point format
- lazyQueueRefresh = time.Second * 10 // refresh period of the connected queue
- persistCumulativeTimeRefresh = time.Minute * 5 // refresh period of the cumulative running time persistence
- posBalanceCacheLimit = 8192 // the maximum number of cached items in positive balance queue
- negBalanceCacheLimit = 8192 // the maximum number of cached items in negative balance queue
-
- // connectedBias is applied to already connected clients So that
+ defaultPosExpTC = 36000 // default time constant (in seconds) for exponentially reducing positive balance
+ defaultNegExpTC = 3600 // default time constant (in seconds) for exponentially reducing negative balance
+ lazyQueueRefresh = time.Second * 10 // refresh period of the connected queue
+ tryActivatePeriod = time.Second * 5 // periodically check whether inactive clients can be activated
+ dropInactiveCycles = 2 // number of activation check periods after non-priority inactive peers are dropped
+ persistExpirationRefresh = time.Minute * 5 // refresh period of the token expiration persistence
+ freeRatioTC = time.Hour // time constant of token supply control based on free service availability
+
+ // activeBias is applied to already connected clients So that
// already connected client won't be kicked out very soon and we
// can ensure all connected clients can have enough time to request
// or sync some data.
//
// todo(rjl493456442) make it configurable. It can be the option of
// free trial time!
- connectedBias = time.Minute * 3
+ activeBias = time.Minute * 3
)
// clientPool implements a client database that assigns a priority to each client
@@ -60,7 +55,7 @@ const (
// then negative balance is accumulated.
//
// Balance tracking and priority calculation for connected clients is done by
-// balanceTracker. connectedQueue ensures that clients with the lowest positive or
+// balanceTracker. activeQueue ensures that clients with the lowest positive or
// highest negative balance get evicted when the total capacity allowance is full
// and new clients with a better balance want to connect.
//
@@ -69,11 +64,8 @@ const (
// each client can have several minutes of connection time.
//
// Balances of disconnected clients are stored in nodeDB including positive balance
-// and negative banalce. Negative balance is transformed into a logarithmic form
-// with a constantly shifting linear offset in order to implement an exponential
-// decrease. Besides nodeDB will have a background thread to check the negative
-// balance of disconnected client. If the balance is low enough, then the record
-// will be dropped.
+// and negative banalce. Boeth positive balance and negative balance will decrease
+// exponentially. If the balance is low enough, then the record will be dropped.
type clientPool struct {
ndb *nodeDB
lock sync.Mutex
@@ -82,19 +74,32 @@ type clientPool struct {
closed bool
removePeer func(enode.ID)
- connectedMap map[enode.ID]*clientInfo
- connectedQueue *prque.LazyQueue
+ connectedMap map[enode.ID]*clientInfo
+ activeQueue *prque.LazyQueue
+ inactiveQueue *prque.Prque
+ dropInactivePeers map[uint64][]*clientInfo
+ dropInactiveCounter uint64
+
+ activeBalances, inactiveBalances expiredValue
+ lastConnectedBalanceUpdate mclock.AbsTime
+ freeRatio, averageFreeRatio float64
defaultPosFactors, defaultNegFactors priceFactors
- connLimit int // The maximum number of connections that clientpool can support
- capLimit uint64 // The maximum cumulative capacity that clientpool can support
- connectedCap uint64 // The sum of the capacity of the current clientpool connected
- priorityConnected uint64 // The sum of the capacity of currently connected priority clients
- freeClientCap uint64 // The capacity value of each free client
- startTime mclock.AbsTime // The timestamp at which the clientpool started running
- cumulativeTime int64 // The cumulative running time of clientpool at the start point.
- disableBias bool // Disable connection bias(used in testing)
+ activeLimit int // The maximum number of connections that clientpool can support
+ capLimit uint64 // The maximum cumulative capacity that clientpool can support
+ activeCap uint64 // The sum of the capacity of the current clientpool connected
+ priorityActive uint64 // The sum of the capacity of currently connected priority clients
+ minCap uint64 // The minimal capacity value allowed for any client
+ freeClientCap uint64 // The capacity value of each free client
+ disableBias bool // Disable connection bias(used in testing)
+
+ // fields in this group are protected by expLock
+ expLock sync.RWMutex
+ posExpTC, negExpTC uint64
+ posExp, negExp fixed64
+ posExpTCi, negExpTCi float64 // already inverted (logMultiplier/time)
+ freeRatioLastUpdate mclock.AbsTime
}
// clientPoolPeer represents a client peer in the pool.
@@ -106,87 +111,146 @@ type clientPoolPeer interface {
ID() enode.ID
freeClientId() string
updateCapacity(uint64)
- freezeClient()
+ freeze()
}
-// clientInfo represents a connected client
+// clientInfo defines all information required by clientpool.
type clientInfo struct {
- address string
- id enode.ID
- connectedAt mclock.AbsTime
- capacity uint64
- priority bool
- pool *clientPool
- peer clientPoolPeer
- queueIndex int // position in connectedQueue
- balanceTracker balanceTracker
- posFactors, negFactors priceFactors
- balanceMetaInfo string
-}
-
-// connSetIndex callback updates clientInfo item index in connectedQueue
+ id enode.ID
+ address string
+ active bool
+ capacity uint64
+ priority bool
+ pool *clientPool
+ peer clientPoolPeer
+ connectedAt mclock.AbsTime
+ queueIndex int
+ balanceTracker balanceTracker
+ posFactors priceFactors
+ negFactors priceFactors
+}
+
+// connSetIndex callback updates clientInfo item index in activeQueue
func connSetIndex(a interface{}, index int) {
a.(*clientInfo).queueIndex = index
}
-// connPriority callback returns actual priority of clientInfo item in connectedQueue
+// connPriority callback returns actual priority of clientInfo item in activeQueue
func connPriority(a interface{}, now mclock.AbsTime) int64 {
c := a.(*clientInfo)
return c.balanceTracker.getPriority(now)
}
-// connMaxPriority callback returns estimated maximum priority of clientInfo item in connectedQueue
+// connMaxPriority callback returns estimated maximum priority of clientInfo item in activeQueue
func connMaxPriority(a interface{}, until mclock.AbsTime) int64 {
c := a.(*clientInfo)
pri := c.balanceTracker.estimatedPriority(until, true)
c.balanceTracker.addCallback(balanceCallbackQueue, pri+1, func() {
c.pool.lock.Lock()
- if c.queueIndex != -1 {
- c.pool.connectedQueue.Update(c.queueIndex)
+ if c.active && c.queueIndex != -1 {
+ c.pool.activeQueue.Update(c.queueIndex)
}
c.pool.lock.Unlock()
})
return pri
}
-// priceFactors determine the pricing policy (may apply either to positive or
-// negative balances which may have different factors).
-// - timeFactor is cost unit per nanosecond of connection time
-// - capacityFactor is cost unit per nanosecond of connection time per 1000000 capacity
-// - requestFactor is cost unit per request "realCost" unit
-type priceFactors struct {
- timeFactor, capacityFactor, requestFactor float64
-}
-
// newClientPool creates a new client pool
-func newClientPool(db ethdb.Database, freeClientCap uint64, clock mclock.Clock, removePeer func(enode.ID)) *clientPool {
+func newClientPool(db ethdb.Database, minCap, freeClientCap uint64, clock mclock.Clock, removePeer func(enode.ID)) *clientPool {
ndb := newNodeDB(db, clock)
+ posExp, negExp := ndb.getExpiration()
pool := &clientPool{
- ndb: ndb,
- clock: clock,
- connectedMap: make(map[enode.ID]*clientInfo),
- connectedQueue: prque.NewLazyQueue(connSetIndex, connPriority, connMaxPriority, clock, lazyQueueRefresh),
- freeClientCap: freeClientCap,
- removePeer: removePeer,
- startTime: clock.Now(),
- cumulativeTime: ndb.getCumulativeTime(),
- stopCh: make(chan struct{}),
- }
- // If the negative balance of free client is even lower than 1,
- // delete this entry.
- ndb.nbEvictCallBack = func(now mclock.AbsTime, b negBalance) bool {
- balance := math.Exp(float64(b.logValue-pool.logOffset(now)) / fixedPointMultiplier)
- return balance <= 1
+ ndb: ndb,
+ clock: clock,
+ connectedMap: make(map[enode.ID]*clientInfo),
+ activeQueue: prque.NewLazyQueue(connSetIndex, connPriority, connMaxPriority, clock, lazyQueueRefresh),
+ inactiveQueue: prque.New(connSetIndex),
+ dropInactivePeers: make(map[uint64][]*clientInfo),
+ minCap: minCap,
+ freeClientCap: freeClientCap,
+ removePeer: removePeer,
+ freeRatioLastUpdate: clock.Now(),
+ posExp: posExp,
+ negExp: negExp,
+ freeRatio: 1,
+ averageFreeRatio: 1,
+ stopCh: make(chan struct{}),
+ }
+ // set default expiration constants used by tests
+ // Note: server overwrites this if token sale is active
+ pool.setExpirationTCs(0, defaultNegExpTC)
+ // calculate total token balance amount
+ var start enode.ID
+ for {
+ ids := pool.ndb.getPosBalanceIDs(start, enode.ID{}, 1000)
+ var stop bool
+ l := len(ids)
+ if l == 1000 {
+ l--
+ start = ids[l]
+ } else {
+ stop = true
+ }
+ for i := 0; i < l; i++ {
+ pool.inactiveBalances.addExp(pool.ndb.getOrNewBalance(ids[i].Bytes(), false).value)
+ }
+ if stop {
+ break
+ }
+ }
+ // The positive and negative balances of clients are stored in database
+ // and both of these decay exponentially over time. Delete them if the
+ // value is small enough.
+ ndb.evictCallBack = func(now mclock.AbsTime, neg bool, b tokenBalance) bool {
+ var expiration fixed64
+ if neg {
+ expiration = pool.negExpiration(now)
+ } else {
+ expiration = pool.posExpiration(now)
+ }
+ return b.value.value(expiration) <= uint64(time.Second)
}
go func() {
for {
select {
case <-clock.After(lazyQueueRefresh):
pool.lock.Lock()
- pool.connectedQueue.Refresh()
+ pool.activeQueue.Refresh()
+ pool.lock.Unlock()
+ case <-pool.stopCh:
+ return
+ }
+ }
+ }()
+ go func() {
+ for {
+ select {
+ case <-clock.After(persistExpirationRefresh):
+ pool.lock.Lock()
+ now := pool.clock.Now()
+ posExp := pool.posExpiration(now)
+ negExp := pool.negExpiration(now)
+ pool.lock.Unlock()
+ pool.ndb.setExpiration(posExp, negExp)
+ case <-pool.stopCh:
+ return
+ }
+ }
+ }()
+ go func() {
+ for {
+ select {
+ case <-clock.After(tryActivatePeriod):
+ pool.lock.Lock()
+ pool.tryActivateClients()
+ for _, c := range pool.dropInactivePeers[pool.dropInactiveCounter] {
+ if _, ok := pool.connectedMap[c.id]; ok && !c.active && !c.priority {
+ pool.drop(c.peer, true)
+ }
+ }
+ delete(pool.dropInactivePeers, pool.dropInactiveCounter)
+ pool.dropInactiveCounter++
pool.lock.Unlock()
- case <-clock.After(persistCumulativeTimeRefresh):
- pool.ndb.setCumulativeTime(pool.logOffset(clock.Now()))
case <-pool.stopCh:
return
}
@@ -201,122 +265,212 @@ func (f *clientPool) stop() {
f.lock.Lock()
f.closed = true
f.lock.Unlock()
- f.ndb.setCumulativeTime(f.logOffset(f.clock.Now()))
+ now := f.clock.Now()
+ f.ndb.setExpiration(f.posExpiration(now), f.negExpiration(now))
f.ndb.close()
}
+// updateFreeRatio updates freeRatio, averageFreeRatio, posExp and negExp based
+// on free service availability. Should be called after capLimit or priorityActive
+// is changed.
+func (f *clientPool) updateFreeRatio() {
+ f.freeRatio = 0
+ if f.priorityActive < f.capLimit {
+ freeCap := f.capLimit - f.priorityActive
+ if freeCap > f.freeClientCap {
+ freeCapThreshold := f.capLimit / 4
+ if freeCap > freeCapThreshold {
+ f.freeRatio = 1
+ } else {
+ f.freeRatio = float64(freeCap-f.freeClientCap) / float64(freeCapThreshold-f.freeClientCap)
+ }
+ }
+ }
+ f.expLock.Lock()
+ now := f.clock.Now()
+ dt := now - f.freeRatioLastUpdate
+ if dt < 0 {
+ dt = 0
+ }
+ f.averageFreeRatio -= (f.freeRatio - f.averageFreeRatio) * math.Expm1(-float64(dt)/float64(freeRatioTC))
+ f.freeRatioLastUpdate = now
+
+ f.posExp += fixed64(float64(dt) * f.posExpTCi * f.freeRatio)
+ f.negExp += fixed64(float64(dt) * f.negExpTCi * f.freeRatio)
+ f.expLock.Unlock()
+}
+
+// setExpirationTCs sets positive and negative token expiration time constants.
+// Specified in seconds, 0 means infinite (no expiration).
+func (f *clientPool) setExpirationTCs(pos, neg uint64) {
+ f.lock.Lock()
+ f.updateFreeRatio()
+ f.lock.Unlock()
+
+ f.expLock.Lock()
+ f.posExpTC, f.negExpTC = pos, neg
+ if pos > 0 {
+ f.posExpTCi = fixedFactor / float64(pos*uint64(time.Second))
+ } else {
+ f.posExpTCi = 0
+ }
+ if neg > 0 {
+ f.negExpTCi = fixedFactor / float64(neg*uint64(time.Second))
+ } else {
+ f.negExpTCi = 0
+ }
+ f.expLock.Unlock()
+}
+
+// getExpirationTCs returns the current positive and negative token expiration
+// time constants
+func (f *clientPool) getExpirationTCs() (pos, neg uint64) {
+ f.expLock.Lock()
+ defer f.expLock.Unlock()
+
+ return f.posExpTC, f.negExpTC
+}
+
+// posExpiration implements expirationController. Expiration happens only when
+// free service is available.
+func (f *clientPool) posExpiration(now mclock.AbsTime) fixed64 {
+ f.expLock.RLock()
+ defer f.expLock.RUnlock()
+
+ if f.posExpTC == 0 {
+ return f.posExp
+ }
+ dt := now - f.freeRatioLastUpdate
+ if dt < 0 {
+ dt = 0
+ }
+ return f.posExp + fixed64(float64(dt)*f.posExpTCi*f.freeRatio)
+}
+
+// negExpiration implements expirationController. Expiration happens only when
+// free service is available.
+func (f *clientPool) negExpiration(now mclock.AbsTime) fixed64 {
+ f.expLock.RLock()
+ defer f.expLock.RUnlock()
+
+ if f.negExpTC == 0 {
+ return f.negExp
+ }
+ dt := now - f.freeRatioLastUpdate
+ if dt < 0 {
+ dt = 0
+ }
+ return f.negExp + fixed64(float64(dt)*f.negExpTCi*f.freeRatio)
+}
+
+// totalTokenLimit returns the current token supply limit. Token prices are based
+// on the ratio of total token amount and supply limit while the limit depends on
+// averageFreeRatio, ensuring the availability of free service most of the time.
+func (f *clientPool) totalTokenLimit() uint64 {
+ f.lock.Lock()
+ defer f.lock.Unlock()
+
+ f.updateFreeRatio()
+ d := f.averageFreeRatio
+ if d > 0.5 {
+ d = -math.Log(0.5/d) * float64(freeRatioTC)
+ } else {
+ d = 0
+ }
+ return uint64(d * float64(f.capLimit) * f.defaultPosFactors.capacityFactor)
+}
+
+// totalTokenAmount returns the total amount of currently existing service tokens
+func (f *clientPool) totalTokenAmount() uint64 {
+ f.lock.Lock()
+ defer f.lock.Unlock()
+
+ now := f.clock.Now()
+ if now > f.lastConnectedBalanceUpdate+mclock.AbsTime(time.Second) {
+ f.activeBalances = expiredValue{}
+ for _, c := range f.connectedMap {
+ pos, _ := c.balanceTracker.getBalance(now)
+ f.activeBalances.addExp(pos)
+ }
+ f.lastConnectedBalanceUpdate = now
+ }
+ sum := f.activeBalances
+ sum.addExp(f.inactiveBalances)
+ return sum.value(f.posExpiration(now))
+}
+
// connect should be called after a successful handshake. If the connection was
// rejected, there is no need to call disconnect.
-func (f *clientPool) connect(peer clientPoolPeer, capacity uint64) bool {
+func (f *clientPool) connect(peer clientPoolPeer, reqCapacity uint64) (uint64, error) {
f.lock.Lock()
defer f.lock.Unlock()
// Short circuit if clientPool is already closed.
if f.closed {
- return false
+ return 0, fmt.Errorf("Client pool is already closed")
}
// Dedup connected peers.
id, freeID := peer.ID(), peer.freeClientId()
if _, ok := f.connectedMap[id]; ok {
clientRejectedMeter.Mark(1)
log.Debug("Client already connected", "address", freeID, "id", peerIdToString(id))
- return false
- }
- // Create a clientInfo but do not add it yet
- var (
- posBalance uint64
- negBalance uint64
- now = f.clock.Now()
- )
- pb := f.ndb.getOrNewPB(id)
- posBalance = pb.value
-
- nb := f.ndb.getOrNewNB(freeID)
- if nb.logValue != 0 {
- negBalance = uint64(math.Exp(float64(nb.logValue-f.logOffset(now))/fixedPointMultiplier) * float64(time.Second))
+ return 0, fmt.Errorf("Client already connected address=%s id=%s", freeID, peerIdToString(id))
}
+ pb := f.ndb.getOrNewBalance(id.Bytes(), false)
+ nb := f.ndb.getOrNewBalance([]byte(freeID), true)
e := &clientInfo{
- pool: f,
- peer: peer,
- address: freeID,
- queueIndex: -1,
- id: id,
- connectedAt: now,
- priority: posBalance != 0,
- posFactors: f.defaultPosFactors,
- negFactors: f.defaultNegFactors,
- balanceMetaInfo: pb.meta,
- }
- // If the client is a free client, assign with a low free capacity,
- // Otherwise assign with the given value(priority client)
- if !e.priority || capacity == 0 {
- capacity = f.freeClientCap
- }
+ id: id,
+ address: freeID,
+ capacity: reqCapacity,
+ pool: f,
+ peer: peer,
+ queueIndex: -1,
+ connectedAt: f.clock.Now(),
+ priority: pb.value.base != 0,
+ posFactors: f.defaultPosFactors,
+ negFactors: f.defaultNegFactors,
+ }
+ missing, capacity := f.capAvailable(id, freeID, reqCapacity, 0, true)
+ f.connectedMap[id] = e
+ if missing != 0 {
+ // capacity is not available, add client to inactive queue
+ f.initBalanceTracker(&e.balanceTracker, pb, nb, capacity, false)
+ f.inactiveQueue.Push(e, -connPriority(e, f.clock.Now()))
+ return 0, nil
+ }
+ // capacity is available, add client
+ e.active = true
e.capacity = capacity
-
- // Starts a balance tracker
- e.balanceTracker.init(f.clock, capacity)
- e.balanceTracker.setBalance(posBalance, negBalance)
- e.updatePriceFactors()
-
- // If the number of clients already connected in the clientpool exceeds its
- // capacity, evict some clients with lowest priority.
- //
- // If the priority of the newly added client is lower than the priority of
- // all connected clients, the client is rejected.
- newCapacity := f.connectedCap + capacity
- newCount := f.connectedQueue.Size() + 1
- if newCapacity > f.capLimit || newCount > f.connLimit {
- var (
- kickList []*clientInfo
- kickPriority int64
- )
- f.connectedQueue.MultiPop(func(data interface{}, priority int64) bool {
- c := data.(*clientInfo)
- kickList = append(kickList, c)
- kickPriority = priority
- newCapacity -= c.capacity
- newCount--
- return newCapacity > f.capLimit || newCount > f.connLimit
- })
- bias := connectedBias
- if f.disableBias {
- bias = 0
- }
- if newCapacity > f.capLimit || newCount > f.connLimit || (e.balanceTracker.estimatedPriority(now+mclock.AbsTime(bias), false)-kickPriority) > 0 {
- for _, c := range kickList {
- f.connectedQueue.Push(c)
- }
- clientRejectedMeter.Mark(1)
- log.Debug("Client rejected", "address", freeID, "id", peerIdToString(id))
- return false
- }
- // accept new client, drop old ones
- for _, c := range kickList {
- f.dropClient(c, now, true)
- }
- }
-
+ f.initBalanceTracker(&e.balanceTracker, pb, nb, capacity, true)
// Register new client to connection queue.
- f.connectedMap[id] = e
- f.connectedQueue.Push(e)
- f.connectedCap += e.capacity
+ f.inactiveBalances.subExp(pb.value)
+ f.activeBalances.addExp(pb.value)
+ f.activeQueue.Push(e)
+ f.activeCap += e.capacity
// If the current client is a paid client, monitor the status of client,
// downgrade it to normal client if positive balance is used up.
if e.priority {
- f.priorityConnected += capacity
+ f.priorityActive += capacity
+ f.updateFreeRatio()
e.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) })
}
- // If the capacity of client is not the default value(free capacity), notify
- // it to update capacity.
- if e.capacity != f.freeClientCap {
- e.peer.updateCapacity(e.capacity)
- }
- totalConnectedGauge.Update(int64(f.connectedCap))
+ totalConnectedGauge.Update(int64(f.activeCap))
clientConnectedMeter.Mark(1)
log.Debug("Client accepted", "address", freeID)
- return true
+ return e.capacity, nil
+}
+
+// initBalanceTracker initializes the positive and negative balances and price factors
+func (f *clientPool) initBalanceTracker(bt *balanceTracker, pb tokenBalance, nb tokenBalance, capacity uint64, active bool) {
+ bt.exp = f
+ bt.init(f.clock, capacity)
+ bt.setBalance(pb.value, nb.value)
+ if active {
+ updatePriceFactors(bt, f.defaultPosFactors, f.defaultNegFactors)
+ } else {
+ zeroPriceFactors(bt)
+ }
}
// disconnect should be called when a connection is terminated. If the disconnection
@@ -326,17 +480,106 @@ func (f *clientPool) disconnect(p clientPoolPeer) {
f.lock.Lock()
defer f.lock.Unlock()
+ f.drop(p, false)
+}
+
+// drop deactivates the peer if necessary and drops it from the inactive queue
+func (f *clientPool) drop(p clientPoolPeer, kicked bool) {
// Short circuit if client pool is already closed.
if f.closed {
return
}
- // Short circuit if the peer hasn't been registered.
- e := f.connectedMap[p.ID()]
- if e == nil {
+ e, ok := f.connectedMap[p.ID()]
+ if !ok {
log.Debug("Client not connected", "address", p.freeClientId(), "id", peerIdToString(p.ID()))
return
}
- f.dropClient(e, f.clock.Now(), false)
+ tryActivate := e.active
+ if e.active {
+ f.deactivateClient(e, false)
+ }
+ f.finalizeBalance(e, f.clock.Now())
+ f.inactiveQueue.Remove(e.queueIndex)
+ delete(f.connectedMap, e.id)
+ if kicked {
+ clientKickedMeter.Mark(1)
+ log.Debug("Client kicked out", "address", e.address)
+ } else {
+ clientDisconnectedMeter.Mark(1)
+ log.Debug("Client disconnected", "address", e.address)
+ }
+ if tryActivate {
+ f.tryActivateClients()
+ }
+}
+
+// capAvailable checks whether the current priority level of the given client is enough to
+// connect or change capacity to the requested level and then stay connected for at least
+// the specified duration. If not then the additional required amount of positive balance is returned.
+func (f *clientPool) capAvailable(id enode.ID, freeID string, capacity uint64, minConnTime time.Duration, kick bool) (uint64, uint64) {
+ var missing uint64
+ if capacity == 0 {
+ capacity = f.freeClientCap
+ }
+ if capacity < f.minCap {
+ capacity = f.minCap
+ }
+ newCapacity := f.activeCap + capacity
+ newCount := f.activeQueue.Size() + 1
+ client := f.connectedMap[id]
+ if client != nil && client.active {
+ newCapacity -= client.capacity
+ newCount--
+ }
+ if newCapacity > f.capLimit || newCount > f.activeLimit {
+ var (
+ popList []*clientInfo
+ targetPriority int64
+ )
+ f.activeQueue.MultiPop(func(data interface{}, priority int64) bool {
+ c := data.(*clientInfo)
+ popList = append(popList, c)
+ if c != client {
+ targetPriority = priority
+ newCapacity -= c.capacity
+ newCount--
+ }
+ return newCapacity > f.capLimit || newCount > f.activeLimit
+ })
+ if newCapacity > f.capLimit || newCount > f.activeLimit {
+ missing = math.MaxUint64
+ } else {
+ var bt *balanceTracker
+ if client != nil {
+ bt = &client.balanceTracker
+ } else {
+ bt = &balanceTracker{}
+ f.initBalanceTracker(bt, f.ndb.getOrNewBalance(id.Bytes(), false), f.ndb.getOrNewBalance([]byte(freeID), true), capacity, true)
+ }
+ if capacity != f.freeClientCap && targetPriority >= 0 {
+ targetPriority = -1
+ }
+ bias := activeBias
+ if f.disableBias {
+ bias = 0
+ }
+ if bias < minConnTime {
+ bias = minConnTime
+ }
+ missing = bt.posBalanceMissing(targetPriority, capacity, bias)
+ }
+ if missing != 0 {
+ kick = false
+ }
+ for _, c := range popList {
+ if kick && c != client {
+ f.deactivateClient(c, true)
+ } else {
+ f.activeQueue.Push(c)
+ }
+ }
+ }
+ return missing, capacity
}
// forClients iterates through a list of clients, calling the callback for each one.
@@ -371,27 +614,61 @@ func (f *clientPool) setDefaultFactors(posFactors, negFactors priceFactors) {
f.defaultNegFactors = negFactors
}
-// dropClient removes a client from the connected queue and finalizes its balance.
-// If kick is true then it also initiates the disconnection.
-func (f *clientPool) dropClient(e *clientInfo, now mclock.AbsTime, kick bool) {
- if _, ok := f.connectedMap[e.id]; !ok {
+// deactivateClient puts a client in inactive state
+func (f *clientPool) deactivateClient(e *clientInfo, scheduleDrop bool) {
+ if _, ok := f.connectedMap[e.id]; !ok || !e.active {
return
}
- f.finalizeBalance(e, now)
- f.connectedQueue.Remove(e.queueIndex)
- delete(f.connectedMap, e.id)
- f.connectedCap -= e.capacity
+ f.activeQueue.Remove(e.queueIndex)
+ f.activeCap -= e.capacity
if e.priority {
- f.priorityConnected -= e.capacity
- }
- totalConnectedGauge.Update(int64(f.connectedCap))
- if kick {
- clientKickedMeter.Mark(1)
- log.Debug("Client kicked out", "address", e.address)
- f.removePeer(e.id)
- } else {
- clientDisconnectedMeter.Mark(1)
- log.Debug("Client disconnected", "address", e.address)
+ f.priorityActive -= e.capacity
+ f.updateFreeRatio()
+ }
+ e.active = false
+ e.peer.updateCapacity(0)
+ totalConnectedGauge.Update(int64(f.activeCap))
+ f.inactiveQueue.Push(e, -connPriority(e, f.clock.Now()))
+ if scheduleDrop {
+ f.dropInactivePeers[f.dropInactiveCounter+dropInactiveCycles] = append(f.dropInactivePeers[f.dropInactiveCounter+dropInactiveCycles], e)
+ }
+}
+
+// tryActivateClients checks whether some inactive clients have enough priority now
+// and activates them if possible
+func (f *clientPool) tryActivateClients() {
+ now := f.clock.Now()
+ for f.inactiveQueue.Size() != 0 {
+ e := f.inactiveQueue.PopItem().(*clientInfo)
+ missing, capacity := f.capAvailable(e.id, e.address, e.capacity, 0, true)
+ if missing != 0 {
+ f.inactiveQueue.Push(e, -connPriority(e, now))
+ return
+ }
+ // capacity is available, activate client
+ e.active = true
+ e.capacity = capacity
+ e.peer.updateCapacity(capacity)
+ balance, _ := e.balanceTracker.getBalance(now)
+ e.balanceTracker.setCapacity(capacity)
+ updatePriceFactors(&e.balanceTracker, f.defaultPosFactors, f.defaultNegFactors)
+ // Register activated client to connection queue.
+ f.inactiveBalances.subExp(balance)
+ f.activeBalances.addExp(balance)
+ f.activeQueue.Push(e)
+ f.activeCap += e.capacity
+
+ // If the current client is a paid client, monitor the status of client,
+ // downgrade it to normal client if positive balance is used up.
+ if e.priority {
+ f.priorityActive += capacity
+ f.updateFreeRatio()
+ e.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(e.id) })
+ }
+ e.peer.updateCapacity(e.capacity)
+ totalConnectedGauge.Update(int64(f.activeCap))
+ clientConnectedMeter.Mark(1)
+ log.Debug("Client activated", "address", e.address)
}
}
@@ -401,7 +678,7 @@ func (f *clientPool) capacityInfo() (uint64, uint64, uint64) {
f.lock.Lock()
defer f.lock.Unlock()
- return f.capLimit, f.connectedCap, f.priorityConnected
+ return f.capLimit, f.activeCap, f.priorityActive
}
// finalizeBalance stops the balance tracker, retrieves the final balances and
@@ -409,17 +686,27 @@ func (f *clientPool) capacityInfo() (uint64, uint64, uint64) {
func (f *clientPool) finalizeBalance(c *clientInfo, now mclock.AbsTime) {
c.balanceTracker.stop(now)
pos, neg := c.balanceTracker.getBalance(now)
+ f.inactiveBalances.addExp(pos)
+ f.activeBalances.subExp(pos)
- pb, nb := f.ndb.getOrNewPB(c.id), f.ndb.getOrNewNB(c.address)
- pb.value = pos
- f.ndb.setPB(c.id, pb)
-
- neg /= uint64(time.Second) // Convert the expanse to second level.
- if neg > 1 {
- nb.logValue = int64(math.Log(float64(neg))*fixedPointMultiplier) + f.logOffset(now)
- f.ndb.setNB(c.address, nb)
- } else {
- f.ndb.delNB(c.address) // Negative balance is small enough, drop it directly.
+ for index, value := range []expiredValue{pos, neg} {
+ var (
+ id []byte
+ expiration fixed64
+ )
+ neg := index == 1
+ if !neg {
+ id = c.id.Bytes()
+ expiration = f.posExpiration(f.clock.Now())
+ } else {
+ id = []byte(c.address)
+ expiration = f.negExpiration(f.clock.Now())
+ }
+ if value.value(expiration) > uint64(time.Second) {
+ f.ndb.setBalance(id, neg, tokenBalance{value: value})
+ } else {
+ f.ndb.delBalance(id, neg) // balance is small enough, drop it directly.
+ }
}
}
@@ -434,429 +721,165 @@ func (f *clientPool) balanceExhausted(id enode.ID) {
return
}
if c.priority {
- f.priorityConnected -= c.capacity
+ f.priorityActive -= c.capacity
+ f.updateFreeRatio()
}
c.priority = false
if c.capacity != f.freeClientCap {
- f.connectedCap += f.freeClientCap - c.capacity
- totalConnectedGauge.Update(int64(f.connectedCap))
+ f.activeCap += f.freeClientCap - c.capacity
+ totalConnectedGauge.Update(int64(f.activeCap))
c.capacity = f.freeClientCap
c.balanceTracker.setCapacity(c.capacity)
c.peer.updateCapacity(c.capacity)
}
- pb := f.ndb.getOrNewPB(id)
- pb.value = 0
- f.ndb.setPB(id, pb)
+ f.ndb.delBalance(id.Bytes(), false)
}
-// setConnLimit sets the maximum number and total capacity of connected clients,
+// setactiveLimit sets the maximum number and total capacity of connected clients,
// dropping some of them if necessary.
func (f *clientPool) setLimits(totalConn int, totalCap uint64) {
f.lock.Lock()
defer f.lock.Unlock()
- f.connLimit = totalConn
+ f.activeLimit = totalConn
f.capLimit = totalCap
- if f.connectedCap > f.capLimit || f.connectedQueue.Size() > f.connLimit {
- f.connectedQueue.MultiPop(func(data interface{}, priority int64) bool {
- f.dropClient(data.(*clientInfo), mclock.Now(), true)
- return f.connectedCap > f.capLimit || f.connectedQueue.Size() > f.connLimit
+ if f.activeCap > f.capLimit || f.activeQueue.Size() > f.activeLimit {
+ f.activeQueue.MultiPop(func(data interface{}, priority int64) bool {
+ f.deactivateClient(data.(*clientInfo), true)
+ return f.activeCap > f.capLimit || f.activeQueue.Size() > f.activeLimit
})
+ } else {
+ f.tryActivateClients()
}
+ f.updateFreeRatio()
}
// setCapacity sets the assigned capacity of a connected client
-func (f *clientPool) setCapacity(c *clientInfo, capacity uint64) error {
- if f.connectedMap[c.id] != c {
- return fmt.Errorf("client %064x is not connected", c.id[:])
- }
- if c.capacity == capacity {
- return nil
- }
- if !c.priority {
- return errNoPriority
- }
- oldCapacity := c.capacity
- c.capacity = capacity
- f.connectedCap += capacity - oldCapacity
- c.balanceTracker.setCapacity(capacity)
- f.connectedQueue.Update(c.queueIndex)
- if f.connectedCap > f.capLimit {
- var kickList []*clientInfo
- kick := true
- f.connectedQueue.MultiPop(func(data interface{}, priority int64) bool {
- client := data.(*clientInfo)
- kickList = append(kickList, client)
- f.connectedCap -= client.capacity
- if client == c {
- kick = false
- }
- return kick && (f.connectedCap > f.capLimit)
- })
- if kick {
- now := mclock.Now()
- for _, c := range kickList {
- f.dropClient(c, now, true)
- }
- } else {
- c.capacity = oldCapacity
- c.balanceTracker.setCapacity(oldCapacity)
- for _, c := range kickList {
- f.connectedCap += c.capacity
- f.connectedQueue.Push(c)
- }
- return errNoPriority
+func (f *clientPool) setCapacity(id enode.ID, freeID string, capacity uint64, minConnTime time.Duration, setCap bool) (uint64, uint64, error) {
+ c := f.connectedMap[id]
+ if c != nil {
+ if c.capacity == capacity {
+ return 0, capacity, nil
}
}
- totalConnectedGauge.Update(int64(f.connectedCap))
- f.priorityConnected += capacity - oldCapacity
- c.updatePriceFactors()
- c.peer.updateCapacity(c.capacity)
- return nil
+ var missing uint64
+ missing, capacity = f.capAvailable(id, freeID, capacity, 0, setCap && c != nil)
+ if missing != 0 {
+ return missing, capacity, errNoPriority
+ }
+ // capacity update is possible
+ if setCap {
+ if c == nil {
+ return 0, capacity, fmt.Errorf("client %064x is not connected", c.id[:])
+ }
+ f.activeCap += capacity - c.capacity
+ f.priorityActive += capacity - c.capacity
+ f.updateFreeRatio()
+ c.capacity = capacity
+ c.balanceTracker.setCapacity(capacity)
+ f.activeQueue.Update(c.queueIndex)
+ totalConnectedGauge.Update(int64(f.activeCap))
+ updatePriceFactors(&c.balanceTracker, c.posFactors, c.negFactors)
+ c.peer.updateCapacity(c.capacity)
+ f.tryActivateClients()
+ }
+ return 0, capacity, nil
}
-// requestCost feeds request cost after serving a request from the given peer.
-func (f *clientPool) requestCost(p *clientPeer, cost uint64) {
+// setCapacityLocked is the equivalent of setCapacity used when f.lock is already locked
+func (f *clientPool) setCapacityLocked(id enode.ID, freeID string, capacity uint64, minConnTime time.Duration, setCap bool) (uint64, uint64, error) {
f.lock.Lock()
defer f.lock.Unlock()
- info, exist := f.connectedMap[p.ID()]
- if !exist || f.closed {
- return
+ return f.setCapacity(id, freeID, capacity, minConnTime, setCap)
+}
+
+// requestCost feeds request cost after serving a request from the given peer and
+// returns the remaining token balance
+func (f *clientPool) requestCost(p *clientPeer, cost uint64) uint64 {
+ f.lock.Lock()
+ defer f.lock.Unlock()
+
+ c := f.connectedMap[p.ID()]
+ if c == nil || f.closed {
+ return 0
}
- info.balanceTracker.requestCost(cost)
+ return c.balanceTracker.requestCost(cost)
}
-// logOffset calculates the time-dependent offset for the logarithmic
-// representation of negative balance
-//
-// From another point of view, the result returned by the function represents
-// the total time that the clientpool is cumulatively running(total_hours/multiplier).
-func (f *clientPool) logOffset(now mclock.AbsTime) int64 {
- // Note: fixedPointMultiplier acts as a multiplier here; the reason for dividing the divisor
- // is to avoid int64 overflow. We assume that int64(negBalanceExpTC) >> fixedPointMultiplier.
- cumulativeTime := int64((time.Duration(now - f.startTime)) / (negBalanceExpTC / fixedPointMultiplier))
- return f.cumulativeTime + cumulativeTime
+// updatePriceFactors sets the pricing factors for an individual connected client
+func updatePriceFactors(bt *balanceTracker, posFactors, negFactors priceFactors) {
+ bt.setFactors(posFactors, negFactors)
}
-// setClientPriceFactors sets the pricing factors for an individual connected client
-func (c *clientInfo) updatePriceFactors() {
- c.balanceTracker.setFactors(true, c.negFactors.timeFactor+float64(c.capacity)*c.negFactors.capacityFactor/1000000, c.negFactors.requestFactor)
- c.balanceTracker.setFactors(false, c.posFactors.timeFactor+float64(c.capacity)*c.posFactors.capacityFactor/1000000, c.posFactors.requestFactor)
+// zeroPriceFactors sets the pricing factors to zero
+func zeroPriceFactors(bt *balanceTracker) {
+ bt.setFactors(priceFactors{0, 0, 0}, priceFactors{0, 0, 0})
}
// getPosBalance retrieves a single positive balance entry from cache or the database
-func (f *clientPool) getPosBalance(id enode.ID) posBalance {
+func (f *clientPool) getPosBalance(id enode.ID) tokenBalance {
f.lock.Lock()
defer f.lock.Unlock()
- return f.ndb.getOrNewPB(id)
+ if c := f.connectedMap[id]; c != nil {
+ value, _ := c.balanceTracker.getBalance(f.clock.Now())
+ return tokenBalance{value: value}
+ } else {
+ return f.ndb.getOrNewBalance(id.Bytes(), false)
+ }
}
// addBalance updates the balance of a client (either overwrites it or adds to it).
// It also updates the balance meta info string.
-func (f *clientPool) addBalance(id enode.ID, amount int64, meta string) (uint64, uint64, error) {
+func (f *clientPool) addBalance(id enode.ID, amount int64) (uint64, uint64, error) {
f.lock.Lock()
defer f.lock.Unlock()
- pb := f.ndb.getOrNewPB(id)
- var negBalance uint64
+ now := f.clock.Now()
+ pb := f.ndb.getOrNewBalance(id.Bytes(), false)
+ var negBalance expiredValue
c := f.connectedMap[id]
if c != nil {
- pb.value, negBalance = c.balanceTracker.getBalance(f.clock.Now())
+ pb.value, negBalance = c.balanceTracker.getBalance(now)
}
oldBalance := pb.value
- if amount > 0 {
- if amount > maxBalance || pb.value > maxBalance-uint64(amount) {
- return oldBalance, oldBalance, errBalanceOverflow
- }
- pb.value += uint64(amount)
- } else {
- if uint64(-amount) > pb.value {
- pb.value = 0
- } else {
- pb.value -= uint64(-amount)
- }
+ posExp := f.posExpiration(now)
+ oldValue := oldBalance.value(posExp)
+ if amount > 0 && (amount > maxBalance || oldValue > maxBalance-uint64(amount)) {
+ return oldValue, oldValue, errBalanceOverflow
}
- pb.meta = meta
- f.ndb.setPB(id, pb)
+ pb.value.add(amount, posExp)
+ f.ndb.setBalance(id.Bytes(), false, pb)
if c != nil {
c.balanceTracker.setBalance(pb.value, negBalance)
- if !c.priority && pb.value > 0 {
- // The capacity should be adjusted based on the requirement,
- // but we have no idea about the new capacity, need a second
- // call to udpate it.
- c.priority = true
- f.priorityConnected += c.capacity
- c.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) })
- }
- // if balance is set to zero then reverting to non-priority status
- // is handled by the balanceExhausted callback
- c.balanceMetaInfo = meta
- }
- return oldBalance, pb.value, nil
-}
-
-// posBalance represents a recently accessed positive balance entry
-type posBalance struct {
- value uint64
- meta string
-}
-
-// EncodeRLP implements rlp.Encoder
-func (e *posBalance) EncodeRLP(w io.Writer) error {
- return rlp.Encode(w, []interface{}{e.value, e.meta})
-}
-
-// DecodeRLP implements rlp.Decoder
-func (e *posBalance) DecodeRLP(s *rlp.Stream) error {
- var entry struct {
- Value uint64
- Meta string
- }
- if err := s.Decode(&entry); err != nil {
- return err
- }
- e.value = entry.Value
- e.meta = entry.Meta
- return nil
-}
-
-// negBalance represents a negative balance entry of a disconnected client
-type negBalance struct{ logValue int64 }
-
-// EncodeRLP implements rlp.Encoder
-func (e *negBalance) EncodeRLP(w io.Writer) error {
- return rlp.Encode(w, []interface{}{uint64(e.logValue)})
-}
-
-// DecodeRLP implements rlp.Decoder
-func (e *negBalance) DecodeRLP(s *rlp.Stream) error {
- var entry struct {
- LogValue uint64
- }
- if err := s.Decode(&entry); err != nil {
- return err
- }
- e.logValue = int64(entry.LogValue)
- return nil
-}
-
-const (
- // nodeDBVersion is the version identifier of the node data in db
- //
- // Changelog:
- // * Replace `lastTotal` with `meta` in positive balance: version 0=>1
- nodeDBVersion = 1
-
- // dbCleanupCycle is the cycle of db for useless data cleanup
- dbCleanupCycle = time.Hour
-)
-
-var (
- positiveBalancePrefix = []byte("pb:") // dbVersion(uint16 big endian) + positiveBalancePrefix + id -> balance
- negativeBalancePrefix = []byte("nb:") // dbVersion(uint16 big endian) + negativeBalancePrefix + ip -> balance
- cumulativeRunningTimeKey = []byte("cumulativeTime:") // dbVersion(uint16 big endian) + cumulativeRunningTimeKey -> cumulativeTime
-)
-
-type nodeDB struct {
- db ethdb.Database
- pcache *lru.Cache
- ncache *lru.Cache
- auxbuf []byte // 37-byte auxiliary buffer for key encoding
- verbuf [2]byte // 2-byte auxiliary buffer for db version
- nbEvictCallBack func(mclock.AbsTime, negBalance) bool // Callback to determine whether the negative balance can be evicted.
- clock mclock.Clock
- closeCh chan struct{}
- cleanupHook func() // Test hook used for testing
-}
-
-func newNodeDB(db ethdb.Database, clock mclock.Clock) *nodeDB {
- pcache, _ := lru.New(posBalanceCacheLimit)
- ncache, _ := lru.New(negBalanceCacheLimit)
- ndb := &nodeDB{
- db: db,
- pcache: pcache,
- ncache: ncache,
- auxbuf: make([]byte, 37),
- clock: clock,
- closeCh: make(chan struct{}),
- }
- binary.BigEndian.PutUint16(ndb.verbuf[:], uint16(nodeDBVersion))
- go ndb.expirer()
- return ndb
-}
-
-func (db *nodeDB) close() {
- close(db.closeCh)
-}
-
-func (db *nodeDB) key(id []byte, neg bool) []byte {
- prefix := positiveBalancePrefix
- if neg {
- prefix = negativeBalancePrefix
- }
- if len(prefix)+len(db.verbuf)+len(id) > len(db.auxbuf) {
- db.auxbuf = append(db.auxbuf, make([]byte, len(prefix)+len(db.verbuf)+len(id)-len(db.auxbuf))...)
- }
- copy(db.auxbuf[:len(db.verbuf)], db.verbuf[:])
- copy(db.auxbuf[len(db.verbuf):len(db.verbuf)+len(prefix)], prefix)
- copy(db.auxbuf[len(prefix)+len(db.verbuf):len(prefix)+len(db.verbuf)+len(id)], id)
- return db.auxbuf[:len(prefix)+len(db.verbuf)+len(id)]
-}
-
-func (db *nodeDB) getCumulativeTime() int64 {
- blob, err := db.db.Get(append(cumulativeRunningTimeKey, db.verbuf[:]...))
- if err != nil || len(blob) == 0 {
- return 0
- }
- return int64(binary.BigEndian.Uint64(blob))
-}
-
-func (db *nodeDB) setCumulativeTime(v int64) {
- binary.BigEndian.PutUint64(db.auxbuf[:8], uint64(v))
- db.db.Put(append(cumulativeRunningTimeKey, db.verbuf[:]...), db.auxbuf[:8])
-}
-
-func (db *nodeDB) getOrNewPB(id enode.ID) posBalance {
- key := db.key(id.Bytes(), false)
- item, exist := db.pcache.Get(string(key))
- if exist {
- return item.(posBalance)
- }
- var balance posBalance
- if enc, err := db.db.Get(key); err == nil {
- if err := rlp.DecodeBytes(enc, &balance); err != nil {
- log.Error("Failed to decode positive balance", "err", err)
- }
- }
- db.pcache.Add(string(key), balance)
- return balance
-}
-
-func (db *nodeDB) setPB(id enode.ID, b posBalance) {
- if b.value == 0 && len(b.meta) == 0 {
- db.delPB(id)
- return
- }
- key := db.key(id.Bytes(), false)
- enc, err := rlp.EncodeToBytes(&(b))
- if err != nil {
- log.Error("Failed to encode positive balance", "err", err)
- return
- }
- db.db.Put(key, enc)
- db.pcache.Add(string(key), b)
-}
-
-func (db *nodeDB) delPB(id enode.ID) {
- key := db.key(id.Bytes(), false)
- db.db.Delete(key)
- db.pcache.Remove(string(key))
-}
-
-// getPosBalanceIDs returns a lexicographically ordered list of IDs of accounts
-// with a positive balance
-func (db *nodeDB) getPosBalanceIDs(start, stop enode.ID, maxCount int) (result []enode.ID) {
- if maxCount <= 0 {
- return
- }
- it := db.db.NewIteratorWithStart(db.key(start.Bytes(), false))
- defer it.Release()
- for i := len(stop[:]) - 1; i >= 0; i-- {
- stop[i]--
- if stop[i] != 255 {
- break
- }
- }
- stopKey := db.key(stop.Bytes(), false)
- keyLen := len(stopKey)
-
- for it.Next() {
- var id enode.ID
- if len(it.Key()) != keyLen || bytes.Compare(it.Key(), stopKey) == 1 {
- return
- }
- copy(id[:], it.Key()[keyLen-len(id):])
- result = append(result, id)
- if len(result) == maxCount {
- return
- }
- }
- return
-}
-
-func (db *nodeDB) getOrNewNB(id string) negBalance {
- key := db.key([]byte(id), true)
- item, exist := db.ncache.Get(string(key))
- if exist {
- return item.(negBalance)
- }
- var balance negBalance
- if enc, err := db.db.Get(key); err == nil {
- if err := rlp.DecodeBytes(enc, &balance); err != nil {
- log.Error("Failed to decode negative balance", "err", err)
- }
- }
- db.ncache.Add(string(key), balance)
- return balance
-}
-
-func (db *nodeDB) setNB(id string, b negBalance) {
- key := db.key([]byte(id), true)
- enc, err := rlp.EncodeToBytes(&(b))
- if err != nil {
- log.Error("Failed to encode negative balance", "err", err)
- return
- }
- db.db.Put(key, enc)
- db.ncache.Add(string(key), b)
-}
-
-func (db *nodeDB) delNB(id string) {
- key := db.key([]byte(id), true)
- db.db.Delete(key)
- db.ncache.Remove(string(key))
-}
-
-func (db *nodeDB) expirer() {
- for {
- select {
- case <-db.clock.After(dbCleanupCycle):
- db.expireNodes()
- case <-db.closeCh:
- return
- }
- }
-}
-
-// expireNodes iterates the whole node db and checks whether the negative balance
-// entry can deleted.
-//
-// The rationale behind this is: server doesn't need to keep the negative balance
-// records if they are low enough.
-func (db *nodeDB) expireNodes() {
- var (
- visited int
- deleted int
- start = time.Now()
- )
- iter := db.db.NewIteratorWithPrefix(append(db.verbuf[:], negativeBalancePrefix...))
- for iter.Next() {
- visited += 1
- var balance negBalance
- if err := rlp.DecodeBytes(iter.Value(), &balance); err != nil {
- log.Error("Failed to decode negative balance", "err", err)
- continue
+ if c.active {
+ f.activeQueue.Update(c.queueIndex)
+ if !c.priority && pb.value.base > 0 {
+ // The capacity should be adjusted based on the requirement,
+ // but we have no idea about the new capacity, need a second
+ // call to udpate it.
+ f.priorityActive += c.capacity
+ f.updateFreeRatio()
+ c.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) })
+ }
+ f.activeBalances.subExp(oldBalance)
+ f.activeBalances.addExp(pb.value)
+ } else {
+ f.inactiveQueue.Remove(c.queueIndex)
+ f.inactiveQueue.Push(c, -connPriority(c, f.clock.Now()))
+ f.inactiveBalances.subExp(oldBalance)
+ f.inactiveBalances.addExp(pb.value)
}
- if db.nbEvictCallBack != nil && db.nbEvictCallBack(db.clock.Now(), balance) {
- deleted += 1
- db.db.Delete(iter.Key())
+ if pb.value.base > 0 {
+ c.priority = true
+ // if balance is set to zero then reverting to non-priority status
+ // is handled by the balanceExhausted callback
}
+ } else {
+ f.inactiveBalances.subExp(oldBalance)
+ f.inactiveBalances.addExp(pb.value)
}
- // Invoke testing hook if it's not nil.
- if db.cleanupHook != nil {
- db.cleanupHook()
- }
- log.Debug("Expire nodes", "visited", visited, "deleted", deleted, "elapsed", common.PrettyDuration(time.Since(start)))
+ f.tryActivateClients()
+ return oldValue, pb.value.value(posExp), nil
}
diff --git a/les/clientpool_test.go b/les/clientpool_test.go
index 6308113fe74c..a34deff1b88d 100644
--- a/les/clientpool_test.go
+++ b/les/clientpool_test.go
@@ -17,11 +17,8 @@
package les
import (
- "bytes"
"fmt"
- "math"
"math/rand"
- "reflect"
"testing"
"time"
@@ -56,29 +53,34 @@ func TestClientPoolL100C300P20(t *testing.T) {
const testClientPoolTicks = 100000
-type poolTestPeer int
-
-func (i poolTestPeer) ID() enode.ID {
- return enode.ID{byte(i % 256), byte(i >> 8)}
+type poolTestPeer struct {
+ index int
+ disconnCh chan int
+ cap uint64
}
-func (i poolTestPeer) freeClientId() string {
- return fmt.Sprintf("addr #%d", i)
+func newPoolTestPeer(i int, disconnCh chan int) *poolTestPeer {
+ return &poolTestPeer{index: i, disconnCh: disconnCh}
}
-func (i poolTestPeer) updateCapacity(uint64) {}
-
-type poolTestPeerWithCap struct {
- poolTestPeer
+func (i *poolTestPeer) ID() enode.ID {
+ return enode.ID{byte(i.index % 256), byte(i.index >> 8)}
+}
- cap uint64
+func (i *poolTestPeer) freeClientId() string {
+ return fmt.Sprintf("addr #%d", i)
}
-func (i *poolTestPeerWithCap) updateCapacity(cap uint64) { i.cap = cap }
+func (i *poolTestPeer) updateCapacity(cap uint64) {
+ i.cap = cap
+ if cap == 0 && i.disconnCh != nil {
+ i.disconnCh <- i.index
+ }
+}
-func (i poolTestPeer) freezeClient() {}
+func (i *poolTestPeer) freeze() {}
-func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomDisconnect bool) {
+func testClientPool(t *testing.T, activeLimit, clientCount, paidCount int, randomDisconnect bool) {
rand.Seed(time.Now().UnixNano())
var (
clock mclock.Simulated
@@ -89,15 +91,16 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD
disconnFn = func(id enode.ID) {
disconnCh <- int(id[0]) + int(id[1])<<8
}
- pool = newClientPool(db, 1, &clock, disconnFn)
+ pool = newClientPool(db, 1, 1, &clock, disconnFn)
)
+
pool.disableBias = true
- pool.setLimits(connLimit, uint64(connLimit))
+ pool.setLimits(activeLimit, uint64(activeLimit))
pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
// pool should accept new peers up to its connected limit
- for i := 0; i < connLimit; i++ {
- if pool.connect(poolTestPeer(i), 0) {
+ for i := 0; i < activeLimit; i++ {
+ if cap, _ := pool.connect(newPoolTestPeer(i, disconnCh), 0); cap != 0 {
connected[i] = true
} else {
t.Fatalf("Test peer #%d rejected", i)
@@ -111,28 +114,30 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD
// give a positive balance to some of the peers
amount := testClientPoolTicks / 2 * int64(time.Second) // enough for half of the simulation period
for i := 0; i < paidCount; i++ {
- pool.addBalance(poolTestPeer(i).ID(), amount, "")
+ pool.addBalance(newPoolTestPeer(i, disconnCh).ID(), amount)
}
}
i := rand.Intn(clientCount)
if connected[i] {
if randomDisconnect {
- pool.disconnect(poolTestPeer(i))
+ pool.disconnect(newPoolTestPeer(i, disconnCh))
connected[i] = false
connTicks[i] += tickCounter
}
} else {
- if pool.connect(poolTestPeer(i), 0) {
+ if cap, _ := pool.connect(newPoolTestPeer(i, disconnCh), 0); cap != 0 {
connected[i] = true
connTicks[i] -= tickCounter
+ } else {
+ pool.disconnect(newPoolTestPeer(i, disconnCh))
}
}
pollDisconnects:
for {
select {
case i := <-disconnCh:
- pool.disconnect(poolTestPeer(i))
+ pool.disconnect(newPoolTestPeer(i, disconnCh))
if connected[i] {
connTicks[i] += tickCounter
connected[i] = false
@@ -143,10 +148,10 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD
}
}
- expTicks := testClientPoolTicks/2*connLimit/clientCount + testClientPoolTicks/2*(connLimit-paidCount)/(clientCount-paidCount)
+ expTicks := testClientPoolTicks/2*activeLimit/clientCount + testClientPoolTicks/2*(activeLimit-paidCount)/(clientCount-paidCount)
expMin := expTicks - expTicks/5
expMax := expTicks + expTicks/5
- paidTicks := testClientPoolTicks/2*connLimit/clientCount + testClientPoolTicks/2
+ paidTicks := testClientPoolTicks/2*activeLimit/clientCount + testClientPoolTicks/2
paidMin := paidTicks - paidTicks/5
paidMax := paidTicks + paidTicks/5
@@ -172,15 +177,15 @@ func TestConnectPaidClient(t *testing.T) {
clock mclock.Simulated
db = rawdb.NewMemoryDatabase()
)
- pool := newClientPool(db, 1, &clock, nil)
+ pool := newClientPool(db, 1, 1, &clock, nil)
defer pool.stop()
pool.setLimits(10, uint64(10))
pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
// Add balance for an external client and mark it as paid client
- pool.addBalance(poolTestPeer(0).ID(), 1000, "")
+ pool.addBalance(newPoolTestPeer(0, nil).ID(), 1000)
- if !pool.connect(poolTestPeer(0), 10) {
+ if cap, _ := pool.connect(newPoolTestPeer(0, nil), 10); cap == 0 {
t.Fatalf("Failed to connect paid client")
}
}
@@ -190,16 +195,16 @@ func TestConnectPaidClientToSmallPool(t *testing.T) {
clock mclock.Simulated
db = rawdb.NewMemoryDatabase()
)
- pool := newClientPool(db, 1, &clock, nil)
+ pool := newClientPool(db, 1, 1, &clock, nil)
defer pool.stop()
pool.setLimits(10, uint64(10)) // Total capacity limit is 10
pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
// Add balance for an external client and mark it as paid client
- pool.addBalance(poolTestPeer(0).ID(), 1000, "")
+ pool.addBalance(newPoolTestPeer(0, nil).ID(), 1000)
// Connect a fat paid client to pool, should reject it.
- if pool.connect(poolTestPeer(0), 100) {
+ if cap, _ := pool.connect(newPoolTestPeer(0, nil), 100); cap != 0 {
t.Fatalf("Connected fat paid client, should reject it")
}
}
@@ -210,23 +215,23 @@ func TestConnectPaidClientToFullPool(t *testing.T) {
db = rawdb.NewMemoryDatabase()
)
removeFn := func(enode.ID) {} // Noop
- pool := newClientPool(db, 1, &clock, removeFn)
+ pool := newClientPool(db, 1, 1, &clock, removeFn)
defer pool.stop()
pool.setLimits(10, uint64(10)) // Total capacity limit is 10
pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
for i := 0; i < 10; i++ {
- pool.addBalance(poolTestPeer(i).ID(), 1000000000, "")
- pool.connect(poolTestPeer(i), 1)
+ pool.addBalance(newPoolTestPeer(i, nil).ID(), int64(time.Second))
+ pool.connect(newPoolTestPeer(i, nil), 1)
}
- pool.addBalance(poolTestPeer(11).ID(), 1000, "") // Add low balance to new paid client
- if pool.connect(poolTestPeer(11), 1) {
+ pool.addBalance(newPoolTestPeer(11, nil).ID(), 1000) // Add low balance to new paid client
+ if cap, _ := pool.connect(newPoolTestPeer(11, nil), 1); cap != 0 {
t.Fatalf("Low balance paid client should be rejected")
}
clock.Run(time.Second)
- pool.addBalance(poolTestPeer(12).ID(), 1000000000*60*3, "") // Add high balance to new paid client
- if !pool.connect(poolTestPeer(12), 1) {
- t.Fatalf("High balance paid client should be accpected")
+ pool.addBalance(newPoolTestPeer(12, nil).ID(), int64(time.Minute*5)) // Add high balance to new paid client
+ if cap, _ := pool.connect(newPoolTestPeer(12, nil), 1); cap == 0 {
+ t.Fatalf("High balance paid client should be accepted")
}
}
@@ -237,19 +242,19 @@ func TestPaidClientKickedOut(t *testing.T) {
kickedCh = make(chan int, 1)
)
removeFn := func(id enode.ID) { kickedCh <- int(id[0]) }
- pool := newClientPool(db, 1, &clock, removeFn)
+ pool := newClientPool(db, 1, 1, &clock, removeFn)
defer pool.stop()
pool.setLimits(10, uint64(10)) // Total capacity limit is 10
pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
for i := 0; i < 10; i++ {
- pool.addBalance(poolTestPeer(i).ID(), 1000000000, "") // 1 second allowance
- pool.connect(poolTestPeer(i), 1)
+ pool.addBalance(newPoolTestPeer(i, kickedCh).ID(), 1000000000) // 1 second allowance
+ pool.connect(newPoolTestPeer(i, kickedCh), 1)
clock.Run(time.Millisecond)
}
clock.Run(time.Second)
- clock.Run(connectedBias)
- if !pool.connect(poolTestPeer(11), 0) {
+ clock.Run(activeBias)
+ if cap, _ := pool.connect(newPoolTestPeer(11, kickedCh), 0); cap == 0 {
t.Fatalf("Free client should be accectped")
}
select {
@@ -267,11 +272,11 @@ func TestConnectFreeClient(t *testing.T) {
clock mclock.Simulated
db = rawdb.NewMemoryDatabase()
)
- pool := newClientPool(db, 1, &clock, nil)
+ pool := newClientPool(db, 1, 1, &clock, nil)
defer pool.stop()
pool.setLimits(10, uint64(10))
pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
- if !pool.connect(poolTestPeer(0), 10) {
+ if cap, _ := pool.connect(newPoolTestPeer(0, nil), 10); cap == 0 {
t.Fatalf("Failed to connect free client")
}
}
@@ -282,24 +287,24 @@ func TestConnectFreeClientToFullPool(t *testing.T) {
db = rawdb.NewMemoryDatabase()
)
removeFn := func(enode.ID) {} // Noop
- pool := newClientPool(db, 1, &clock, removeFn)
+ pool := newClientPool(db, 1, 1, &clock, removeFn)
defer pool.stop()
pool.setLimits(10, uint64(10)) // Total capacity limit is 10
pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
for i := 0; i < 10; i++ {
- pool.connect(poolTestPeer(i), 1)
+ pool.connect(newPoolTestPeer(i, nil), 1)
}
- if pool.connect(poolTestPeer(11), 1) {
+ if cap, _ := pool.connect(newPoolTestPeer(11, nil), 1); cap != 0 {
t.Fatalf("New free client should be rejected")
}
clock.Run(time.Minute)
- if pool.connect(poolTestPeer(12), 1) {
+ if cap, _ := pool.connect(newPoolTestPeer(12, nil), 1); cap != 0 {
t.Fatalf("New free client should be rejected")
}
clock.Run(time.Millisecond)
clock.Run(4 * time.Minute)
- if !pool.connect(poolTestPeer(13), 1) {
+ if cap, _ := pool.connect(newPoolTestPeer(13, nil), 1); cap == 0 {
t.Fatalf("Old client connects more than 5min should be kicked")
}
}
@@ -311,21 +316,22 @@ func TestFreeClientKickedOut(t *testing.T) {
kicked = make(chan int, 10)
)
removeFn := func(id enode.ID) { kicked <- int(id[0]) }
- pool := newClientPool(db, 1, &clock, removeFn)
+ pool := newClientPool(db, 1, 1, &clock, removeFn)
defer pool.stop()
pool.setLimits(10, uint64(10)) // Total capacity limit is 10
pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
for i := 0; i < 10; i++ {
- pool.connect(poolTestPeer(i), 1)
+ pool.connect(newPoolTestPeer(i, kicked), 1)
clock.Run(time.Millisecond)
}
- if pool.connect(poolTestPeer(10), 1) {
+ if cap, _ := pool.connect(newPoolTestPeer(10, kicked), 1); cap != 0 {
t.Fatalf("New free client should be rejected")
}
+ pool.disconnect(newPoolTestPeer(10, kicked))
clock.Run(5 * time.Minute)
for i := 0; i < 10; i++ {
- pool.connect(poolTestPeer(i+10), 1)
+ pool.connect(newPoolTestPeer(i+10, kicked), 1)
}
for i := 0; i < 10; i++ {
select {
@@ -346,18 +352,18 @@ func TestPositiveBalanceCalculation(t *testing.T) {
kicked = make(chan int, 10)
)
removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop
- pool := newClientPool(db, 1, &clock, removeFn)
+ pool := newClientPool(db, 1, 1, &clock, removeFn)
defer pool.stop()
pool.setLimits(10, uint64(10)) // Total capacity limit is 10
pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
- pool.addBalance(poolTestPeer(0).ID(), int64(time.Minute*3), "")
- pool.connect(poolTestPeer(0), 10)
+ pool.addBalance(newPoolTestPeer(0, kicked).ID(), int64(time.Minute*3))
+ pool.connect(newPoolTestPeer(0, kicked), 10)
clock.Run(time.Minute)
- pool.disconnect(poolTestPeer(0))
- pb := pool.ndb.getOrNewPB(poolTestPeer(0).ID())
- if pb.value != uint64(time.Minute*2) {
+ pool.disconnect(newPoolTestPeer(0, kicked))
+ pb := pool.ndb.getOrNewBalance(newPoolTestPeer(0, kicked).ID().Bytes(), false)
+ if pb.value != expval(uint64(time.Minute*2)) {
t.Fatalf("Positive balance mismatch, want %v, got %v", uint64(time.Minute*2), pb.value)
}
}
@@ -369,16 +375,14 @@ func TestDowngradePriorityClient(t *testing.T) {
kicked = make(chan int, 10)
)
removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop
- pool := newClientPool(db, 1, &clock, removeFn)
+ pool := newClientPool(db, 1, 1, &clock, removeFn)
defer pool.stop()
pool.setLimits(10, uint64(10)) // Total capacity limit is 10
pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
- p := &poolTestPeerWithCap{
- poolTestPeer: poolTestPeer(0),
- }
- pool.addBalance(p.ID(), int64(time.Minute), "")
- pool.connect(p, 10)
+ p := newPoolTestPeer(0, kicked)
+ pool.addBalance(p.ID(), int64(time.Minute))
+ p.cap, _ = pool.connect(p, 10)
if p.cap != 10 {
t.Fatalf("The capcacity of priority peer hasn't been updated, got: %d", p.cap)
}
@@ -388,156 +392,130 @@ func TestDowngradePriorityClient(t *testing.T) {
if p.cap != 1 {
t.Fatalf("The capcacity of peer should be downgraded, got: %d", p.cap)
}
- pb := pool.ndb.getOrNewPB(poolTestPeer(0).ID())
- if pb.value != 0 {
+ pb := pool.ndb.getOrNewBalance(newPoolTestPeer(0, kicked).ID().Bytes(), false)
+ if pb.value.base != 0 {
t.Fatalf("Positive balance mismatch, want %v, got %v", 0, pb.value)
}
- pool.addBalance(poolTestPeer(0).ID(), int64(time.Minute), "")
- pb = pool.ndb.getOrNewPB(poolTestPeer(0).ID())
- if pb.value != uint64(time.Minute) {
+ pool.addBalance(newPoolTestPeer(0, kicked).ID(), int64(time.Minute))
+ pb = pool.ndb.getOrNewBalance(newPoolTestPeer(0, kicked).ID().Bytes(), false)
+ if pb.value != expval(uint64(time.Minute)) {
t.Fatalf("Positive balance mismatch, want %v, got %v", uint64(time.Minute), pb.value)
}
}
func TestNegativeBalanceCalculation(t *testing.T) {
var (
- clock mclock.Simulated
- db = rawdb.NewMemoryDatabase()
- kicked = make(chan int, 10)
+ clock mclock.Simulated
+ db = rawdb.NewMemoryDatabase()
)
- removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop
- pool := newClientPool(db, 1, &clock, removeFn)
+ pool := newClientPool(db, 1, 1, &clock, nil)
defer pool.stop()
pool.setLimits(10, uint64(10)) // Total capacity limit is 10
pool.setDefaultFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1})
for i := 0; i < 10; i++ {
- pool.connect(poolTestPeer(i), 1)
+ pool.connect(newPoolTestPeer(i, nil), 1)
}
clock.Run(time.Second)
for i := 0; i < 10; i++ {
- pool.disconnect(poolTestPeer(i))
- nb := pool.ndb.getOrNewNB(poolTestPeer(i).freeClientId())
- if nb.logValue != 0 {
+ pool.disconnect(newPoolTestPeer(i, nil))
+ nb := pool.ndb.getOrNewBalance([]byte(newPoolTestPeer(i, nil).freeClientId()), true)
+ if nb.value.base != 0 {
t.Fatalf("Short connection shouldn't be recorded")
}
}
-
for i := 0; i < 10; i++ {
- pool.connect(poolTestPeer(i), 1)
+ pool.connect(newPoolTestPeer(i, nil), 1)
}
clock.Run(time.Minute)
for i := 0; i < 10; i++ {
- pool.disconnect(poolTestPeer(i))
- nb := pool.ndb.getOrNewNB(poolTestPeer(i).freeClientId())
- nb.logValue -= pool.logOffset(clock.Now())
- nb.logValue /= fixedPointMultiplier
- if nb.logValue != int64(math.Log(float64(time.Minute/time.Second))) {
- t.Fatalf("Negative balance mismatch, want %v, got %v", int64(math.Log(float64(time.Minute/time.Second))), nb.logValue)
- }
- }
-}
-
-func TestNodeDB(t *testing.T) {
- ndb := newNodeDB(rawdb.NewMemoryDatabase(), mclock.System{})
- defer ndb.close()
-
- if !bytes.Equal(ndb.verbuf[:], []byte{0x00, nodeDBVersion}) {
- t.Fatalf("version buffer mismatch, want %v, got %v", []byte{0x00, nodeDBVersion}, ndb.verbuf)
- }
- var cases = []struct {
- id enode.ID
- ip string
- balance interface{}
- positive bool
- }{
- {enode.ID{0x00, 0x01, 0x02}, "", posBalance{value: 100}, true},
- {enode.ID{0x00, 0x01, 0x02}, "", posBalance{value: 200}, true},
- {enode.ID{}, "127.0.0.1", negBalance{logValue: 10}, false},
- {enode.ID{}, "127.0.0.1", negBalance{logValue: 20}, false},
- }
- for _, c := range cases {
- if c.positive {
- ndb.setPB(c.id, c.balance.(posBalance))
- if pb := ndb.getOrNewPB(c.id); !reflect.DeepEqual(pb, c.balance.(posBalance)) {
- t.Fatalf("Positive balance mismatch, want %v, got %v", c.balance.(posBalance), pb)
- }
- } else {
- ndb.setNB(c.ip, c.balance.(negBalance))
- if nb := ndb.getOrNewNB(c.ip); !reflect.DeepEqual(nb, c.balance.(negBalance)) {
- t.Fatalf("Negative balance mismatch, want %v, got %v", c.balance.(negBalance), nb)
- }
+ pool.disconnect(newPoolTestPeer(i, nil))
+ nb := pool.ndb.getOrNewBalance([]byte(newPoolTestPeer(i, nil).freeClientId()), true)
+ value := nb.value.value(pool.negExpiration(clock.Now()))
+ if value != uint64(time.Minute) {
+ t.Fatalf("Negative balance mismatch, want %v, got %v", time.Minute, value)
}
}
- for _, c := range cases {
- if c.positive {
- ndb.delPB(c.id)
- if pb := ndb.getOrNewPB(c.id); !reflect.DeepEqual(pb, posBalance{}) {
- t.Fatalf("Positive balance mismatch, want %v, got %v", posBalance{}, pb)
- }
- } else {
- ndb.delNB(c.ip)
- if nb := ndb.getOrNewNB(c.ip); !reflect.DeepEqual(nb, negBalance{}) {
- t.Fatalf("Negative balance mismatch, want %v, got %v", negBalance{}, nb)
- }
- }
- }
- ndb.setCumulativeTime(100)
- if ndb.getCumulativeTime() != 100 {
- t.Fatalf("Cumulative time mismatch, want %v, got %v", 100, ndb.getCumulativeTime())
- }
}
-func TestNodeDBExpiration(t *testing.T) {
+func TestInactiveClient(t *testing.T) {
var (
- iterated int
- done = make(chan struct{}, 1)
+ clock mclock.Simulated
+ db = rawdb.NewMemoryDatabase()
)
- callback := func(now mclock.AbsTime, b negBalance) bool {
- iterated += 1
- return true
- }
- clock := &mclock.Simulated{}
- ndb := newNodeDB(rawdb.NewMemoryDatabase(), clock)
- defer ndb.close()
- ndb.nbEvictCallBack = callback
- ndb.cleanupHook = func() { done <- struct{}{} }
-
- var cases = []struct {
- ip string
- balance negBalance
- }{
- {"127.0.0.1", negBalance{logValue: 1}},
- {"127.0.0.2", negBalance{logValue: 1}},
- {"127.0.0.3", negBalance{logValue: 1}},
- {"127.0.0.4", negBalance{logValue: 1}},
- }
- for _, c := range cases {
- ndb.setNB(c.ip, c.balance)
- }
- clock.WaitForTimers(1)
- clock.Run(time.Hour + time.Minute)
- select {
- case <-done:
- case <-time.NewTimer(time.Second).C:
- t.Fatalf("timeout")
- }
- if iterated != 4 {
- t.Fatalf("Failed to evict useless negative balances, want %v, got %d", 4, iterated)
- }
- clock.WaitForTimers(1)
- for _, c := range cases {
- ndb.setNB(c.ip, c.balance)
- }
- clock.Run(time.Hour + time.Minute)
- select {
- case <-done:
- case <-time.NewTimer(time.Second).C:
- t.Fatalf("timeout")
- }
- if iterated != 8 {
- t.Fatalf("Failed to evict useless negative balances, want %v, got %d", 4, iterated)
+ pool := newClientPool(db, 1, 1, &clock, nil)
+ defer pool.stop()
+ pool.setLimits(2, uint64(2)) // Total capacity limit is 10
+
+ p1 := newPoolTestPeer(1, nil)
+ p2 := newPoolTestPeer(2, nil)
+ p3 := newPoolTestPeer(3, nil)
+ pool.addBalance(p1.ID(), 1000)
+ pool.addBalance(p3.ID(), 2000)
+ // p1: 1000 p2: 0 p3: 2000
+ p1.cap, _ = pool.connect(p1, 1)
+ if p1.cap != 1 {
+ t.Fatalf("Failed to connect peer #1")
+ }
+ p2.cap, _ = pool.connect(p2, 1)
+ if p2.cap != 1 {
+ t.Fatalf("Failed to connect peer #2")
+ }
+ p3.cap, _ = pool.connect(p3, 1)
+ if p3.cap != 1 {
+ t.Fatalf("Failed to connect peer #3")
+ }
+ if p2.cap != 0 {
+ t.Fatalf("Failed to deactivate peer #2")
+ }
+ pool.addBalance(p2.ID(), 3000)
+ // p1: 1000 p2: 3000 p3: 2000
+ if p2.cap != 1 {
+ t.Fatalf("Failed to activate peer #2")
+ }
+ if p1.cap != 0 {
+ t.Fatalf("Failed to deactivate peer #1")
+ }
+ pool.addBalance(p2.ID(), -2500)
+ // p1: 1000 p2: 500 p3: 2000
+ if p1.cap != 1 {
+ t.Fatalf("Failed to activate peer #1")
+ }
+ if p2.cap != 0 {
+ t.Fatalf("Failed to deactivate peer #2")
+ }
+ pool.setDefaultFactors(priceFactors{1e-9, 0, 0}, priceFactors{1e-9, 0, 0})
+ p4 := newPoolTestPeer(4, nil)
+ pool.addBalance(p4.ID(), 1500)
+ // p1: 1000 p2: 500 p3: 2000 p4: 1500
+ p4.cap, _ = pool.connect(p4, 1)
+ if p4.cap != 1 {
+ t.Fatalf("Failed to activate peer #4")
+ }
+ if p1.cap != 0 {
+ t.Fatalf("Failed to deactivate peer #1")
+ }
+ clock.Run(time.Second * 600)
+ // manually trigger a check to avoid a long real-time wait
+ pool.lock.Lock()
+ pool.tryActivateClients()
+ pool.lock.Unlock()
+ // p1: 1000 p2: 500 p3: 2000 p4: 900
+ if p1.cap != 1 {
+ t.Fatalf("Failed to activate peer #1")
+ }
+ if p4.cap != 0 {
+ t.Fatalf("Failed to deactivate peer #4")
+ }
+ pool.disconnect(p2)
+ pool.disconnect(p4)
+ pool.addBalance(p1.ID(), -1000)
+ if p1.cap != 1 {
+ t.Fatalf("Should not deactivate peer #1")
+ }
+ if p2.cap != 0 {
+ t.Fatalf("Should not activate peer #2")
}
}
diff --git a/les/costtracker.go b/les/costtracker.go
index 81da04566007..abf8618ffde9 100644
--- a/les/costtracker.go
+++ b/les/costtracker.go
@@ -159,15 +159,9 @@ func newCostTracker(db ethdb.Database, config *eth.Config) (*costTracker, uint64
}
ct.gfLoop()
costList := ct.makeCostList(ct.globalFactor() * 1.25)
- for _, c := range costList {
- amount := minBufferReqAmount[c.MsgCode]
- cost := c.BaseCost + amount*c.ReqCost
- if cost > ct.minBufLimit {
- ct.minBufLimit = cost
- }
- }
- ct.minBufLimit *= uint64(minBufferMultiplier)
- return ct, (ct.minBufLimit-1)/bufLimitRatio + 1
+ var minRecharge uint64
+ ct.minBufLimit, minRecharge = costList.decode(ProtocolLengths[ServerProtocolVersions[len(ServerProtocolVersions)-1]]).reqParams()
+ return ct, minRecharge
}
// stop stops the cost tracker and saves the cost factor statistics to the database
@@ -480,6 +474,22 @@ func (table requestCostTable) getMaxCost(code, amount uint64) uint64 {
return costs.baseCost + amount*costs.reqCost
}
+func (table requestCostTable) reqParams() (minRecharge, minBufLimit uint64) {
+ for code, c := range table {
+ amount := minBufferReqAmount[code]
+ cost := c.baseCost + amount*c.reqCost
+ if cost > minBufLimit {
+ minBufLimit = cost
+ }
+ }
+ minBufLimit *= uint64(minBufferMultiplier)
+ if minBufLimit < 1 {
+ minBufLimit = 1
+ }
+ minRecharge = (minBufLimit-1)/bufLimitRatio + 1
+ return
+}
+
// decode converts a cost list to a cost table
func (list RequestCostList) decode(protocolLength uint64) requestCostTable {
table := make(requestCostTable)
diff --git a/les/expiredvalue.go b/les/expiredvalue.go
new file mode 100644
index 000000000000..0bb251a1ac87
--- /dev/null
+++ b/les/expiredvalue.go
@@ -0,0 +1,129 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package les
+
+import "math"
+
+// expiredValue is a scalar value that is continuously expired (decreased
+// exponentially) based on the provided logarithmic expiration offset value.
+//
+// The formula for value calculation is: base*2^(exp-logOffset). In order to
+// simplify the calculation of expiredValue, its value is expressed in the form
+// of an exponent with a base of 2.
+//
+// Also here is a trick to reduce a lot of calculations. In theory, when a value X
+// decays over time and then a new value Y is added, the final result should be
+// X*2^(exp-logOffset)+Y. However it's very hard to represent in memory.
+// So the trick is using the idea of inflation instead of exponential decay. At this
+// moment the temporary value becomes: X*2^exp+Y*2^logOffset_1, apply the exponential
+// decay when we actually want to calculate the value.
+//
+// e.g.
+// t0: V = 100
+// t1: add 30, inflationary value is: 100 + 30/0.3, 0.3 is the decay coefficient
+// t2: get value, decay coefficient is 0.2 now, final result is: 200*0.2 = 40
+type expiredValue struct {
+ base, exp uint64
+}
+
+// value calculates the value at the given moment.
+func (e expiredValue) value(logOffset fixed64) uint64 {
+ offset := uint64ToFixed64(e.exp) - logOffset
+ return uint64(float64(e.base) * offset.pow2())
+}
+
+// add adds a signed value at the given moment
+func (e *expiredValue) add(amount int64, logOffset fixed64) int64 {
+ integer, frac := logOffset.toUint64(), logOffset.fraction()
+ factor := frac.pow2()
+ base := factor * float64(amount)
+ if integer < e.exp {
+ base /= math.Pow(2, float64(e.exp-integer))
+ }
+ if integer > e.exp {
+ e.base >>= (integer - e.exp)
+ e.exp = integer
+ }
+ if base >= 0 || uint64(-base) <= e.base {
+ e.base += uint64(base)
+ return amount
+ }
+ net := int64(-float64(e.base) / factor)
+ e.base = 0
+ return net
+}
+
+// addExp adds another expiredValue
+func (e *expiredValue) addExp(a expiredValue) {
+ if e.exp > a.exp {
+ a.base >>= (e.exp - a.exp)
+ }
+ if e.exp < a.exp {
+ e.base >>= (a.exp - e.exp)
+ e.exp = a.exp
+ }
+ e.base += a.base
+}
+
+// subExp subtracts another expiredValue
+func (e *expiredValue) subExp(a expiredValue) {
+ if e.exp > a.exp {
+ a.base >>= (e.exp - a.exp)
+ }
+ if e.exp < a.exp {
+ e.base >>= (a.exp - e.exp)
+ e.exp = a.exp
+ }
+ if e.base > a.base {
+ e.base -= a.base
+ } else {
+ e.base = 0
+ }
+}
+
+// fixedFactor is the fixed point multiplier factor used by fixed64.
+const fixedFactor = 0x1000000
+
+// fixed64 implements 64-bit fixed point arithmetic functions.
+type fixed64 int64
+
+// uint64ToFixed64 converts uint64 integer to fixed64 format.
+func uint64ToFixed64(f uint64) fixed64 {
+ return fixed64(f * fixedFactor)
+}
+
+// float64ToFixed64 converts float64 to fixed64 format.
+func float64ToFixed64(f float64) fixed64 {
+ return fixed64(f * fixedFactor)
+}
+
+// toUint64 converts fixed64 format to uint64.
+func (f64 fixed64) toUint64() uint64 {
+ return uint64(f64) / fixedFactor
+}
+
+// fraction returns the fractional part of a fixed64 value.
+func (f64 fixed64) fraction() fixed64 {
+ return f64 % fixedFactor
+}
+
+var fixedLogFactor = math.Log(2) / float64(fixedFactor)
+
+// pow2Fixed returns the base 2 power of the fixed point value.
+func (f64 fixed64) pow2() float64 {
+ return math.Exp(float64(f64) * fixedLogFactor)
+}
diff --git a/les/expiredvalue_test.go b/les/expiredvalue_test.go
new file mode 100644
index 000000000000..a59510a70291
--- /dev/null
+++ b/les/expiredvalue_test.go
@@ -0,0 +1,116 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package les
+
+import "testing"
+
+func TestValueExpiration(t *testing.T) {
+ var cases = []struct {
+ input expiredValue
+ timeOffset fixed64
+ expect uint64
+ }{
+ {expiredValue{base: 128, exp: 0}, uint64ToFixed64(0), 128},
+ {expiredValue{base: 128, exp: 0}, uint64ToFixed64(1), 64},
+ {expiredValue{base: 128, exp: 0}, uint64ToFixed64(2), 32},
+ {expiredValue{base: 128, exp: 2}, uint64ToFixed64(2), 128},
+ {expiredValue{base: 128, exp: 2}, uint64ToFixed64(3), 64},
+ }
+ for _, c := range cases {
+ if got := c.input.value(c.timeOffset); got != c.expect {
+ t.Fatalf("Value mismatch, want=%d, got=%d", c.expect, got)
+ }
+ }
+}
+
+func TestValueAddition(t *testing.T) {
+ var cases = []struct {
+ input expiredValue
+ addend int64
+ timeOffset fixed64
+ expect uint64
+ expectNet int64
+ }{
+ // Addition
+ {expiredValue{base: 128, exp: 0}, 128, uint64ToFixed64(0), 256, 128},
+ {expiredValue{base: 128, exp: 2}, 128, uint64ToFixed64(0), 640, 128},
+
+ // Addition with offset
+ {expiredValue{base: 128, exp: 0}, 128, uint64ToFixed64(1), 192, 128},
+ {expiredValue{base: 128, exp: 2}, 128, uint64ToFixed64(1), 384, 128},
+ {expiredValue{base: 128, exp: 2}, 128, uint64ToFixed64(3), 192, 128},
+
+ // Subtraction
+ {expiredValue{base: 128, exp: 0}, -64, uint64ToFixed64(0), 64, -64},
+ {expiredValue{base: 128, exp: 0}, -128, uint64ToFixed64(0), 0, -128},
+ {expiredValue{base: 128, exp: 0}, -192, uint64ToFixed64(0), 0, -128},
+
+ // Subtraction with offset
+ {expiredValue{base: 128, exp: 0}, -64, uint64ToFixed64(1), 0, -64},
+ {expiredValue{base: 128, exp: 0}, -128, uint64ToFixed64(1), 0, -64},
+ {expiredValue{base: 128, exp: 2}, -128, uint64ToFixed64(1), 128, -128},
+ {expiredValue{base: 128, exp: 2}, -128, uint64ToFixed64(2), 0, -128},
+ }
+ for _, c := range cases {
+ if net := c.input.add(c.addend, c.timeOffset); net != c.expectNet {
+ t.Fatalf("Net amount mismatch, want=%d, got=%d", c.expectNet, net)
+ }
+ if got := c.input.value(c.timeOffset); got != c.expect {
+ t.Fatalf("Value mismatch, want=%d, got=%d", c.expect, got)
+ }
+ }
+}
+
+func TestExpiredValueAddition(t *testing.T) {
+ var cases = []struct {
+ input expiredValue
+ another expiredValue
+ timeOffset fixed64
+ expect uint64
+ }{
+ {expiredValue{base: 128, exp: 0}, expiredValue{base: 128, exp: 0}, uint64ToFixed64(0), 256},
+ {expiredValue{base: 128, exp: 1}, expiredValue{base: 128, exp: 0}, uint64ToFixed64(0), 384},
+ {expiredValue{base: 128, exp: 0}, expiredValue{base: 128, exp: 1}, uint64ToFixed64(0), 384},
+ {expiredValue{base: 128, exp: 0}, expiredValue{base: 128, exp: 0}, uint64ToFixed64(1), 128},
+ }
+ for _, c := range cases {
+ c.input.addExp(c.another)
+ if got := c.input.value(c.timeOffset); got != c.expect {
+ t.Fatalf("Value mismatch, want=%d, got=%d", c.expect, got)
+ }
+ }
+}
+
+func TestExpiredValueSubtraction(t *testing.T) {
+ var cases = []struct {
+ input expiredValue
+ another expiredValue
+ timeOffset fixed64
+ expect uint64
+ }{
+ {expiredValue{base: 128, exp: 0}, expiredValue{base: 128, exp: 0}, uint64ToFixed64(0), 0},
+ {expiredValue{base: 128, exp: 0}, expiredValue{base: 128, exp: 1}, uint64ToFixed64(0), 0},
+ {expiredValue{base: 128, exp: 1}, expiredValue{base: 128, exp: 0}, uint64ToFixed64(0), 128},
+ {expiredValue{base: 128, exp: 1}, expiredValue{base: 128, exp: 0}, uint64ToFixed64(1), 64},
+ }
+ for _, c := range cases {
+ c.input.subExp(c.another)
+ if got := c.input.value(c.timeOffset); got != c.expect {
+ t.Fatalf("Value mismatch, want=%d, got=%d", c.expect, got)
+ }
+ }
+}
diff --git a/les/flowcontrol/control.go b/les/flowcontrol/control.go
index 490013677c63..1c40882902fc 100644
--- a/les/flowcontrol/control.go
+++ b/les/flowcontrol/control.go
@@ -185,6 +185,14 @@ func (node *ClientNode) UpdateParams(params ServerParams) {
}
}
+// Params returns the current server parameters
+func (node *ClientNode) Params() ServerParams {
+ node.lock.Lock()
+ defer node.lock.Unlock()
+
+ return node.params
+}
+
// updateParams updates the flow control parameters of the node
func (node *ClientNode) updateParams(params ServerParams, now mclock.AbsTime) {
diff := int64(params.BufLimit - node.params.BufLimit)
diff --git a/les/handler_test.go b/les/handler_test.go
index 1612caf42769..d3e4df1b77bd 100644
--- a/les/handler_test.go
+++ b/les/handler_test.go
@@ -38,20 +38,29 @@ import (
"github.com/ethereum/go-ethereum/trie"
)
-func expectResponse(r p2p.MsgReader, msgcode, reqID, bv uint64, data interface{}) error {
+func expectResponse(r p2p.MsgReader, protocol int, msgcode, reqID, bv, cost uint64, data interface{}) error {
type resp struct {
- ReqID, BV uint64
- Data interface{}
+ ReqID uint64
+ SF stateFeedback
+ Data interface{}
}
- return p2p.ExpectMsg(r, msgcode, resp{reqID, bv, data})
+ sf := stateFeedback{
+ protocolVersion: protocol,
+ stateFeedbackV4: stateFeedbackV4{
+ BV: bv,
+ RealCost: cost,
+ TokenBalance: 0,
+ },
+ }
+ return p2p.ExpectMsg(r, msgcode, resp{reqID, sf, data})
}
// Tests that block headers can be retrieved from a remote chain based on user queries.
-func TestGetBlockHeadersLes2(t *testing.T) { testGetBlockHeaders(t, 2) }
func TestGetBlockHeadersLes3(t *testing.T) { testGetBlockHeaders(t, 3) }
+func TestGetBlockHeadersLes4(t *testing.T) { testGetBlockHeaders(t, 4) }
func testGetBlockHeaders(t *testing.T, protocol int) {
- server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil, false, true, 0)
+ server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil, false, true, 0, true)
defer tearDown()
bc := server.handler.blockchain
@@ -168,19 +177,20 @@ func testGetBlockHeaders(t *testing.T, protocol int) {
// Send the hash request and verify the response
reqID++
+ cost := server.peer.speer.getRequestCost(GetBlockHeadersMsg, int(tt.query.Amount))
sendRequest(server.peer.app, GetBlockHeadersMsg, reqID, tt.query)
- if err := expectResponse(server.peer.app, BlockHeadersMsg, reqID, testBufLimit, headers); err != nil {
+ if err := expectResponse(server.peer.app, protocol, BlockHeadersMsg, reqID, testBufLimit, cost, headers); err != nil {
t.Errorf("test %d: headers mismatch: %v", i, err)
}
}
}
// Tests that block contents can be retrieved from a remote chain based on their hashes.
-func TestGetBlockBodiesLes2(t *testing.T) { testGetBlockBodies(t, 2) }
func TestGetBlockBodiesLes3(t *testing.T) { testGetBlockBodies(t, 3) }
+func TestGetBlockBodiesLes4(t *testing.T) { testGetBlockBodies(t, 4) }
func testGetBlockBodies(t *testing.T, protocol int) {
- server, tearDown := newServerEnv(t, downloader.MaxBlockFetch+15, protocol, nil, false, true, 0)
+ server, tearDown := newServerEnv(t, downloader.MaxBlockFetch+15, protocol, nil, false, true, 0, true)
defer tearDown()
bc := server.handler.blockchain
@@ -245,20 +255,21 @@ func testGetBlockBodies(t *testing.T, protocol int) {
reqID++
// Send the hash request and verify the response
+ cost := server.peer.speer.getRequestCost(GetBlockBodiesMsg, len(hashes))
sendRequest(server.peer.app, GetBlockBodiesMsg, reqID, hashes)
- if err := expectResponse(server.peer.app, BlockBodiesMsg, reqID, testBufLimit, bodies); err != nil {
+ if err := expectResponse(server.peer.app, protocol, BlockBodiesMsg, reqID, testBufLimit, cost, bodies); err != nil {
t.Errorf("test %d: bodies mismatch: %v", i, err)
}
}
}
// Tests that the contract codes can be retrieved based on account addresses.
-func TestGetCodeLes2(t *testing.T) { testGetCode(t, 2) }
func TestGetCodeLes3(t *testing.T) { testGetCode(t, 3) }
+func TestGetCodeLes4(t *testing.T) { testGetCode(t, 4) }
func testGetCode(t *testing.T, protocol int) {
// Assemble the test environment
- server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0)
+ server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0, true)
defer tearDown()
bc := server.handler.blockchain
@@ -276,18 +287,19 @@ func testGetCode(t *testing.T, protocol int) {
}
}
+ cost := server.peer.speer.getRequestCost(GetCodeMsg, len(codereqs))
sendRequest(server.peer.app, GetCodeMsg, 42, codereqs)
- if err := expectResponse(server.peer.app, CodeMsg, 42, testBufLimit, codes); err != nil {
+ if err := expectResponse(server.peer.app, protocol, CodeMsg, 42, testBufLimit, cost, codes); err != nil {
t.Errorf("codes mismatch: %v", err)
}
}
// Tests that the stale contract codes can't be retrieved based on account addresses.
-func TestGetStaleCodeLes2(t *testing.T) { testGetStaleCode(t, 2) }
func TestGetStaleCodeLes3(t *testing.T) { testGetStaleCode(t, 3) }
+func TestGetStaleCodeLes4(t *testing.T) { testGetStaleCode(t, 4) }
func testGetStaleCode(t *testing.T, protocol int) {
- server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0)
+ server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0, true)
defer tearDown()
bc := server.handler.blockchain
@@ -296,8 +308,9 @@ func testGetStaleCode(t *testing.T, protocol int) {
BHash: bc.GetHeaderByNumber(number).Hash(),
AccKey: crypto.Keccak256(testContractAddr[:]),
}
+ cost := server.peer.speer.getRequestCost(GetCodeMsg, 1)
sendRequest(server.peer.app, GetCodeMsg, 42, []*CodeReq{req})
- if err := expectResponse(server.peer.app, CodeMsg, 42, testBufLimit, expected); err != nil {
+ if err := expectResponse(server.peer.app, protocol, CodeMsg, 42, testBufLimit, cost, expected); err != nil {
t.Errorf("codes mismatch: %v", err)
}
}
@@ -307,12 +320,12 @@ func testGetStaleCode(t *testing.T, protocol int) {
}
// Tests that the transaction receipts can be retrieved based on hashes.
-func TestGetReceiptLes2(t *testing.T) { testGetReceipt(t, 2) }
func TestGetReceiptLes3(t *testing.T) { testGetReceipt(t, 3) }
+func TestGetReceiptLes4(t *testing.T) { testGetReceipt(t, 4) }
func testGetReceipt(t *testing.T, protocol int) {
// Assemble the test environment
- server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0)
+ server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0, true)
defer tearDown()
bc := server.handler.blockchain
@@ -327,19 +340,20 @@ func testGetReceipt(t *testing.T, protocol int) {
receipts = append(receipts, rawdb.ReadRawReceipts(server.db, block.Hash(), block.NumberU64()))
}
// Send the hash request and verify the response
+ cost := server.peer.speer.getRequestCost(GetReceiptsMsg, len(hashes))
sendRequest(server.peer.app, GetReceiptsMsg, 42, hashes)
- if err := expectResponse(server.peer.app, ReceiptsMsg, 42, testBufLimit, receipts); err != nil {
+ if err := expectResponse(server.peer.app, protocol, ReceiptsMsg, 42, testBufLimit, cost, receipts); err != nil {
t.Errorf("receipts mismatch: %v", err)
}
}
// Tests that trie merkle proofs can be retrieved
-func TestGetProofsLes2(t *testing.T) { testGetProofs(t, 2) }
func TestGetProofsLes3(t *testing.T) { testGetProofs(t, 3) }
+func TestGetProofsLes4(t *testing.T) { testGetProofs(t, 4) }
func testGetProofs(t *testing.T, protocol int) {
// Assemble the test environment
- server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0)
+ server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0, true)
defer tearDown()
bc := server.handler.blockchain
@@ -362,18 +376,19 @@ func testGetProofs(t *testing.T, protocol int) {
}
}
// Send the proof request and verify the response
+ cost := server.peer.speer.getRequestCost(GetProofsV2Msg, len(proofreqs))
sendRequest(server.peer.app, GetProofsV2Msg, 42, proofreqs)
- if err := expectResponse(server.peer.app, ProofsV2Msg, 42, testBufLimit, proofsV2.NodeList()); err != nil {
+ if err := expectResponse(server.peer.app, protocol, ProofsV2Msg, 42, testBufLimit, cost, proofsV2.NodeList()); err != nil {
t.Errorf("proofs mismatch: %v", err)
}
}
// Tests that the stale contract codes can't be retrieved based on account addresses.
-func TestGetStaleProofLes2(t *testing.T) { testGetStaleProof(t, 2) }
func TestGetStaleProofLes3(t *testing.T) { testGetStaleProof(t, 3) }
+func TestGetStaleProofLes4(t *testing.T) { testGetStaleProof(t, 4) }
func testGetStaleProof(t *testing.T, protocol int) {
- server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0)
+ server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0, true)
defer tearDown()
bc := server.handler.blockchain
@@ -395,7 +410,8 @@ func testGetStaleProof(t *testing.T, protocol int) {
t.Prove(account, 0, proofsV2)
expected = proofsV2.NodeList()
}
- if err := expectResponse(server.peer.app, ProofsV2Msg, 42, testBufLimit, expected); err != nil {
+ cost := server.peer.speer.getRequestCost(GetProofsV2Msg, 1)
+ if err := expectResponse(server.peer.app, protocol, ProofsV2Msg, 42, testBufLimit, cost, expected); err != nil {
t.Errorf("codes mismatch: %v", err)
}
}
@@ -405,8 +421,8 @@ func testGetStaleProof(t *testing.T, protocol int) {
}
// Tests that CHT proofs can be correctly retrieved.
-func TestGetCHTProofsLes2(t *testing.T) { testGetCHTProofs(t, 2) }
func TestGetCHTProofsLes3(t *testing.T) { testGetCHTProofs(t, 3) }
+func TestGetCHTProofsLes4(t *testing.T) { testGetCHTProofs(t, 4) }
func testGetCHTProofs(t *testing.T, protocol int) {
config := light.TestServerIndexerConfig
@@ -420,7 +436,7 @@ func testGetCHTProofs(t *testing.T, protocol int) {
time.Sleep(10 * time.Millisecond)
}
}
- server, tearDown := newServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, false, true, 0)
+ server, tearDown := newServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, false, true, 0, true)
defer tearDown()
bc := server.handler.blockchain
@@ -446,14 +462,15 @@ func testGetCHTProofs(t *testing.T, protocol int) {
AuxReq: auxHeader,
}}
// Send the proof request and verify the response
+ cost := server.peer.speer.getRequestCost(GetHelperTrieProofsMsg, len(requestsV2))
sendRequest(server.peer.app, GetHelperTrieProofsMsg, 42, requestsV2)
- if err := expectResponse(server.peer.app, HelperTrieProofsMsg, 42, testBufLimit, proofsV2); err != nil {
+ if err := expectResponse(server.peer.app, protocol, HelperTrieProofsMsg, 42, testBufLimit, cost, proofsV2); err != nil {
t.Errorf("proofs mismatch: %v", err)
}
}
-func TestGetBloombitsProofsLes2(t *testing.T) { testGetBloombitsProofs(t, 2) }
func TestGetBloombitsProofsLes3(t *testing.T) { testGetBloombitsProofs(t, 3) }
+func TestGetBloombitsProofsLes4(t *testing.T) { testGetBloombitsProofs(t, 4) }
// Tests that bloombits proofs can be correctly retrieved.
func testGetBloombitsProofs(t *testing.T, protocol int) {
@@ -468,7 +485,7 @@ func testGetBloombitsProofs(t *testing.T, protocol int) {
time.Sleep(10 * time.Millisecond)
}
}
- server, tearDown := newServerEnv(t, int(config.BloomTrieSize+config.BloomTrieConfirms), protocol, waitIndexers, false, true, 0)
+ server, tearDown := newServerEnv(t, int(config.BloomTrieSize+config.BloomTrieConfirms), protocol, waitIndexers, false, true, 0, true)
defer tearDown()
bc := server.handler.blockchain
@@ -494,18 +511,19 @@ func testGetBloombitsProofs(t *testing.T, protocol int) {
trie.Prove(key, 0, &proofs.Proofs)
// Send the proof request and verify the response
+ cost := server.peer.speer.getRequestCost(GetHelperTrieProofsMsg, len(requests))
sendRequest(server.peer.app, GetHelperTrieProofsMsg, 42, requests)
- if err := expectResponse(server.peer.app, HelperTrieProofsMsg, 42, testBufLimit, proofs); err != nil {
+ if err := expectResponse(server.peer.app, protocol, HelperTrieProofsMsg, 42, testBufLimit, cost, proofs); err != nil {
t.Errorf("bit %d: proofs mismatch: %v", bit, err)
}
}
}
-func TestTransactionStatusLes2(t *testing.T) { testTransactionStatus(t, 2) }
func TestTransactionStatusLes3(t *testing.T) { testTransactionStatus(t, 3) }
+func TestTransactionStatusLes4(t *testing.T) { testTransactionStatus(t, 4) }
func testTransactionStatus(t *testing.T, protocol int) {
- server, tearDown := newServerEnv(t, 0, protocol, nil, false, true, 0)
+ server, tearDown := newServerEnv(t, 0, protocol, nil, false, true, 0, true)
defer tearDown()
server.handler.addTxsSync = true
@@ -515,12 +533,15 @@ func testTransactionStatus(t *testing.T, protocol int) {
test := func(tx *types.Transaction, send bool, expStatus light.TxStatus) {
reqID++
+ var cost uint64
if send {
+ cost = server.peer.speer.getRequestCost(SendTxV2Msg, 1)
sendRequest(server.peer.app, SendTxV2Msg, reqID, types.Transactions{tx})
} else {
+ cost = server.peer.speer.getRequestCost(GetTxStatusMsg, 1)
sendRequest(server.peer.app, GetTxStatusMsg, reqID, []common.Hash{tx.Hash()})
}
- if err := expectResponse(server.peer.app, TxStatusMsg, reqID, testBufLimit, []light.TxStatus{expStatus}); err != nil {
+ if err := expectResponse(server.peer.app, protocol, TxStatusMsg, reqID, testBufLimit, cost, []light.TxStatus{expStatus}); err != nil {
t.Errorf("transaction status mismatch")
}
}
@@ -595,8 +616,11 @@ func testTransactionStatus(t *testing.T, protocol int) {
test(tx2, false, light.TxStatus{Status: core.TxStatusPending})
}
-func TestStopResumeLes3(t *testing.T) {
- server, tearDown := newServerEnv(t, 0, 3, nil, true, true, testBufLimit/10)
+func TestStopResumeLes3(t *testing.T) { testStopResume(t, 3) }
+func TestStopResumeLes4(t *testing.T) { testStopResume(t, 4) }
+
+func testStopResume(t *testing.T, protocol int) {
+ server, tearDown := newServerEnv(t, 0, protocol, nil, true, true, testBufLimit/10, true)
defer tearDown()
server.handler.server.costTracker.testing = true
@@ -616,7 +640,7 @@ func TestStopResumeLes3(t *testing.T) {
for expBuf >= testCost {
req()
expBuf -= testCost
- if err := expectResponse(server.peer.app, BlockHeadersMsg, reqID, expBuf, []*types.Header{header}); err != nil {
+ if err := expectResponse(server.peer.app, protocol, BlockHeadersMsg, reqID, expBuf, testCost, []*types.Header{header}); err != nil {
t.Errorf("expected response and failed: %v", err)
}
}
@@ -635,7 +659,15 @@ func TestStopResumeLes3(t *testing.T) {
// expect a ResumeMsg with the partially recharged buffer value
expBuf += testBufRecharge * wait
- if err := p2p.ExpectMsg(server.peer.app, ResumeMsg, expBuf); err != nil {
+ sf := stateFeedback{
+ protocolVersion: protocol,
+ stateFeedbackV4: stateFeedbackV4{
+ BV: expBuf,
+ RealCost: 0,
+ TokenBalance: 0,
+ },
+ }
+ if err := p2p.ExpectMsg(server.peer.app, ResumeMsg, sf); err != nil {
t.Errorf("expected ResumeMsg and failed: %v", err)
}
}
diff --git a/les/metrics.go b/les/metrics.go
index 9ef8c365180c..12780346b65c 100644
--- a/les/metrics.go
+++ b/les/metrics.go
@@ -40,6 +40,8 @@ var (
miscInTxsTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/txs", nil)
miscInTxStatusPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/txStatus", nil)
miscInTxStatusTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/txStatus", nil)
+ miscInLespayPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/lespay", nil)
+ miscInLespayTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/lespay", nil)
miscOutPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/total", nil)
miscOutTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/total", nil)
@@ -59,6 +61,8 @@ var (
miscOutTxsTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/txs", nil)
miscOutTxStatusPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/txStatus", nil)
miscOutTxStatusTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/txStatus", nil)
+ miscOutLespayPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/lespay", nil)
+ miscOutLespayTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/lespay", nil)
miscServingTimeHeaderTimer = metrics.NewRegisteredTimer("les/misc/serve/header", nil)
miscServingTimeBodyTimer = metrics.NewRegisteredTimer("les/misc/serve/body", nil)
@@ -68,6 +72,7 @@ var (
miscServingTimeHelperTrieTimer = metrics.NewRegisteredTimer("les/misc/serve/helperTrie", nil)
miscServingTimeTxTimer = metrics.NewRegisteredTimer("les/misc/serve/txs", nil)
miscServingTimeTxStatusTimer = metrics.NewRegisteredTimer("les/misc/serve/txStatus", nil)
+ miscServingTimeLespayTimer = metrics.NewRegisteredTimer("les/misc/serve/lespay", nil)
connectionTimer = metrics.NewRegisteredTimer("les/connection/duration", nil)
serverConnectionGauge = metrics.NewRegisteredGauge("les/connection/server", nil)
diff --git a/les/odr_test.go b/les/odr_test.go
index bbe439dfec82..a56de3176961 100644
--- a/les/odr_test.go
+++ b/les/odr_test.go
@@ -38,8 +38,8 @@ import (
type odrTestFn func(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte
-func TestOdrGetBlockLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetBlock) }
func TestOdrGetBlockLes3(t *testing.T) { testOdr(t, 3, 1, true, odrGetBlock) }
+func TestOdrGetBlockLes4(t *testing.T) { testOdr(t, 4, 1, true, odrGetBlock) }
func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
var block *types.Block
@@ -55,8 +55,8 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon
return rlp
}
-func TestOdrGetReceiptsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetReceipts) }
func TestOdrGetReceiptsLes3(t *testing.T) { testOdr(t, 3, 1, true, odrGetReceipts) }
+func TestOdrGetReceiptsLes4(t *testing.T) { testOdr(t, 4, 1, true, odrGetReceipts) }
func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
var receipts types.Receipts
@@ -76,8 +76,8 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.Chain
return rlp
}
-func TestOdrAccountsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrAccounts) }
func TestOdrAccountsLes3(t *testing.T) { testOdr(t, 3, 1, true, odrAccounts) }
+func TestOdrAccountsLes4(t *testing.T) { testOdr(t, 4, 1, true, odrAccounts) }
func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678")
@@ -105,8 +105,8 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon
return res
}
-func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, true, odrContractCall) }
func TestOdrContractCallLes3(t *testing.T) { testOdr(t, 3, 2, true, odrContractCall) }
+func TestOdrContractCallLes4(t *testing.T) { testOdr(t, 4, 2, true, odrContractCall) }
type callmsg struct {
types.Message
@@ -155,8 +155,8 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai
return res
}
-func TestOdrTxStatusLes2(t *testing.T) { testOdr(t, 2, 1, false, odrTxStatus) }
func TestOdrTxStatusLes3(t *testing.T) { testOdr(t, 3, 1, false, odrTxStatus) }
+func TestOdrTxStatusLes4(t *testing.T) { testOdr(t, 4, 1, false, odrTxStatus) }
func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
var txs types.Transactions
@@ -236,7 +236,7 @@ func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn od
// still expect all retrievals to pass, now data should be cached locally
if checkCached {
- client.handler.backend.peers.unregister(client.peer.speer.id)
+ client.handler.backend.peers.disconnect(client.peer.speer.id)
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
test(5)
}
diff --git a/les/peer.go b/les/peer.go
index 28ec201bc9a7..40e9d25ea473 100644
--- a/les/peer.go
+++ b/les/peer.go
@@ -132,6 +132,7 @@ type peerCommons struct {
frozen uint32 // Flag whether the peer is frozen.
announceType uint64 // New block announcement type.
headInfo blockInfo // Latest block information.
+ active bool
// Background task queue for caching peer tasks and executing in order.
sendQueue *execQueue
@@ -478,12 +479,18 @@ func (p *serverPeer) requestTxStatus(reqID uint64, txHashes []common.Hash) error
return sendRequest(p.rw, GetTxStatusMsg, reqID, txHashes)
}
-// SendTxStatus creates a reply with a batch of transactions to be added to the remote transaction pool.
+// sendTxs creates a reply with a batch of transactions to be added to the remote transaction pool.
func (p *serverPeer) sendTxs(reqID uint64, txs rlp.RawValue) error {
p.Log().Debug("Sending batch of transactions", "size", len(txs))
return sendRequest(p.rw, SendTxV2Msg, reqID, txs)
}
+// sendLespay sends a set of commands to the service token sale module
+func (p *serverPeer) sendLespay(reqID uint64, cmd []byte) error {
+ p.Log().Debug("Sending batch of lespay commands", "size", len(cmd))
+ return sendRequest(p.rw, LespayMsg, reqID, cmd)
+}
+
// waitBefore implements distPeer interface
func (p *serverPeer) waitBefore(maxCost uint64) (time.Duration, float64) {
return p.fcServer.CanSend(maxCost)
@@ -554,10 +561,12 @@ func (p *serverPeer) updateFlowControl(update keyValueMap) {
// If any of the flow control params is nil, refuse to update.
var params flowcontrol.ServerParams
+ updated := false
if update.get("flowControl/BL", ¶ms.BufLimit) == nil && update.get("flowControl/MRR", ¶ms.MinRecharge) == nil {
// todo can light client set a minimal acceptable flow control params?
p.fcParams = params
p.fcServer.UpdateParams(params)
+ updated = true
}
var MRC RequestCostList
if update.get("flowControl/MRC", &MRC) == nil {
@@ -565,9 +574,20 @@ func (p *serverPeer) updateFlowControl(update keyValueMap) {
for code, cost := range costUpdate {
p.fcCosts[code] = cost
}
+ updated = true
+ }
+ if updated {
+ p.active = p.paramsUseful()
}
}
+// paramsUseful returns true if the server parameters ensure the minimum required
+// buffer limit and recharge
+func (p *serverPeer) paramsUseful() bool {
+ reqRecharge, reqBufLimit := p.fcCosts.reqParams()
+ return p.fcParams.MinRecharge >= reqRecharge && p.fcParams.BufLimit >= reqBufLimit
+}
+
// Handshake executes the les protocol handshake, negotiating version number,
// network IDs, difficulties, head and genesis blocks.
func (p *serverPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, server *LesServer) error {
@@ -614,6 +634,7 @@ func (p *serverPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge
p.fcParams = sParams
p.fcServer = flowcontrol.NewServerNode(sParams, &mclock.System{})
p.fcCosts = MRC.decode(ProtocolLengths[uint(p.version)])
+ p.active = p.paramsUseful()
recv.get("checkpoint/value", &p.checkpoint)
recv.get("checkpoint/registerHeight", &p.checkpointNumber)
@@ -634,6 +655,9 @@ func (p *serverPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge
type clientPeer struct {
peerCommons
+ activate, deactivate func()
+ getBalance func() uint64
+
// responseLock ensures that responses are queued in the same order as
// RequestProcessed is called
responseLock sync.Mutex
@@ -681,8 +705,8 @@ func (p *clientPeer) sendStop() error {
}
// sendResume notifies the client about getting out of frozen state
-func (p *clientPeer) sendResume(bv uint64) error {
- return p2p.Send(p.rw, ResumeMsg, bv)
+func (p *clientPeer) sendResume(sf stateFeedback) error {
+ return p2p.Send(p.rw, ResumeMsg, sf)
}
// freeze temporarily puts the client in a frozen state which means all unprocessed
@@ -711,7 +735,19 @@ func (p *clientPeer) freeze() {
continue
}
atomic.StoreUint32(&p.frozen, 0)
- p.sendResume(bufValue)
+ var balance uint64
+ if p.getBalance != nil {
+ balance = p.getBalance()
+ }
+ sf := stateFeedback{
+ protocolVersion: p.version,
+ stateFeedbackV4: stateFeedbackV4{
+ BV: bufValue,
+ RealCost: 0,
+ TokenBalance: balance,
+ },
+ }
+ p.sendResume(sf)
return
}
}()
@@ -728,12 +764,13 @@ type reply struct {
}
// send sends the reply with the calculated buffer value
-func (r *reply) send(bv uint64) error {
+func (r *reply) send(sf stateFeedback) error {
type resp struct {
- ReqID, BV uint64
- Data rlp.RawValue
+ ReqID uint64
+ SF stateFeedback
+ Data rlp.RawValue
}
- return p2p.Send(r.w, r.msgcode, resp{r.reqID, bv, r.data})
+ return p2p.Send(r.w, r.msgcode, resp{r.reqID, sf, r.data})
}
// size returns the RLP encoded size of the message data
@@ -786,6 +823,12 @@ func (p *clientPeer) replyTxStatus(reqID uint64, stats []light.TxStatus) *reply
return &reply{p.rw, TxStatusMsg, reqID, data}
}
+// replyLespay sends a set of replies to lespay commands
+func (p *clientPeer) replyLespay(reqID uint64, reply []byte, delay uint) error {
+ p.Log().Debug("Sending batch of lespay replies", "size", len(reply))
+ return sendRequest(p.rw, LespayReplyMsg, reqID, lespayReply{reply, delay})
+}
+
// sendAnnounce announces the availability of a number of blocks through
// a hash notification.
func (p *clientPeer) sendAnnounce(request announceData) error {
@@ -798,46 +841,21 @@ func (p *clientPeer) updateCapacity(cap uint64) {
p.lock.Lock()
defer p.lock.Unlock()
- p.fcParams = flowcontrol.ServerParams{MinRecharge: cap, BufLimit: cap * bufLimitRatio}
- p.fcClient.UpdateParams(p.fcParams)
- var kvList keyValueList
- kvList = kvList.add("flowControl/MRR", cap)
- kvList = kvList.add("flowControl/BL", cap*bufLimitRatio)
- p.mustQueueSend(func() { p.sendAnnounce(announceData{Update: kvList}) })
-}
-
-// freezeClient temporarily puts the client in a frozen state which means all
-// unprocessed and subsequent requests are dropped. Unfreezing happens automatically
-// after a short time if the client's buffer value is at least in the slightly positive
-// region. The client is also notified about being frozen/unfrozen with a Stop/Resume
-// message.
-func (p *clientPeer) freezeClient() {
- if p.version < lpv3 {
- // if Stop/Resume is not supported then just drop the peer after setting
- // its frozen status permanently
- atomic.StoreUint32(&p.frozen, 1)
- p.Peer.Disconnect(p2p.DiscUselessPeer)
- return
+ if !p.active && cap != 0 && p.activate != nil {
+ p.activate()
}
- if atomic.SwapUint32(&p.frozen, 1) == 0 {
- go func() {
- p.sendStop()
- time.Sleep(freezeTimeBase + time.Duration(rand.Int63n(int64(freezeTimeRandom))))
- for {
- bufValue, bufLimit := p.fcClient.BufferStatus()
- if bufLimit == 0 {
- return
- }
- if bufValue <= bufLimit/8 {
- time.Sleep(freezeCheckPeriod)
- } else {
- atomic.StoreUint32(&p.frozen, 0)
- p.sendResume(bufValue)
- break
- }
- }
- }()
+ if cap != 0 || p.version >= lpv4 {
+ p.fcParams = flowcontrol.ServerParams{MinRecharge: cap, BufLimit: cap * bufLimitRatio}
+ p.fcClient.UpdateParams(p.fcParams)
+ var kvList keyValueList
+ kvList = kvList.add("flowControl/BL", cap*bufLimitRatio)
+ kvList = kvList.add("flowControl/MRR", cap)
+ p.mustQueueSend(func() { p.sendAnnounce(announceData{Update: kvList}) })
}
+ if p.active && cap == 0 && p.deactivate != nil {
+ p.deactivate()
+ }
+
}
// Handshake executes the les protocol handshake, negotiating version number,
@@ -859,8 +877,14 @@ func (p *clientPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge
*lists = (*lists).add("serveRecentState", stateRecent)
*lists = (*lists).add("txRelay", nil)
}
- *lists = (*lists).add("flowControl/BL", server.defParams.BufLimit)
- *lists = (*lists).add("flowControl/MRR", server.defParams.MinRecharge)
+ p.active = p.version < lpv4
+ if p.active {
+ p.fcParams = server.defParams
+ } else {
+ p.fcParams = flowcontrol.ServerParams{}
+ }
+ *lists = (*lists).add("flowControl/BL", p.fcParams.BufLimit)
+ *lists = (*lists).add("flowControl/MRR", p.fcParams.MinRecharge)
var costList RequestCostList
if server.costTracker.testCostList != nil {
@@ -870,7 +894,6 @@ func (p *clientPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge
}
*lists = (*lists).add("flowControl/MRC", costList)
p.fcCosts = costList.decode(ProtocolLengths[uint(p.version)])
- p.fcParams = server.defParams
// Add advertised checkpoint and register block height which
// client can verify the checkpoint validity.
@@ -890,7 +913,7 @@ func (p *clientPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge
// set default announceType on server side
p.announceType = announceTypeSimple
}
- p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams)
+ p.fcClient = flowcontrol.NewClientNode(server.fcManager, p.fcParams)
}
return nil
})
@@ -913,7 +936,7 @@ type clientPeerSubscriber interface {
// clientPeerSet represents the set of active client peers currently
// participating in the Light Ethereum sub-protocol.
type clientPeerSet struct {
- peers map[string]*clientPeer
+ active, inactive map[string]*clientPeer
// subscribers is a batch of subscribers and peerset will notify
// these subscribers when the peerset changes(new client peer is
// added or removed)
@@ -924,17 +947,23 @@ type clientPeerSet struct {
// newClientPeerSet creates a new peer set to track the client peers.
func newClientPeerSet() *clientPeerSet {
- return &clientPeerSet{peers: make(map[string]*clientPeer)}
+ return &clientPeerSet{
+ active: make(map[string]*clientPeer),
+ inactive: make(map[string]*clientPeer),
+ }
}
// subscribe adds a service to be notified about added or removed
// peers and also register all active peers into the given service.
func (ps *clientPeerSet) subscribe(sub clientPeerSubscriber) {
ps.lock.Lock()
- defer ps.lock.Unlock()
-
ps.subscribers = append(ps.subscribers, sub)
- for _, p := range ps.peers {
+ notify := make([]*clientPeer, 0, len(ps.active))
+ for _, p := range ps.active {
+ notify = append(notify, p)
+ }
+ ps.lock.Unlock()
+ for _, p := range notify {
sub.registerPeer(p)
}
}
@@ -954,19 +983,25 @@ func (ps *clientPeerSet) unSubscribe(sub clientPeerSubscriber) {
// register adds a new peer into the peer set, or returns an error if the
// peer is already known.
-func (ps *clientPeerSet) register(peer *clientPeer) error {
+func (ps *clientPeerSet) register(p *clientPeer) error {
ps.lock.Lock()
- defer ps.lock.Unlock()
-
if ps.closed {
+ ps.lock.Unlock()
return errClosed
}
- if _, exist := ps.peers[peer.id]; exist {
+ if _, ok := ps.active[p.id]; ok {
+ ps.lock.Unlock()
return errAlreadyRegistered
}
- ps.peers[peer.id] = peer
- for _, sub := range ps.subscribers {
- sub.registerPeer(peer)
+ delete(ps.inactive, p.id)
+ ps.active[p.id] = p
+
+ peers := make([]clientPeerSubscriber, len(ps.subscribers))
+ copy(peers, ps.subscribers)
+ ps.lock.Unlock()
+
+ for _, n := range peers {
+ n.registerPeer(p)
}
return nil
}
@@ -974,29 +1009,60 @@ func (ps *clientPeerSet) register(peer *clientPeer) error {
// unregister removes a remote peer from the peer set, disabling any further
// actions to/from that particular entity. It also initiates disconnection
// at the networking layer.
-func (ps *clientPeerSet) unregister(id string) error {
+func (ps *clientPeerSet) unregister(p *clientPeer) error {
ps.lock.Lock()
- defer ps.lock.Unlock()
+ if _, ok := ps.active[p.id]; !ok {
+ ps.lock.Unlock()
+ return errNotRegistered
+ } else {
+ delete(ps.active, p.id)
+ ps.inactive[p.id] = p
+ peers := make([]clientPeerSubscriber, len(ps.subscribers))
+ copy(peers, ps.subscribers)
+ ps.lock.Unlock()
+
+ for _, n := range peers {
+ n.unregisterPeer(p)
+ }
+ return nil
+ }
+}
- p, ok := ps.peers[id]
- if !ok {
+// disconnect removes a remote peer from either the active or inactive set and
+// initiates disconnection at the networking layer.
+func (ps *clientPeerSet) disconnect(id string) error {
+ ps.lock.Lock()
+
+ var (
+ peers []clientPeerSubscriber
+ p *clientPeer
+ ok bool
+ )
+ if p, ok = ps.active[id]; ok {
+ delete(ps.active, p.id)
+ peers = make([]clientPeerSubscriber, len(ps.subscribers))
+ copy(peers, ps.subscribers)
+ } else if p, ok = ps.inactive[id]; ok {
+ delete(ps.inactive, id)
+ } else {
+ ps.lock.Unlock()
return errNotRegistered
}
- delete(ps.peers, id)
- for _, sub := range ps.subscribers {
- sub.unregisterPeer(p)
+ ps.lock.Unlock()
+ for _, n := range peers {
+ n.unregisterPeer(p)
}
- p.Peer.Disconnect(p2p.DiscRequested)
+ p.Peer.Disconnect(p2p.DiscUselessPeer)
return nil
}
-// ids returns a list of all registered peer IDs
+// ids returns a list of all active peer IDs
func (ps *clientPeerSet) ids() []string {
ps.lock.RLock()
defer ps.lock.RUnlock()
var ids []string
- for id := range ps.peers {
+ for id := range ps.active {
ids = append(ids, id)
}
return ids
@@ -1007,24 +1073,27 @@ func (ps *clientPeerSet) peer(id string) *clientPeer {
ps.lock.RLock()
defer ps.lock.RUnlock()
- return ps.peers[id]
+ if p, ok := ps.active[id]; ok {
+ return p
+ }
+ return ps.inactive[id]
}
-// len returns if the current number of peers in the set.
+// len returns if the current number of peers in the active set.
func (ps *clientPeerSet) len() int {
ps.lock.RLock()
defer ps.lock.RUnlock()
- return len(ps.peers)
+ return len(ps.active)
}
-// allClientPeers returns all client peers in a list.
+// allClientPeers returns all active client peers in a list.
func (ps *clientPeerSet) allPeers() []*clientPeer {
ps.lock.RLock()
defer ps.lock.RUnlock()
- list := make([]*clientPeer, 0, len(ps.peers))
- for _, p := range ps.peers {
+ list := make([]*clientPeer, 0, len(ps.active))
+ for _, p := range ps.active {
list = append(list, p)
}
return list
@@ -1036,7 +1105,10 @@ func (ps *clientPeerSet) close() {
ps.lock.Lock()
defer ps.lock.Unlock()
- for _, p := range ps.peers {
+ for _, p := range ps.active {
+ p.Disconnect(p2p.DiscQuitting)
+ }
+ for _, p := range ps.inactive {
p.Disconnect(p2p.DiscQuitting)
}
ps.closed = true
@@ -1045,7 +1117,7 @@ func (ps *clientPeerSet) close() {
// serverPeerSet represents the set of active server peers currently
// participating in the Light Ethereum sub-protocol.
type serverPeerSet struct {
- peers map[string]*serverPeer
+ active, inactive map[string]*serverPeer
// subscribers is a batch of subscribers and peerset will notify
// these subscribers when the peerset changes(new server peer is
// added or removed)
@@ -1056,17 +1128,23 @@ type serverPeerSet struct {
// newServerPeerSet creates a new peer set to track the active server peers.
func newServerPeerSet() *serverPeerSet {
- return &serverPeerSet{peers: make(map[string]*serverPeer)}
+ return &serverPeerSet{
+ active: make(map[string]*serverPeer),
+ inactive: make(map[string]*serverPeer),
+ }
}
// subscribe adds a service to be notified about added or removed
// peers and also register all active peers into the given service.
func (ps *serverPeerSet) subscribe(sub serverPeerSubscriber) {
ps.lock.Lock()
- defer ps.lock.Unlock()
-
ps.subscribers = append(ps.subscribers, sub)
- for _, p := range ps.peers {
+ notify := make([]*serverPeer, 0, len(ps.active))
+ for _, p := range ps.active {
+ notify = append(notify, p)
+ }
+ ps.lock.Unlock()
+ for _, p := range notify {
sub.registerPeer(p)
}
}
@@ -1086,19 +1164,25 @@ func (ps *serverPeerSet) unSubscribe(sub serverPeerSubscriber) {
// register adds a new server peer into the set, or returns an error if the
// peer is already known.
-func (ps *serverPeerSet) register(peer *serverPeer) error {
+func (ps *serverPeerSet) register(p *serverPeer) error {
ps.lock.Lock()
- defer ps.lock.Unlock()
-
if ps.closed {
+ ps.lock.Unlock()
return errClosed
}
- if _, exist := ps.peers[peer.id]; exist {
+ if _, ok := ps.active[p.id]; ok {
+ ps.lock.Unlock()
return errAlreadyRegistered
}
- ps.peers[peer.id] = peer
- for _, sub := range ps.subscribers {
- sub.registerPeer(peer)
+ delete(ps.inactive, p.id)
+ ps.active[p.id] = p
+
+ peers := make([]serverPeerSubscriber, len(ps.subscribers))
+ copy(peers, ps.subscribers)
+ ps.lock.Unlock()
+
+ for _, n := range peers {
+ n.registerPeer(p)
}
return nil
}
@@ -1106,29 +1190,60 @@ func (ps *serverPeerSet) register(peer *serverPeer) error {
// unregister removes a remote peer from the active set, disabling any further
// actions to/from that particular entity. It also initiates disconnection at
// the networking layer.
-func (ps *serverPeerSet) unregister(id string) error {
+func (ps *serverPeerSet) unregister(p *serverPeer) error {
ps.lock.Lock()
- defer ps.lock.Unlock()
+ if _, ok := ps.active[p.id]; !ok {
+ ps.lock.Unlock()
+ return errNotRegistered
+ } else {
+ delete(ps.active, p.id)
+ ps.inactive[p.id] = p
+ peers := make([]serverPeerSubscriber, len(ps.subscribers))
+ copy(peers, ps.subscribers)
+ ps.lock.Unlock()
+
+ for _, n := range peers {
+ n.unregisterPeer(p)
+ }
+ return nil
+ }
+}
- p, ok := ps.peers[id]
- if !ok {
+// disconnect removes a remote peer from either the active or inactive set and
+// initiates disconnection at the networking layer.
+func (ps *serverPeerSet) disconnect(id string) error {
+ ps.lock.Lock()
+
+ var (
+ peers []serverPeerSubscriber
+ p *serverPeer
+ ok bool
+ )
+ if p, ok = ps.active[id]; ok {
+ delete(ps.active, p.id)
+ peers = make([]serverPeerSubscriber, len(ps.subscribers))
+ copy(peers, ps.subscribers)
+ } else if p, ok = ps.inactive[id]; ok {
+ delete(ps.inactive, id)
+ } else {
+ ps.lock.Unlock()
return errNotRegistered
}
- delete(ps.peers, id)
- for _, sub := range ps.subscribers {
- sub.unregisterPeer(p)
+ ps.lock.Unlock()
+ for _, n := range peers {
+ n.unregisterPeer(p)
}
- p.Peer.Disconnect(p2p.DiscRequested)
+ p.Peer.Disconnect(p2p.DiscUselessPeer)
return nil
}
-// ids returns a list of all registered peer IDs
+// ids returns a list of all active peer IDs
func (ps *serverPeerSet) ids() []string {
ps.lock.RLock()
defer ps.lock.RUnlock()
var ids []string
- for id := range ps.peers {
+ for id := range ps.active {
ids = append(ids, id)
}
return ids
@@ -1139,15 +1254,18 @@ func (ps *serverPeerSet) peer(id string) *serverPeer {
ps.lock.RLock()
defer ps.lock.RUnlock()
- return ps.peers[id]
+ if p, ok := ps.active[id]; ok {
+ return p
+ }
+ return ps.inactive[id]
}
-// len returns if the current number of peers in the set.
+// len returns if the current number of peers in the active set.
func (ps *serverPeerSet) len() int {
ps.lock.RLock()
defer ps.lock.RUnlock()
- return len(ps.peers)
+ return len(ps.active)
}
// bestPeer retrieves the known peer with the currently highest total difficulty.
@@ -1161,7 +1279,7 @@ func (ps *serverPeerSet) bestPeer() *serverPeer {
bestPeer *serverPeer
bestTd *big.Int
)
- for _, p := range ps.peers {
+ for _, p := range ps.active {
if td := p.Td(); bestTd == nil || td.Cmp(bestTd) > 0 {
bestPeer, bestTd = p, td
}
@@ -1169,13 +1287,13 @@ func (ps *serverPeerSet) bestPeer() *serverPeer {
return bestPeer
}
-// allServerPeers returns all server peers in a list.
+// allPeers returns all active server peers in a list.
func (ps *serverPeerSet) allPeers() []*serverPeer {
ps.lock.RLock()
defer ps.lock.RUnlock()
- list := make([]*serverPeer, 0, len(ps.peers))
- for _, p := range ps.peers {
+ list := make([]*serverPeer, 0, len(ps.active))
+ for _, p := range ps.active {
list = append(list, p)
}
return list
@@ -1187,7 +1305,10 @@ func (ps *serverPeerSet) close() {
ps.lock.Lock()
defer ps.lock.Unlock()
- for _, p := range ps.peers {
+ for _, p := range ps.active {
+ p.Disconnect(p2p.DiscQuitting)
+ }
+ for _, p := range ps.inactive {
p.Disconnect(p2p.DiscQuitting)
}
ps.closed = true
diff --git a/les/peer_test.go b/les/peer_test.go
index 59a2ad700954..a5da9a7a9682 100644
--- a/les/peer_test.go
+++ b/les/peer_test.go
@@ -85,7 +85,7 @@ func TestPeerSubscription(t *testing.T) {
checkIds([]string{peer.id})
checkPeers(sub.regCh)
- peers.unregister(peer.id)
+ peers.unregister(peer)
checkIds([]string{})
checkPeers(sub.unregCh)
}
diff --git a/les/protocol.go b/les/protocol.go
index 36af88aea6d0..5140795ecca3 100644
--- a/les/protocol.go
+++ b/les/protocol.go
@@ -33,17 +33,18 @@ import (
const (
lpv2 = 2
lpv3 = 3
+ lpv4 = 4
)
// Supported versions of the les protocol (first is primary)
var (
- ClientProtocolVersions = []uint{lpv2, lpv3}
- ServerProtocolVersions = []uint{lpv2, lpv3}
+ ClientProtocolVersions = []uint{lpv2, lpv3, lpv4}
+ ServerProtocolVersions = []uint{lpv2, lpv3, lpv4}
AdvertiseProtocolVersions = []uint{lpv2} // clients are searching for the first advertised protocol in the list
)
// Number of implemented message corresponding to different protocol versions.
-var ProtocolLengths = map[uint]uint64{lpv2: 22, lpv3: 24}
+var ProtocolLengths = map[uint]uint64{lpv2: 22, lpv3: 24, lpv4: 26}
const (
NetworkId = 1
@@ -74,6 +75,9 @@ const (
// Protocol messages introduced in LPV3
StopMsg = 0x16
ResumeMsg = 0x17
+ // Protocol messages introduced in LPV4
+ LespayMsg = 0x18
+ LespayReplyMsg = 0x19
)
type requestInfo struct {
@@ -201,6 +205,11 @@ type hashOrNumber struct {
Number uint64 // Block hash from which to retrieve headers (excludes Hash)
}
+type lespayReply struct {
+ Reply []byte
+ Delay uint
+}
+
// EncodeRLP is a specialized encoder for hashOrNumber to encode only one of the
// two contained union fields.
func (hn *hashOrNumber) EncodeRLP(w io.Writer) error {
@@ -235,3 +244,28 @@ func (hn *hashOrNumber) DecodeRLP(s *rlp.Stream) error {
type CodeData []struct {
Value []byte
}
+
+type stateFeedbackV4 struct {
+ BV, RealCost, TokenBalance uint64
+}
+
+type stateFeedback struct {
+ protocolVersion int
+ stateFeedbackV4
+}
+
+func (sf stateFeedback) EncodeRLP(w io.Writer) error {
+ if sf.protocolVersion >= lpv4 {
+ return rlp.Encode(w, sf.stateFeedbackV4)
+ } else {
+ return rlp.Encode(w, sf.BV)
+ }
+}
+
+func (sf *stateFeedback) DecodeRLP(s *rlp.Stream) error {
+ if sf.protocolVersion >= lpv4 {
+ return s.Decode(&sf.stateFeedbackV4)
+ } else {
+ return s.Decode(&sf.BV)
+ }
+}
diff --git a/les/request_test.go b/les/request_test.go
index f58ebca9c1d9..f5e370a6ba6d 100644
--- a/les/request_test.go
+++ b/les/request_test.go
@@ -36,22 +36,22 @@ func secAddr(addr common.Address) []byte {
type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest
-func TestBlockAccessLes2(t *testing.T) { testAccess(t, 2, tfBlockAccess) }
func TestBlockAccessLes3(t *testing.T) { testAccess(t, 3, tfBlockAccess) }
+func TestBlockAccessLes4(t *testing.T) { testAccess(t, 4, tfBlockAccess) }
func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
return &light.BlockRequest{Hash: bhash, Number: number}
}
-func TestReceiptsAccessLes2(t *testing.T) { testAccess(t, 2, tfReceiptsAccess) }
func TestReceiptsAccessLes3(t *testing.T) { testAccess(t, 3, tfReceiptsAccess) }
+func TestReceiptsAccessLes4(t *testing.T) { testAccess(t, 4, tfReceiptsAccess) }
func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
return &light.ReceiptsRequest{Hash: bhash, Number: number}
}
-func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) }
func TestTrieEntryAccessLes3(t *testing.T) { testAccess(t, 3, tfTrieEntryAccess) }
+func TestTrieEntryAccessLes4(t *testing.T) { testAccess(t, 4, tfTrieEntryAccess) }
func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
if number := rawdb.ReadHeaderNumber(db, bhash); number != nil {
@@ -60,8 +60,8 @@ func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) ligh
return nil
}
-func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) }
func TestCodeAccessLes3(t *testing.T) { testAccess(t, 3, tfCodeAccess) }
+func TestCodeAccessLes4(t *testing.T) { testAccess(t, 4, tfCodeAccess) }
func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrRequest {
number := rawdb.ReadHeaderNumber(db, bhash)
diff --git a/les/retrieve.go b/les/retrieve.go
index 5fa68b745682..45976bc9eca0 100644
--- a/les/retrieve.go
+++ b/les/retrieve.go
@@ -345,7 +345,7 @@ func (r *sentReq) tryRequest() {
if hrto {
pp.Log().Debug("Request timed out hard")
if r.rm.peers != nil {
- r.rm.peers.unregister(pp.id)
+ r.rm.peers.disconnect(pp.id)
}
}
diff --git a/les/server.go b/les/server.go
index f72f31321abf..2430d1914a5f 100644
--- a/les/server.go
+++ b/les/server.go
@@ -44,6 +44,7 @@ type LesServer struct {
handler *serverHandler
lesTopics []discv5.Topic
privateKey *ecdsa.PrivateKey
+ srvr *p2p.Server
// Flow control and capacity management
fcManager *flowcontrol.ClientManager
@@ -51,6 +52,7 @@ type LesServer struct {
defParams flowcontrol.ServerParams
servingQueue *servingQueue
clientPool *clientPool
+ tokenSale *tokenSale
minCapacity, maxCapacity, freeCapacity uint64
threadsIdle int // Request serving threads count when system is idle.
@@ -116,8 +118,13 @@ func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) {
srv.maxCapacity = totalRecharge
}
srv.fcManager.SetCapacityLimits(srv.freeCapacity, srv.maxCapacity, srv.freeCapacity*2)
- srv.clientPool = newClientPool(srv.chainDb, srv.freeCapacity, mclock.System{}, func(id enode.ID) { go srv.peers.unregister(peerIdToString(id)) })
+ srv.clientPool = newClientPool(srv.chainDb, srv.minCapacity, srv.freeCapacity, mclock.System{}, func(id enode.ID) { go srv.peers.disconnect(peerIdToString(id)) })
srv.clientPool.setDefaultFactors(priceFactors{0, 1, 1}, priceFactors{0, 1, 1})
+ srv.tokenSale = newTokenSale(srv.clientPool, 0.1, 100)
+ if config.LespayTestModule {
+ srv.tokenSale.addReceiver("test", testReceiver{})
+ srv.clientPool.setExpirationTCs(defaultPosExpTC, defaultNegExpTC)
+ }
checkpoint := srv.latestLocalCheckpoint()
if !checkpoint.Empty() {
@@ -148,6 +155,12 @@ func (s *LesServer) APIs() []rpc.API {
Service: NewPrivateDebugAPI(s),
Public: false,
},
+ {
+ Namespace: "lespay",
+ Version: "1.0",
+ Service: NewPrivateLespayAPI(s.peers, nil, nil, s.srvr.DiscV5, s.tokenSale),
+ Public: false,
+ },
}
}
@@ -167,6 +180,7 @@ func (s *LesServer) Protocols() []p2p.Protocol {
// Start starts the LES server
func (s *LesServer) Start(srvr *p2p.Server) {
+ s.srvr = srvr
s.privateKey = srvr.PrivateKey
s.handler.start()
@@ -174,6 +188,7 @@ func (s *LesServer) Start(srvr *p2p.Server) {
go s.capacityManagement()
if srvr.DiscV5 != nil {
+ srvr.DiscV5.RegisterTalkHandler("lespay", s.handler.talkRequestHandler)
for _, topic := range s.lesTopics {
topic := topic
go func() {
@@ -191,6 +206,11 @@ func (s *LesServer) Start(srvr *p2p.Server) {
func (s *LesServer) Stop() {
close(s.closeCh)
+ if s.srvr.DiscV5 != nil {
+ s.srvr.DiscV5.RemoveTalkHandler("lespay")
+ }
+ s.tokenSale.stop()
+
// Disconnect existing sessions.
// This also closes the gate for any new registrations on the peer set.
// sessions which are already established but not added to pm.peers yet
diff --git a/les/server_handler.go b/les/server_handler.go
index 186bdcbb03f4..b928c3fb0b24 100644
--- a/les/server_handler.go
+++ b/les/server_handler.go
@@ -20,6 +20,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"
+ "net"
"sync"
"sync/atomic"
"time"
@@ -35,6 +36,7 @@ import (
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
)
@@ -54,10 +56,7 @@ const (
MaxTxStatus = 256 // Amount of transactions to queried per request
)
-var (
- errTooManyInvalidRequest = errors.New("too many invalid requests made")
- errFullClientPool = errors.New("client pool is full")
-)
+var errTooManyInvalidRequest = errors.New("too many invalid requests made")
// serverHandler is responsible for serving light client and process
// all incoming light requests.
@@ -103,6 +102,9 @@ func (h *serverHandler) stop() {
func (h *serverHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error {
peer := newClientPeer(int(version), h.server.config.NetworkId, p, newMeteredMsgWriter(rw, int(version)))
defer peer.close()
+ peer.getBalance = func() uint64 {
+ return h.server.clientPool.getPosBalance(p.ID()).value.value(h.server.clientPool.posExpiration(mclock.Now()))
+ }
h.wg.Add(1)
defer h.wg.Done()
return h.handle(peer)
@@ -134,28 +136,58 @@ func (h *serverHandler) handle(p *clientPeer) error {
}
defer p.fcClient.Disconnect()
- // Disconnect the inbound peer if it's rejected by clientPool
- if !h.server.clientPool.connect(p, 0) {
- p.Log().Debug("Light Ethereum peer registration failed", "err", errFullClientPool)
- return errFullClientPool
+ var (
+ connectedAt mclock.AbsTime
+ wg *sync.WaitGroup // Wait group used to track all in-flight task routines.
+ )
+ p.activate = func() {
+ // Register the peer locally
+ if err := h.server.peers.register(p); err != nil {
+ h.server.clientPool.disconnect(p)
+ p.Log().Error("Light Ethereum peer registration failed", "err", err)
+ return
+ }
+ clientConnectionGauge.Update(int64(h.server.peers.len()))
+ connectedAt = mclock.Now()
+ wg = new(sync.WaitGroup)
+ p.active = true
}
- // Register the peer locally
- if err := h.server.peers.register(p); err != nil {
- h.server.clientPool.disconnect(p)
- p.Log().Error("Light Ethereum peer registration failed", "err", err)
- return err
+ p.deactivate = func() {
+ h.server.peers.unregister(p)
+ if p.version < lpv4 {
+ h.server.peers.disconnect(p.id)
+ }
+ clientConnectionGauge.Update(int64(h.server.peers.len()))
+ connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
+ p.active = false
+ }
+ if p.active {
+ p.activate()
}
- clientConnectionGauge.Update(int64(h.server.peers.len()))
- var wg sync.WaitGroup // Wait group used to track all in-flight task routines.
+ if capacity, err := h.server.clientPool.connect(p, 0); err != nil {
+ // Disconnect the inbound peer if it's rejected by clientPool
+ p.Log().Debug("Light Ethereum peer registration failed", "err", err)
+ return err
+ } else if capacity != p.fcParams.MinRecharge {
+ if p.version < lpv4 {
+ h.server.peers.disconnect(p.id)
+ } else {
+ p.updateCapacity(capacity)
+ }
+ }
- connectedAt := mclock.Now()
defer func() {
wg.Wait() // Ensure all background task routines have exited.
- h.server.peers.unregister(p.id)
h.server.clientPool.disconnect(p)
- clientConnectionGauge.Update(int64(h.server.peers.len()))
- connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
+ p.responseLock.Lock()
+ if p.active {
+ p.deactivate()
+ }
+ p.activate = nil
+ p.deactivate = nil
+ p.responseLock.Unlock()
+ h.server.peers.disconnect(p.id)
}()
// Spawn a main loop to handle all incoming messages.
@@ -166,7 +198,7 @@ func (h *serverHandler) handle(p *clientPeer) error {
return err
default:
}
- if err := h.handleMsg(p, &wg); err != nil {
+ if err := h.handleMsg(p, wg); err != nil {
p.Log().Debug("Light Ethereum message handling failed", "err", err)
return err
}
@@ -245,22 +277,33 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error {
if reply != nil {
replySize = reply.size()
}
- var realCost uint64
+ var realCost, balance uint64
if h.server.costTracker.testing {
realCost = maxCost // Assign a fake cost for testing purpose
} else {
realCost = h.server.costTracker.realCost(servingTime, msg.Size, replySize)
+ if realCost > maxCost {
+ realCost = maxCost
+ }
}
bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
if amount != 0 {
// Feed cost tracker request serving statistic.
h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost)
// Reduce priority "balance" for the specific peer.
- h.server.clientPool.requestCost(p, realCost)
+ balance = h.server.clientPool.requestCost(p, realCost)
+ }
+ sf := stateFeedback{
+ protocolVersion: p.version,
+ stateFeedbackV4: stateFeedbackV4{
+ BV: bv,
+ RealCost: realCost,
+ TokenBalance: balance,
+ },
}
if reply != nil {
p.mustQueueSend(func() {
- if err := reply.send(bv); err != nil {
+ if err := reply.send(sf); err != nil {
select {
case p.errCh <- err:
default:
@@ -374,7 +417,7 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error {
first = false
}
reply := p.replyBlockHeaders(req.ReqID, headers)
- sendResponse(req.ReqID, query.Amount, p.replyBlockHeaders(req.ReqID, headers), task.done())
+ sendResponse(req.ReqID, query.Amount, reply, task.done())
if metrics.EnabledExpensive {
miscOutHeaderPacketsMeter.Mark(1)
miscOutHeaderTrafficMeter.Mark(int64(reply.size()))
@@ -824,6 +867,38 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error {
}
}()
}
+ case LespayMsg:
+ p.Log().Trace("Received transaction status query request")
+ if metrics.EnabledExpensive {
+ miscInLespayPacketsMeter.Mark(1)
+ miscInLespayTrafficMeter.Mark(int64(msg.Size))
+ defer func(start time.Time) { miscServingTimeLespayTimer.UpdateSince(start) }(time.Now())
+ }
+ var req struct {
+ ReqID uint64
+ Cmd []byte
+ }
+ if err := msg.Decode(&req); err != nil {
+ clientErrorMeter.Mark(1)
+ return errResp(ErrDecode, "msg %v: %v", msg, err)
+ }
+ if !h.server.tokenSale.queueCommand(p.id, lespayCmd{
+ cmd: req.Cmd,
+ id: p.ID(),
+ freeID: p.freeClientId(),
+ send: func(reply []byte, delay uint) {
+ if metrics.EnabledExpensive {
+ miscOutLespayPacketsMeter.Mark(1)
+ miscOutLespayTrafficMeter.Mark(int64(len(reply)))
+ }
+ p.queueSend(func() {
+ p.replyLespay(req.ReqID, reply, delay)
+ })
+ },
+ }) {
+ clientErrorMeter.Mark(1)
+ return errResp(ErrRequestRejected, "")
+ }
default:
p.Log().Trace("Received invalid message", "code", msg.Code)
@@ -959,3 +1034,49 @@ func (h *serverHandler) broadcastHeaders() {
}
}
}
+
+// talkRequestHandler implements discv5.TalkRequestHandler. It processes a list of
+// lespay token sale commands and returns the results and the recommended delay.
+//
+// Note: the UDP talk format for lespay commands allows multiple commands in a single
+// packet because UDP does not guarantee the correct order of messages which might be
+// important in some cases (like deposit followed by buyTokens).
+func (h *serverHandler) talkRequestHandler(id enode.ID, addr *net.UDPAddr, payload interface{}) (interface{}, uint, bool) {
+ c, ok := payload.([]interface{})
+ if !ok {
+ return nil, 0, false
+ }
+ type result struct {
+ data []byte
+ delay uint
+ }
+ resultCh := make(chan result, len(c))
+ results := make([][]byte, len(c))
+ for _, c := range c {
+ cmd, ok := c.([]byte)
+ if !ok {
+ return nil, 0, false
+ }
+ if !h.server.tokenSale.queueCommand(id.String(), lespayCmd{
+ cmd: cmd,
+ id: id,
+ freeID: addr.IP.String(),
+ send: func(reply []byte, delay uint) {
+ resultCh <- result{reply, delay}
+ },
+ }) {
+ return nil, 0, false
+ }
+ }
+
+ var lastDelay uint
+ for i := range results {
+ select {
+ case r := <-resultCh:
+ results[i], lastDelay = r.data, r.delay
+ case <-h.closeCh:
+ return nil, 0, false
+ }
+ }
+ return results, lastDelay, true
+}
diff --git a/les/servingqueue.go b/les/servingqueue.go
index 9db84e6159cf..b1a8dda80b3d 100644
--- a/les/servingqueue.go
+++ b/les/servingqueue.go
@@ -70,7 +70,7 @@ type runToken chan struct{}
// start blocks until the task can start and returns true if it is allowed to run.
// Returning false means that the task should be cancelled.
func (t *servingTask) start() bool {
- if t.peer.isFrozen() {
+ if t.peer != nil && t.peer.isFrozen() {
return false
}
t.tokenCh = make(chan runToken, 1)
@@ -289,7 +289,7 @@ func (sq *servingQueue) addTask(task *servingTask) {
sq.queuedTime += task.expTime
sqServedGauge.Update(int64(sq.recentTime))
sqQueuedGauge.Update(int64(sq.queuedTime))
- if sq.recentTime+sq.queuedTime > sq.burstLimit {
+ if sq.burstLimit != 0 && sq.recentTime+sq.queuedTime > sq.burstLimit {
sq.freezePeers()
}
}
diff --git a/les/test_helper.go b/les/test_helper.go
index d9ffe32db205..da8222add4f9 100644
--- a/les/test_helper.go
+++ b/les/test_helper.go
@@ -78,10 +78,10 @@ var (
processConfirms = big.NewInt(1)
// The token bucket buffer limit for testing purpose.
- testBufLimit = uint64(1000000)
+ testBufLimit = uint64(6000)
// The buffer recharging speed for testing purpose.
- testBufRecharge = uint64(1000)
+ testBufRecharge = uint64(1)
)
/*
@@ -281,7 +281,7 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da
}
server.costTracker, server.freeCapacity = newCostTracker(db, server.config)
server.costTracker.testCostList = testCostList(0) // Disable flow control mechanism.
- server.clientPool = newClientPool(db, 1, clock, nil)
+ server.clientPool = newClientPool(db, 1, 1, clock, nil)
server.clientPool.setLimits(10000, 10000) // Assign enough capacity for clientpool
server.handler = newServerHandler(server, simulation.Blockchain(), db, txpool, func() bool { return true })
if server.oracle != nil {
@@ -309,7 +309,8 @@ func newTestPeer(t *testing.T, name string, version int, handler *serverHandler,
// Generate a random id and create the peer
var id enode.ID
rand.Read(id[:])
- peer := newClientPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net)
+ cpeer := newClientPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net)
+ speer := newServerPeer(version, NetworkId, false, p2p.NewPeer(id, name, nil), app)
// Start the peer on a new thread
errCh := make(chan error, 1)
@@ -317,13 +318,14 @@ func newTestPeer(t *testing.T, name string, version int, handler *serverHandler,
select {
case <-handler.closeCh:
errCh <- p2p.DiscQuitting
- case errCh <- handler.handle(peer):
+ case errCh <- handler.handle(cpeer):
}
}()
tp := &testPeer{
app: app,
net: net,
- cpeer: peer,
+ cpeer: cpeer,
+ speer: speer,
}
// Execute any implicitly requested handshakes and return
if shake {
@@ -395,8 +397,13 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu
expList = expList.add("serveStateSince", uint64(0))
expList = expList.add("serveRecentState", uint64(core.TriesInMemory-4))
expList = expList.add("txRelay", nil)
- expList = expList.add("flowControl/BL", testBufLimit)
- expList = expList.add("flowControl/MRR", testBufRecharge)
+ if p.cpeer.version >= lpv4 {
+ expList = expList.add("flowControl/BL", uint64(0))
+ expList = expList.add("flowControl/MRR", uint64(0))
+ } else {
+ expList = expList.add("flowControl/BL", testBufLimit)
+ expList = expList.add("flowControl/MRR", testBufRecharge)
+ }
expList = expList.add("flowControl/MRC", costList)
if err := p2p.ExpectMsg(p.app, StatusMsg, expList); err != nil {
@@ -405,9 +412,16 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu
if err := p2p.Send(p.app, StatusMsg, sendList); err != nil {
t.Fatalf("status send: %v", err)
}
- p.cpeer.fcParams = flowcontrol.ServerParams{
- BufLimit: testBufLimit,
- MinRecharge: testBufRecharge,
+}
+
+func (p *testPeer) expectCapUpdate(t *testing.T) {
+ if p.cpeer.version >= lpv4 {
+ var expList keyValueList
+ expList = expList.add("flowControl/BL", testBufLimit)
+ expList = expList.add("flowControl/MRR", testBufRecharge)
+ if err := p2p.ExpectMsg(p.app, AnnounceMsg, announceData{Update: expList}); err != nil {
+ t.Fatalf("status recv: %v", err)
+ }
}
}
@@ -438,7 +452,7 @@ type testServer struct {
bloomTrieIndexer *core.ChainIndexer
}
-func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, simClock bool, newPeer bool, testCost uint64) (*testServer, func()) {
+func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, simClock bool, newPeer bool, testCost uint64, expectCapUpdate bool) (*testServer, func()) {
db := rawdb.NewMemoryDatabase()
indexers := testIndexers(db, nil, light.TestServerIndexerConfig)
@@ -480,6 +494,9 @@ func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallba
cIndexer.Close()
bIndexer.Close()
}
+ if expectCapUpdate {
+ server.peer.expectCapUpdate(t)
+ }
return server, teardown
}
diff --git a/les/tokensale.go b/les/tokensale.go
new file mode 100644
index 000000000000..099058fa7138
--- /dev/null
+++ b/les/tokensale.go
@@ -0,0 +1,860 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package les
+
+import (
+ "encoding/binary"
+ "fmt"
+ "math"
+ "strconv"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common/mclock"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+const (
+ basePriceTC = time.Hour * 10 // time constant for controlling the base price
+ tokenSellMaxRatio = 0.9 // total amount/supply limit ratio over which selling price does not increase further
+ tsMinDelay = time.Second * 5 // minimum recommended delay for sending the next command
+ tsMaxBurst = 16 // maximum commands processed in a row before the recommended delay has elapsed
+)
+
+// paymentReceiver processes incoming payments and can be implemented using different
+// payment technologies
+type paymentReceiver interface {
+ info() keyValueList
+ receivePayment(from enode.ID, proofOfPayment []byte, reader ethdb.KeyValueReader, writer ethdb.KeyValueWriter) (value uint64, err error)
+ requestPayment(from enode.ID, value uint64, reader ethdb.KeyValueReader) uint64
+}
+
+// tokenSale handles client balance deposits, conversion to and from service tokens
+// and granting connections and capacity changes through a set of commands called "lespay".
+type tokenSale struct {
+ lock sync.Mutex
+ clientPool *clientPool
+ stopCh chan struct{}
+ receivers map[string]paymentReceiver
+ receiverNames []string
+ basePrice, minBasePrice float64
+ totalTokenLimit, totalTokenAmount func() uint64
+
+ qlock sync.Mutex
+ sq *servingQueue
+ sources map[string]*cmdSource
+ delayFactorZero, delayFactorLast mclock.AbsTime
+ tsProcessDelay, tsTargetPeriod time.Duration
+}
+
+// newTokenSale creates a new token sale module instance
+func newTokenSale(clientPool *clientPool, minBasePrice float64, talkSpeed int) *tokenSale {
+ t := &tokenSale{
+ clientPool: clientPool,
+ receivers: make(map[string]paymentReceiver),
+ basePrice: minBasePrice,
+ minBasePrice: minBasePrice,
+ totalTokenLimit: clientPool.totalTokenLimit,
+ totalTokenAmount: clientPool.totalTokenAmount,
+ stopCh: make(chan struct{}),
+ sq: newServingQueue(0, 0),
+ sources: make(map[string]*cmdSource),
+ delayFactorZero: mclock.Now(),
+ delayFactorLast: mclock.Now(),
+ tsProcessDelay: time.Second / time.Duration(talkSpeed),
+ tsTargetPeriod: 5 * time.Second / time.Duration(talkSpeed),
+ }
+ t.sq.setThreads(1)
+ go func() {
+ cleanupCounter := 0
+ for {
+ select {
+ case <-time.After(time.Second * 10):
+ t.lock.Lock()
+ cost, ok := t.tokenPrice(1, true)
+ if cost > t.basePrice*10 || !ok {
+ cost = t.basePrice * 10
+ }
+ t.basePrice += (cost - t.basePrice) * float64(time.Second*10) / float64(basePriceTC)
+ if t.basePrice < minBasePrice {
+ t.basePrice = minBasePrice
+ }
+ t.lock.Unlock()
+
+ cleanupCounter++
+ if cleanupCounter == 100 {
+ t.sourceMapCleanup()
+ cleanupCounter = 0
+ }
+ case <-t.stopCh:
+ return
+ }
+ }
+ }()
+ return t
+}
+
+type (
+ // cmdSource represents a source where lespay commands can come from.
+ // It can be either an LES connected peer or a UDP address.
+ cmdSource struct {
+ ch chan lespayCmd
+ delayUntil mclock.AbsTime
+ burstCounter int
+ }
+ // lespayCmd represents a single lespay command, including the source it came
+ // from and the callback that is going to process the results.
+ lespayCmd struct {
+ cmd []byte
+ id enode.ID
+ freeID string
+ send func([]byte, uint)
+ }
+)
+
+// priority returns the processing priority for the next command coming from the given
+// source. Commands sent before the previously recommended delay has elapsed have a
+// lower priority. It also checks whether the number of commands consecutively sent
+// before the delay has elapsed exceeds maxBurst and rejects the command instantly if
+// necessary.
+func (c *cmdSource) priority() (int64, bool) {
+ dt := c.delayUntil - mclock.Now()
+ if dt <= 0 {
+ c.burstCounter = 0
+ return 0, true
+ }
+ if c.burstCounter >= tsMaxBurst {
+ return 0, false
+ }
+ c.burstCounter++
+ return -int64(dt), true
+}
+
+// addDelay adds the given amount to the recommended delay
+func (c *cmdSource) addDelay(now mclock.AbsTime, delay time.Duration) uint {
+ dt := time.Duration(c.delayUntil - now)
+ if dt <= 0 {
+ dt = 0
+ }
+ dt += delay
+ if dt < tsMinDelay {
+ dt = tsMinDelay
+ }
+ c.delayUntil = now + mclock.AbsTime(dt)
+ return uint((dt + time.Second - 1) / time.Second)
+}
+
+// delayFactor calculates the amount added to the recommended delay after processing
+// a single command
+func (t *tokenSale) delayFactor(now mclock.AbsTime) time.Duration {
+ if now > t.delayFactorZero {
+ t.delayFactorZero = now
+ }
+ t.delayFactorZero += mclock.AbsTime(t.tsTargetPeriod) + t.delayFactorLast - now
+ t.delayFactorLast = now
+ if now >= t.delayFactorZero {
+ return 0
+ } else {
+ return time.Duration(t.delayFactorZero-now) / 4
+ }
+}
+
+// sourceMapCleanup removes unnecessary entries from the command source map
+func (t *tokenSale) sourceMapCleanup() {
+ t.qlock.Lock()
+ defer t.qlock.Unlock()
+
+ now := mclock.Now()
+ for src, s := range t.sources {
+ if s.delayUntil < now {
+ delete(t.sources, src)
+ }
+ }
+}
+
+// queueCommand schedules a lespay command (encapsulated in a lespayCmd) for execution
+func (t *tokenSale) queueCommand(src string, cmd lespayCmd) bool {
+ t.qlock.Lock()
+ defer t.qlock.Unlock()
+
+ s := t.sources[src]
+ if s == nil {
+ s = &cmdSource{}
+ t.sources[src] = s
+ }
+ if s.ch != nil {
+ select {
+ case s.ch <- cmd:
+ return true
+ default:
+ return false
+ }
+ }
+ s.ch = make(chan lespayCmd, 16)
+ s.ch <- cmd
+
+ go func() {
+ loop:
+ for {
+ select {
+ case cmd := <-s.ch:
+ t.qlock.Lock()
+ pri, ok := s.priority()
+ t.qlock.Unlock()
+ if ok {
+ task := t.sq.newTask(nil, 0, pri)
+ if !task.start() {
+ break loop
+ }
+ reply := t.runCommand(cmd.cmd, cmd.id, cmd.freeID)
+ t.qlock.Lock()
+ now := mclock.Now()
+ delay := s.addDelay(now, t.delayFactor(now))
+ t.qlock.Unlock()
+ cmd.send(reply, delay)
+ time.Sleep(t.tsProcessDelay)
+ task.done()
+ } else {
+ cmd.send(nil, 0)
+ }
+ default:
+ break loop
+ }
+ t.qlock.Lock()
+ s.ch = nil
+ t.qlock.Unlock()
+ }
+ }()
+ return true
+}
+
+// stop stops the token sale module
+func (t *tokenSale) stop() {
+ close(t.stopCh)
+ t.sq.stop()
+}
+
+// addReceiver adds a new payment receiver module
+func (t *tokenSale) addReceiver(id string, r paymentReceiver) {
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ t.receivers[id] = r
+ t.receiverNames = append(t.receiverNames, id)
+}
+
+// tokenPrice returns the PC units required to buy the specified amount of service
+// tokens or the PC units received when selling the given amount of tokens.
+// Returns false if not possible.
+//
+// Note: the price of each token unit depends on the current amount of existing tokens
+// and the total token limit, first raising from 0 to basePrice linearly, then tends to
+// infinity as tokenAmount approaches tokenLimit.
+//
+// if 0 <= tokenAmount <= tokenLimit/2:
+// tokenPrice = basePrice*tokenAmount/(tokenLimit/2)
+// if tokenLimit/2 <= tokenAmount < tokenLimit:
+// tokenPrice = basePrice*tokenLimit/2/(tokenLimit-tokenAmount)
+//
+// The price of multiple tokens is calculated as an integral based on the above formula.
+func (t *tokenSale) tokenPrice(buySellAmount uint64, buy bool) (float64, bool) {
+ tokenLimit := t.totalTokenLimit()
+ tokenAmount := t.totalTokenAmount()
+ if buy {
+ if tokenAmount+buySellAmount >= tokenLimit {
+ return 0, false
+ }
+ } else {
+ maxAmount := uint64(float64(tokenLimit) * tokenSellMaxRatio)
+ if tokenAmount > maxAmount {
+ tokenAmount = maxAmount
+ }
+ if tokenAmount < buySellAmount {
+ buySellAmount = tokenAmount
+ }
+ tokenAmount -= buySellAmount
+ }
+ r := float64(tokenAmount) / float64(tokenLimit)
+ b := float64(buySellAmount) / float64(tokenLimit)
+ var relPrice float64
+ if r < 0.5 {
+ // first purchased token is in the linear range
+ if r+b <= 0.5 {
+ // all purchased tokens are in the linear range
+ relPrice = b * (r + r + b)
+ b = 0
+ } else {
+ // some purchased tokens are in the 1/x range, calculate linear price
+ // update starting point and amount left to buy in the 1/x range
+ relPrice = (0.5 - r) * (r + 0.5)
+ b = r + b - 0.5
+ r = 0.5
+ }
+ }
+ if b > 0 {
+ // some purchased tokens are in the 1/x range
+ l := 1 - r
+ if l < 1e-10 {
+ return 0, false
+ }
+ l = -b / l
+ if l < -1+1e-10 {
+ return 0, false
+ }
+ relPrice += -math.Log1p(l) / 2
+ }
+ return t.basePrice * float64(tokenLimit) * relPrice, true
+}
+
+// tokenBuyAmount returns the service token amount currently available for the given
+// sum of PC units
+func (t *tokenSale) tokenBuyAmount(price float64) uint64 {
+ tokenLimit := t.totalTokenLimit()
+ tokenAmount := t.totalTokenAmount()
+ if tokenLimit <= tokenAmount {
+ return 0
+ }
+ r := float64(tokenAmount) / float64(tokenLimit)
+ c := price / (t.basePrice * float64(tokenLimit))
+ var relTokens float64
+ if r < 0.5 {
+ // first purchased token is in the linear range
+ relTokens = math.Sqrt(r*r+c) - r
+ if r+relTokens <= 0.5 {
+ // all purchased tokens are in the linear range, no more to spend
+ c = 0
+ } else {
+ // some purchased tokens are in the 1/x range, calculate linear amount
+ // update starting point and available funds left to buy in the 1/x range
+ relTokens = 0.5 - r
+ c -= (0.5 - r) * (r + 0.5)
+ r = 0.5
+ }
+ }
+ if c > 0 {
+ relTokens -= math.Expm1(-2*c) * (1 - r)
+ }
+ return uint64(relTokens * float64(tokenLimit))
+}
+
+// tokenSellAmount returns the service token amount that needs to be sold in order
+// to receive the given sum of PC units. Returns false if not possible.
+func (t *tokenSale) tokenSellAmount(price float64) (uint64, bool) {
+ tokenLimit := t.totalTokenLimit()
+ tokenAmount := t.totalTokenAmount()
+ r := float64(tokenAmount) / float64(tokenLimit)
+ if r > tokenSellMaxRatio {
+ r = tokenSellMaxRatio
+ }
+ c := price / (t.basePrice * float64(tokenLimit))
+ var relTokens float64
+ if r > 0.5 {
+ // first sold token is in the 1/x range
+ relTokens = math.Expm1(2*c) * (1 - r)
+ if r-relTokens >= 0.5 || 1-r < 1e-10 {
+ // all sold tokens are in the 1/x range, no more to sell
+ c = 0
+ } else {
+ // some sold tokens are in the linear range, calculate price in 1/x range
+ // update starting point and remaining price to sell for in the linear range
+ relTokens = r - 0.5
+ c -= math.Log1p(relTokens/(1-r)) / 2
+ r = 0.5
+ }
+ }
+ if c > 0 {
+ // some sold tokens are in the linear range
+ if x := r*r - c; x >= 0 {
+ relTokens += r - math.Sqrt(x)
+ } else {
+ return 0, false
+ }
+ }
+ return uint64(relTokens * float64(tokenLimit)), true
+}
+
+// connection checks whether it is possible with the current balance levels to establish
+// requested connection or capacity change and then stay connected for the given amount
+// of time. If it is possible and setCap is also true then the client is activated of the
+// capacity change is performed. If not then returns how many tokens are missing and how
+// much that would currently cost using the specified payment module(s).
+func (t *tokenSale) connection(id enode.ID, freeID string, requestedCapacity uint64, stayConnected time.Duration, paymentModule []string, setCap bool) (availableCapacity, tokenBalance, tokensMissing, pcBalance, pcMissing uint64, paymentRequired []uint64, err error) {
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ tokensMissing, availableCapacity, err = t.clientPool.setCapacityLocked(id, freeID, requestedCapacity, stayConnected, setCap)
+ pb := t.clientPool.getPosBalance(id)
+ tokenBalance = pb.value.value(t.clientPool.posExpiration(mclock.Now()))
+ cb := t.clientPool.ndb.getCurrencyBalance(id)
+ pcBalance = cb.amount
+ if tokensMissing == 0 {
+ return
+ }
+ tokenLimit := t.clientPool.totalTokenLimit()
+ tokenAmount := t.clientPool.totalTokenAmount()
+ if tokenLimit <= tokenAmount || tokenLimit-tokenAmount <= tokensMissing {
+ pcMissing = math.MaxUint64
+ } else {
+ tokensAvailable := tokenLimit - tokenAmount
+ pcr := -math.Log(float64(tokensAvailable-tokensMissing)/float64(tokensAvailable)) * t.basePrice
+ if pcr > 0 {
+ if pcr > maxBalance {
+ pcMissing = math.MaxUint64
+ } else {
+ pcMissing = uint64(pcr)
+ if pcMissing > maxBalance {
+ pcMissing = math.MaxUint64
+ } else {
+ if pcMissing > pcBalance {
+ pcMissing -= pcBalance
+ } else {
+ pcMissing = 0
+ }
+ }
+ }
+ }
+ }
+ if pcMissing == 0 {
+ return
+ }
+ paymentRequired = make([]uint64, len(paymentModule))
+ for i, recID := range paymentModule {
+ if rec, ok := t.receivers[recID]; !ok || pcMissing == math.MaxUint64 {
+ paymentRequired[i] = math.MaxUint64
+ } else {
+ paymentRequired[i] = rec.requestPayment(id, pcMissing, newReaderTable(t.clientPool.ndb.db, receiverPrefix(id, recID)))
+ }
+ }
+ return
+}
+
+// deposit credits a payment on the sender's account using the specified payment module
+func (t *tokenSale) deposit(id enode.ID, paymentModule string, proofOfPayment []byte) (pcValue, pcBalance uint64, err error) {
+ writer := t.clientPool.ndb.atomicWriteLock(id.Bytes())
+ t.lock.Lock()
+ defer func() {
+ t.lock.Unlock()
+ t.clientPool.ndb.atomicWriteUnlock(id.Bytes())
+ }()
+
+ cb := t.clientPool.ndb.getCurrencyBalance(id)
+ pcBalance = cb.amount
+ pm := t.receivers[paymentModule]
+ if pm == nil {
+ return 0, pcBalance, fmt.Errorf("Unknown payment receiver '%s'", paymentModule)
+ }
+ prefix := receiverPrefix(id, paymentModule)
+ pcValue, err = pm.receivePayment(id, proofOfPayment, newReaderTable(t.clientPool.ndb.db, prefix), newWriterTable(writer, prefix))
+ if err != nil {
+ return 0, pcBalance, err
+ }
+ pcBalance += pcValue
+ cb.amount = pcBalance
+ t.clientPool.ndb.setCurrencyBalance(id, cb)
+ return
+}
+
+// buyTokens tries to convert the permanent balance (nominated in the server's preferred
+// currency, PC) to service tokens. If spendAll is true then it sells the maxSpend amount
+// of PC coins if the received service token amount is at least minReceive. If spendAll is
+// false then is buys minReceive amount of tokens if it does not cost more than maxSpend
+// amount of PC coins.
+// if relative is true then maxSpend and minReceive are specified relative to their current
+// balances. In this case maxSpend represents the amount under which the PC balance should
+// not go and minReceive represents the amount the service token balance should reach.
+// This mode is useful when actual conversion is intended to happen and the sender has to
+// retry the command after not receiving a reply previously. In this case the sender cannot
+// be sure whether the conversion has already happened or not. If relative is true then it
+// is impossible to do a conversion twice. In exchange the sender needs to know its current
+// balances (which it probably does if it has made a previous call to just ask the current price).
+func (t *tokenSale) buyTokens(id enode.ID, maxSpend, minReceive uint64, relative, spendAll bool) (pcBalance, tokenBalance, spend, receive uint64, success bool) {
+ t.clientPool.ndb.atomicWriteLock(id.Bytes())
+ t.lock.Lock()
+ defer func() {
+ t.lock.Unlock()
+ t.clientPool.ndb.atomicWriteUnlock(id.Bytes())
+ }()
+
+ pb := t.clientPool.getPosBalance(id)
+ tokenBalance = pb.value.value(t.clientPool.posExpiration(mclock.Now()))
+ cb := t.clientPool.ndb.getCurrencyBalance(id)
+ pcBalance = cb.amount
+ if relative {
+ if pcBalance > maxSpend {
+ maxSpend = pcBalance - maxSpend
+ } else {
+ maxSpend = 0
+ }
+ if minReceive > tokenBalance {
+ minReceive -= tokenBalance
+ } else {
+ minReceive = 0
+ }
+ }
+
+ if maxSpend > pcBalance {
+ maxSpend = pcBalance
+ }
+ if spendAll {
+ spend = maxSpend
+ receive = t.tokenBuyAmount(float64(spend))
+ success = receive >= minReceive
+ } else {
+ receive = minReceive
+ if cost, ok := t.tokenPrice(receive, true); ok {
+ spend = uint64(cost) + 1 // ensure that we don't sell small amounts for free
+ } else {
+ spend = math.MaxUint64
+ }
+ success = spend <= maxSpend
+ }
+ if success {
+ pcBalance -= spend
+ cb.amount = pcBalance
+ tokenBalance += receive
+ t.clientPool.ndb.setCurrencyBalance(id, cb)
+ t.clientPool.addBalance(id, int64(receive))
+ }
+ return
+}
+
+// sellTokens tries to convert service tokens to permanent balance (nominated in the server's
+// preferred currency, PC). Parameters work similarly to buyTokens.
+func (t *tokenSale) sellTokens(id enode.ID, maxSell, minRefund uint64, relative, sellAll bool) (pcBalance, tokenBalance, sell, refund uint64, success bool) {
+ t.clientPool.ndb.atomicWriteLock(id.Bytes())
+ t.lock.Lock()
+ defer func() {
+ t.lock.Unlock()
+ t.clientPool.ndb.atomicWriteUnlock(id.Bytes())
+ }()
+
+ pb := t.clientPool.getPosBalance(id)
+ tokenBalance = pb.value.value(t.clientPool.posExpiration(mclock.Now()))
+ cb := t.clientPool.ndb.getCurrencyBalance(id)
+ pcBalance = cb.amount
+ if relative {
+ if pcBalance < minRefund {
+ minRefund -= pcBalance
+ } else {
+ minRefund = 0
+ }
+ if maxSell < tokenBalance {
+ maxSell = tokenBalance - maxSell
+ } else {
+ maxSell = 0
+ }
+ }
+
+ if maxSell > tokenBalance {
+ maxSell = tokenBalance
+ }
+ if sellAll {
+ sell = maxSell
+ if r, ok := t.tokenPrice(sell, false); ok {
+ refund = uint64(r)
+ success = refund >= minRefund
+ }
+ } else {
+ refund = minRefund
+ if s, ok := t.tokenSellAmount(float64(refund)); ok {
+ sell = s + 1 // ensure that we don't sell small amounts for free
+ } else {
+ sell = math.MaxUint64
+ }
+ success = sell <= maxSell
+ }
+ if success {
+ pcBalance += refund
+ tokenBalance -= sell
+ cb.amount = pcBalance
+ t.clientPool.ndb.setCurrencyBalance(id, cb)
+ t.clientPool.addBalance(id, -int64(sell))
+ }
+ return
+}
+
+// getBalance returns the current PC balance and service token balance
+func (t *tokenSale) getBalance(id enode.ID) (pcBalance, tokenBalance uint64) {
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ pb := t.clientPool.getPosBalance(id)
+ cb := t.clientPool.ndb.getCurrencyBalance(id)
+ return pb.value.value(t.clientPool.posExpiration(mclock.Now())), cb.amount
+}
+
+// info returns general information about the server, including version info of the
+// lespay command set, supported payment modules and token expiration time constant
+func (t *tokenSale) info() (version, compatible uint, info keyValueList, receivers []string) {
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ exp, _ := t.clientPool.getExpirationTCs()
+ info = info.add("tokenExpiration", strconv.FormatUint(exp, 10))
+ return 1, 1, info, t.receiverNames
+}
+
+// receiverInfo returns information about the specified payment receiver(s) if supported
+func (t *tokenSale) receiverInfo(receiverIDs []string) []keyValueList {
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ res := make([]keyValueList, len(receiverIDs))
+ for i, id := range receiverIDs {
+ if rec, ok := t.receivers[id]; ok {
+ res[i] = rec.info()
+ }
+ }
+ return res
+}
+
+const (
+ tsInfo = iota
+ tsReceiverInfo
+ tsGetBalance
+ tsDeposit
+ tsBuyTokens
+ tsSellTokens
+ tsConnection
+)
+
+type (
+ tsInfoResults struct {
+ Version, Compatible uint
+ Info keyValueList
+ Receivers []string
+ }
+ tsInfoApiResults struct {
+ Version, Compatible uint
+ Info keyValueMapDecoded
+ Receivers []string
+ }
+ tsReceiverInfoParams []string
+ tsReceiverInfoResults []keyValueList
+ tsReceiverInfoApiResults []keyValueMapDecoded
+ tsGetBalanceResults struct {
+ PcBalance, TokenBalance uint64
+ }
+ tsDepositParams struct {
+ PaymentModule string
+ ProofOfPayment []byte
+ }
+ tsDepositResults struct {
+ PcValue, PcBalance uint64
+ Err string
+ }
+ tsBuyTokensParams struct {
+ MaxSpend, MinReceive uint64
+ Relative, SpendAll bool
+ }
+ tsBuyTokensResults struct {
+ PcBalance, TokenBalance, Spend, Receive uint64
+ Success bool
+ }
+ tsSellTokensParams struct {
+ MaxSell, MinRefund uint64
+ Relative, SellAll bool
+ }
+ tsSellTokensResults struct {
+ PcBalance, TokenBalance, Sell, Refund uint64
+ Success bool
+ }
+ tsConnectionParams struct {
+ RequestedCapacity, StayConnected uint64
+ PaymentModule []string
+ SetCap bool
+ }
+ tsConnectionResults struct {
+ AvailableCapacity, TokenBalance, TokensMissing, PcBalance, PcMissing uint64
+ PaymentRequired []uint64
+ Err string
+ }
+)
+
+// runCommand runs an encoded lespay command and returns the encoded results
+func (t *tokenSale) runCommand(cmd []byte, id enode.ID, freeID string) []byte {
+ var res []byte
+ switch cmd[0] {
+ case tsInfo:
+ var results tsInfoResults
+ if len(cmd) == 1 {
+ results.Version, results.Compatible, results.Info, results.Receivers = t.info()
+ res, _ = rlp.EncodeToBytes(&results)
+ }
+ case tsReceiverInfo:
+ var (
+ params tsReceiverInfoParams
+ results tsReceiverInfoResults
+ )
+ if err := rlp.DecodeBytes(cmd[1:], ¶ms); err == nil {
+ results = t.receiverInfo(params)
+ res, _ = rlp.EncodeToBytes(&results)
+ }
+ case tsGetBalance:
+ var results tsGetBalanceResults
+ if len(cmd) == 1 {
+ results.PcBalance, results.TokenBalance = t.getBalance(id)
+ res, _ = rlp.EncodeToBytes(&results)
+ }
+ case tsDeposit:
+ var (
+ params tsDepositParams
+ results tsDepositResults
+ )
+ if err := rlp.DecodeBytes(cmd[1:], ¶ms); err == nil {
+ results.PcValue, results.PcBalance, err = t.deposit(id, params.PaymentModule, params.ProofOfPayment)
+ if err != nil {
+ results.Err = err.Error()
+ }
+ res, _ = rlp.EncodeToBytes(&results)
+ }
+ case tsBuyTokens:
+ var (
+ params tsBuyTokensParams
+ results tsBuyTokensResults
+ )
+ if err := rlp.DecodeBytes(cmd[1:], ¶ms); err == nil {
+ results.PcBalance, results.TokenBalance, results.Spend, results.Receive, results.Success =
+ t.buyTokens(id, params.MaxSpend, params.MinReceive, params.Relative, params.SpendAll)
+ res, _ = rlp.EncodeToBytes(&results)
+ }
+ case tsSellTokens:
+ var (
+ params tsSellTokensParams
+ results tsSellTokensResults
+ )
+ if err := rlp.DecodeBytes(cmd[1:], ¶ms); err == nil {
+ results.PcBalance, results.TokenBalance, results.Sell, results.Refund, results.Success =
+ t.sellTokens(id, params.MaxSell, params.MinRefund, params.Relative, params.SellAll)
+ res, _ = rlp.EncodeToBytes(&results)
+ }
+ case tsConnection:
+ var (
+ params tsConnectionParams
+ results tsConnectionResults
+ )
+ if err := rlp.DecodeBytes(cmd[1:], ¶ms); err == nil {
+ results.AvailableCapacity, results.TokenBalance, results.TokensMissing, results.PcBalance, results.PcMissing, results.PaymentRequired, err =
+ t.connection(id, freeID, params.RequestedCapacity, time.Duration(params.StayConnected)*time.Second, params.PaymentModule, params.SetCap)
+ if err != nil {
+ results.Err = err.Error()
+ }
+ res, _ = rlp.EncodeToBytes(&results)
+ }
+ }
+ return res
+}
+
+type keyValueMapDecoded map[string]interface{}
+
+// DecodeRLP implements rlp.Decoder
+func (k *keyValueMapDecoded) DecodeRLP(s *rlp.Stream) error {
+ var list keyValueList
+ if err := s.Decode(&list); err != nil {
+ return err
+ }
+ *k = make(keyValueMapDecoded)
+ for _, item := range list {
+ var s string
+ if err := rlp.DecodeBytes(item.Value, &s); err != nil {
+ return err
+ }
+ (*k)[item.Key] = s
+ }
+ return nil
+}
+
+// testReceiver implements paymentReceiver. It should only be used for testing.
+type testReceiver struct{}
+
+func (t testReceiver) info() keyValueList {
+ var info keyValueList
+ info = info.add("description", "Test payment receiver")
+ info = info.add("version", "1.0.0")
+ return info
+}
+
+// receivePayment implements paymentReceiver. proofOfPayment is a base 10 ascii number
+// which is credited to the sender's account without any further conditions.
+func (t testReceiver) receivePayment(from enode.ID, proofOfPayment []byte, reader ethdb.KeyValueReader, writer ethdb.KeyValueWriter) (value uint64, err error) {
+ if len(proofOfPayment) > 8 {
+ err = fmt.Errorf("proof of payment is too long; max 8 bytes long big endian integer expected")
+ return
+ }
+ var b [8]byte
+ copy(b[8-len(proofOfPayment):], proofOfPayment)
+ value = binary.BigEndian.Uint64(b[:])
+ return
+}
+
+// requestPayment implements paymentReceiver
+func (t testReceiver) requestPayment(from enode.ID, value uint64, reader ethdb.KeyValueReader) uint64 {
+ return value
+}
+
+// readerTable is a wrapper around a database that prefixes each key access with a pre-
+// configured string.
+type readerTable struct {
+ db ethdb.KeyValueReader
+ prefix []byte
+}
+
+// newReaderTable returns a database object that prefixes all keys with a given string.
+func newReaderTable(db ethdb.KeyValueReader, prefix []byte) ethdb.KeyValueReader {
+ return &readerTable{
+ db: db,
+ prefix: prefix,
+ }
+}
+
+// Has retrieves if a prefixed version of a key is present in the database.
+func (t *readerTable) Has(key []byte) (bool, error) {
+ return t.db.Has(append(t.prefix, key...))
+}
+
+// Get retrieves the given prefixed key if it's present in the database.
+func (t *readerTable) Get(key []byte) ([]byte, error) {
+ return t.db.Get(append(t.prefix, key...))
+}
+
+// writerTable is a wrapper around a database that prefixes each key access with a pre-
+// configured string.
+type writerTable struct {
+ db ethdb.KeyValueWriter
+ prefix []byte
+}
+
+// newReaderTable returns a database object that prefixes all keys with a given string.
+func newWriterTable(db ethdb.KeyValueWriter, prefix []byte) ethdb.KeyValueWriter {
+ return &writerTable{
+ db: db,
+ prefix: prefix,
+ }
+}
+
+// Put inserts the given value into the database at a prefixed version of the
+// provided key.
+func (t *writerTable) Put(key []byte, value []byte) error {
+ return t.db.Put(append(t.prefix, key...), value)
+}
+
+// Delete removes the given prefixed key from the database.
+func (t writerTable) Delete(key []byte) error {
+ return t.db.Delete(append(t.prefix, key...))
+}
diff --git a/les/tokensale_test.go b/les/tokensale_test.go
new file mode 100644
index 000000000000..5c5db9d9b3d6
--- /dev/null
+++ b/les/tokensale_test.go
@@ -0,0 +1,157 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package les
+
+import (
+ "math/rand"
+ "testing"
+)
+
+func TestTokenPriceCalculation(t *testing.T) {
+ var totalLimit, totalAmount uint64
+ ts := &tokenSale{
+ basePrice: 1,
+ totalTokenLimit: func() uint64 { return totalLimit },
+ totalTokenAmount: func() uint64 { return totalAmount },
+ }
+ totalLimit = 1000000000000
+ maxDiff := int64(totalLimit / 1000000)
+ // inaccuracy increases around both ends of the allowed token range
+ min := totalLimit / 100
+ max := uint64(float64(totalLimit) * tokenSellMaxRatio)
+ for count := 0; count < 100000; count++ {
+ start := min + uint64(rand.Int63n(int64(max-min)))
+ stop := min + uint64(rand.Int63n(int64(max-min)))
+ if start > stop {
+ start, stop = stop, start
+ }
+ // buy (start-stop) tokens in two steps
+ mid := start + uint64(rand.Int63n(int64(stop-start+1)))
+ totalAmount = start
+ cost, ok := ts.tokenPrice(mid-start, true)
+ if !ok {
+ t.Fatalf("Failed to buy tokens")
+ }
+ totalAmount = mid
+ cost2, ok := ts.tokenPrice(stop-mid, true)
+ if !ok {
+ t.Fatalf("Failed to buy tokens")
+ }
+ cost += cost2
+
+ // sell the same amount of tokens in two steps
+ mid = start + uint64(rand.Int63n(int64(stop-start+1)))
+ totalAmount = stop
+ refund, ok := ts.tokenPrice(stop-mid, false)
+ if !ok {
+ t.Fatalf("Failed to sell tokens")
+ }
+ totalAmount = mid
+ refund2, ok := ts.tokenPrice(mid-start, false)
+ if !ok {
+ t.Fatalf("Failed to sell tokens")
+ }
+ refund += refund2
+ ratio := (refund + 1) / (cost + 1)
+ if ratio < 0.999999 || ratio > 1.000001 {
+ t.Fatalf("Token selling price does not match buy cost")
+ }
+
+ // buy tokens for the previously calculated price in two steps
+ pcost := cost * rand.Float64()
+ totalAmount = start
+ totalAmount += ts.tokenBuyAmount(pcost)
+ totalAmount += ts.tokenBuyAmount(cost - pcost)
+
+ diff := int64(totalAmount - stop)
+ if diff > maxDiff || diff < -maxDiff {
+ t.Fatalf("Bought token amount mismatch")
+ }
+
+ // sell tokens for the previously calculated price in two steps
+ pcost = cost * rand.Float64()
+ totalAmount = stop
+ soldAmount, ok := ts.tokenSellAmount(pcost)
+ if !ok {
+ t.Fatalf("Failed to sell tokens")
+ }
+ totalAmount -= soldAmount
+ soldAmount, ok = ts.tokenSellAmount(cost - pcost)
+ if !ok {
+ t.Fatalf("Failed to sell tokens")
+ }
+ totalAmount -= soldAmount
+
+ diff = int64(totalAmount - start)
+ if diff > maxDiff || diff < -maxDiff {
+ t.Fatalf("Sold token amount mismatch")
+ }
+ }
+}
+
+func TestSingleTokenPrice(t *testing.T) {
+ var totalLimit, totalAmount uint64
+ ts := &tokenSale{
+ basePrice: 1,
+ totalTokenLimit: func() uint64 { return totalLimit },
+ totalTokenAmount: func() uint64 { return totalAmount },
+ }
+ totalLimit = 1000000000000
+ buyLimit := uint64(float64(totalLimit) * tokenSellMaxRatio)
+ for count := 0; count < 10000; count++ {
+ totalAmount = uint64(rand.Int63n(int64(buyLimit)))
+ relAmount := float64(totalAmount) / float64(totalLimit)
+ var expPrice, maxDiff float64
+ if relAmount < 0.5 {
+ expPrice = relAmount * 2
+ maxDiff = 0.001
+ } else {
+ expPrice = 0.5 / (1 - relAmount)
+ maxDiff = 0.001 * expPrice
+ }
+ price, ok := ts.tokenPrice(1, true)
+ if !ok {
+ t.Fatalf("Failed to buy tokens")
+ }
+ if price < expPrice-maxDiff || price > expPrice+maxDiff {
+ t.Fatalf("Token price mismatch")
+ }
+
+ price, ok = ts.tokenPrice(1, false)
+ if !ok {
+ t.Fatalf("Failed to sell tokens")
+ }
+ if price < expPrice-maxDiff || price > expPrice+maxDiff {
+ t.Fatalf("Token price mismatch")
+ }
+
+ if relAmount > 0.01 {
+ amount := ts.tokenBuyAmount(expPrice * 100)
+ if amount < 99 || amount > 101 {
+ t.Fatalf("Bought token amount mismatch")
+ }
+
+ amount, ok = ts.tokenSellAmount(expPrice * 100)
+ if !ok {
+ t.Fatalf("Failed to sell tokens")
+ }
+ if amount < 99 || amount > 101 {
+ t.Fatalf("Sold token amount mismatch")
+ }
+ }
+ }
+}
diff --git a/les/ulc_test.go b/les/ulc_test.go
index 273c63e4bd5f..d26a782e2312 100644
--- a/les/ulc_test.go
+++ b/les/ulc_test.go
@@ -28,8 +28,8 @@ import (
"github.com/ethereum/go-ethereum/p2p/enode"
)
-func TestULCAnnounceThresholdLes2(t *testing.T) { testULCAnnounceThreshold(t, 2) }
func TestULCAnnounceThresholdLes3(t *testing.T) { testULCAnnounceThreshold(t, 3) }
+func TestULCAnnounceThresholdLes4(t *testing.T) { testULCAnnounceThreshold(t, 4) }
func testULCAnnounceThreshold(t *testing.T, protocol int) {
// todo figure out why it takes fetcher so longer to fetcher the announced header.
@@ -124,9 +124,9 @@ func connect(server *serverHandler, serverId enode.ID, client *clientHandler, pr
return peer1, peer2, nil
}
-// newTestServerPeer creates server peer.
+// newServerPeer creates server peer.
func newTestServerPeer(t *testing.T, blocks int, protocol int) (*testServer, *enode.Node, func()) {
- s, teardown := newServerEnv(t, blocks, protocol, nil, false, false, 0)
+ s, teardown := newServerEnv(t, blocks, protocol, nil, false, false, 0, false)
key, err := crypto.GenerateKey()
if err != nil {
t.Fatal("generate key err:", err)
diff --git a/p2p/discv5/net.go b/p2p/discv5/net.go
index dd2ec3e9298f..2c27f61b14a7 100644
--- a/p2p/discv5/net.go
+++ b/p2p/discv5/net.go
@@ -22,12 +22,14 @@ import (
"errors"
"fmt"
"net"
+ "sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/rlp"
"golang.org/x/crypto/sha3"
@@ -64,12 +66,17 @@ type Network struct {
refreshResp chan (<-chan struct{}) // ...and get the channel to block on from this one
read chan ingressPacket // ingress packets arrive here
timeout chan timeoutEvent
- queryReq chan *findnodeQuery // lookups submit findnode queries on this channel
+ queryReq chan deferredQuery // lookups submit findnode queries on this channel
tableOpReq chan func()
tableOpResp chan struct{}
topicRegisterReq chan topicRegisterReq
topicSearchReq chan topicSearchReq
+ talkRequestSubLock sync.RWMutex
+ talkResponseSubLock sync.Mutex
+ talkRequestSubs map[string]TalkRequestHandler
+ talkResponseSubs map[string]TalkResponseHandler
+
// State of the main loop.
tab *Table
topictab *topicTable
@@ -79,6 +86,21 @@ type Network struct {
timeoutTimers map[timeoutEvent]*time.Timer
}
+type (
+ // TalkRequestHandler processes an incoming talkRequest. If ok is true then response is
+ // sent back, along with the recommended delay feedback. If there is no urgent need to
+ // communicate (for example, in case of regular polling) then the sender should wait the
+ // given amount of seconds before sending the next request. If delay recommendation is
+ // disobeyed too many times by a sender then the handler can stop responding until the
+ // last recommended delay has elapsed.
+ TalkRequestHandler func(id enode.ID, addr *net.UDPAddr, request interface{}) (response interface{}, delay uint, ok bool)
+ // TalkResponseHandler is registered by the sender of each talkRequest. If the request
+ // is canceled the handler is still called with a nil parameter. If too many responses
+ // from a peer are considered invalid by their handlers then the peer goes into
+ // contested state.
+ TalkResponseHandler func(response interface{}, delay uint) (valid bool)
+)
+
// transport is implemented by the UDP transport.
// it is an interface so we can test without opening lots of UDP
// sockets and without generating a private key.
@@ -101,6 +123,14 @@ type findnodeQuery struct {
reply chan<- []*Node
}
+type talkQuery struct {
+ remote *Node
+ talkID string
+ payload interface{}
+ key string
+ handler TalkResponseHandler
+}
+
type topicRegisterReq struct {
add bool
topic Topic
@@ -151,10 +181,12 @@ func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, dbPath string, netres
timeoutTimers: make(map[timeoutEvent]*time.Timer),
tableOpReq: make(chan func()),
tableOpResp: make(chan struct{}),
- queryReq: make(chan *findnodeQuery),
+ queryReq: make(chan deferredQuery),
topicRegisterReq: make(chan topicRegisterReq),
topicSearchReq: make(chan topicSearchReq),
nodes: make(map[NodeID]*Node),
+ talkRequestSubs: make(map[string]TalkRequestHandler),
+ talkResponseSubs: make(map[string]TalkResponseHandler),
}
go net.loop()
return net, nil
@@ -410,7 +442,6 @@ loop:
// Ingress packet handling.
case pkt := <-net.read:
- //fmt.Println("read", pkt.ev)
log.Trace("<-net.read")
n := net.internNode(&pkt)
prestate := n.state
@@ -446,7 +477,7 @@ loop:
case q := <-net.queryReq:
log.Trace("<-net.queryReq")
if !q.start(net) {
- q.remote.deferQuery(q)
+ q.deferQuery()
}
// Interacting with the table.
@@ -767,14 +798,21 @@ type nodeNetGuts struct {
// State machine fields. Access to these fields
// is restricted to the Network.loop goroutine.
state *nodeState
- pingEcho []byte // hash of last ping sent by us
- pingTopics []Topic // topic set sent by us in last ping
- deferredQueries []*findnodeQuery // queries that can't be sent yet
- pendingNeighbours *findnodeQuery // current query, waiting for reply
+ pingEcho []byte // hash of last ping sent by us
+ pingTopics []Topic // topic set sent by us in last ping
+ deferredQueries []deferredQuery // queries that can't be sent yet
+ pendingNeighbours *findnodeQuery // current query, waiting for reply
queryTimeouts int
+ talkFailures int
}
-func (n *nodeNetGuts) deferQuery(q *findnodeQuery) {
+type deferredQuery interface {
+ start(net *Network) bool
+ cancel()
+ deferQuery()
+}
+
+func (n *nodeNetGuts) deferQuery(q deferredQuery) {
n.deferredQueries = append(n.deferredQueries, q)
}
@@ -810,6 +848,14 @@ func (q *findnodeQuery) start(net *Network) bool {
return false
}
+func (q *findnodeQuery) cancel() {
+ q.reply <- nil
+}
+
+func (q *findnodeQuery) deferQuery() {
+ q.remote.deferQuery(q)
+}
+
// Node Events (the input to the state machine).
type nodeEvent uint
@@ -828,6 +874,8 @@ const (
topicRegisterPacket
topicQueryPacket
topicNodesPacket
+ talkRequestPacket
+ talkResponsePacket
// Non-packet events.
// Event values in this category are allocated outside
@@ -835,6 +883,7 @@ const (
pongTimeout nodeEvent = iota + 256
pingTimeout
neighboursTimeout
+ talkTimeout
)
// Node State Machine.
@@ -868,7 +917,7 @@ func init() {
n.pingEcho = nil
// Abort active queries.
for _, q := range n.deferredQueries {
- q.reply <- nil
+ q.cancel()
}
n.deferredQueries = nil
if n.pendingNeighbours != nil {
@@ -1021,7 +1070,7 @@ func (net *Network) handle(n *Node, ev nodeEvent, pkt *ingressPacket) error {
//fmt.Println("handle", n.addr().String(), n.state, ev)
if pkt != nil {
if err := net.checkPacket(n, ev, pkt); err != nil {
- //fmt.Println("check err:", err)
+ //fmt.Println("check err:", err, pkt)
return err
}
// Start the background expiration goroutine after the first
@@ -1036,6 +1085,7 @@ func (net *Network) handle(n *Node, ev nodeEvent, pkt *ingressPacket) error {
if n.state == nil {
n.state = unknown //???
}
+ //fmt.Println("old state:", n.state)
next, err := n.state.handle(net, n, ev, pkt)
net.transition(n, next)
//fmt.Println("new state:", n.state)
@@ -1196,7 +1246,51 @@ func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket)
}
}
return n.state, nil
-
+ case talkRequestPacket:
+ p := pkt.data.(*talkRequest)
+ net.talkRequestSubLock.RLock()
+ subFn := net.talkRequestSubs[string(p.TalkID)]
+ net.talkRequestSubLock.RUnlock()
+ if subFn != nil {
+ resp, delay, ok := subFn(enode.ID(n.sha), n.addr(), p.Payload)
+ if ok {
+ net.conn.send(n, talkResponsePacket, talkResponse{ReplyTok: pkt.hash, Delay: delay, Payload: resp})
+ } else {
+ n.talkFailures++
+ }
+ } else {
+ n.talkFailures++
+ }
+ if n.talkFailures > maxTalkFailures && n.state == known {
+ return contested, errors.New("too many talk failures")
+ }
+ return n.state, nil
+ case talkResponsePacket:
+ p := pkt.data.(*talkResponse)
+ net.talkResponseSubLock.Lock()
+ key := string(n.sha[:]) + string(p.ReplyTok)
+ subFn := net.talkResponseSubs[key]
+ if subFn != nil {
+ delete(net.talkResponseSubs, key)
+ }
+ net.talkResponseSubLock.Unlock()
+ if subFn == nil || !subFn(p.Payload, p.Delay) {
+ n.talkFailures++
+ if n.talkFailures > maxTalkFailures && n.state == known {
+ return contested, errors.New("too many talk failures")
+ }
+ }
+ return n.state, nil
+ case talkTimeout:
+ if n.pendingNeighbours != nil {
+ n.pendingNeighbours.reply <- nil
+ n.pendingNeighbours = nil
+ }
+ n.queryTimeouts++
+ if n.queryTimeouts > maxFindnodeFailures && n.state == known {
+ return contested, errors.New("too many timeouts")
+ }
+ return n.state, nil
default:
return n.state, errInvalidEvent
}
@@ -1260,3 +1354,82 @@ func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error {
n.startNextQuery(net)
return nil
}
+
+// RegisterTalkHandler assigns a handler callback to the given talk ID
+func (net *Network) RegisterTalkHandler(talkID string, handler TalkRequestHandler) {
+ net.talkRequestSubLock.Lock()
+ net.talkRequestSubs[talkID] = handler
+ net.talkRequestSubLock.Unlock()
+}
+
+// RemoveTalkHandler removes the handler assigned to the given talk ID
+func (net *Network) RemoveTalkHandler(talkID string) {
+ net.talkRequestSubLock.Lock()
+ delete(net.talkRequestSubs, talkID)
+ net.talkRequestSubLock.Unlock()
+}
+
+func (q *talkQuery) start(net *Network) bool {
+ if q.remote == net.tab.self {
+ return false
+ }
+ if q.remote.state == known {
+ net.talkResponseSubLock.Lock()
+ hash := net.conn.send(q.remote, talkRequestPacket, talkRequest{TalkID: []byte(q.talkID), Payload: q.payload})
+ q.key = string(q.remote.sha[:]) + string(hash[:])
+ net.talkResponseSubs[q.key] = q.handler
+ net.talkResponseSubLock.Unlock()
+ return true
+ }
+ // If the node is not known yet, it won't accept queries.
+ // Initiate the transition to known.
+ // The request will be sent later when the node reaches known state.
+ if q.remote.state == unknown {
+ net.transition(q.remote, verifyinit)
+ }
+ return false
+}
+
+func (q *talkQuery) cancel() {
+ q.handler(nil, 0)
+}
+
+func (q *talkQuery) deferQuery() {
+ q.remote.deferQuery(q)
+}
+
+// SendTalkRequest sends a talRequest and registers a response handler. It returns
+// a cancel function that removes the response handler and calls it with a nil parameter
+// if the response has not arrived yet.
+func (net *Network) SendTalkRequest(to *enode.Node, talkID string, payload interface{}, handler TalkResponseHandler) (cancel func() bool) {
+ var nodeID NodeID
+ copy(nodeID[:], crypto.FromECDSAPub(to.Pubkey())[1:])
+ node := net.nodes[nodeID]
+ if node == nil {
+ node = NewNode(nodeID, to.IP(), uint16(to.UDP()), uint16(to.TCP()))
+ node.state = unknown
+ net.nodes[nodeID] = node
+ }
+ q := &talkQuery{remote: node, talkID: talkID, payload: payload, handler: handler}
+ if node.state == nil || !node.state.canQuery {
+ net.ping(node, node.addr())
+ }
+ select {
+ case net.queryReq <- q:
+ case <-net.closed:
+ return nil
+ }
+
+ return func() bool {
+ net.talkResponseSubLock.Lock()
+ cancel := q.key != "" && net.talkResponseSubs[q.key] != nil
+ if cancel {
+ delete(net.talkResponseSubs, q.key)
+ }
+ net.talkResponseSubLock.Unlock()
+ if cancel {
+ handler(nil, 0)
+ }
+ return cancel
+ }
+}
diff --git a/p2p/discv5/table.go b/p2p/discv5/table.go
index 64c3ecd1c7b2..23b68e56928a 100644
--- a/p2p/discv5/table.go
+++ b/p2p/discv5/table.go
@@ -35,6 +35,7 @@ const (
nBuckets = hashBits + 1 // Number of buckets
maxFindnodeFailures = 5
+ maxTalkFailures = 5
)
type Table struct {
diff --git a/p2p/discv5/udp.go b/p2p/discv5/udp.go
index 088f95cac6a2..a90c70b312c9 100644
--- a/p2p/discv5/udp.go
+++ b/p2p/discv5/udp.go
@@ -119,6 +119,17 @@ type (
Nodes []rpcNode
}
+ talkRequest struct {
+ TalkID []byte
+ Payload interface{}
+ }
+
+ talkResponse struct {
+ ReplyTok []byte
+ Delay uint
+ Payload interface{}
+ }
+
rpcNode struct {
IP net.IP // len 4 for IPv4 or 16 for IPv6
UDP uint16 // for discovery protocol
@@ -420,6 +431,10 @@ func decodePacket(buffer []byte, pkt *ingressPacket) error {
pkt.data = new(topicQuery)
case topicNodesPacket:
pkt.data = new(topicNodes)
+ case talkRequestPacket:
+ pkt.data = new(talkRequest)
+ case talkResponsePacket:
+ pkt.data = new(talkResponse)
default:
return fmt.Errorf("unknown packet type: %d", sigdata[0])
}
diff --git a/p2p/enode/node.go b/p2p/enode/node.go
index 9eb2544ffe14..3f6cda6d4a21 100644
--- a/p2p/enode/node.go
+++ b/p2p/enode/node.go
@@ -217,7 +217,7 @@ func (n ID) MarshalText() ([]byte, error) {
// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (n *ID) UnmarshalText(text []byte) error {
- id, err := parseID(string(text))
+ id, err := ParseID(string(text))
if err != nil {
return err
}
@@ -229,14 +229,14 @@ func (n *ID) UnmarshalText(text []byte) error {
// The string may be prefixed with 0x.
// It panics if the string is not a valid ID.
func HexID(in string) ID {
- id, err := parseID(in)
+ id, err := ParseID(in)
if err != nil {
panic(err)
}
return id
}
-func parseID(in string) (ID, error) {
+func ParseID(in string) (ID, error) {
var id ID
b, err := hex.DecodeString(strings.TrimPrefix(in, "0x"))
if err != nil {